Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9fae8b2
add simulation envs
beanie00 Dec 2, 2025
bfada88
add simulation agent and trainer
beanie00 Dec 3, 2025
ae3c490
add run scripts
beanie00 Dec 3, 2025
275d2d1
add span grouping and modify the adapter
beanie00 Dec 4, 2025
d33946c
apply step reward and message to get_train_data_batch in demon
beanie00 Dec 6, 2025
8818137
error fixed
beanie00 Dec 6, 2025
d512d8a
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 6, 2025
30109a2
clean simulation/envs
beanie00 Dec 10, 2025
76b56ff
clean task_data
beanie00 Dec 10, 2025
34b5c89
change task data path
beanie00 Dec 10, 2025
d9c069d
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 11, 2025
49bfcab
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 12, 2025
da50a7a
move examples/simulation/* to contrib folder
beanie00 Dec 12, 2025
c9312e3
rollback prev agentlightning triplet and daemon
beanie00 Dec 12, 2025
e119c68
link with PR #407
beanie00 Dec 12, 2025
966e571
clean files
beanie00 Dec 15, 2025
140ecb1
clean files
beanie00 Dec 15, 2025
cee8188
clean prompt builder and script files
beanie00 Dec 15, 2025
59e2c31
move prompt_builder to agent
beanie00 Dec 15, 2025
7340bc6
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Dec 18, 2025
a7bde15
update readme and clean files
beanie00 Dec 19, 2025
d647fc5
update readme
beanie00 Dec 20, 2025
70df42c
refactor triplet group
beanie00 Dec 22, 2025
cbbc3e6
fix intrinsic list length mismatch
beanie00 Dec 23, 2025
22e67ca
clean code
beanie00 Dec 23, 2025
d98c061
fix formatting via pre-commit
beanie00 Dec 29, 2025
0eaf16b
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Jan 8, 2026
181940a
add missing copyright headers
beanie00 Jan 12, 2026
5e08b3b
Merge branch 'microsoft:main' into feature/agl-simulation
beanie00 Feb 2, 2026
afd2cf5
fix missing link
beanie00 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions contrib/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
# Put contrib-related gitignore files here.

# recipes/simulation related
recipes/simulation/agl_envs/
recipes/simulation/wandb/
134 changes: 134 additions & 0 deletions contrib/agentlightning/contrib/adapter/triplet_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) Microsoft. All rights reserved.

from __future__ import annotations

from typing import Dict, List, Optional

from agentlightning.adapter.triplet import TracerTraceToTriplet
from agentlightning.types import Span, Triplet


class TracerTraceToTripletGroup(TracerTraceToTriplet):
"""Convert tracer-emitted spans into triplet trajectories.

Attributes:
repair_hierarchy: When `True`, repair the span tree using
[`TraceTree.repair_hierarchy()`][agentlightning.adapter.triplet.TraceTree.repair_hierarchy]
before matching calls and rewards.
llm_call_match: Regular expression pattern that selects LLM call span names.
agent_match: Optional regular expression pattern for agent span names. When omitted, spans
from any agent are considered.
exclude_llm_call_in_reward: When `True`, ignore matches under reward spans while searching
for rewards.
reward_match: Strategy used to associate rewards with LLM calls.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _extract_span_groups(self, spans):
def resolve_step_count(span, next_span, spans, index):
"""
Determine step_count for a given span using next_span or fallback search.
"""
# CASE A: If next_span exists and parent_id matches
if next_span and span.parent_id == next_span.span_id:
return next_span.attributes.get("step_count")

# CASE B: Fallback — search forward for agentlightning.operation
for s in spans[index + 1 :]:
if s.name == "agentlightning.operation" and span.parent_id == s.span_id:
return s.attributes.get("step_count")

return None

def extract_step_count_from_links(span):
"""
Extract step_count from agentlightning.link.* attributes.
"""
key = span.attributes.get("agentlightning.link.0.key_match")
if key == "step_count":
return span.attributes.get("agentlightning.link.0.value_match")
return None

span_groups = {}

for i, span in enumerate(spans):
next_span = spans[i + 1] if i + 1 < len(spans) else None
step_count = None

if span.name == "openai.chat.completion":
step_count = resolve_step_count(span, next_span, spans, i)
if step_count is None:
continue

step_count = str(step_count)
span_groups.setdefault(step_count, {})
span_groups[step_count]["call_span"] = span

elif span.name == "agentlightning.object":
step_count = extract_step_count_from_links(span)
if step_count is None:
continue

step_count = str(step_count)
span_groups.setdefault(step_count, {})
span_groups[step_count]["object_span"] = span

elif span.name == "agentlightning.annotation":
step_count = extract_step_count_from_links(span)
if step_count is None:
continue

step_count = str(step_count)
span_groups.setdefault(step_count, {})
span_groups[step_count]["annotation_span"] = span

return span_groups

def adapt_group(self, source: Sequence[Span], /) -> List[Triplet]:
span_groups = self._extract_span_groups(source)

def token_ids(span: Optional[Span], key: str) -> list:
return span.attributes.get(key, []) if span else []

def reward0(span: Optional[Span]) -> float:
if not span:
return 0.0
return float(span.attributes.get("agentlightning.reward.0.value", 0.0))

def reward1(span: Optional[Span]) -> Optional[float]:
if not span:
return 0.0
return float(span.attributes.get("agentlightning.reward.1.value", 0.0))

def message(span: Optional[Span]) -> Optional[str]:
if not span:
return ""
return span.attributes.get("agentlightning.object.literal", "")

triplets: List[Triplet] = []

for group in span_groups.values():
call_span = group.get("call_span")
if not token_ids(call_span, "prompt_token_ids") and not token_ids(call_span, "response_token_ids"):
continue

object_span = group.get("object_span")
annotation_span = group.get("annotation_span")
request_id = group.get("request_id")

triplets.append(
Triplet(
prompt={"token_ids": token_ids(call_span, "prompt_token_ids")},
response={"token_ids": token_ids(call_span, "response_token_ids")},
reward=reward0(annotation_span),
metadata={
"response_id": request_id,
"intrinsic_reward": reward1(annotation_span),
"message": message(object_span),
},
)
)

return triplets
157 changes: 157 additions & 0 deletions contrib/agentlightning/contrib/agent/simulation_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) Microsoft. All rights reserved.

from __future__ import annotations

import logging
import os
from typing import Any, Dict

import numpy as np
from add_instruction import add_chat_instruction, add_single_instruction
from agl_envs.simulation import make_env_manager
from autogen_agentchat.agents import AssistantAgent
from autogen_core.models import ModelFamily
from autogen_ext.models.openai import OpenAIChatCompletionClient

from agentlightning import LLM, LitAgent, NamedResources, Rollout, configure_logger, emit_object, emit_reward, operation
from agentlightning.utils.otel import make_link_attributes
from contrib.recipes.simulation.prompt_builder import HistoryPromptBuilder

logger = configure_logger(name=__name__, level=logging.ERROR)


class SimulationAgent(LitAgent):
def __init__(self, config, trained_agents: str | None = None) -> None:
super().__init__(trained_agents=trained_agents)
self.config = config
self.env = None

def _build_agent(self, llm: LLM, temperature: float):
model_client = OpenAIChatCompletionClient(
model=llm.model,
base_url=llm.endpoint,
api_key=os.environ.get("OPENAI_API_KEY", "token-abc123"),
model_info={
"vision": False,
"function_calling": True,
"json_output": False,
"family": ModelFamily.UNKNOWN,
"structured_output": False,
},
temperature=temperature,
)

return AssistantAgent(
name="simulation",
model_client=model_client,
)

def _get_instructed_prompt(self, prompt, sep="\n\n"):
"""Return instructed observation based on prompt_type and captioner type."""
prompt_type = self.config.captioner.prompt_type
cap_type = self.config.captioner.type

if prompt_type == "chat":
if cap_type == "cot":
return add_chat_instruction(prompt, "cot", sep, self.config.env_name)
elif cap_type == "naive":
return add_chat_instruction(prompt, "naive", sep)

elif prompt_type == "single":
if cap_type == "cot":
return add_single_instruction(prompt, "cot", sep, self.config.env_name)
elif cap_type == "naive":
return add_single_instruction(prompt, "naive", sep, self.config.env_name)

raise ValueError(f"Unsupported prompt_type={prompt_type}, type={cap_type}")

async def rollout_async(
self,
task: Dict[str, Any],
resources: NamedResources,
rollout: Rollout,
) -> float | None:
rollout_id = rollout.rollout_id
logger.info(f"[Rollout {rollout_id}] Task: {task}")

format_penalty = float(self.config["format_penalty"])
reward_scale = float(self.config["reawrd_scale"])

# Setup agent
llm: LLM = resources.get("main_llm")
print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4)
if "max_tokens" in self.config and self.config["max_tokens"] > -1:
self.agent._model_client.max_tokens = self.config["max_tokens"]

try:
# Setup environment
prompt_builder = HistoryPromptBuilder(
max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type
)

self.env = make_env_manager(self.config.env_name, task, self.config)
env_obs, infos, available_actions_hint = self.env.reset()

prompt_builder.init(self.env)
prompt_builder.update_observation(env_obs)
prompt_builder.update_admissible_actions(available_actions_hint)

prompt = prompt_builder.get_prompt()

episode_reward, done = 0.0, False

step_count = 0
while not done:
try:
instructed_prompt = self._get_instructed_prompt(prompt)

# Main agent step
with operation(step_count=step_count):
result = await self.agent._model_client.create(instructed_prompt)
output = result.content
logger.info(f"[LLM output]: {output}")

except Exception as e:
logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True)
break

if self.config.log_env_obs:
emit_object(env_obs, attributes=make_link_attributes({"step_count": str(step_count)}))

env_obs, executed_action, is_valid, step_reward, terminated, truncated, info, available_actions_hint = (
self.env.step(output, use_reasoning=self.config.captioner.type == "cot", use_success_rate=self.config.use_success_rate)
)

prompt_builder.update_step_count()
prompt_builder.update_action(executed_action)
prompt_builder.update_observation(env_obs)
prompt_builder.update_admissible_actions(available_actions_hint)

prompt = prompt_builder.get_prompt()

if rollout.mode == "train":
step_reward *= reward_scale

if format_penalty != 0.0:
emit_reward(
{
"extrinsic_reward": step_reward,
"intrinsic_reward": 0.0 if is_valid else -1.0 * format_penalty,
},
primary_key="extrinsic_reward",
attributes=make_link_attributes({"step_count": str(step_count)}),
)
else:
emit_reward(step_reward, attributes=make_link_attributes({"step_count": str(step_count)}))

episode_reward += float(step_reward)
done = np.logical_or(terminated, truncated)

step_count += 1

return episode_reward

finally:
if self.env is not None:
self.env.close()
Loading