Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions agents/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,30 @@

logger = logging.getLogger(__name__)

REQUEST_TIMEOUT = float(os.getenv("RITS_REQUEST_TIMEOUT_SECONDS", 60.0))
MAX_RETRIES = int(os.getenv("RITS_MAX_RETRIES", 2))

timeout = httpx.Timeout(
connect=10.0,
read=REQUEST_TIMEOUT,
write=30.0,
pool=10.0,
)

class RITSChatModel(BaseChatModel):
"""LangChain-compatible chat model using httpx for internal RITS inference service."""

# Mapping from endpoint name (short) to payload model name (full)
MODEL_NAME_MAPPING: Dict[str, str] = {
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
# Open Source Models
"qwen3-5-397b-a17b-fp8": "Qwen/Qwen3.5-397B-A17B-FP8",
"mistral-large-3-675b-2512-fp4": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4",
"glm-5-1": "",
"moonshotai-kimi-k2-5":"moonshotai/Kimi-K2.5",
"gpt-oss-120b": "openai/gpt-oss-120b",
"qwen3-5-397b-a17b-fp8": "qwen/qwen3.5-397B-A17B-FP8",
"mistral-large-3-675b-2512-fp4": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4"
# smaller models
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
"qwen2-5-72b-instruct": "Qwen/Qwen2.5-72B-Instruct",
}

model_name: str
Expand Down Expand Up @@ -125,12 +139,35 @@ async def _agenerate(
if self.bound_tools:
payload["tools"] = self.bound_tools

# Add MAX_RETRIES and timeout handling
# async with httpx.AsyncClient(timeout=timeout) as client:
# for attempt in range(MAX_RETRIES + 1):
# try:
# resp = await client.post(
# url,
# json=payload,
# headers=headers,
# )
# resp.raise_for_status()
# break

# except httpx.ReadTimeout:
# if attempt == MAX_RETRIES:
# raise
# await asyncio.sleep(2 ** attempt)

# except httpx.HTTPError:
# if attempt == MAX_RETRIES:
# raise
# await asyncio.sleep(2 ** attempt)
# data = resp.json()

async with httpx.AsyncClient() as client:
resp = await client.post(
url,
headers=headers,
json=payload,
timeout=60.0
timeout=float(os.environ.get("RITS_REQUEST_TIMEOUT_SECONDS", "60"))
)
resp.raise_for_status()
data = resp.json()
Expand Down
15 changes: 12 additions & 3 deletions benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
Results saved to: output/capability_{id}_{timestamp}/<domain>.json
e.g. output/capability_2_feb_18_11_21am/hockey.json
"""
import os
import asyncio
from contextlib import AsyncExitStack
import json
Expand Down Expand Up @@ -145,7 +146,7 @@ def _setup_phoenix(endpoint: str, project_name: str = "enterprise-benchmark") ->
Path(__file__).parent / "benchmark" / "mcp_connection_config.yaml"
)
# Timeout for agent execution (seconds)
AGENT_TIMEOUT_SECONDS = 300
AGENT_TIMEOUT_SECONDS = float(os.environ.get("AGENT_TIMEOUT_SECONDS", "300"))


async def run_benchmark_for_domain(
Expand Down Expand Up @@ -316,7 +317,7 @@ async def run_benchmark_for_domain(
except Exception as e:
import traceback
result.status = "error"
result.error = str(e)
result.error = f"{type(e).__name__} "+str(e)
tlog(f" Status: error | {type(e).__name__}: {str(e)[:200]}")
tlog(f" Traceback: {traceback.format_exc()}")

Expand Down Expand Up @@ -357,6 +358,7 @@ async def run_capability(
top_k_tools: int = 0,
max_iterations: Optional[int] = None,
restart: bool = False,
temperature: float = 0.0,
) -> List[BenchmarkResult]:
"""Run benchmark for a given capability_id, iterating over all domain files."""

Expand Down Expand Up @@ -397,7 +399,7 @@ async def run_capability(
tlog(f"Restart mode: skipping {len(completed)} already-completed domain(s): {sorted(completed)}")
domain_list = [d for d in domain_list if d not in completed]

llm = create_llm(provider=provider, model=model)
llm = create_llm(provider=provider, model=model, temperature=temperature)

# Process each domain, writing output incrementally
all_results: List[BenchmarkResult] = []
Expand Down Expand Up @@ -553,6 +555,12 @@ def main():
default="enterprise-benchmark",
help="Phoenix project name for grouping traces (default: enterprise-benchmark)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="LLM temperature (default: 0.0)"
)

args = parser.parse_args()
capability_ids = args.capability_id # list of ints now
Expand Down Expand Up @@ -588,6 +596,7 @@ def _make_run_task_coro(tid: int):
top_k_tools=args.top_k_tools,
max_iterations=args.max_iterations,
restart=args.restart,
temperature=args.temperature
)

def _make_list_tools_coro(tid: int):
Expand Down
2 changes: 1 addition & 1 deletion evaluator/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import os
import json
import deepcopy
from copy import deepcopy
from prompt import GroundednessPrompt, CorrectnessPrompt
from utils import JudgeInput, JudgeOutput
from langchain_openai import ChatOpenAI
Expand Down
Loading