diff --git a/contrib/.gitignore b/contrib/.gitignore index 037c95ed0..33b4d21bd 100644 --- a/contrib/.gitignore +++ b/contrib/.gitignore @@ -1 +1,5 @@ # Put contrib-related gitignore files here. + +# recipes/simulation related +recipes/simulation/agl_envs/ +recipes/simulation/wandb/ diff --git a/contrib/agentlightning/contrib/adapter/triplet_group.py b/contrib/agentlightning/contrib/adapter/triplet_group.py new file mode 100644 index 000000000..864c0bb39 --- /dev/null +++ b/contrib/agentlightning/contrib/adapter/triplet_group.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +from typing import Dict, List, Optional + +from agentlightning.adapter.triplet import TracerTraceToTriplet +from agentlightning.types import Span, Triplet + + +class TracerTraceToTripletGroup(TracerTraceToTriplet): + """Convert tracer-emitted spans into triplet trajectories. + + Attributes: + repair_hierarchy: When `True`, repair the span tree using + [`TraceTree.repair_hierarchy()`][agentlightning.adapter.triplet.TraceTree.repair_hierarchy] + before matching calls and rewards. + llm_call_match: Regular expression pattern that selects LLM call span names. + agent_match: Optional regular expression pattern for agent span names. When omitted, spans + from any agent are considered. + exclude_llm_call_in_reward: When `True`, ignore matches under reward spans while searching + for rewards. + reward_match: Strategy used to associate rewards with LLM calls. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _extract_span_groups(self, spans): + def resolve_step_count(span, next_span, spans, index): + """ + Determine step_count for a given span using next_span or fallback search. + """ + # CASE A: If next_span exists and parent_id matches + if next_span and span.parent_id == next_span.span_id: + return next_span.attributes.get("step_count") + + # CASE B: Fallback — search forward for agentlightning.operation + for s in spans[index + 1 :]: + if s.name == "agentlightning.operation" and span.parent_id == s.span_id: + return s.attributes.get("step_count") + + return None + + def extract_step_count_from_links(span): + """ + Extract step_count from agentlightning.link.* attributes. + """ + key = span.attributes.get("agentlightning.link.0.key_match") + if key == "step_count": + return span.attributes.get("agentlightning.link.0.value_match") + return None + + span_groups = {} + + for i, span in enumerate(spans): + next_span = spans[i + 1] if i + 1 < len(spans) else None + step_count = None + + if span.name == "openai.chat.completion": + step_count = resolve_step_count(span, next_span, spans, i) + if step_count is None: + continue + + step_count = str(step_count) + span_groups.setdefault(step_count, {}) + span_groups[step_count]["call_span"] = span + + elif span.name == "agentlightning.object": + step_count = extract_step_count_from_links(span) + if step_count is None: + continue + + step_count = str(step_count) + span_groups.setdefault(step_count, {}) + span_groups[step_count]["object_span"] = span + + elif span.name == "agentlightning.annotation": + step_count = extract_step_count_from_links(span) + if step_count is None: + continue + + step_count = str(step_count) + span_groups.setdefault(step_count, {}) + span_groups[step_count]["annotation_span"] = span + + return span_groups + + def adapt_group(self, source: Sequence[Span], /) -> List[Triplet]: + span_groups = self._extract_span_groups(source) + + def token_ids(span: Optional[Span], key: str) -> list: + return span.attributes.get(key, []) if span else [] + + def reward0(span: Optional[Span]) -> float: + if not span: + return 0.0 + return float(span.attributes.get("agentlightning.reward.0.value", 0.0)) + + def reward1(span: Optional[Span]) -> Optional[float]: + if not span: + return 0.0 + return float(span.attributes.get("agentlightning.reward.1.value", 0.0)) + + def message(span: Optional[Span]) -> Optional[str]: + if not span: + return "" + return span.attributes.get("agentlightning.object.literal", "") + + triplets: List[Triplet] = [] + + for group in span_groups.values(): + call_span = group.get("call_span") + if not token_ids(call_span, "prompt_token_ids") and not token_ids(call_span, "response_token_ids"): + continue + + object_span = group.get("object_span") + annotation_span = group.get("annotation_span") + request_id = group.get("request_id") + + triplets.append( + Triplet( + prompt={"token_ids": token_ids(call_span, "prompt_token_ids")}, + response={"token_ids": token_ids(call_span, "response_token_ids")}, + reward=reward0(annotation_span), + metadata={ + "response_id": request_id, + "intrinsic_reward": reward1(annotation_span), + "message": message(object_span), + }, + ) + ) + + return triplets diff --git a/contrib/agentlightning/contrib/agent/simulation_agent.py b/contrib/agentlightning/contrib/agent/simulation_agent.py new file mode 100644 index 000000000..0ce73f0d8 --- /dev/null +++ b/contrib/agentlightning/contrib/agent/simulation_agent.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import logging +import os +from typing import Any, Dict + +import numpy as np +from add_instruction import add_chat_instruction, add_single_instruction +from agl_envs.simulation import make_env_manager +from autogen_agentchat.agents import AssistantAgent +from autogen_core.models import ModelFamily +from autogen_ext.models.openai import OpenAIChatCompletionClient + +from agentlightning import LLM, LitAgent, NamedResources, Rollout, configure_logger, emit_object, emit_reward, operation +from agentlightning.utils.otel import make_link_attributes +from contrib.recipes.simulation.prompt_builder import HistoryPromptBuilder + +logger = configure_logger(name=__name__, level=logging.ERROR) + + +class SimulationAgent(LitAgent): + def __init__(self, config, trained_agents: str | None = None) -> None: + super().__init__(trained_agents=trained_agents) + self.config = config + self.env = None + + def _build_agent(self, llm: LLM, temperature: float): + model_client = OpenAIChatCompletionClient( + model=llm.model, + base_url=llm.endpoint, + api_key=os.environ.get("OPENAI_API_KEY", "token-abc123"), + model_info={ + "vision": False, + "function_calling": True, + "json_output": False, + "family": ModelFamily.UNKNOWN, + "structured_output": False, + }, + temperature=temperature, + ) + + return AssistantAgent( + name="simulation", + model_client=model_client, + ) + + def _get_instructed_prompt(self, prompt, sep="\n\n"): + """Return instructed observation based on prompt_type and captioner type.""" + prompt_type = self.config.captioner.prompt_type + cap_type = self.config.captioner.type + + if prompt_type == "chat": + if cap_type == "cot": + return add_chat_instruction(prompt, "cot", sep, self.config.env_name) + elif cap_type == "naive": + return add_chat_instruction(prompt, "naive", sep) + + elif prompt_type == "single": + if cap_type == "cot": + return add_single_instruction(prompt, "cot", sep, self.config.env_name) + elif cap_type == "naive": + return add_single_instruction(prompt, "naive", sep, self.config.env_name) + + raise ValueError(f"Unsupported prompt_type={prompt_type}, type={cap_type}") + + async def rollout_async( + self, + task: Dict[str, Any], + resources: NamedResources, + rollout: Rollout, + ) -> float | None: + rollout_id = rollout.rollout_id + logger.info(f"[Rollout {rollout_id}] Task: {task}") + + format_penalty = float(self.config["format_penalty"]) + reward_scale = float(self.config["reawrd_scale"]) + + # Setup agent + llm: LLM = resources.get("main_llm") + print("Training with model:", llm.model, "on endpoint:", llm.endpoint) + self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4) + if "max_tokens" in self.config and self.config["max_tokens"] > -1: + self.agent._model_client.max_tokens = self.config["max_tokens"] + + try: + # Setup environment + prompt_builder = HistoryPromptBuilder( + max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type + ) + + self.env = make_env_manager(self.config.env_name, task, self.config) + env_obs, infos, available_actions_hint = self.env.reset() + + prompt_builder.init(self.env) + prompt_builder.update_observation(env_obs) + prompt_builder.update_admissible_actions(available_actions_hint) + + prompt = prompt_builder.get_prompt() + + episode_reward, done = 0.0, False + + step_count = 0 + while not done: + try: + instructed_prompt = self._get_instructed_prompt(prompt) + + # Main agent step + with operation(step_count=step_count): + result = await self.agent._model_client.create(instructed_prompt) + output = result.content + logger.info(f"[LLM output]: {output}") + + except Exception as e: + logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True) + break + + if self.config.log_env_obs: + emit_object(env_obs, attributes=make_link_attributes({"step_count": str(step_count)})) + + env_obs, executed_action, is_valid, step_reward, terminated, truncated, info, available_actions_hint = ( + self.env.step(output, use_reasoning=self.config.captioner.type == "cot", use_success_rate=self.config.use_success_rate) + ) + + prompt_builder.update_step_count() + prompt_builder.update_action(executed_action) + prompt_builder.update_observation(env_obs) + prompt_builder.update_admissible_actions(available_actions_hint) + + prompt = prompt_builder.get_prompt() + + if rollout.mode == "train": + step_reward *= reward_scale + + if format_penalty != 0.0: + emit_reward( + { + "extrinsic_reward": step_reward, + "intrinsic_reward": 0.0 if is_valid else -1.0 * format_penalty, + }, + primary_key="extrinsic_reward", + attributes=make_link_attributes({"step_count": str(step_count)}), + ) + else: + emit_reward(step_reward, attributes=make_link_attributes({"step_count": str(step_count)})) + + episode_reward += float(step_reward) + done = np.logical_or(terminated, truncated) + + step_count += 1 + + return episode_reward + + finally: + if self.env is not None: + self.env.close() diff --git a/contrib/agentlightning/contrib/algorithm/simulation_verl/daemon.py b/contrib/agentlightning/contrib/algorithm/simulation_verl/daemon.py new file mode 100644 index 000000000..2d414ec80 --- /dev/null +++ b/contrib/agentlightning/contrib/algorithm/simulation_verl/daemon.py @@ -0,0 +1,873 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +import random +import socket +import threading +import time +import uuid +from collections import defaultdict +from collections.abc import Mapping +from typing import Any, Dict, List, Literal, Optional, Tuple, cast + +import numpy as np +import requests +import torch +from flask import Flask, Response, abort, request +from tensordict import TensorDict +from verl import DataProto + +from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy +from agentlightning.adapter.triplet import TraceToTripletBase +from agentlightning.llm_proxy import LLMProxy, ModelConfig +from agentlightning.store.base import LightningStore +from agentlightning.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task +from agentlightning.reward import find_final_reward + +from contrib.agentlightning.contrib.adapter.triplet_group import TracerTraceToTripletGroup + +__all__ = [ + "AgentModeDaemon", + "get_left_padded_ids_and_attention_mask", + "get_right_padded_ids_and_attention_mask", +] + + +def get_left_padded_ids_and_attention_mask( + ids: List[int], max_length: int, pad_token_id: int +) -> Tuple[List[int], List[int]]: + """ + Left-pad (or truncate) a sequence of token IDs to a fixed length, + and build the corresponding attention mask. + + Args: + ids: the original list of token IDs. + max_length: desired total length after padding/truncation. + pad_token_id: ID to use for padding. + + Returns: + padded_ids (any): list of length == max_length. + attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads. + """ + seq_len = len(ids) + + if seq_len >= max_length: + # too long → truncate from the left, keep the last max_length tokens + trimmed = ids[-max_length:] + attention_mask = [1] * max_length + return trimmed, attention_mask + + # too short → pad on the left + pad_len = max_length - seq_len + padded_ids = [pad_token_id] * pad_len + ids + attention_mask = [0] * pad_len + [1] * seq_len + return padded_ids, attention_mask + + +def get_right_padded_ids_and_attention_mask( + ids: List[int], max_length: int, pad_token_id: int +) -> Tuple[List[int], List[int]]: + """ + Right-pad (or truncate) a sequence of token IDs to a fixed length, + and build the corresponding attention mask. + + Args: + ids: the original list of token IDs. + max_length: desired total length after padding/truncation. + pad_token_id: ID to use for padding. + + Returns: + padded_ids (any): list of length == max_length. + attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads. + """ + seq_len = len(ids) + + if seq_len >= max_length: + # too long → truncate to the first max_length tokens + trimmed = ids[:max_length] + attention_mask = [1] * max_length + return trimmed, attention_mask + + # too short → pad on the right + pad_len = max_length - seq_len + padded_ids = ids + [pad_token_id] * pad_len + attention_mask = [1] * seq_len + [0] * pad_len + return padded_ids, attention_mask + + +def _find_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _to_native(obj: Any) -> Any: + """Convert data retrieved from Parquet to data usable in AGL server.""" + # 1) Arrays -> list (then recurse) + if isinstance(obj, np.ndarray): + return _to_native(obj.tolist()) + + # 2) NumPy scalar types -> Python scalars + if isinstance(obj, np.generic): + return _to_native(obj.item()) + + # 3) Dict-like -> dict + if isinstance(obj, Mapping): + return {_to_native(k): _to_native(v) for k, v in obj.items()} # type: ignore + + # 4) Lists/Tuples/Sets -> list + if isinstance(obj, (list, tuple, set)): + return [_to_native(x) for x in obj] # type: ignore + + # 5) Anything else: leave as-is + return obj + + +class SimulationAgentModeDaemon: + """ + AgentModeDaemon using the AgentLightningServer SDK. + + This class manages the server lifecycle, task queueing, and results + retrieval, while also running a proxy server for LLM requests. It maintains + the original interface for compatibility with the RayPPOTrainer. + """ + + def __init__( + self, + port: Optional[int], + train_rollout_n: int, + train_information: Dict[str, Any], + tokenizer: Any, + mini_batch_size: int, + pad_token_id: int, + reward_fillna_value: float = 0.0, + llm_timeout_seconds: float = 1200.0, + mode: Literal["v0", "v1"] = "v1", + llm_proxy: LLMProxy | None = None, + store: LightningStore | None = None, + adapter: TraceToTripletBase | None = None, + ): + self.mode = mode + self.llm_timeout_seconds = llm_timeout_seconds + + # Server and Task Configuration + if mode == "v0": + assert port is not None + self.server_port = port + self.server = AgentLightningServer( + host="0.0.0.0", port=self.server_port, task_timeout_seconds=self.llm_timeout_seconds + ) + self.proxy_port = _find_available_port() # Run proxy on a different port + else: + assert store is not None + self.store = store + if llm_proxy is None: + self.llm_proxy = LLMProxy( + port=_find_available_port(), + model_list=[], + store=store, + ) + else: + # Reuse the existing LLM proxy (probably configured by user) + self.llm_proxy = llm_proxy + + # if adapter is None: + # self.adapter = TracerTraceToTripletGroup() + # else: + # # Reuse the one from trainer + # self.adapter = adapter + self.adapter = TracerTraceToTripletGroup() + + self._internal_loop: Optional[asyncio.AbstractEventLoop] = None + self._internal_loop_thread = threading.Thread(target=self._internal_loop_runner, daemon=True) + self._internal_loop_thread.start() + + # Training and Data Configuration + self.train_rollout_n = train_rollout_n + self.train_information = train_information + self.mini_batch_size = mini_batch_size + self.pad_token_id = pad_token_id + self.tokenizer = tokenizer + self.reward_fillna_value = reward_fillna_value + + # Internal State + self.backend_llm_server_addresses: List[str] = [] + self._total_tasks_queued = 0 + self._completed_rollouts_v0: Dict[str, RolloutLegacy] = {} + self._task_id_to_original_sample: Dict[str, Dict[str, Any]] = {} + self._server_thread: Optional[threading.Thread] = None + self._proxy_thread: Optional[threading.Thread] = None + self.is_train = True + + def _internal_loop_runner(self): + """Run the internal loop.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._internal_loop = loop + loop.run_forever() + loop.close() + + def _start_proxy_server_v0(self): + """ + Initializes and runs a Flask-based proxy server in a separate thread. + This proxy load-balances requests to the actual backend LLM servers. + """ + app = Flask(__name__) + + num_requests = 0 + last_request_time = 0 + + @app.route("/v1/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]) + def proxy(path: str): # type: ignore + if not self.backend_llm_server_addresses: + abort(503, description="No backend LLM servers available.") + + # Randomly choose a backend server for load balancing + target_server = random.choice(self.backend_llm_server_addresses) + target_url = f"http://{target_server}/v1/{path}" + + # Copy client request headers, removing the Host header + headers = {key: value for key, value in request.headers if key.lower() != "host"} + + # Log the request for debugging + nonlocal num_requests, last_request_time + current_time = time.time() + num_requests += 1 + if current_time - last_request_time > 60 or num_requests == 1 or num_requests % 100 == 0: + print(f"Proxying {request.method} request to {target_server}. Request data: {request.get_data()}") + last_request_time = current_time + + try: + # Forward the request to the target backend + resp = requests.request( + method=request.method, + url=target_url, + headers=headers, + params=request.args, # type: ignore + data=request.get_data(), + cookies=request.cookies, + allow_redirects=False, + timeout=self.llm_timeout_seconds, + ) + # Filter out hop-by-hop headers before returning the response + excluded_headers = [ + "content-encoding", + "content-length", + "transfer-encoding", + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "upgrade", + ] + response_headers = [ + (name, value) for name, value in resp.raw.headers.items() if name.lower() not in excluded_headers + ] + if resp.status_code == 200: + # NOTE: from Zhiyuan's code. + # https://github.com/hzy46/verl_agent_mode/blob/2db65ea9858f645a914120357412a7540f8bd82d/verl/trainer/ppo/ray_trainer.py#L692-L711 + # request_json = json.loads(request.get_data().decode("utf-8")) + response_json = json.loads(resp.content.decode("utf-8")) + # response_message = ChatCompletion(**response_json).choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + # tool_schemas = request_json.get("tools", None) + # prompt_ids = self.tokenizer.apply_chat_template(request_json["messages"], tools=tool_schemas, add_generation_prompt=True, tokenize=True) + # full_ids = self.tokenizer.apply_chat_template(request_json["messages"] + [response_message], tools=tool_schemas, add_generation_prompt=False, tokenize=True) + # TBD: response_ids sometimes ends with "\n", shall we keep the extra "\n"? + # sometimes it has some differences with the hacky method in the end, but this should align with ToolCompletionCallback + # response_ids = full_ids[len(prompt_ids):] + + # NOTE (yuge): They are different. Don't know why. + # assert response_json['prompt_token_ids'] == prompt_ids + # patched_response_ids = response_json['response_token_ids'][0] + # assert patched_response_ids == response_ids[:len(patched_response_ids)], f"{patched_response_ids} != {response_ids[:len(patched_response_ids)]}" + # response_json['prompt_token_ids'] = prompt_ids + # response_json['response_token_ids'] = [response_ids] + replaced_return_content = json.dumps(response_json).encode("utf-8") + return Response(replaced_return_content, status=resp.status_code, headers=response_headers) + return Response(resp.content, resp.status_code, response_headers) + except requests.exceptions.RequestException as e: + abort(500, description=f"Error proxying request: {e}") + + def run_app(): + app.run(host="0.0.0.0", port=self.proxy_port, threaded=True, debug=False) + + self._proxy_thread = threading.Thread(target=run_app, daemon=True) + self._proxy_thread.start() + print(f"Proxy server running on port {self.proxy_port}") + + async def _update_proxy_server_v1(self): + model_name = self.train_information.get("model") + if not model_name: + raise ValueError("Model name is not set.") + self.llm_proxy.update_model_list( + [ + ModelConfig( + { + "model_name": model_name, + "litellm_params": { + "model": "hosted_vllm/" + model_name, + "api_base": f"http://{address}/v1/", + }, + } + ) + for address in self.backend_llm_server_addresses + ], + ) + + await self.llm_proxy.restart() + + def start(self): + """Starts the main AgentLightningServer and the proxy server.""" + + if self.mode == "v0": + + def run_server(): + """Run the AgentLightningServer in a separate thread.""" + asyncio.run(self.server.run_forever()) + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + # Wait for the server's internal startup event to be set. + print("Waiting for AgentLightningServer to start...") + is_ready = self.server.startup_event.wait(timeout=20.0) # Wait up to 20s + if not is_ready: + raise RuntimeError("AgentLightningServer failed to start within the timeout period.") + + print(f"AgentLightningServer control plane running on port {self.server_port}") + + self._start_proxy_server_v0() + else: + # Agent lightning server is no longer needed; + # Start proxy server in _async_set_up + pass + + async def _async_set_up(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True): + """Async helper to set up data and resources on the server.""" + self.clear_data_and_server() + if server_addresses != self.backend_llm_server_addresses: + self.backend_llm_server_addresses = server_addresses + if self.mode == "v1" and not self.llm_proxy.is_running(): + await self._update_proxy_server_v1() + self.is_train = is_train + + # 1. Update resources on the server for clients to use + if self.mode == "v0": + llm_resource = LLM( + endpoint=f"http://127.0.0.1:{self.proxy_port}/v1", + model=self.train_information.get("model", "default-model"), + sampling_parameters={ + "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0) + }, + ) + else: + llm_resource = self.llm_proxy.as_resource( + sampling_parameters={ + "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0) + }, + ) + + resources: NamedResources = {"main_llm": llm_resource} + + if self.mode == "v0": + resources_id = await self.server.update_resources(resources) + else: + resources_update = await self.store.add_resources(resources) + resources_id = resources_update.resources_id + + # 2. Queue tasks for agents to process + keys = list(data.keys()) + num_samples = len(data[keys[0]]) + rollouts_per_sample = self.train_rollout_n if is_train else 1 + + enqueue_rollout_requests: List[EnqueueRolloutRequest] = [] + data_id_to_original_sample: Dict[str, Dict[str, Any]] = {} + + for i in range(num_samples): + data_id = str(uuid.uuid4()) + original_sample = {key: data[key][i] for key in keys} + original_sample["data_id"] = data_id + data_id_to_original_sample[data_id] = original_sample + + # For training, each sample is rolled out multiple times + # Data ID is different from Rollout ID, as one data can have multiple rollouts. + for _ in range(rollouts_per_sample): + task_metadata = {"data_id": data_id, "is_train": is_train} + if self.mode == "v0": + # Queue immediately + rollout_id = await self.server.queue_task( + sample=_to_native(original_sample), + mode="train" if is_train else "val", + resources_id=resources_id, + metadata=task_metadata, + ) + + # Store original sample data to reconstruct batch information later + self._task_id_to_original_sample[rollout_id] = original_sample + self._total_tasks_queued += 1 + else: + # Collect tasks to enqueue in batch and queue them later + enqueue_rollout_requests.append( + EnqueueRolloutRequest( + input=_to_native(original_sample), + mode="train" if is_train else "val", + resources_id=resources_id, + config=RolloutConfig( + unresponsive_seconds=self.llm_timeout_seconds, + timeout_seconds=self.llm_timeout_seconds, + ), + metadata=task_metadata, + ) + ) + + if self.mode == "v1": + # Enqueue all the tasks in a single batch + rollouts = await self.store.enqueue_many_rollouts(enqueue_rollout_requests) + self._task_id_to_original_sample.update( + { + # Recover the original data and store it for later use. + rollout.rollout_id: data_id_to_original_sample[cast(Dict[str, Any], rollout.metadata)["data_id"]] + for rollout in rollouts + } + ) + self._total_tasks_queued += len(rollouts) + + def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True): + """Synchronous wrapper for setting up data and server resources.""" + coro = self._async_set_up(data, server_addresses, is_train) + + if self.mode == "v0": + if not self.server.loop or not self.server.startup_event.is_set(): + raise RuntimeError("Server is not running or ready.") + + future = asyncio.run_coroutine_threadsafe(coro, self.server.loop) + + else: + if self._internal_loop is None: + raise RuntimeError("Internal loop is not running.") + future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop) + try: + future.result(timeout=60) # Wait for completion with a timeout + except Exception as e: + print(f"Failed to set up data on server: {e}") + raise + + def _validate_data(self, rollout: RolloutLegacy): + if rollout.final_reward is None: + print( + f"Warning: Reward is None for rollout {rollout.rollout_id}, will be auto-set to {self.reward_fillna_value}." + ) + if rollout.triplets is None: + print(f"Warning: Triplet is None for rollout {rollout.rollout_id}.") + elif len(rollout.triplets) == 0: + print(f"Warning: Length of triplets is 0 for rollout {rollout.rollout_id}.") + elif any(not r.response.get("token_ids", []) for r in rollout.triplets): + print(f"Warning: Rollout {rollout.rollout_id} contains empty response: {rollout.triplets}") + elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets): + print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}") + + async def _validate_data_v1(self, rollout: Rollout) -> RolloutLegacy: + """Convert Rollout to RolloutLegacy and validate. + + 1. Task: construct from Rollout + 2. Triplets: obtained by querying spans and feeding into the adapter + 3. Final reward: extracted from last triplet's reward, searching backwards if not found + """ + # Query spans for this rollout (latest attempt) + spans = await self.store.query_spans(rollout.rollout_id, attempt_id="latest") + final_reward = find_final_reward(spans) + + # Convert spans to triplets using the adapter + if not spans: + # No triplets found, will emit a warning later. + triplets = [] + else: + # triplets = self.adapter.adapt(spans) + triplets = self.adapter.adapt_group(spans) + + # # Extract final reward from triplets + # final_reward: Optional[float] = None + # if triplets: + # # Search backwards through triplets for the first non-None reward + # for triplet in reversed(triplets): + # if triplet.reward is not None: + # final_reward = triplet.reward + # break + + # Construct the Task object from Rollout + task = Task( + rollout_id=rollout.rollout_id, + input=rollout.input, + mode=rollout.mode, + resources_id=rollout.resources_id, + metadata=rollout.metadata or {}, + ) + + # Create the Rollout object (without trace and logs as per user's note) + result_rollout = RolloutLegacy( + rollout_id=rollout.rollout_id, + task=task, + final_reward=final_reward, + triplets=triplets, + metadata=rollout.metadata or {}, + ) + + # Run the same validation as v0 + self._validate_data(result_rollout) + + return result_rollout + + async def _async_run_until_finished(self, verbose: bool = True): + """Async helper to wait for all tasks to complete.""" + while len(self._completed_rollouts_v0) < self._total_tasks_queued: + if self.mode == "v0": + completed_batch = await self.server.retrieve_completed_rollouts() + else: + completed_batch = await self.store.wait_for_rollouts( + rollout_ids=list(self._task_id_to_original_sample.keys()), timeout=0 + ) + for rollout in completed_batch: + if rollout.rollout_id in self._completed_rollouts_v0: + # Already processed, skip + continue + if isinstance(rollout, Rollout): + rollout = await self._validate_data_v1(rollout) + else: + self._validate_data(rollout) + if rollout.rollout_id not in self._task_id_to_original_sample: + print(f"Warning: Received unknown rollout ID {rollout.rollout_id}, skipping.") + else: + self._completed_rollouts_v0[rollout.rollout_id] = rollout + if verbose: + print(f"Completed {len(self._completed_rollouts_v0)}/{self._total_tasks_queued} tasks...") + await asyncio.sleep(5) + + print("All tasks finished.") + + def run_until_all_finished(self, verbose: bool = True): + """Synchronously waits for all queued tasks to be completed and reported.""" + if self._total_tasks_queued == 0: + print("Warning: No tasks were queued.") + return + + if self.mode == "v0": + if not self.server.loop or not self.server.startup_event.is_set(): + raise RuntimeError("Server is not running or ready.") + loop = self.server.loop + else: + loop = self._internal_loop + assert loop is not None + + coro = self._async_run_until_finished(verbose) + future = asyncio.run_coroutine_threadsafe(coro, loop) + try: + future.result() # Wait indefinitely for all tasks to complete + except Exception as e: + print(f"Error while waiting for tasks to finish: {e}") + raise + + def get_test_metrics(self): + """Calculates and returns metrics for a validation run.""" + assert not self.is_train, "This method should only be called during validation." + assert len(self._completed_rollouts_v0) == self._total_tasks_queued + + sample_stat_list: List[Dict[str, Any]] = [] + sample_stat_list_by_source: Dict[str, List[Dict[str, Any]]] = defaultdict( + list + ) # FIXME: Evaluate whether grouping stats by source is actually needed. + + for rollout_id, rollout in self._completed_rollouts_v0.items(): + final_reward_raw: Optional[float] = rollout.final_reward + final_reward = self._fillna_reward(rollout) + if not rollout.triplets: + print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.") + sample_stat_list.append({"reward": final_reward, "has_reward": final_reward_raw is not None}) + continue + response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets] + + if "data_source" in self._task_id_to_original_sample[rollout_id]: + # When a test sample includes a 'data_source' field, record per-source statistics for test results. + # TODO: This is a flawed design. We should have a better way to handle this. + data_source = self._task_id_to_original_sample[rollout_id]["data_source"] + sample_stat_list_by_source[data_source].append( + { + "sum_response_length": np.sum(response_length_list), + "mean_response_length": np.mean(response_length_list) if response_length_list else 0, + "turn_count": len(rollout.triplets), + "reward": final_reward, + "has_reward": final_reward_raw is not None, + } + ) + sample_stat_list.append( + { + "sum_response_length": np.sum(response_length_list), + "mean_response_length": np.mean(response_length_list) if response_length_list else 0, + "turn_count": len(rollout.triplets), + "reward": final_reward, + "has_reward": final_reward_raw is not None, + } + ) + metric_dict: Dict[str, Any] = {} + + stats_w_trace = [stat for stat in sample_stat_list if "sum_response_length" in stat] + stats_w_trace_by_source = { + data_source: [stat for stat in sample_stats if "sum_response_length" in stat] + for data_source, sample_stats in sample_stat_list_by_source.items() + } + for data_source, sample_stats in sample_stat_list_by_source.items(): + metric_dict.update( + { + f"val/{data_source}/n_rollouts": len(sample_stats), + f"val/{data_source}/n_rollouts_w_trace": len(stats_w_trace_by_source[data_source]), + f"val/{data_source}/n_rollouts_w_reward": len( + [stat for stat in sample_stats if stat["has_reward"]] + ), + f"val/{data_source}/reward": np.mean( + [stat["reward"] for stat in sample_stats] + ), # each rollout must have a reward (fillna if missing) + f"val/{data_source}/mean_response_length": np.mean( + [stat["mean_response_length"] for stat in stats_w_trace_by_source[data_source]] + ), + f"val/{data_source}/sum_response_length": np.mean( + [stat["sum_response_length"] for stat in stats_w_trace_by_source[data_source]] + ), + f"val/{data_source}/turn_count": np.mean( + [stat["turn_count"] for stat in stats_w_trace_by_source[data_source]] + ), + } + ) + metric_dict.update( + { + "val/n_rollouts": len(sample_stat_list), + "val/n_rollouts_w_trace": len(stats_w_trace), + "val/n_rollouts_w_reward": len([stat for stat in sample_stat_list if stat["has_reward"]]), + "val/reward": np.mean( + [stat["reward"] for stat in sample_stat_list] + ), # each rollout must have a reward (fillna if missing) + "val/mean_response_length": np.mean([stat["mean_response_length"] for stat in stats_w_trace]), + "val/sum_response_length": np.mean([stat["sum_response_length"] for stat in stats_w_trace]), + "val/turn_count": np.mean([stat["turn_count"] for stat in stats_w_trace]), + } + ) + return metric_dict + + def get_train_data_batch( + self, + max_prompt_length: int, + max_response_length: int, + device: torch.device, + use_final_reward_as_step_reward: bool = True, + use_intrinsic_reward: bool = False, + is_gigpo: bool = False, + ): + """ + Processes completed rollouts to generate a training data batch. + + This function reconstructs the logic from the original AgentModeDaemon, + using data retrieved from the new server architecture. It handles padding, + truncation, and tensor creation for the PPO training loop. + """ + assert self.is_train, "This method should only be called during training." + assert len(self._completed_rollouts_v0) == self._total_tasks_queued + + # 1. Reconstruct the `finished_id_to_sample_info` structure from completed rollouts + finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {} + finished_id_to_final_reward: Dict[str, float] = {} + sample_with_reward_count = 0 + for rollout_id, rollout in self._completed_rollouts_v0.items(): + original_sample = self._task_id_to_original_sample[rollout_id] + sample_with_reward_count += int(rollout.final_reward is not None) + final_reward = self._fillna_reward(rollout) + + if not rollout.triplets: + finished_id_to_final_reward[rollout_id] = final_reward + print(f"Warning: No triplets found for training rollout {rollout.rollout_id}, skipping.") + continue + + # The client should report triplets that contain prompt_ids and response_ids. + # Example triplet.prompt: {"token_ids": [...]} + # Example triplet.response: {"token_ids": [...]} + # trace_list = [ + # {"prompt_ids": t.prompt.get("token_ids", []), "response_ids": t.response.get("token_ids", [])} + # for t in rollout.triplets + # ] + trace_list = [] + for t in rollout.triplets: + trace_dict = { + "prompt_ids": t.prompt.get("token_ids", []), + "response_ids": t.response.get("token_ids", []), + "step_reward": t.reward, + "step_intrinsic_reward": t.metadata.get("intrinsic_reward", 0.0), + "message": t.metadata.get("message", ""), + } + + trace_list.append(trace_dict) + + info = { + "final_reward": final_reward, + "trace_list": trace_list, + "data_id": original_sample["data_id"], + } + finished_id_to_sample_info[rollout_id] = info + finished_id_to_final_reward[rollout_id] = final_reward + # + # --- Data processing and tensor creation logic --- + # Get all the reported data. + # prompt_ids are left-padded. + # response_ids are right-padded. + # They are concatenated in the middle. + # Discard handling: + # - Those exceeding max_prompt_length will be marked for discard, but not + # discarded here. They are only truncated and marked, to be discarded later. + # This is for the correctness of the advantage calculation. + # - The discard for the PPO mini-batch should also be handled this way. + input_ids_list: List[List[int]] = [] + input_attention_mask_list: List[List[int]] = [] + response_ids_list: List[List[int]] = [] + response_attention_mask_list: List[List[int]] = [] + final_reward_list: List[float] = [] + step_reward_list: List[float] = [] + data_id_list: List[str] = [] + rollout_id_list: List[str] = [] + turn_index_list: List[int] = [] + is_drop_list: List[bool] = [] + n_trunc_sample_because_of_response = 0 + + # optional fields + step_intrinsic_reward_list: List[float] = [] + message_list: List[str] = [] + + for rollout_id, sample_info in finished_id_to_sample_info.items(): + for turn_index, trace in enumerate(sample_info["trace_list"]): + + final_reward_list.append(sample_info["final_reward"]) + step_reward_list.append(trace["step_reward"]) + step_intrinsic_reward_list.append(trace["step_intrinsic_reward"]) + message_list.append(trace["message"]) + + prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"] + + # Mark samples with prompts exceeding max_prompt_length to be dropped later + if len(prompt_ids) > max_prompt_length: + prompt_ids = prompt_ids[:max_prompt_length] + is_drop_list.append(True) + else: + is_drop_list.append(False) + + # Truncate responses that exceed max_response_length + if len(response_ids) > max_response_length: + response_ids = response_ids[:max_response_length] + n_trunc_sample_because_of_response += 1 + + # Pad prompts to the left and responses to the right + one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask( + prompt_ids, max_prompt_length, self.pad_token_id + ) + one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask( + response_ids, max_response_length, self.pad_token_id + ) + + input_ids_list.append(one_input_ids) + input_attention_mask_list.append(one_input_attention_mask) + response_ids_list.append(one_response_ids) + response_attention_mask_list.append(one_response_attention_mask) + data_id_list.append(sample_info["data_id"]) + rollout_id_list.append(rollout_id) + turn_index_list.append(turn_index) + + n_transition = len(input_ids_list) + batch_input_ids = torch.LongTensor(input_ids_list).to(device) + input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device) + batch_response_ids = torch.LongTensor(response_ids_list).to(device) + response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device) + + # Concatenate prompts and responses to form the full sequence + batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1) + attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1) + position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0) + is_drop_mask = torch.BoolTensor(is_drop_list).to(device) + if use_final_reward_as_step_reward: + scores = torch.tensor(final_reward_list, dtype=torch.float32).to(device) + else: + scores = torch.tensor(step_reward_list, dtype=torch.float32).to(device) + + # Create token-level scores by placing the final reward at the last token position + token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) + # At the eos_mask_idx position of each sample, fill in the corresponding scores. + # torch.arange(n_transition) generates [0,1,2,...,bsz-1] as indices for the batch dimension. + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) + token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores + # Only take the last response_length part of the sequence to get the token-level scores for the model's response part. + token_level_scores = token_level_scores[:, -max_response_length:] + + # Create token-level intrinsic rewards + token_level_intrinsic_rewards = None + if use_intrinsic_reward: + step_intrinsic_reward_list = [0.0 if reward is None else reward for reward in step_intrinsic_reward_list] + intrinsic_rewards = torch.tensor(step_intrinsic_reward_list, dtype=torch.float32).to(device) + token_level_intrinsic_rewards = torch.zeros_like(attention_mask, dtype=intrinsic_rewards.dtype) + token_level_intrinsic_rewards[torch.arange(n_transition), eos_mask_idx] = intrinsic_rewards + token_level_intrinsic_rewards = token_level_intrinsic_rewards[:, -max_response_length:] + + # Form the final batch using TensorDict + batch_dict = { + "prompts": batch_input_ids, + "responses": batch_response_ids, + "input_ids": batch_seq, # here input_ids become the whole sentences + "attention_mask": attention_mask, + "position_ids": position_ids, + "is_drop_mask": is_drop_mask, + "token_level_scores": token_level_scores.contiguous(), + } + batch_dict["step_rewards"] = torch.tensor(np.array(step_reward_list), dtype=torch.float32).to(device) + if use_intrinsic_reward: + batch_dict["step_intrinsic_rewards"] = torch.tensor( + np.array(step_intrinsic_reward_list), dtype=torch.float32 + ).to(device) + batch_dict["token_level_intrinsic_rewards"] = token_level_intrinsic_rewards.contiguous() + + batch = TensorDict(batch_dict, batch_size=n_transition) + data_proto = DataProto(batch=batch) + + data_metrics = { + "training/reward": np.mean(list(finished_id_to_final_reward.values())), + "training/n_rollouts": len(finished_id_to_final_reward), + "training/n_rollouts_w_trace": len(finished_id_to_sample_info), + "training/n_rollouts_w_reward": sample_with_reward_count, + "training/n_truncated_triplets": n_trunc_sample_because_of_response, + "training/n_triplets": n_transition, + } + + # Add non-tensor data for advantage calculation and logging + data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list) # type: ignore + data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list) # type: ignore + data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list) # type: ignore + + data_proto.non_tensor_batch["step_rewards"] = np.array(step_reward_list) + if is_gigpo: + data_proto.non_tensor_batch["anchor_obs"] = np.array(message_list) + + return data_proto, data_metrics + + def clear_data_and_server(self): + """Resets the internal state of the daemon for the next run.""" + self.backend_llm_server_addresses = [] + self._completed_rollouts_v0.clear() + self._task_id_to_original_sample.clear() + self._total_tasks_queued = 0 + # For a true reset, the server's internal queues would also need clearing. + # This implementation assumes that `set_up_data_and_server` is called + # for each new run, effectively starting a fresh batch. + + def _fillna_reward(self, rollout: RolloutLegacy): + if rollout.final_reward is None: + if self.reward_fillna_value is not None: # type: ignore + final_reward = self.reward_fillna_value + else: + raise ValueError(f"Reward is None for rollout {rollout.rollout_id}, please check the reward function.") + else: + final_reward = rollout.final_reward + return final_reward diff --git a/contrib/agentlightning/contrib/algorithm/simulation_verl/trainer.py b/contrib/agentlightning/contrib/algorithm/simulation_verl/trainer.py new file mode 100644 index 000000000..7cac4769e --- /dev/null +++ b/contrib/agentlightning/contrib/algorithm/simulation_verl/trainer.py @@ -0,0 +1,543 @@ +# Copyright (c) Microsoft. All rights reserved. + +# type: ignore + +from __future__ import annotations + +import random +from contextlib import contextmanager +from copy import deepcopy +from pprint import pprint +from typing import Dict, Tuple, Type + +import numpy as np +import torch +import verl +from codetiming import Timer +from omegaconf import OmegaConf +from tqdm import tqdm +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.metric_utils import ( + _compute_response_info, + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.utils.metric import reduce_metrics +from verl.utils.tracking import Tracking + +from agentlightning.adapter import TraceAdapter, TraceToTripletBase +from agentlightning.llm_proxy import LLMProxy +from agentlightning.store.base import LightningStore + +from .daemon import SimulationAgentModeDaemon + +__all__ = [ + "SimulationAgentLightningTrainer", +] + + +@contextmanager +def _timer(name: str, timing_raw: Dict[str, float]): + with Timer(name=name, logger=None) as timer: + yield + if name not in timing_raw: + timing_raw[name] = 0 + timing_raw[name] += timer.last + + +# This function is adapted from verl. +# We introduce a new parameter `suffix` to distinguish between metrics computed +# before and after AgentLightning’s post-processing. +# - "Before" refers to raw reward and advantage values. +# - "After" refers to values computed following post-processing, which involves: +# (1) Dropping prompts that exceed the maximum allowed length. +# (2) Adjusting the batch size to be a multiple of the mini PPO size. +# Different suffixes are used to label these two stages accordingly. +def compute_data_metrics(batch: DataProto, use_critic: bool = True, suffix: str = "") -> Dict[str, Any]: + """ + Computes various metrics from a batch of data for PPO training. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) + - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + """ + sequence_score = batch.batch["token_level_scores"].sum(-1) + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] + + max_response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch["values"] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + metrics = { + # score + "critic/score/mean" + suffix: torch.mean(sequence_score).detach().item(), + "critic/score/max" + suffix: torch.max(sequence_score).detach().item(), + "critic/score/min" + suffix: torch.min(sequence_score).detach().item(), + # reward + "critic/rewards/mean" + suffix: torch.mean(sequence_reward).detach().item(), + "critic/rewards/max" + suffix: torch.max(sequence_reward).detach().item(), + "critic/rewards/min" + suffix: torch.min(sequence_reward).detach().item(), + # adv + "critic/advantages/mean" + suffix: torch.mean(valid_adv).detach().item(), + "critic/advantages/max" + suffix: torch.max(valid_adv).detach().item(), + "critic/advantages/min" + suffix: torch.min(valid_adv).detach().item(), + # returns + "critic/returns/mean" + suffix: torch.mean(valid_returns).detach().item(), + "critic/returns/max" + suffix: torch.max(valid_returns).detach().item(), + "critic/returns/min" + suffix: torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean" + suffix: torch.mean(valid_values).detach().item(), + "critic/values/max" + suffix: torch.max(valid_values).detach().item(), + "critic/values/min" + suffix: torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var" + suffix: (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), + # response length + "response_length/mean" + suffix: torch.mean(response_length).detach().item(), + "response_length/max" + suffix: torch.max(response_length).detach().item(), + "response_length/min" + suffix: torch.min(response_length).detach().item(), + "response_length/clip_ratio" + + suffix: torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + # prompt length + "prompt_length/mean" + suffix: torch.mean(prompt_length).detach().item(), + "prompt_length/max" + suffix: torch.max(prompt_length).detach().item(), + "prompt_length/min" + suffix: torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio" + + suffix: torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + return metrics + + +class SimulationAgentLightningTrainer(RayPPOTrainer): + """ + Specialized PPO trainer for agent-based reinforcement learning. + + This trainer is designed specifically for scenarios where the model interacts with + external environments, tools, or APIs through an AgentLightningServer. It simplifies + the training loop by removing the complex conditional logic present in the original + RayPPOTrainer and focusing on the agent mode workflow. + + Key differences from RayPPOTrainer: + + 1. Uses AgentModeDaemon for server communication + 2. Simplified data flow without pop/union operations + 3. Direct batch processing through agent daemon + 4. Streamlined validation using agent_mode validation + """ + + def __init__( + self, + store: LightningStore | None, + llm_proxy: LLMProxy | None, + adapter: TraceAdapter | None, + daemon_cls: Type[SimulationAgentModeDaemon], + **kwargs, + ): + super().__init__(**kwargs) + self.store = store + self.llm_proxy = llm_proxy + self.adapter = adapter + self.daemon_cls = daemon_cls + + def _validate(self): + assert len(self.val_dataloader) == 1, "Please set val_batch_size to None for better throughput." + + test_data = next(iter(self.val_dataloader)) + test_batch = DataProto.from_single_dict(test_data) + + self.async_rollout_manager.wake_up() + self.agent_mode_daemon.set_up_data_and_server( + test_batch.non_tensor_batch, + self.async_rollout_manager.server_addresses, + is_train=False, + ) + self.agent_mode_daemon.run_until_all_finished() + test_metrics = self.agent_mode_daemon.get_test_metrics() + self.agent_mode_daemon.clear_data_and_server() + self.async_rollout_manager.sleep() + return test_metrics + + def _compute_reference_log_prob(self, batch: DataProto) -> DataProto: + """Compute reference log probability using the correct worker based on LoRA configuration. + + In verl 0.6.0+, when LoRA is detected (indicated by ref_in_actor=True), + the reference policy is computed by the actor rollout worker instead of a separate + ref policy worker. This method handles both scenarios by checking the ref_in_actor flag. + Note: verl sets ref_in_actor=True when it detects LoRA configuration (e.g., lora_rank > 0 or lora_adapter_path is set). + + Args: + batch: The data batch to compute reference log probabilities for. + + Returns: + DataProto with reference log probabilities added. + + Raises: + RuntimeError: If the required worker is not available. + """ + if getattr(self, "ref_in_actor", False): + actor_worker = getattr(self, "actor_rollout_wg", None) + if actor_worker is None: + raise RuntimeError("actor_rollout_wg is required when ref_in_actor is True.") + return actor_worker.compute_ref_log_prob(batch) + + ref_worker = getattr(self, "ref_policy_wg", None) + if ref_worker is None: + raise RuntimeError( + "Reference policy worker was not initialized. " + "Ensure `use_reference_policy` is enabled and the VERL config exposes the ref worker." + ) + return ref_worker.compute_ref_log_prob(batch) + + def _train_step(self, batch_dict: dict) -> dict: + # Isolate in a separate method to automatically recycle the variables before validation. + batch: DataProto = DataProto.from_single_dict(batch_dict) + metrics = {} + timing_raw = {} + + with _timer("step", timing_raw): + + # When agent mode is enabled, we read the batch as it is. + gen_batch = batch + + # generate a batch + with _timer("gen", timing_raw): + self.async_rollout_manager.wake_up() + self.agent_mode_daemon.set_up_data_and_server( + gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses + ) + self.agent_mode_daemon.run_until_all_finished() + batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( + max_prompt_length=self.config.data.max_prompt_length, + max_response_length=self.config.data.max_response_length, + device=gen_batch.batch["fake_ids"].device, + use_final_reward_as_step_reward=self.config.algorithm.use_final_reward_as_step_reward, + use_intrinsic_reward=self.config.algorithm.use_intrinsic_reward, + ) + metrics.update(agent_metrics) + self.agent_mode_daemon.clear_data_and_server() + self.async_rollout_manager.sleep() + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with _timer("gen_max", timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + # uid is used for algorithm like GRPO, should be aligned to data id + batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"] + + batch.batch["response_mask"] = compute_response_mask(batch) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with _timer("reward", timing_raw): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + reward_extra_infos_dict = {} + + # for agent mode, pad the lengths to calculate old log prob, ref, and values + batch, pad_size = pad_dataproto_to_divisor(batch, self.actor_rollout_wg.world_size) + + # recompute old_log_probs + with _timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer("ref", timing_raw): + ref_log_prob = self._compute_reference_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with _timer("values", timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + # for agent mode, unpad to calculate adv + # it is important, as adv should be based on the raw traces + batch = unpad_dataproto(batch, pad_size=pad_size) + + with _timer("adv", timing_raw): + # if agent_mode is enabled, there is already token_level_scores + # token_level_scores is not needed to compute here + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + if self.config.algorithm.use_intrinsic_reward: + batch.batch["token_level_rewards"] = ( + batch.batch["token_level_scores"] + batch.batch["token_level_intrinsic_rewards"] + ) # (bs, seq_len) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details. + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing")) + + # after advantages are assinged, we begin to drop (1) long prompt (2) floor to ppo minisize + keep_indices = (~batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0] + metrics["training/n_triplets_prompt_too_long"] = ( + batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0] + ) + batch = batch[keep_indices] + # next, round to minibatch size + mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + n_transition = len(batch) + random_indices = list(range(n_transition)) + random.shuffle(random_indices) + batch.reorder(torch.tensor(random_indices).type(torch.int32)) + n_remained_transition = n_transition // mini_batch_size * mini_batch_size + batch = batch[list(range(n_remained_transition))] + metrics["training/n_triplets_dropped_remainder"] = n_transition - n_remained_transition + + # Agent mode note: Change the order of balance batch; + # 1. first calculate advantage + # 2. then drop the samples (too long prompt & floor to ppo minisize) + # 3. balance + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # update critic + if self.use_critic: + with _timer("update_critic", timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with _timer("update_actor", timing_raw): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with _timer("dump_rollout_generations", timing_raw): + print(batch.batch.keys()) + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + self._dump_generations( + inputs=inputs, + outputs=outputs, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) + + # compute training metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_after_processing")) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + return metrics + + def fit(self): + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + assert self.async_rollout_mode, "If agent mode is enabled, async server must be enabled" + if self.adapter is not None and not isinstance(self.adapter, TraceToTripletBase): + raise ValueError("Adapter must be a TraceToTripletBase for currently VERL implementation.") + verl_version = verl.__version__ + if verl_version == "0.5.0": + # Note (Zhiyuan): To avoid further patch into vllm async server, using the same sentence to get the naming here. + # However, it is possible that verl updates the naming and causes incompatibility. + # Reference: https://github.com/volcengine/verl/blob/5b5e09d9cc20625e436d01f69d9cc739ff681c54/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L217 + model = "/".join(self.config.actor_rollout_ref.model.path.split("/")[-2:]) + else: + # For other versions (e.g., 0.6.0), we use the full path to the model. + model = self.config.actor_rollout_ref.model.path + self.agent_mode_daemon = self.daemon_cls( + self.config.agentlightning.port, + self.config.actor_rollout_ref.rollout.n, + train_information={ + "model": model, + "temperature": self.config.actor_rollout_ref.rollout.temperature, + }, + tokenizer=self.tokenizer, + mini_batch_size=self.config.actor_rollout_ref.actor.ppo_mini_batch_size, + pad_token_id=self.tokenizer.pad_token_id, + mode="v1" if self.store is not None else "v0", + store=self.store, + llm_proxy=self.llm_proxy, + adapter=self.adapter, + ) + self.agent_mode_daemon.start() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + is_last_step = self.global_steps >= self.total_training_steps + + # train step + metrics = self._train_step(batch_dict) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with _timer("validate", timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with _timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # step metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + + # This exit logic is to ensure a robust CI. + pprint(f"Flush the logger...") + del logger # Make sure the loggers are flushed and closed properly + pprint(f"Training finished at step {self.global_steps}.") + return + + progress_bar.update(1) + self.global_steps += 1 diff --git a/contrib/recipes/simulation/README.md b/contrib/recipes/simulation/README.md new file mode 100644 index 000000000..ae6a05c9c --- /dev/null +++ b/contrib/recipes/simulation/README.md @@ -0,0 +1,97 @@ + + + +# Simulation Example + +## Overview + +This example implements agents across various simulation environments within Agent Lightning. +The example is designed to run on a single node with 8 GPUs, each having at least 40 GB of memory. + +This example depends on the simulation environments (e.g., ALFWorld, ScienceWorld) provided in the [agl-envs repository](https://github.com/agent-lightning/agl-envs). +For more information about the supported simulation environments, please refer to the [simulation README](https://github.com/agent-lightning/agl-envs/simulation/README.md). + +--- + +  + +## Included Files + +| File/Directory | Description | +|----------------|-------------| +| `config_env/` | Configuration for environment settings. For more information, please refer to the "Configure Your Environment Settings" section. | +| `config_verl/` | Configuration for RL training with VerL | +| `add_instruction.py` | Adding instructions to the agent’s input prompt to guide the format of the response | +| `prompt_builder.py` | Managing conversation history and generating input prompts in multi-turn scenarios | +| `train_simulation_agent.py` | RL training script | + +--- + +  + +## Install Simulation Environments + +Run the following script once to install simulation enviornment and related AGL dependency: + +```bash +cd contrib/recipes/simulation + +git clone https://github.com/agent-lightning/agl-envs +mv agl-envs agl_envs + +# Install alfworld dependency +bash agl_envs/simulation/setup/setup_alfworld.sh +conda activate alfworld + +# Install scienceworld dependency +bash agl_envs/simulation/setup/setup_sciworld.sh +conda activate sciworld + +# Install AGL dependency +bash install_agl.sh +``` + +> If you plan to use WandB for experiment tracking, log in to WandB before training. + +--- + +  + +## Configure Your Environment Settings +### Captioner type (cot or naive) +- cot: guide the agent to output its reasoning first, then take an action +- naive: guide the agent to take an action directly, without outputting any reasoning + +### Prompt type (chat or single) + +When performing multi-turn rollouts, the unit of the input prompt can be defined in two different ways. + +(1) **Trajectory-wise unit**: +All interaction history up to the current step is accumulated in a chat format and directly used to construct the next input prompt. + +(2) **Turn-wise unit**: +Only a subset of the interaction history is included for each turn. The prompt is reconstructed by combining the current turn’s state with selected past information, rather than using the full trajectory. + +![prompt_type](./assets/prompt_type.png) + + +You can use the `trajectory-wise unit` by setting `prompt_type` to `chat`, and the `turn-wise unit` by setting `prompt_type` to `single`. Currently, for ALFWorld, we only support the `single` mode, while for ScienceWorld, both `chat` and `single` modes are supported. + +We follow the single-mode prompt for ALFWorld from [verl-agent](https://github.com/langfengQ/verl-agent) and the single-mode prompt for ScienceWorld from [RLVMR](https://github.com/Tencent/digitalhuman/tree/main/RLVMR). Thank you to the authors of VERL-Agent and RLVMR for their valuable work. + +--- + +  + +## Run RL Training (GRPO) + +```bash +# Run alfworld +python3 train_simulation_agent.py --algorithm grpo --env alfworld + +# Run scienceworld single task task_num 0 +python3 train_simulation_agent.py --algorithm grpo --env scienceworld --task_num 0 + +# Run scienceworld multi-task +python3 train_simulation_agent.py --algorithm grpo --env scienceworld --task_num -1 +``` diff --git a/contrib/recipes/simulation/add_instruction.py b/contrib/recipes/simulation/add_instruction.py new file mode 100644 index 000000000..627500153 --- /dev/null +++ b/contrib/recipes/simulation/add_instruction.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft. All rights reserved. + +import copy + +from autogen_core.models import UserMessage + +# Instruction text definitions +COT_INSTRUCTION = """ +Now it's your turn to take an action. You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags. +Once you've finished your reasoning, you should choose an appropriate action for the current step and present it within tags. +""".strip() + +NAIVE_INSTRUCTION = """ +Please response with only one line with one sentence, following the possible action format shown above. No extra words are allowed. +""".strip() + +# Mapping for instruction text types +INSTRUCTION_MAP = { + "cot": COT_INSTRUCTION, + "naive": NAIVE_INSTRUCTION, +} + + +def _get_instruction(type: str, env_name: str = None): + """ + Retrieve an instruction string from INSTRUCTION_MAP based on the given type. + + Args: + type (str): Instruction type key (e.g., "cot", "naive", "critic", "tip"). + env_name (str, optional): Currently unused. Reserved for future + environment-specific instruction handling. + + Returns: + str: The corresponding instruction text. + + Raises: + ValueError: If the given instruction type is not found in INSTRUCTION_MAP. + """ + if type in INSTRUCTION_MAP: + return INSTRUCTION_MAP[type] + else: + raise ValueError(f"Unknown instruction type: {type}") + + +def add_chat_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = None): + """ + Append an instruction to the content of the last message in a chat-style prompt. + + This function does not modify the original prompt. Instead, it returns a + deep-copied prompt list with the instruction appended. + + Args: + prompt (list): A conversation history represented as a list of objects. + Each object must have a `.content` attribute. + type (str): Instruction type key (e.g., "cot", "naive", "critic", "tip"). + sep (str, optional): Separator inserted between the existing content + and the instruction. + env_name (str, optional): Currently unused. Reserved for future use. + + Returns: + list: A new prompt list with the instruction appended to the last message. + """ + new_prompt = copy.deepcopy(prompt) + instruction = _get_instruction(type, env_name) + new_prompt[-1].content += sep + instruction + + return new_prompt + + +def add_single_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = None): + """ + Append an instruction to a single prompt or a chat-style prompt. + + - If `prompt` is a string, the instruction is appended to the string. + - If `prompt` is a list, the instruction is appended to the `.content` + of the last message. + + Args: + prompt (str or list): Either a single prompt string or a conversation + history list whose elements have a `.content` attribute. + type (str): Instruction type key (e.g., "cot", "naive", "critic", "tip"). + sep (str, optional): Separator inserted between the existing content + and the instruction. + env_name (str, optional): Currently unused. Reserved for future use. + + Returns: + str or list: The updated prompt with the instruction appended. + + Raises: + TypeError: If `prompt` is neither a string nor a list. + """ + instruction = _get_instruction(type, env_name) + + if isinstance(prompt, str): + return prompt + sep + instruction + elif isinstance(prompt, list): + new_prompt = copy.deepcopy(prompt) + new_prompt[-1].content += sep + instruction + return new_prompt + else: + raise TypeError("Prompt must be a string or a list of strings") diff --git a/contrib/recipes/simulation/assets/agl_simulation.png b/contrib/recipes/simulation/assets/agl_simulation.png new file mode 100644 index 000000000..2c2f00f00 Binary files /dev/null and b/contrib/recipes/simulation/assets/agl_simulation.png differ diff --git a/contrib/recipes/simulation/assets/prompt_type.png b/contrib/recipes/simulation/assets/prompt_type.png new file mode 100644 index 000000000..da6349ddc Binary files /dev/null and b/contrib/recipes/simulation/assets/prompt_type.png differ diff --git a/contrib/recipes/simulation/config_env/alfworld.yaml b/contrib/recipes/simulation/config_env/alfworld.yaml new file mode 100644 index 000000000..1d1020993 --- /dev/null +++ b/contrib/recipes/simulation/config_env/alfworld.yaml @@ -0,0 +1,16 @@ +env_name: alfworld +seed: 0 +format_penalty: 0.1 +binary_reward: False +save_rollout: False +log_env_obs: False +reawrd_scale: 10.0 +use_success_rate: False + +captioner: + type: cot # naive or cot + prompt_type: single # chat or single + max_history: 2 + +alfworld_kwargs: + max_steps: 50 diff --git a/contrib/recipes/simulation/config_env/scienceworld.yaml b/contrib/recipes/simulation/config_env/scienceworld.yaml new file mode 100644 index 000000000..08b9cad9a --- /dev/null +++ b/contrib/recipes/simulation/config_env/scienceworld.yaml @@ -0,0 +1,16 @@ +env_name: scienceworld +seed: 0 +format_penalty: 0.1 +binary_reward: False +save_rollout: False +log_env_obs: False # True for GiGPO +reawrd_scale: 10.0 +use_success_rate: True + +# only for scienceworld +use_action_correction: False + +captioner: + type: cot # naive or cot + prompt_type: single # chat or single + max_history: 2 diff --git a/contrib/recipes/simulation/config_verl/alfworld/grpo.yaml b/contrib/recipes/simulation/config_verl/alfworld/grpo.yaml new file mode 100644 index 000000000..71a7d3c91 --- /dev/null +++ b/contrib/recipes/simulation/config_verl/alfworld/grpo.yaml @@ -0,0 +1,86 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 2 + MINI_BATCH_SIZE: 32 + PER_GPU_BATCH_SIZE: 16 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-1.5B-Instruct + PROJECT_NAME: AGL-Simulation-ALFWorld + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: grpo-alfworld-${variables.TRIAL} + DATA_DIR: agl_envs/simulation/task_data/alfworld + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 32 + val_batch_size: 140 + max_prompt_length: 2048 + max_response_length: 512 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.6 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: true + kl_loss_coef: 0.01 + kl_loss_type: low_var_kl + entropy_coeff: 0.001 + fsdp_config: + param_offload: false + optimizer_offload: false + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 100 + test_freq: 5 + total_epochs: 200 diff --git a/contrib/recipes/simulation/config_verl/scienceworld/grpo.yaml b/contrib/recipes/simulation/config_verl/scienceworld/grpo.yaml new file mode 100644 index 000000000..d6c968604 --- /dev/null +++ b/contrib/recipes/simulation/config_verl/scienceworld/grpo.yaml @@ -0,0 +1,87 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 2 + MINI_BATCH_SIZE: 32 + PER_GPU_BATCH_SIZE: 16 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-1.5B-Instruct + PROJECT_NAME: AGL-Simulation-ScienceWorld + TASK_NUM: ${oc.env:TASK_NUM,-1} + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: grpo-sciworld-${variables.TRIAL} + DATA_DIR: agl_envs/simulation/task_data/scienceworld/multi_data + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 32 + val_batch_size: 144 + max_prompt_length: 6000 + max_response_length: 1024 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.6 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: true + kl_loss_coef: 0.01 + kl_loss_type: low_var_kl + entropy_coeff: 0.001 + fsdp_config: + param_offload: false + optimizer_offload: false + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: true + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 100 + test_freq: 5 + total_epochs: 500 diff --git a/contrib/recipes/simulation/install_agl.sh b/contrib/recipes/simulation/install_agl.sh new file mode 100644 index 000000000..4def754ec --- /dev/null +++ b/contrib/recipes/simulation/install_agl.sh @@ -0,0 +1,13 @@ +# This setup is based on CUDA 12.6 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126 +pip install transformers==4.56.1 +pip install wandb +pip install vllm==0.10.2 +pip install verl==0.5.0 +pip install click==8.2.1 +pip install --extra-index-url https://miropsota.github.io/torch_packages_builder flash_attn==2.8.3+pt2.8.0cu126 +pip install 'openai-agents[litellm]'==0.2.9 +pip install -U "autogen-agentchat" "autogen-ext[openai]" + +(cd ../../../ && pip install -e .[dev]) + diff --git a/contrib/recipes/simulation/prompt_builder.py b/contrib/recipes/simulation/prompt_builder.py new file mode 100644 index 000000000..905285d55 --- /dev/null +++ b/contrib/recipes/simulation/prompt_builder.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Optional + +from autogen_core.models import AssistantMessage, UserMessage + + +class HistoryPromptBuilder: + """ + Builds prompts using a history of observations and actions. + + Supports two prompt styles: + - chat: multi-turn user/assistant messages + - single: a single formatted prompt with optional history + """ + + def __init__(self, max_history: int = -1, prompt_type: str = "chat"): + """ + Args: + max_history (int): Maximum number of past steps to include + (-1 means unlimited). + prompt_type (str): Prompt style ("chat" or "single"). + """ + self.max_history = max_history + self.prompt_type = prompt_type + + self._events = [] + self.admissible_actions = None + + self.step_count = 1 + + def update_step_count(self): + """Increment the current step counter.""" + self.step_count += 1 + + def update_instruction_prompt(self, instruction: str): + """Set the instruction/system prompt used in chat mode.""" + self.instruction = instruction + + def update_single_obs_template(self, single_obs_template_wo_his: str, single_obs_template: str): + """ + Set templates for single-prompt mode. + + Args: + single_obs_template_wo_his (str): Template without history. + single_obs_template (str): Template with history. + """ + self.single_obs_template_wo_his = single_obs_template_wo_his + self.single_obs_template = single_obs_template + + def update_observation(self, obs: dict): + """Append an observation to the event history.""" + self._events.append( + { + "type": "observation", + "text": obs, + } + ) + + def update_action(self, action: str): + """Append an action to the event history.""" + self._events.append( + { + "type": "action", + "action": action, + } + ) + + def update_admissible_actions(self, admissible_actions): + """Update the list of admissible actions for the current step.""" + self.admissible_actions = admissible_actions + + def init(self, env): + """ + Initialize the prompt builder at the beginning of an episode. + + - Clears the event history + - Loads prompt instructions or templates from the environment + """ + self._events.clear() + + if self.prompt_type == "chat": + inst_prompt = env.get_instruction_prompt(info) + self.update_instruction_prompt(inst_prompt) + elif self.prompt_type == "single": + template_wo_his, template = env.get_single_prompt_template() + self.update_single_obs_template(template_wo_his, template) + else: + raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") + + def get_chat_prompt(self): + """ + Construct a chat-style prompt from the event history. + + Returns: + List[Message]: A sequence of User and Assistant messages. + """ + if self.max_history != -1: + events = self._events[-(self.max_history * 2 + 1) :] + else: + events = self._events + + messages = [] + + for idx, event in enumerate(events): + event_type = event.get("type") + message = None + + if event_type == "observation": + content = event.get("text", "") + # Attach instruction prompt to the first observation + if idx == 0 and self.instruction: + content += "\n" + self.instruction + message = UserMessage(source="user", content=content) + + elif event_type == "action": + content = event.get("action", "") + message = AssistantMessage(source="assistant", content=content) + + if message: + messages.append(message) + + return messages + + def get_single_prompt(self): + """ + Construct a single formatted prompt using templates. + + Returns: + List[Message]: A single User message. + """ + if self.max_history != -1: + events = self._events[-(self.max_history * 2 + 1) :] + else: + events = self._events + + current_obs = events[-1]["text"] + + # Case 1: No history available + if len(events) == 1: + template = self.single_obs_template_wo_his + kwargs = {"current_observation": current_obs} + if "{admissible_actions}" in template: + kwargs["admissible_actions"] = self.admissible_actions + single_prompt = template.format(**kwargs) + + # Case 2: History exists + else: + template = self.single_obs_template + history = "" + obs_count = 0 + for idx, event in enumerate(events): + if events[idx]["type"] == "observation" and idx != len(events) - 1: + next_event = events[idx + 1] + history += f"[Observation {max(self.step_count-self.max_history+obs_count, 1)}: '{event['text']}', " + history += ( + f"Action {max(self.step_count-self.max_history+obs_count, 1)}: '{next_event['action']}']\n " + ) + obs_count += 1 + + kwargs = { + "step_count": self.step_count - 1, + "history_length": min(self.step_count - 1, self.max_history), + "history": history, + "current_step": self.step_count, + "current_observation": current_obs, + } + if "{admissible_actions}" in template: + kwargs["admissible_actions"] = self.admissible_actions + single_prompt = template.format(**kwargs) + + return [UserMessage(source="user", content=single_prompt)] + + def get_prompt(self): + """ + Return the final prompt based on the configured prompt type. + """ + if self.prompt_type == "chat": + prompt = self.get_chat_prompt() + elif self.prompt_type == "single": + prompt = self.get_single_prompt() + + return prompt diff --git a/contrib/recipes/simulation/train_simulation_agent.py b/contrib/recipes/simulation/train_simulation_agent.py new file mode 100644 index 000000000..a8469051b --- /dev/null +++ b/contrib/recipes/simulation/train_simulation_agent.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft. All rights reserved. + +import argparse +import os +import subprocess + +from omegaconf import OmegaConf + +from agentlightning import Trainer +from agentlightning.algorithm.verl import VERL +from contrib.agentlightning.contrib.algorithm.simulation_verl.daemon import SimulationAgentModeDaemon +from contrib.agentlightning.contrib.algorithm.simulation_verl.trainer import SimulationAgentLightningTrainer + + +def run_cmd(cmd): + """Execute a shell command and print its output""" + print(f"👉 Running: {cmd}") + result = subprocess.run(cmd, shell=True, text=True, capture_output=True) + if result.stdout: + print(result.stdout) + if result.stderr: + print(result.stderr) + return result + + +def kill_process_on_port(port): + result = subprocess.run(f"sudo lsof -t -i :{port}", shell=True, capture_output=True, text=True) + pids = result.stdout.strip().split("\n") + for pid in pids: + if pid: + print(f"🔪 Killing process {pid} on port {port}") + subprocess.run(f"sudo kill -9 {pid}", shell=True) + + +def train_val_dataset(cfg): + """Load training and validation datasets from parquet files.""" + from datasets import Dataset + + train_data = Dataset.from_parquet(cfg["data"]["train_files"]) + val_data = Dataset.from_parquet(cfg["data"]["val_files"]) + return train_data, val_data + + +def get_config(path): + cfg = OmegaConf.load(path) + OmegaConf.resolve(cfg) + if "variables" in cfg: + del cfg["variables"] + return cfg + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="scienceworld") + parser.add_argument("--algorithm", type=str, default="grpo") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--n_workers", type=int, default=64, help="Number of workers for training") + parser.add_argument("--trial", type=int, default=0, help="Number of trials") + parser.add_argument("--task_num", type=int, default=25, help="ScienceWorld Task number to inject as env var") + parser.add_argument("--_background", action="store_true", help=argparse.SUPPRESS) + args = parser.parse_args() + + # Restart Ray cluster cleanly + kill_process_on_port(4747) + run_cmd("pkill -f AgentLightning") + run_cmd("ray stop") + run_cmd("env RAY_DEBUG=legacy HYDRA_FULL_ERROR=1 VLLM_USE_V1=1 ray start --head --dashboard-host=0.0.0.0") + + # set environment variable before loading configs + os.environ["TRIAL"] = str(args.trial) + if args.env == "scienceworld": + os.environ["TASK_NUM"] = str(args.task_num) + + # Load configs + agent_config_path = f"config_env/{args.env}.yaml" + if args.debug: + trainer_config_path = f"config_verl/{args.env}/debug/{args.algorithm}.yaml" + else: + trainer_config_path = f"config_verl/{args.env}/{args.algorithm}.yaml" + agent_config = get_config(agent_config_path) + + if "gigpo" in args.algorithm: + agent_config.log_env_obs = True + rl_training_config = get_config(trainer_config_path) + + # Load datasets + train_dataset, val_dataset = train_val_dataset(rl_training_config) + + # Initialize agent + from contrib.agentlightning.contrib.agent.simulation_agent import SimulationAgent + + agent = SimulationAgent(agent_config) + + # Initialize trainer and start training + trainer = Trainer( + algorithm=VERL( + config=rl_training_config, + trainer_cls=SimulationAgentLightningTrainer, + daemon_cls=SimulationAgentModeDaemon, + ), + n_workers=args.n_workers, + ) + trainer.fit(agent, train_dataset, val_dataset=val_dataset)