Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 47 additions & 33 deletions platoon/agents/actions/subagent.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,71 @@
import asyncio
from typing import cast
from typing import Any, cast

from platoon.agents.base import ForkableAgent
from platoon.envs.base import ForkableEnv
from platoon.episode.context import budget_tracker, current_agent, current_env, episode_step_timeout
from platoon.episode.context import budget_tracker, current_agent, current_env, current_trajectory, episode_step_timeout
from platoon.episode.loop import run_episode
from platoon.episode.trajectory import BudgetExceededError
from platoon.utils.span_profile import profile_span


async def launch_subagent(goal: str, max_steps: int = 15, task_misc: dict | None = None) -> str:
async def launch_subagent(goal: str, max_steps: int = 15, task_misc: dict | None = None, verbose: bool = True) -> Any:
"""Launch a subagent to solve a task.

Args:
goal: The goal of the subagent.
max_steps: The maximum number of steps the subagent can take.

Returns:
Returns the answer or finish message for the goal.
Returns the result of the subagent's execution.
"""
# Cast is safe here: launch_subagent only works in contexts with forkable agents/envs
agent = cast(ForkableAgent, current_agent.get())
env = cast(ForkableEnv, current_env.get())
task = env.task

subtask = task.fork(goal, max_steps, task_misc=task_misc)
forked_agent = await agent.fork(subtask)
forked_env = await env.fork(subtask)

try:
budget_tracker.get().reserve_budget(max_steps + 1, raise_on_failure=True)
except (BudgetExceededError, ValueError) as e:
guidance = getattr(e, "guidance", "")
msg = f"Not enough budget to launch subagent for goal {goal}. {e}"
if guidance:
msg += " " + guidance
return msg

traj = await asyncio.create_task(
run_episode(
forked_agent,
forked_env,
timeout=episode_step_timeout.get(),
)
)
parent_traj = current_trajectory.get()
async with profile_span(
"launch_subagent",
metadata={
"goal_len": len(goal),
"max_steps": max_steps,
"parent_task_id": getattr(task, "id", None),
"parent_trajectory_id": parent_traj.id,
},
):
subtask = task.fork(goal, max_steps, task_misc=task_misc)
forked_agent = await agent.fork(subtask)
forked_env = await env.fork(subtask)

try:
budget_tracker.get().reserve_budget(max_steps + 1, raise_on_failure=True)
except (BudgetExceededError, ValueError) as e:
guidance = getattr(e, "guidance", "")
msg = f"Not enough budget to launch subagent for goal {goal}. {e}"
if guidance:
msg += " " + guidance
return msg

budget_tracker.get().release_budget(max_steps + 1)
try:
traj = await asyncio.create_task(
run_episode(
forked_agent,
forked_env,
timeout=episode_step_timeout.get(),
)
)
finally:
budget_tracker.get().release_budget(max_steps + 1)

used_recursive = int(budget_tracker.get().used_budget_for(traj.id))
remaining_total = int(budget_tracker.get().remaining_budget())
used_recursive = int(budget_tracker.get().used_budget_for(traj.id))
remaining_total = int(budget_tracker.get().remaining_budget())

budget_message = (
f"\n\nBudget used by subagent: {used_recursive}/{max_steps} steps. "
f"Total remaining budget for the current task is {remaining_total} steps.\n"
)
budget_message = (
f"\n\nBudget used by subagent: {used_recursive}/{max_steps} steps. "
f"Total remaining budget for the current task is {remaining_total} steps.\n"
)

return (traj.finish_message or traj.error_message or "") + budget_message
if verbose:
return (traj.finish_message or traj.error_message or "") + budget_message
else:
return traj.finish_message or traj.error_message or ""
49 changes: 45 additions & 4 deletions platoon/agents/codeact/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
import uuid
from typing import cast

from openai.types.chat import ChatCompletionMessageParam
Expand All @@ -9,7 +10,10 @@
from platoon.agents.codeact.prompt_builder import CodeActPromptBuilder, PromptMode
from platoon.envs.base import Task
from platoon.envs.codeact import CodeActAction, CodeActObservation
from platoon.episode.context import current_trajectory
from platoon.utils import async_hang_debug
from platoon.utils.llm_client import LLMClient
from platoon.utils.span_profile import profile_span


def extract_code_and_thought(raw_action: str) -> tuple[str, str]:
Expand Down Expand Up @@ -125,6 +129,8 @@ async def act(self, obs: CodeActObservation) -> CodeActAction:
return self._stuck_in_loop_action()

prompt = cast(list[ChatCompletionMessageParam], self.prompt_builder.build_messages(obs))
current_traj = current_trajectory.get(None)
request_id = str(uuid.uuid4())
request_kwargs = {
"stop": ["</python>"],
"max_completion_tokens": self.inference_params.max_completion_tokens,
Expand All @@ -135,11 +141,46 @@ async def act(self, obs: CodeActObservation) -> CodeActAction:
if self.inference_params.top_p is not None:
request_kwargs["top_p"] = self.inference_params.top_p

response = await self.llm_client.async_chat_completion(
prompt,
# Stop sequence is agent-level behavior, not a global rollout knob.
**request_kwargs,
async_hang_debug.track_current_task(
request_id=request_id,
kind="agent_llm",
metadata={
"task_id": obs.task.id,
"trajectory_id": current_traj.id if current_traj is not None else None,
"parent_trajectory_id": (
current_traj.parent_info.id
if current_traj is not None and current_traj.parent_info is not None
else None
),
"step_index": len(obs.history),
"message_count": len(prompt),
"model": getattr(self.llm_client, "model", None),
"timeout": request_kwargs["timeout"],
},
)
try:
async with profile_span(
"agent_act",
metadata={
"task_id": obs.task.id,
"trajectory_id": current_traj.id if current_traj is not None else None,
"parent_trajectory_id": (
current_traj.parent_info.id
if current_traj is not None and current_traj.parent_info is not None
else None
),
"step_index": len(obs.history),
"message_count": len(prompt),
"model": getattr(self.llm_client, "model", None),
},
):
response = await self.llm_client.async_chat_completion(
prompt,
# Stop sequence is agent-level behavior, not a global rollout knob.
**request_kwargs,
)
finally:
async_hang_debug.untrack(request_id)
response_text = response.choices[0].message.content or ""
# NOTE: We only do this conditionally, because with Areal, stop words are not supported.
# And so we might already have the stop word in the response.
Expand Down
133 changes: 107 additions & 26 deletions platoon/envs/codeact/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
finish_message,
)
from platoon.utils.ipython_shell import ShellCapture, strip_ansi_escape_sequences
from platoon.utils.span_profile import profile_span

from .types import (
CodeActAction,
Expand Down Expand Up @@ -76,24 +77,34 @@ async def evaluate(self) -> tuple[float, dict]:
return 0.0, {}

async def step(self, action: CodeActAction) -> CodeActObservation:
step = await self._code_executor.run(action.parsed_code)

if finish_message.get(None) is not None or error_message.get(None) is not None:
self._state.finished = True
self._state.misc["finish_message"] = finish_message.get()

step.thought = action.parsed_thought
step.reward, reward_info = await self.evaluate()
step.misc["action_misc"] = action.misc
step.misc["reward_misc"] = reward_info
self._state.reward += step.reward
self._state.history.append(step)

traj_collection = current_trajectory_collection.get()
traj = current_trajectory.get()
traj_collection.add_trajectory_step(traj.id, self._state.history[-1])
if self._state.finished:
traj.reward = self._state.reward
async with profile_span(
"env_step",
metadata={
"task_id": self._state.task.id,
"trajectory_id": traj.id,
"parent_trajectory_id": traj.parent_info.id if traj.parent_info is not None else None,
"step_index": len(self._state.history),
"code_len": len(action.parsed_code or ""),
},
):
step = await self._code_executor.run(action.parsed_code)

if finish_message.get(None) is not None or error_message.get(None) is not None:
self._state.finished = True
self._state.misc["finish_message"] = finish_message.get()

step.thought = action.parsed_thought
step.reward, reward_info = await self.evaluate()
step.misc["action_misc"] = action.misc
step.misc["reward_misc"] = reward_info
self._state.reward += step.reward
self._state.history.append(step)

traj_collection = current_trajectory_collection.get()
traj_collection.add_trajectory_step(traj.id, self._state.history[-1])
if self._state.finished:
traj.reward = self._state.reward

return await self.observe()

Expand Down Expand Up @@ -264,19 +275,67 @@ def visit_Call(self, node: ast.Call) -> None:
self.generic_visit(node)


class WhileLoopDetector(ast.NodeVisitor):
"""AST visitor to detect while loops before execution.

While loops are easy for the model to misuse in ways that block the single
in-process event loop forever. Prefer bounded `for` loops or iteration over
concrete collections/ranges instead.
"""

def __init__(self):
self.errors: list[str] = []

def visit_While(self, node: ast.While) -> None:
condition_src = ast.unparse(node.test) if hasattr(ast, "unparse") else "<condition>"
self.errors.append(
"ERROR: `while` loops are not allowed in this environment because they can hang the shared "
"Python executor. Consider rewriting this as a bounded `for` loop (for example, `for i in range(...)`) "
f"or iterate over a concrete collection instead. Detected condition: `{condition_src}`"
)
self.generic_visit(node)


class InteractiveInputDetector(ast.NodeVisitor):
"""AST visitor to detect blocking interactive input calls before execution."""

INTERACTIVE_CALLS = {"input"}

def __init__(self):
self.errors: list[str] = []

def visit_Call(self, node: ast.Call) -> None:
func_name = None
if isinstance(node.func, ast.Name):
func_name = node.func.id
elif isinstance(node.func, ast.Attribute):
func_name = node.func.attr

if func_name in self.INTERACTIVE_CALLS:
self.errors.append(
"ERROR: Interactive input functions like `input()` are not allowed in this environment because "
"they block the shared Python executor waiting for terminal input that will never arrive."
)
self.generic_visit(node)


class IPythonCodeExecutor(CodeExecutor):
# TODO: Separate actions and modules? Use this info to build action space description?
def __init__(
self,
task: Task,
actions: tuple[Callable[..., object], ...] | Sequence[Callable[..., object]] = (finish, safe_asyncio),
detect_unawaited_async_calls: bool = True,
detect_while_loops: bool = False,
detect_interactive_input: bool = False,
):
self.task = task
self.actions = actions
self.shell = self._create_shell()
# self.timeout_seconds = timeout_seconds
self.detect_unawaited_async_calls = detect_unawaited_async_calls
self.detect_while_loops = detect_while_loops
self.detect_interactive_input = detect_interactive_input

def _create_shell(self) -> InteractiveShellEmbed:
original_excepthook = sys.excepthook
Expand Down Expand Up @@ -331,17 +390,39 @@ async def run(self, code: str) -> CodeActStep:
if detector.errors:
return CodeActStep(code=code, error="\n".join(detector.errors))

with ShellCapture() as capture:
await self.shell.run_cell_async(code)
if self.detect_while_loops:
detector = WhileLoopDetector()
detector.visit(tree)
if detector.errors:
return CodeActStep(code=code, error="\n".join(detector.errors))

cap_stdout = strip_ansi_escape_sequences(capture.pop_stdout())
cap_stderr = strip_ansi_escape_sequences(capture.pop_stderr())
if self.detect_interactive_input:
detector = InteractiveInputDetector()
detector.visit(tree)
if detector.errors:
return CodeActStep(code=code, error="\n".join(detector.errors))

# TODO: This might cause unexpected filtering of outputs.
# Guard against empty stdout before indexing first line
first_line = cap_stdout.splitlines()[0] if cap_stdout.splitlines() else ""
if cap_stdout.startswith("Out[") or ("[?7hOut[1]:" in first_line):
cap_stdout = "".join(cap_stdout.split(":")[1:])
traj = current_trajectory.get()
async with profile_span(
"code_executor_run",
metadata={
"task_id": self.task.id,
"trajectory_id": traj.id,
"parent_trajectory_id": traj.parent_info.id if traj.parent_info is not None else None,
"code_len": len(code),
},
):
with ShellCapture() as capture:
await self.shell.run_cell_async(code)

cap_stdout = strip_ansi_escape_sequences(capture.pop_stdout())
cap_stderr = strip_ansi_escape_sequences(capture.pop_stderr())

# TODO: This might cause unexpected filtering of outputs.
# Guard against empty stdout before indexing first line
first_line = cap_stdout.splitlines()[0] if cap_stdout.splitlines() else ""
if cap_stdout.startswith("Out[") or ("[?7hOut[1]:" in first_line):
cap_stdout = "".join(cap_stdout.split(":")[1:])

return CodeActStep(
code=code,
Expand Down
23 changes: 18 additions & 5 deletions platoon/episode/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,31 @@
finish_message,
)
from platoon.episode.trajectory import StepBudgetTracker, Trajectory, TrajectoryCollection
from platoon.utils.span_profile import profile_span


# NOTE: Call using asyncio.create_task() to make sure edits to contextvars do not leak to parent context
async def run_episode(agent: Agent, env: Env, verbose: bool = False, timeout: int | None = 300) -> Trajectory:
try:
step_count = 0
set_context_vars(agent, env, timeout=timeout)
obs = await env.reset()
while not halt_episode(obs):
action = await asyncio.wait_for(agent.act(obs), timeout=timeout)
obs = await asyncio.wait_for(env.step(action), timeout=timeout)
step_count += 1
traj = current_trajectory.get()
async with profile_span(
"run_episode",
metadata={
"agent_type": type(agent).__name__,
"env_type": type(env).__name__,
"task_id": getattr(env.task, "id", None),
"timeout": timeout,
"trajectory_id": traj.id,
"parent_trajectory_id": traj.parent_info.id if traj.parent_info is not None else None,
},
):
obs = await env.reset()
while not halt_episode(obs):
action = await asyncio.wait_for(agent.act(obs), timeout=timeout)
obs = await asyncio.wait_for(env.step(action), timeout=timeout)
step_count += 1
except (Exception, asyncio.CancelledError) as e:
tb_summary = traceback.extract_tb(e.__traceback__)
origin = ""
Expand Down
Loading
Loading