diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md new file mode 100644 index 00000000..f3d5e082 --- /dev/null +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -0,0 +1,315 @@ +# Multi-Turn Conversation Benchmarking - Quick Start Guide + +## Quick Start in 5 Minutes + +### 1. Prepare Your Dataset + +Create a JSONL file with your conversations. All rows for a given `conversation_id` must appear +**consecutively** in the file (no interleaving with other conversations): + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello!", "system": "You are a helpful assistant"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi! How can I help?"} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "What's 2+2?"} +{"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "2+2 equals 4."} +``` + +**Rules**: + +- Alternate between "user" and "assistant" roles +- Start with "user" role +- Sequential turn numbers (1, 2, 3, ...) +- Same `conversation_id` for all turns in a conversation +- All rows for the same `conversation_id` must be grouped together + +### 2. Create Configuration File + +Save as `multi_turn_config.yaml`: + +```yaml +name: "my-multi-turn-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: my_conversations + type: performance + path: path/to/your/conversations.jsonl + multi_turn: # ← Presence of this block enables multi-turn mode + mode: independent # ← Per-conv pipelines; no cross-conv turn barrier + turn_timeout_s: 300 # ← Max wait for prev turn + +settings: + load_pattern: + type: multi_turn # ← Use multi-turn scheduler + target_concurrency: 32 # ← Required: max simultaneous conversations + + client: + workers: 4 + +endpoint_config: + endpoints: + - "http://your-endpoint:8000" + api_type: openai + +report_dir: logs/my_multi_turn_benchmark +``` + +Results are written to `report_dir` (here: `logs/my_multi_turn_benchmark/`). + +### 3. Run Benchmark + +```bash +inference-endpoint benchmark from-config --config multi_turn_config.yaml +``` + +That's it! Your benchmark will now: + +- ✅ Enforce turn ordering (turn N+1 waits for turn N) +- ✅ Include conversation history in each request +- ✅ Track per-turn and per-conversation metrics +- ✅ Log all turns with conversation metadata + +--- + +## Understanding Results + +After the benchmark completes, check the directory configured via `report_dir`: + +### Events Log + +The `events.jsonl` file contains one JSON record per line: + +- Standard fields: `sample_uuid`, `event_type`, `timestamp_ns` +- **New fields**: `conversation_id`, `turn_number` + +Query examples: + +```bash +# All events for a specific conversation +grep '"conversation_id": "c1"' logs/my_multi_turn_benchmark/events.jsonl + +# With jq for structured output +jq 'select(.conversation_id == "c1") | {conversation_id, turn_number, event_type, timestamp_ns}' \ + logs/my_multi_turn_benchmark/events.jsonl +``` + +### Metrics + +Currently available: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each turn +- **Conversation tracking**: All events tagged with conversation_id + +_Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a future update._ + +--- + +## Conversation Modes Explained + +### Independent Mode (Default) + +```yaml +mode: independent +``` + +**Behavior**: + +- Up to `target_concurrency` conversations are active simultaneously +- Turns within each conversation are strictly sequenced (turn N+1 waits for turn N) +- Conversations run independently of each other — a short conversation can finish while a long one is still on turn 2 + +**Use for**: Realistic production load simulation. For single-conversation debugging, set `target_concurrency: 1`. + +**Example timeline** (target_concurrency: 3, 4 conversations total): + +``` +t=0: conv1-turn1, conv2-turn1, conv3-turn1 ← 3 conversations start +t=0.5: conv1-turn2 (after conv1-turn1 completes) +t=0.7: conv2 finishes → worker picks up conv4-turn1 +t=0.8: conv1-turn3 (after conv1-turn2 completes) +... +``` + +--- + +## Concurrency Control + +`target_concurrency` is **required** for the `multi_turn` load pattern. It controls how many +conversations are active simultaneously. Each active conversation has exactly one in-flight turn +at a time — a worker issues turn N, waits for the response, then issues turn N+1. A new +conversation starts only after a worker finishes all turns of its current one. + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← 32 conversations active simultaneously +``` + +--- + +## Troubleshooting + +### Validate Your Dataset Before Running + +Use the bundled validation script to check your JSONL file for schema errors before benchmarking: + +```bash +python scripts/validate_jsonl_schema.py path/to/your/conversations.jsonl +``` + +This catches missing required fields, invalid role sequences, non-consecutive turn numbers, and +interleaved conversations — all errors that would otherwise surface at benchmark startup. + +### "Conversation has invalid role sequence" + +**Problem**: Your dataset doesn't follow a valid role sequence. + +**Fix**: Check your JSONL. Valid sequences: + +- Plain chat: `user → assistant → user → assistant → ...` +- Agentic (tool-use): `user → assistant → tool → assistant → tool → ... → user` + +Conversations may also end with a `tool` row (the model's response to the final tool call is the benchmark target). + +### "Rows for conversation X are not consecutive" + +**Problem**: Rows for the same `conversation_id` are interleaved with rows from other conversations. + +**Fix**: Sort your JSONL so all rows for each conversation appear together. + +### "Turn timed out waiting for prev turn" + +**Problem**: Previous turn took longer than `turn_timeout_s`. + +**Fixes**: + +1. Increase `turn_timeout_s` in config +2. Check if your endpoint is slow or unresponsive +3. Look for errors in the endpoint logs + +### Dataset not loading + +**Problem**: MultiTurnDataset not recognized. + +**Fix**: Ensure `multi_turn:` block is present in the dataset config. The file format +is auto-detected from the `.jsonl` extension — no `format` field is needed: + +```yaml +datasets: + - path: your_file.jsonl + multi_turn: + mode: independent +``` + +--- + +## Example Datasets + +### Simple 2-Turn Conversation + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +``` + +### With System Prompt + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Who won?", "system": "You are a sports expert"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "The Lakers won."} +``` + +### Multiple Conversations + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +{"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hey"} +{"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "Hi there!"} +``` + +### With Model Override + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Summarize this", "model": "gpt-4"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Here's the summary..."} +``` + +--- + +## Testing Your Setup + +### 1. Use the Example Dataset + +```bash +cd examples/09_MultiTurn +inference-endpoint benchmark from-config --config multi_turn_benchmark.yaml +``` + +### 2. Check the Logs + +```bash +cat logs/multi_turn_test/benchmark.log +# Look for: "Turn X of conversation_id issued" +``` + +### 3. Verify Event Recording + +```bash +# List all unique conversation IDs in the events log +jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u +# Should show your conversation IDs +``` + +--- + +## Tips & Best Practices + +### Dataset Design + +- **Keep conversations realistic**: 2-10 turns typical +- **Test edge cases**: 1-turn conversations, very long conversations +- **Include system prompts**: Helps model understand context + +### Performance + +- **Workers**: `client.workers` controls HTTP worker processes, independent of `target_concurrency`. The default (`-1`) auto-tunes based on NUMA topology. +- **Timeout**: Set `turn_timeout_s` = 2x your longest expected turn latency +- **Memory**: ~1KB per turn, plan accordingly for large datasets + +### Debugging + +- **Start small**: Test with 1-2 conversations first +- **Single conversation**: Use `mode: independent` with `target_concurrency: 1` +- **Check events.jsonl**: Verify turn ordering with `jq` + +--- + +## More Information + +- **Full Documentation**: See `examples/09_MultiTurn/README.md` +- **Architecture**: See `AGENTS.md` (Multi-Turn section) + +--- + +## Checklist + +Before running your first multi-turn benchmark: + +- [ ] Dataset follows format (user/assistant alternation, or agentic user→assistant→tool sequences) +- [ ] All rows for each conversation_id are grouped together +- [ ] Config has `multi_turn:` block in the dataset section +- [ ] Config has `load_pattern.type: multi_turn` +- [ ] Endpoint is running and reachable +- [ ] File uses `.jsonl` extension (format is auto-detected) +- [ ] Conversation IDs are unique per conversation +- [ ] Turn numbers are sequential (1, 2, 3, ...) + +Happy benchmarking! diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md new file mode 100644 index 00000000..e7f9505a --- /dev/null +++ b/examples/09_MultiTurn/README.md @@ -0,0 +1,312 @@ +# Multi-Turn Conversation Benchmarking Examples + +This directory contains examples for benchmarking conversational AI workloads with multi-turn conversation support. + +## Overview + +Multi-turn conversation benchmarking enables testing realistic conversational AI scenarios where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing to simulate real-world multi-turn interactions. + +## Dataset Format + +Multi-turn datasets use JSONL format with the following structure: + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} +``` + +### Required Fields + +- `conversation_id`: Unique identifier for each conversation +- `turn`: Turn number within conversation (1-indexed) +- `role`: Speaker role ("user" or "assistant") +- `content`: Message content + +### Optional Fields + +- `system`: System prompt (typically only on first user turn) +- `model`: Model name override for this turn +- `max_new_tokens`: Maximum tokens to generate for this turn + +### Validation Rules + +1. All rows for a given `conversation_id` must appear **consecutively** in the file (no interleaving + with rows from other conversations). Turns within a conversation must be in order. + The flat-row format is intentional: it enables row-by-row streaming without loading entire + conversations into memory first. +2. Conversations must follow a valid role sequence: + - Plain chat: `user → assistant → user → ...` + - Agentic: `user → assistant (with tool_calls) → tool → [tool | assistant (with tool_calls)]* → assistant → user → ...` +3. First turn must be "user" role +4. Turn numbers must be sequential (1, 2, 3, ...) +5. Each conversation must have at least one turn + +## Agentic (Tool-Sequence) Datasets + +For agentic workloads where the model dispatches tools, the dataset must include tool-call +metadata. The source format for these datasets is a **snapshot JSONL** — each line contains the +full conversation history at a particular checkpoint. The benchmarker requires **flat-row JSONL** +(one row per message), so a conversion step is needed first. + +### Source snapshot format + +Each line in the source file represents one snapshot of a conversation: + +```json +{ + "conversation_id": "sim_001", + "conversation_idx": 5, + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + {"role": "assistant", "tool_calls": [{"id": "...", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"ls\"}"}}]}, + {"role": "tool", "tool_call_id": "...", "content": "file1.txt\nfile2.txt"}, + {"role": "assistant", "content": "Done."} + ], + "tools": [...], + "metadata": {} +} +``` + +Multiple snapshots may exist per `conversation_id` (one per `conversation_idx`); only the +highest-indexed snapshot per conversation is used. + +### Converting to flat-row format + +The following commands convert each source snapshot file to the flat-row format required by the benchmarker. +Run from the repo root: + +```bash +# First argument: input snapshot JSONL; second argument: output flat-row JSONL +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_coding_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ + --verify + +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_workflow_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ + --verify +``` + +The `datasets/` directory under `examples/09_MultiTurn/` is a placeholder; run the conversion +commands above to populate it before benchmarking. + +The `--verify` flag cross-checks every client turn's message history against the source snapshot +and exits with code 1 if any mismatch is found. The script also: + +- Collapses consecutive `user` messages into one (keeps turn sequencing clean) +- Merges consecutive `tool` messages for the same assistant dispatch into a single row with a + `tool_results` list (so all parallel results are sent together in one API call) + +### Flat-row format after conversion + +The extra fields supported beyond plain user/assistant: + +| Row role | Extra fields | +| -------------------------------- | ------------------------------------------------------------------ | +| `assistant` with tool calls | `tool_calls: [{id, type, function: {name, arguments}}]` | +| `tool` single result | `tool_call_id: `, `content: ` | +| `tool` parallel results (merged) | `tool_results: [{tool_call_id, content}, ...]` | +| `user` or `tool` turns | `tools: [...]` (OpenAI tool definitions forwarded to the endpoint) | + +Example rows from a converted agentic dataset: + +```jsonl +{"conversation_id": "sim_001", "turn": 1, "role": "user", "content": "Fix the bug in foo.py", "system": "You are a coding agent.", "tools": [...]} +{"conversation_id": "sim_001", "turn": 2, "role": "assistant", "tool_calls": [{"id": "functions.bash:0", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"cat foo.py\"}"}}]} +{"conversation_id": "sim_001", "turn": 3, "role": "tool", "tool_call_id": "functions.bash:0", "content": "def foo():\n return 1/0", "tools": [...]} +{"conversation_id": "sim_001", "turn": 4, "role": "assistant", "content": "The bug is a ZeroDivisionError. Here is the fix: ..."} +``` + +### Running agentic benchmarks + +After converting the datasets, update the `path` field in the config files and run: + +```bash +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_coding_benchmark.yaml + +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_workflow_benchmark.yaml +``` + +--- + +## Configuration + +### Basic Configuration + +```yaml +datasets: + - name: customer_support + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + multi_turn: + mode: independent + turn_timeout_s: 300.0 + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Required for multi_turn load pattern +``` + +### Concurrency Control + +The `target_concurrency` field is **required** for the `multi_turn` load pattern and controls the maximum number of in-flight requests across all conversations: + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Limit to 32 concurrent requests +``` + +**Behavior**: + +- With `target_concurrency`: Limits total in-flight requests across all conversations +- Combines with turn sequencing: Turn N+1 still waits for turn N, AND waits for available slot + +**Use cases**: + +- **Prevent endpoint overload**: Control request rate to busy endpoints +- **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system +- **Resource management**: Stay within port limits, memory constraints + +**Example**: 100 conversations with `target_concurrency: 32` + +``` +t=0: Issue first 32 turn-1s (concurrency limit reached) +t=0.5: Turn-1 completes → issue next turn-1 (slot filled) +t=1.0: Turn-1 completes → issue turn-2 of completed conv (slot filled) +... Maintains ~32 in-flight across all conversations +``` + +### Conversation Modes + +The default mode is `independent`. + +#### Independent Mode (Default) + +Issues turns for each conversation independently — no cross-conversation turn barrier. + +```yaml +multi_turn: + mode: independent + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 +``` + +**Use case**: Realistic production load where short conversations finish while long ones are +still running. Turn 1 of one conversation and turn 100 of another can be in-flight simultaneously. + +For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. + +### Turn Timeout + +Configure maximum wait time for previous turn completion: + +```yaml +multi_turn: + turn_timeout_s: 300.0 # 5 minutes +``` + +If a turn times out waiting for the previous turn, it will be skipped and logged as a warning. + +## Running Multi-Turn Benchmarks + +### Using Configuration File + +```bash +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/multi_turn_benchmark.yaml +``` + +### Viewing Results + +Multi-turn benchmarks produce both per-turn and per-conversation metrics: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn +- **Per-conversation metrics**: Total conversation latency, conversations per second + +Results are stored in the configured `report_dir` with conversation metadata included in the events log (`events.jsonl`). + +## Example Datasets + +### customer_support_conversations.jsonl + +Simple customer support conversations demonstrating basic multi-turn interactions: + +- 3 conversations +- 2-4 turns per conversation +- Customer support agent system prompt + +## Architecture Notes + +### Key Components + +- **ConversationManager**: Tracks conversation state and message history +- **MultiTurnStrategy**: Enforces turn sequencing within conversations +- **MultiTurnDataset**: Validates and structures multi-turn data + +### Turn Sequencing + +The system ensures that: + +1. Turn N+1 cannot be issued until turn N completes +2. Message history is included in subsequent requests +3. Concurrent conversations are supported (in independent mode) + +### Memory Considerations + +Each conversation maintains message history in memory. For large-scale benchmarks with long conversations: + +- Memory usage: ~1KB per turn (approximate) +- 1000 conversations × 10 turns = ~10MB + +## Troubleshooting + +### "Conversation has invalid role sequence" + +**Cause**: Conversation doesn't follow a valid role sequence. + +**Fix**: For plain chat, ensure the dataset alternates between user and assistant: + +``` +user -> assistant -> user -> assistant -> ... +``` + +For agentic datasets, use the conversion script (`scripts/convert_agentic_snapshot.py`) to +produce a properly sequenced flat-row file. The valid agentic sequence is: + +``` +user -> assistant (tool_calls) -> tool -> [tool | assistant (tool_calls)]* -> assistant -> user -> ... +``` + +### "Turn timed out waiting for prev turn" + +**Cause**: Previous turn took longer than `turn_timeout_s` to complete. + +**Fixes**: + +- Increase `turn_timeout_s` in configuration +- Check endpoint performance +- Verify endpoint is responding + +### Single-turn benchmarks unaffected + +Multi-turn logic is only activated when a `multi_turn:` block is present in the dataset configuration. Existing single-turn benchmarks continue to work unchanged with zero performance overhead. + +## Future Enhancements + +Planned features: + +- [ ] Poisson conversation arrival mode implementation +- [ ] Per-conversation metrics in reporting +- [ ] Conversation-level latency percentiles +- [ ] Dynamic conversation branching diff --git a/examples/09_MultiTurn/agentic_coding_benchmark.yaml b/examples/09_MultiTurn/agentic_coding_benchmark.yaml new file mode 100644 index 00000000..f3abc3cf --- /dev/null +++ b/examples/09_MultiTurn/agentic_coding_benchmark.yaml @@ -0,0 +1,33 @@ +name: "agentic-coding-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 1024 + +datasets: + - name: agentic_coding + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. + path: examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl + multi_turn: + mode: independent + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 4096 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_coding diff --git a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml new file mode 100644 index 00000000..239e9374 --- /dev/null +++ b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml @@ -0,0 +1,33 @@ +name: "agentic-workflow-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 512 + +datasets: + - name: agentic_workflow + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. + path: examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl + multi_turn: + mode: independent + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 96 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_workflow diff --git a/examples/09_MultiTurn/customer_support_conversations.jsonl b/examples/09_MultiTurn/customer_support_conversations.jsonl new file mode 100644 index 00000000..ac19e907 --- /dev/null +++ b/examples/09_MultiTurn/customer_support_conversations.jsonl @@ -0,0 +1,10 @@ +{"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "I need help resetting my password", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_001", "turn": 2, "role": "assistant", "content": "I'd be happy to help you reset your password. Can you provide your email address?"} +{"conversation_id": "conv_001", "turn": 3, "role": "user", "content": "It's user@example.com"} +{"conversation_id": "conv_001", "turn": 4, "role": "assistant", "content": "Thank you. I've sent a password reset link to user@example.com. Please check your inbox and follow the instructions."} +{"conversation_id": "conv_002", "turn": 1, "role": "user", "content": "What are your business hours?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_002", "turn": 2, "role": "assistant", "content": "We're open Monday-Friday, 9 AM to 5 PM EST. How can I assist you today?"} +{"conversation_id": "conv_002", "turn": 3, "role": "user", "content": "Do you offer weekend support?"} +{"conversation_id": "conv_002", "turn": 4, "role": "assistant", "content": "For urgent issues, we offer limited support on weekends from 10 AM to 2 PM EST. For non-urgent matters, please contact us during our regular business hours."} +{"conversation_id": "conv_003", "turn": 1, "role": "user", "content": "Can I cancel my subscription?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_003", "turn": 2, "role": "assistant", "content": "Yes, you can cancel your subscription at any time. Would you like me to guide you through the cancellation process?"} diff --git a/examples/09_MultiTurn/datasets/.gitkeep b/examples/09_MultiTurn/datasets/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml new file mode 100644 index 00000000..36066aa3 --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -0,0 +1,36 @@ +name: "multi-turn-customer-support" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + samples: 10 + multi_turn: + mode: independent + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + target_concurrency: 32 + + client: + warmup_connections: 0 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_test diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml new file mode 100644 index 00000000..e1d5f37c --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -0,0 +1,36 @@ +name: "multi-turn-with-concurrency-control" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + samples: 10 + multi_turn: + mode: independent # All conv turn-1 start together + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + target_concurrency: 32 # ← NEW: Limit to 32 concurrent requests + + client: + warmup_connections: 0 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_with_concurrency diff --git a/scripts/convert_agentic_snapshot.py b/scripts/convert_agentic_snapshot.py new file mode 100644 index 00000000..fe217b9b --- /dev/null +++ b/scripts/convert_agentic_snapshot.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert agentic snapshot datasets to the flat-row JSONL format expected by MultiTurnDataset. + +Each snapshot record contains the full conversation history up to a checkpoint: + {"conversation_id": "sim_000001", "conversation_idx": 0, + "messages": [{"role": "system", ...}, ...], "tools": [...], "metadata": {}} + +For each conversation only the final snapshot (highest conversation_idx) is used. +Its messages array is expanded into individual flat rows, one per message. + +Usage: + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl --verify +""" + +import argparse +import json +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Helpers shared between convert() and verify() +# --------------------------------------------------------------------------- + + +def _load_final_snapshots(input_path: Path) -> dict[str, dict]: + """Return {conv_id: record} keeping only the highest conversation_idx per conv.""" + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + return final + + +def _apply_collapses(non_system: list[dict]) -> list[tuple[dict, int]]: + """Apply user-collapse and tool-merge passes, tracking the last source index each + output row covers. + + Returns list of (output_msg, last_source_idx) pairs where last_source_idx is the + 0-based index within non_system of the final source message folded into this row. + """ + # Pass 1: collapse consecutive user messages + collapsed: list[tuple[dict, int]] = [] # (msg, last_source_idx) + for src_idx, msg in enumerate(non_system): + if collapsed and collapsed[-1][0]["role"] == "user" and msg["role"] == "user": + prev_msg, _ = collapsed[-1] + prev_text = prev_msg.get("content") or "" + cur_text = msg.get("content") or "" + collapsed[-1] = ( + {**prev_msg, "content": f"{prev_text}\n\n{cur_text}".strip()}, + src_idx, + ) + else: + collapsed.append((msg, src_idx)) + + # Pass 2: merge consecutive tool messages + # Input messages are raw snapshot wire-format (tool_call_id + content on each msg). + # On merge, upgrade the first message to a tool_results list so the output always + # uses the tool_results array form regardless of how many results there are. + merged: list[tuple[dict, int]] = [] + for msg, last_src in collapsed: + if merged and merged[-1][0]["role"] == "tool" and msg["role"] == "tool": + prev_msg, _ = merged[-1] + tool_results = prev_msg.get("tool_results") + if tool_results is None: + tool_results = [ + { + "tool_call_id": prev_msg.get("tool_call_id"), + "content": prev_msg.get("content"), + } + ] + prev_msg = {"role": "tool", "tool_results": tool_results} + tool_results.append( + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ) + merged[-1] = (prev_msg, last_src) + else: + merged.append((msg, last_src)) + + return merged + + +def _normalize_msg(msg: dict) -> dict: + """Drop None values for comparison.""" + return {k: v for k, v in msg.items() if v is not None} + + +def _expand_row_to_wire_msgs(row: dict) -> list[dict]: + """Expand a single flat row into one or more OpenAI wire-format messages. + + Handles two tool row forms: + - Output flat rows: tool_results array (always used after conversion) + - Raw snapshot messages passed through verify(): tool_call_id + content directly + """ + if isinstance(row.get("tool_results"), list): + return [ + { + "role": "tool", + "tool_call_id": r.get("tool_call_id"), + "content": r.get("content"), + } + for r in row["tool_results"] + ] + msg: dict = {"role": row["role"], "content": row.get("content")} + if row.get("tool_calls"): + msg["tool_calls"] = row["tool_calls"] + if row.get("tool_call_id"): + msg["tool_call_id"] = row["tool_call_id"] + return [msg] + + +def verify(input_path: Path, output_path: Path) -> bool: + """Cross-check every client-turn's pre_built_messages against the source snapshot. + + For each output client turn, reconstruct the pre_built_messages that + MultiTurnDataset would build from the flat rows and compare it against the + ground-truth messages built directly from the source snapshot up to the same + point (accounting for user-collapse and tool-merge). + + Returns: + True if all checks pass, False if any mismatch found. + """ + final = _load_final_snapshots(input_path) + + # Load converted rows grouped by conversation_id + conv_rows: dict[str, list[dict]] = {} + with output_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + row = json.loads(line) + cid = row["conversation_id"] + conv_rows.setdefault(cid, []).append(row) + for cid in conv_rows: + conv_rows[cid].sort(key=lambda r: r["turn"]) + + errors: list[str] = [] + total_checked = 0 + + for conv_id in sorted(final): + record = final[conv_id] + system_content: str | None = None + non_system: list[dict] = [] + for msg in record["messages"]: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Re-apply the same collapses the converter applies, tracking source coverage + processed = _apply_collapses(non_system) # [(output_msg, last_source_idx), ...] + flat_rows = conv_rows.get(conv_id, []) + + if len(processed) != len(flat_rows): + errors.append( + f"{conv_id}: expected {len(processed)} flat rows after collapses, " + f"got {len(flat_rows)} in output" + ) + continue + + client_turn_pairs = [ + (out_pos, flat_row) + for out_pos, (flat_row, _) in enumerate( + zip(flat_rows, processed, strict=True) + ) + if flat_row["role"] in ("user", "tool") + ] + + for ct_idx, (out_pos, flat_row) in enumerate(client_turn_pairs): + # Ground truth: apply the same collapses the converter applies, then + # build the message list from the processed (collapsed/merged) rows up to + # and including this client turn. This correctly reflects what the + # converter produces — consecutive user/tool merges mean history is + # shorter than the raw source but content-equivalent. + expected: list[dict] = [] + if system_content: + expected.append({"role": "system", "content": system_content}) + for proc_msg, _ in processed[: out_pos + 1]: + expected.extend(_expand_row_to_wire_msgs(proc_msg)) + + # Reconstructed output: system + expand all flat rows up to this turn + got: list[dict] = [] + if system_content: + got.append({"role": "system", "content": system_content}) + for row in flat_rows[: out_pos + 1]: + got.extend(_expand_row_to_wire_msgs(row)) + + exp_norm = [_normalize_msg(m) for m in expected] + got_norm = [_normalize_msg(m) for m in got] + + if exp_norm != got_norm: + errors.append( + f"{conv_id} client-turn {ct_idx + 1} (flat turn {flat_row['turn']}):\n" + f" expected {len(exp_norm)} msgs, got {len(got_norm)}\n" + f" EXPECTED: {json.dumps(exp_norm, ensure_ascii=False)[:400]}\n" + f" GOT: {json.dumps(got_norm, ensure_ascii=False)[:400]}" + ) + total_checked += 1 + + if errors: + print( + f"FAIL: {len(errors)} mismatches out of {total_checked} client turns checked.", + file=sys.stderr, + ) + for err in errors[:20]: + print(err, file=sys.stderr) + if len(errors) > 20: + print(f" ... and {len(errors) - 20} more", file=sys.stderr) + return False + + print( + f"OK: all {total_checked} client turns verified against source.", + file=sys.stderr, + ) + return True + + +def convert(input_path: Path, output_path: Path) -> None: + # Group records by conversation_id, keep only the final snapshot per conversation. + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + + print(f"Found {len(final)} conversations in {input_path.name}", file=sys.stderr) + + rows_written = 0 + with output_path.open("w") as out: + for conv_id, record in sorted(final.items()): + messages = record["messages"] + tools = record.get("tools") or [] + + # Extract system message (always first if present). + system_content: str | None = None + non_system: list[dict] = [] + for msg in messages: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Apply the same user-collapse and tool-merge passes used by verify(). + # _apply_collapses returns [(msg, last_source_idx), ...]; strip the indices. + non_system = [msg for msg, _ in _apply_collapses(non_system)] + + first_user_seen = False + for position, msg in enumerate(non_system): + role = msg["role"] + turn = position + 1 # 1-indexed + + row: dict = {"conversation_id": conv_id, "turn": turn, "role": role} + + # System prompt on the first user row only. + if role == "user" and not first_user_seen: + if system_content is not None: + row["system"] = system_content + first_user_seen = True + + # tool_calls for assistant messages that dispatch tools. + if msg.get("tool_calls"): + row["tool_calls"] = msg["tool_calls"] + + if role == "tool": + # All tool rows use tool_results array (single results have one entry). + if msg.get("tool_results"): + row["tool_results"] = msg["tool_results"] + else: + row["tool_results"] = [ + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ] + else: + # content field (may be None for tool-dispatching assistant messages) + row["content"] = msg.get("content") + + # Attach tool definitions to client-turn rows only (user + tool). + # This avoids duplicating the large tools array on every assistant row + # while still making them available via load_sample(). + if role in ("user", "tool") and tools: + row["tools"] = tools + + out.write(json.dumps(row, ensure_ascii=False) + "\n") + rows_written += 1 + + print(f"Wrote {rows_written} rows to {output_path}", file=sys.stderr) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert agentic snapshot JSONL to MultiTurnDataset flat-row JSONL." + ) + parser.add_argument("input", type=Path, help="Input snapshot JSONL file") + parser.add_argument("output", type=Path, help="Output flat-row JSONL file") + parser.add_argument( + "--verify", + action="store_true", + help=( + "After converting, cross-check every client-turn's pre_built_messages " + "against the source snapshot. Exits with code 1 if any mismatch found." + ), + ) + args = parser.parse_args() + + if not args.input.exists(): + print(f"Error: input file not found: {args.input}", file=sys.stderr) + sys.exit(1) + + convert(args.input, args.output) + + if args.verify: + ok = verify(args.input, args.output) + if not ok: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/multi_turn_dataset_schema.json b/scripts/multi_turn_dataset_schema.json new file mode 100644 index 00000000..b1b7ca13 --- /dev/null +++ b/scripts/multi_turn_dataset_schema.json @@ -0,0 +1,557 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Multi-Turn Conversation Dataset Schema", + "description": "JSON schema describing the structure and requirements for multi-turn conversation datasets in the MLPerf Inference Endpoint Benchmarking System", + "version": "1.0.0", + + "definitions": { + "basicMessageTypes": { + "title": "Basic Message Types", + "description": "Plain conversational messages without tool calls", + "oneOf": [ + { + "title": "User Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation (1-indexed)" + }, + "role": { + "const": "user", + "description": "Message role - user initiates turns" + }, + "content": { + "type": "string", + "description": "Message content from the user" + }, + "system": { + "type": "string", + "description": "System prompt (typically only on first user turn)" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "additionalProperties": true + }, + { + "title": "Assistant Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Message role - assistant responds to user" + }, + "content": { + "type": "string", + "description": "Message content from the assistant" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "not": { "required": ["tool_calls"] }, + "additionalProperties": true + } + ] + }, + + "toolCallMessage": { + "title": "Assistant Message with Tool Calls", + "description": "Assistant message that dispatches one or more tool calls. Role must be 'assistant' with a non-empty tool_calls array (OpenAI wire format).", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Role for a tool-dispatching assistant message." + }, + "content": { + "type": ["string", "null"], + "description": "Optional textual prefix alongside tool dispatch (e.g., 'I will investigate this with bash'). Typically null/absent." + }, + "tool_calls": { + "type": "array", + "minItems": 1, + "description": "List of tool calls dispatched by the assistant", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for this tool call (e.g., 'functions.bash:0')" + }, + "type": { + "type": "string", + "const": "function", + "description": "Tool type (currently only 'function' supported)" + }, + "function": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the tool/function to invoke" + }, + "arguments": { + "type": "string", + "description": "JSON string containing function arguments" + } + }, + "required": ["name", "arguments"] + } + }, + "required": ["id", "type", "function"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_calls"], + "additionalProperties": true + }, + + "toolMessage": { + "title": "Tool Result Message", + "description": "Tool execution results as a list. Single results have one entry; parallel results have multiple entries.", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "tool", + "description": "Tool result role" + }, + "tool_results": { + "type": "array", + "minItems": 1, + "description": "List of tool execution results. Single tool calls have one entry; parallel tool calls have multiple entries.", + "items": { + "type": "object", + "properties": { + "tool_call_id": { + "type": "string", + "description": "ID of the tool call this result corresponds to" + }, + "content": { + "type": "string", + "description": "Output/result content from the tool execution" + } + }, + "required": ["tool_call_id", "content"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_results"], + "additionalProperties": true + }, + + "generationParameters": { + "title": "Generation Parameters", + "description": "Optional parameters controlling the model's behavior for generation", + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "Model name override for this turn" + }, + "max_new_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate" + }, + "max_completion_tokens": { + "type": "integer", + "minimum": 1, + "description": "OpenAI API compatible max tokens parameter" + }, + "stream": { + "type": "boolean", + "description": "Whether to use streaming for this turn" + }, + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 2, + "description": "Sampling temperature (0 = deterministic, higher = more random)" + }, + "top_p": { + "type": "number", + "minimum": 0, + "maximum": 1, + "description": "Nucleus sampling parameter" + }, + "top_k": { + "type": "integer", + "minimum": 1, + "description": "Top-k sampling parameter" + }, + "seed": { + "type": "integer", + "description": "Random seed for reproducibility" + }, + "repetition_penalty": { + "type": "number", + "minimum": 0, + "description": "Penalty for repeating tokens" + }, + "frequency_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Frequency penalty for tokens" + }, + "presence_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Presence penalty for tokens" + }, + "stop": { + "oneOf": [ + { "type": "string" }, + { "type": "array", "items": { "type": "string" } } + ], + "description": "Stop sequences for generation" + }, + "n": { + "type": "integer", + "minimum": 1, + "description": "Number of completions to generate" + }, + "logit_bias": { + "type": "object", + "description": "Token probability adjustments (token_id -> bias)" + }, + "name": { + "type": "string", + "description": "Entity name for role tracking (e.g., 'Bob')" + }, + "user": { + "type": "string", + "description": "End-user identifier for monitoring/abuse detection" + }, + "chat_template": { + "type": "string", + "description": "Custom chat formatting template" + }, + "tools": { + "type": "array", + "description": "OpenAI tool definitions for tool-calling models", + "items": { "type": "object" } + } + } + } + }, + + "type": "object", + "oneOf": [ + { + "title": "Plain Conversation Row", + "description": "A single row representing a plain user or assistant message", + "allOf": [ + { "$ref": "#/definitions/basicMessageTypes" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Call Row", + "description": "A single row representing an assistant dispatch of tool calls", + "allOf": [ + { "$ref": "#/definitions/toolCallMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Result Row", + "description": "A single row representing one or more tool results", + "allOf": [ + { "$ref": "#/definitions/toolMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + } + ], + + "examples": [ + { + "title": "Basic user message", + "data": { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "I need help resetting my password", + "system": "You are a helpful customer support agent" + } + }, + { + "title": "Assistant response", + "data": { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'd be happy to help. Can you provide your email address?" + } + }, + { + "title": "Assistant with tool calls (converter/OpenAI wire format)", + "data": { + "conversation_id": "sim_001", + "turn": 2, + "role": "assistant", + "tool_calls": [ + { + "id": "functions.bash:0", + "type": "function", + "function": { + "name": "bash", + "arguments": "{\"cmd\": \"cat foo.py\"}" + } + } + ] + } + }, + { + "title": "Tool result from execution", + "data": { + "conversation_id": "sim_001", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "def foo():\n return 1/0" + } + ] + } + }, + { + "title": "Merged parallel tool results", + "data": { + "conversation_id": "sim_002", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "file1.txt" + }, + { + "tool_call_id": "functions.bash:1", + "content": "file2.txt" + } + ] + } + }, + { + "title": "User turn with generation parameters", + "data": { + "conversation_id": "conv_002", + "turn": 5, + "role": "user", + "content": "What's the best way to optimize this code?", + "temperature": 0.7, + "max_new_tokens": 256, + "top_p": 0.9 + } + } + ], + + "documentation": { + "overview": "Multi-turn conversation datasets enable benchmarking of realistic conversational AI workloads where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing.", + + "requiredFields": [ + { + "field": "conversation_id", + "type": "string", + "description": "Unique identifier for each conversation. All rows belonging to the same conversation must share the same conversation_id." + }, + { + "field": "turn", + "type": "integer", + "description": "Turn number within conversation (1-indexed). Must be consecutive starting at 1 (i.e., 1, 2, 3, …, N with no gaps or duplicates)." + }, + { + "field": "role", + "type": "string", + "enum": ["user", "assistant", "tool"], + "description": "Speaker role. 'user' or 'tool' are client-initiated turns. 'assistant' is the server response — either a terminal reply or a tool dispatch (with tool_calls field)." + }, + { + "field": "content", + "type": "string", + "description": "Message content. Required for 'user' role and plain 'assistant' rows. For tool-dispatching assistant rows, content may be omitted (null/absent). For 'tool' rows using tool_results (merged parallel results), top-level content is absent — results are in the tool_results array instead." + } + ], + + "optionalFields": [ + { + "field": "system", + "type": "string", + "description": "System prompt (typically only on first user turn). Applied to all messages in the conversation." + }, + { + "field": "tool_calls", + "type": "array", + "description": "Tool calls dispatched by assistant (for tool-dispatching 'assistant' rows). Each element has {id, type, function: {name, arguments}}." + }, + { + "field": "tool_results", + "type": "array", + "description": "Tool execution results (required for all 'tool' role rows). Each element has {tool_call_id, content}. Single results have one entry; parallel results have multiple entries." + }, + { + "field": "model", + "type": "string", + "description": "Model name override for this turn." + }, + { + "field": "max_new_tokens", + "type": "integer", + "description": "Maximum tokens to generate for this turn." + }, + { + "field": "temperature", + "type": "number", + "description": "Sampling temperature (0 to 2)." + }, + { + "field": "top_p", + "type": "number", + "description": "Nucleus sampling parameter (0 to 1)." + }, + { + "field": "tools", + "type": "array", + "description": "OpenAI tool definitions forwarded to the endpoint for tool-calling models. The converter attaches this only to client-turn rows (user and tool) to avoid duplicating the large array on every assistant row. Hand-authored datasets typically place it on the first user turn." + } + ], + + "validRoleSequences": [ + { + "name": "Plain conversation", + "sequence": "user → assistant → user → assistant → ...", + "description": "Standard alternating conversation without tool use." + }, + { + "name": "Agentic with tools", + "sequence": "user → assistant → tool → [assistant → tool]* → assistant → user", + "description": "Agent dispatches tools (assistant with tool_calls), executes them, and returns results before final response. 'tool → user' is also valid when no terminal assistant response is needed before the next user turn." + } + ], + "stateMachine": { + "description": "Complete valid-next-state table from _validate_conversation_structure()", + "transitions": { + "start": ["user"], + "user": ["assistant"], + "assistant": ["tool", "user"], + "tool": ["assistant", "user"] + } + }, + + "validationRules": [ + { + "rule": "Turn numbers must be consecutive starting at 1", + "violation": "Turn sequence is not exactly 1, 2, 3, …, N (missing, duplicate, or out-of-range turns)" + }, + { + "rule": "Role sequences must follow the state machine", + "violation": "Invalid transition (e.g., 'user' directly followed by 'user', consecutive 'assistant' rows). Note: 'tool → user' IS a valid transition. The state machine also implicitly enforces that the first row must be a user turn." + } + ], + "notValidated": [ + "tool_results[*].tool_call_id pairing: the validator does NOT verify that tool_call_id values inside tool_results items reference a prior assistant tool_calls entry. Correct pairing is the dataset author's or converter's responsibility." + ], + + "dataTypes": [ + { + "name": "Basic Message", + "roles": ["user", "assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "content"], + "optional": ["system", "...generation parameters"] + } + }, + { + "name": "Tool Call Dispatch", + "roles": ["assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_calls"], + "optional": ["content", "...generation parameters"] + } + }, + { + "name": "Tool Result", + "roles": ["tool"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_results"], + "optional": ["...generation parameters"], + "note": "Expands to one OpenAI tool message per result entry at the wire layer" + } + } + ], + + "conversionFromSnapshot": { + "description": "Agentic datasets are often stored as full-conversation snapshots. Use scripts/convert_agentic_snapshot.py to convert.", + "sourceFormat": "Each JSONL line is a complete conversation snapshot with 'messages' array", + "targetFormat": "Each JSONL line is a single message row with conversation metadata", + "process": [ + "Extract conversation_id, conversation_idx, and messages array", + "Use highest-indexed snapshot per conversation_id", + "Collapse consecutive user messages into a single user row (newline-joined content)", + "Emit all tool result rows as tool_results arrays (single results have one entry; consecutive tool results from parallel dispatch are merged into one row with multiple entries)", + "Flatten messages into individual rows, numbering turns sequentially from 1", + "Attach system prompt to first user row only", + "Attach tools array to client-turn rows (user and tool) only — not to assistant rows", + "Tool-dispatching assistant messages are written as role 'assistant' with tool_calls (OpenAI wire format)" + ] + }, + + "performanceNotes": [ + "Pre-built messages: The system pre-computes complete message lists during dataset load() for efficient turn serving", + "Memory efficiency: ~1KB per turn average; 1000 conversations × 10 turns = ~10MB", + "Hot path: Only client turns (user/tool) are issued; assistant turns remain in backing store for history" + ], + + "commonErrors": [ + { + "error": "Invalid role sequence", + "cause": "Violates state machine (e.g., user→user, user→tool, consecutive assistant rows)", + "fix": "Verify alternation or use conversion script for agentic data" + }, + { + "error": "Turn numbers not consecutive", + "cause": "Turn sequence has gaps, duplicates, or doesn't start at 1", + "fix": "Ensure turns are numbered 1, 2, 3, …, N with no missing or duplicate values" + }, + { + "error": "Turn timeout", + "cause": "Previous turn took too long to complete", + "fix": "Increase turn_timeout_s in configuration or check endpoint performance" + } + ] + } +} diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py new file mode 100644 index 00000000..1be81dd2 --- /dev/null +++ b/scripts/validate_jsonl_schema.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validate multi-turn JSONL dataset files against scripts/multi_turn_dataset_schema.json. + +Checks each row's structure against the JSON schema (field types, required fields, +tool_results shape, etc.). Does NOT check cross-row invariants such as turn +numbering or role sequences — those are enforced by MultiTurnDataset at load time. + +Usage: + python scripts/validate_jsonl_schema.py FILE [FILE ...] + python scripts/validate_jsonl_schema.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl +""" + +import argparse +import json +import sys +from pathlib import Path + +try: + import jsonschema +except ImportError: + print( + "Error: jsonschema not installed. Run: pip install jsonschema", file=sys.stderr + ) + sys.exit(1) + + +def validate_file(path: Path, schema: dict, max_errors: int = 50) -> int: + """Validate every row in a JSONL file against the schema. + + Returns the number of validation errors found. + """ + errors: list[str] = [] + validator = jsonschema.Draft7Validator(schema) + + with path.open() as fh: + for lineno, line in enumerate(fh, 1): + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except json.JSONDecodeError as e: + errors.append(f" line {lineno}: JSON parse error: {e}") + if len(errors) >= max_errors: + break + continue + + conv_id = row.get("conversation_id", "") + turn = row.get("turn", "?") + role = row.get("role", "?") + + row_errors = list(validator.iter_errors(row)) + for err in row_errors: + path_str = " -> ".join(str(p) for p in err.absolute_path) or "(root)" + errors.append( + f" line {lineno} [{conv_id} turn={turn} role={role}] " + f"@ {path_str}: {err.message}" + ) + + if len(errors) >= max_errors: + errors.append(f" ... stopping after {max_errors} errors") + break + + if errors: + print(f"FAIL {path.name}: {len(errors)} error(s)") + for msg in errors: + print(msg) + else: + print(f"OK {path.name}") + + return len(errors) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Validate multi-turn JSONL files against scripts/multi_turn_dataset_schema.json." + ) + parser.add_argument("files", nargs="+", type=Path, help="JSONL files to validate") + parser.add_argument( + "--schema", + type=Path, + default=Path(__file__).parent / "multi_turn_dataset_schema.json", + help="Path to the JSON schema file (default: scripts/multi_turn_dataset_schema.json)", + ) + parser.add_argument( + "--max-errors", + type=int, + default=50, + help="Stop reporting after this many errors per file (default: 50)", + ) + args = parser.parse_args() + + if not args.schema.exists(): + print(f"Error: schema not found: {args.schema}", file=sys.stderr) + sys.exit(1) + + schema = json.load(args.schema.open()) + + total_errors = 0 + for path in args.files: + if not path.exists(): + print(f"Error: file not found: {path}", file=sys.stderr) + total_errors += 1 + continue + total_errors += validate_file(path, schema, max_errors=args.max_errors) + + sys.exit(1 if total_errors > 0 else 0) + + +if __name__ == "__main__": + main() diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 73c3427f..30411af0 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -72,6 +72,7 @@ from inference_endpoint.core.types import QueryResult from inference_endpoint.dataset_manager.dataset import Dataset from inference_endpoint.dataset_manager.factory import DataLoaderFactory +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset from inference_endpoint.endpoint_client.cpu_affinity import AffinityPlan, pin_loadgen from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer @@ -82,6 +83,8 @@ InputValidationError, SetupError, ) +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy from inference_endpoint.load_generator.session import ( BenchmarkSession, PhaseConfig, @@ -343,14 +346,21 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo ) -def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: +def _build_phases( + ctx: BenchmarkContext, + perf_strategy: MultiTurnStrategy | None = None, +) -> list[PhaseConfig]: """Build the phase list from BenchmarkContext.""" phases: list[PhaseConfig] = [] # Performance phase phases.append( PhaseConfig( - "performance", ctx.rt_settings, ctx.dataloader, PhaseType.PERFORMANCE + "performance", + ctx.rt_settings, + ctx.dataloader, + PhaseType.PERFORMANCE, + strategy=perf_strategy, ) ) @@ -513,16 +523,43 @@ async def _run_benchmark_async( launcher.kill_all() raise SetupError(f"Failed to connect to endpoint: {e}") from e + # Build multi-turn strategy if the performance dataset is a MultiTurnDataset. + multi_turn_strategy: MultiTurnStrategy | None = None + if isinstance(ctx.dataloader, MultiTurnDataset): + mt_cfg = None + if ctx.config.datasets: + perf_ds_cfg = next( + ( + d + for d in ctx.config.datasets + if d.type == DatasetType.PERFORMANCE + ), + None, + ) + if perf_ds_cfg is not None: + mt_cfg = perf_ds_cfg.multi_turn + multi_turn_strategy = MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ctx.dataloader.conversation_metadata, + multi_turn_config=mt_cfg, + target_concurrency=ctx.config.settings.load_pattern.target_concurrency, + ) + + def _on_sample_complete(result: QueryResult) -> None: + if multi_turn_strategy is not None: + multi_turn_strategy.on_sample_complete(result) + collector.on_complete_hook(result) + # Create session session = BenchmarkSession( issuer=issuer, event_publisher=publisher, loop=loop, - on_sample_complete=collector.on_complete_hook, + on_sample_complete=_on_sample_complete, session_id=session_id, ) - phases = _build_phases(ctx) + phases = _build_phases(ctx, perf_strategy=multi_turn_strategy) report: Report | None = None loop.add_signal_handler(signal.SIGINT, session.stop) diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index fb349a02..eac1aa47 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -32,6 +32,7 @@ from typing import TYPE_CHECKING from .. import metrics +from .schema import LoadPatternType logger = logging.getLogger(__name__) @@ -194,6 +195,17 @@ def total_samples_to_issue( ) return self.n_samples_to_issue + # Multi-turn must issue exactly all client turns — QPS-based formulas are meaningless. + if ( + self.load_pattern is not None + and self.load_pattern.type == LoadPatternType.MULTI_TURN + ): + result = max(self.min_sample_count, self.n_samples_from_dataset) + logger.debug( + f"Sample count: {result} (multi-turn: issuing all {self.n_samples_from_dataset} client turns)" + ) + return result + # If min_duration is 0, use all dataset samples (new CLI default behavior) if self.min_duration_ms == 0: result = max(self.min_sample_count, self.n_samples_from_dataset) diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 6a1884b4..2846268c 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -60,10 +60,17 @@ class LoadPatternType(str, Enum): MAX_THROUGHPUT = "max_throughput" # Offline: all queries at t=0 POISSON = "poisson" # Online: fixed QPS with Poisson distribution CONCURRENCY = "concurrency" # Online: fixed concurrent requests + MULTI_TURN = "multi_turn" # Multi-turn conversations with turn sequencing BURST = "burst" # Burst pattern (TODO) STEP = "step" # Step pattern (TODO) +class ConversationMode(str, Enum): + """Multi-turn conversation scheduling modes.""" + + INDEPENDENT = "independent" # Per-conv pipelines; no cross-conv turn barrier + + class OSLDistributionType(str, Enum): """Output Sequence Length distribution types.""" @@ -230,6 +237,26 @@ def get_ruleset_instance(self) -> BenchmarkSuiteRuleset: return get_ruleset(self.ruleset) +class MultiTurnConfig(BaseModel): + """Multi-turn conversation configuration. + + Configuration for benchmarking conversational AI workloads with turn sequencing. + Enables testing multi-turn conversations where each turn depends on previous responses. + Presence of this block in the dataset config enables multi-turn mode. + + Attributes: + mode: Conversation scheduling strategy (currently only independent). + turn_timeout_s: Maximum seconds to wait for previous turn completion. + use_dataset_history: If True, use pre-built message history from dataset. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + mode: ConversationMode = ConversationMode.INDEPENDENT + turn_timeout_s: float = 300.0 + use_dataset_history: bool = True + + class Dataset(BaseModel): """Dataset configuration. @@ -260,6 +287,9 @@ class Dataset(BaseModel): accuracy_config: AccuracyConfig | None = Field( None, description="Accuracy evaluation settings" ) + multi_turn: MultiTurnConfig | None = Field( + None, description="Multi-turn conversation configuration" + ) @model_validator(mode="after") def _auto_derive_name(self) -> Self: @@ -389,6 +419,12 @@ def _validate_completeness(self) -> Self: raise ValueError( "Concurrency requires --concurrency (e.g., --concurrency 10)" ) + if self.type == LoadPatternType.MULTI_TURN and ( + not self.target_concurrency or self.target_concurrency <= 0 + ): + raise ValueError( + "Multi-turn requires --concurrency (e.g., --concurrency 96)" + ) return self @@ -584,11 +620,28 @@ def _resolve_and_validate(self) -> Self: f"Offline benchmarks must use 'max_throughput', got '{lp.type}'" ) elif effective_mode == TestType.ONLINE: - if lp.type not in (LoadPatternType.POISSON, LoadPatternType.CONCURRENCY): + if lp.type not in ( + LoadPatternType.POISSON, + LoadPatternType.CONCURRENCY, + LoadPatternType.MULTI_TURN, + ): raise ValueError( - "Online mode requires --load-pattern (poisson or concurrency)" + "Online mode requires --load-pattern (poisson, concurrency, or multi_turn)" ) + # Cross-validate load_pattern.type=multi_turn ↔ dataset.multi_turn config + has_multi_turn_dataset = any( + d.multi_turn is not None for d in (self.datasets or []) + ) + if lp.type == LoadPatternType.MULTI_TURN and not has_multi_turn_dataset: + raise ValueError( + "load_pattern.type=multi_turn requires at least one dataset with multi_turn config" + ) + if has_multi_turn_dataset and lp.type != LoadPatternType.MULTI_TURN: + raise ValueError( + f"Datasets with multi_turn config require load_pattern.type=multi_turn, got '{lp.type}'" + ) + return self @model_validator(mode="after") diff --git a/src/inference_endpoint/config/templates/concurrency_template.yaml b/src/inference_endpoint/config/templates/concurrency_template.yaml index 7b560ed7..c44295b4 100644 --- a/src/inference_endpoint/config/templates/concurrency_template.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_concurrency: 32 # Concurrent requests endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 3a8e004f..2c0c24ae 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: 32 # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index faabffde..72182198 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/online_template.yaml b/src/inference_endpoint/config/templates/online_template.yaml index d33c1fd5..a56dc9b0 100644 --- a/src/inference_endpoint/config/templates/online_template.yaml +++ b/src/inference_endpoint/config/templates/online_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index e9b7a673..b36c41a2 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index accd2ca8..dc7f4faf 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -226,6 +226,7 @@ class Query( Attributes: id: Unique identifier for this query (auto-generated UUID). data: Request payload as a dictionary (typically contains prompt, model, etc.). + metadata: Internal metadata that round-trips through transport (e.g., conversation_id). headers: HTTP headers to include in the request (e.g., authorization). created_at: Timestamp when query was created (seconds since epoch). @@ -249,6 +250,7 @@ class Query( id: str = msgspec.field(default_factory=lambda: str(uuid.uuid4())) data: dict[str, Any] = msgspec.field(default_factory=dict) + metadata: dict[str, Any] = msgspec.field(default_factory=dict) headers: dict[str, str] = msgspec.field(default_factory=dict) created_at: float = msgspec.field(default_factory=time.time) @@ -331,6 +333,32 @@ def get_response_output_string(self) -> str: else: return "" + def with_metadata( + self, additional_metadata: dict[str, Any] | None + ) -> "QueryResult": + """Return a new QueryResult with merged metadata. + + Args: + additional_metadata: Metadata to merge into existing metadata. + Values in additional_metadata override existing keys. + + Returns: + New QueryResult with merged metadata (existing + additional). + If additional_metadata is None or empty, returns self unchanged. + """ + if not additional_metadata: + return self + + merged = dict(self.metadata) + merged.update(additional_metadata) + + return QueryResult( + id=self.id, + response_output=self.response_output, + metadata=merged, + error=self.error, + ) + class StreamChunk( msgspec.Struct, diff --git a/src/inference_endpoint/dataset_manager/__init__.py b/src/inference_endpoint/dataset_manager/__init__.py index 4bb6c575..403b8730 100644 --- a/src/inference_endpoint/dataset_manager/__init__.py +++ b/src/inference_endpoint/dataset_manager/__init__.py @@ -21,6 +21,7 @@ from .dataset import Dataset, EmptyDataset from .factory import DataLoaderFactory +from .multi_turn_dataset import MultiTurnDataset from .predefined.aime25 import AIME25 from .predefined.cnndailymail import CNNDailyMail from .predefined.gpqa import GPQA @@ -29,6 +30,7 @@ from .predefined.random import RandomDataset from .predefined.shopify_product_catalogue import ShopifyProductCatalogue from .transforms import ( + AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -45,6 +47,7 @@ "DataLoaderFactory", "ColumnFilter", "ColumnRemap", + "AddDefaultColumns", "AddStaticColumns", "UserPromptFormatter", "FusedRowProcessor", @@ -58,4 +61,5 @@ "CNNDailyMail", "RandomDataset", "ShopifyProductCatalogue", + "MultiTurnDataset", ] diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index 6ed1674a..8c1226c6 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -24,6 +24,7 @@ from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat +from .multi_turn_dataset import MultiTurnDataset from .transforms import ColumnRemap, MakeAdapterCompatible, Transform logger = logging.getLogger(__name__) @@ -95,18 +96,24 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data if file_format is not None: format_enum = DatasetFormat(file_format) + dataset_id = None + if config.multi_turn is not None: + dataset_id = MultiTurnDataset.DATASET_ID + transforms: list[Transform] = [] if remap is not None: # Parser convention is {target: source} (e.g. {prompt: article}). # ColumnRemap expects {source: target} — flip it. flipped = {src: dst for dst, src in remap.items()} transforms.append(ColumnRemap(flipped)) # type: ignore[arg-type] - transforms.append(MakeAdapterCompatible()) + if dataset_id != MultiTurnDataset.DATASET_ID: + transforms.append(MakeAdapterCompatible()) assert dataset_path is not None return Dataset.load_from_file( Path(dataset_path), transforms=transforms, format=format_enum, + dataset_id=dataset_id, num_repeats=num_repeats, ) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py new file mode 100644 index 00000000..d2f21695 --- /dev/null +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -0,0 +1,415 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-turn conversation dataset for conversational AI benchmarking.""" + +import logging +from typing import Any + +import pandas as pd + +from ..config.schema import APIType, ModelParams +from ..exceptions import InputValidationError +from .dataset import Dataset +from .transforms import ( + AddDefaultColumns, + AddStaticColumns, + apply_transforms, + get_transforms_for_api_type, +) + +logger = logging.getLogger(__name__) + + +def _expand_tool_results(row: dict) -> list[dict]: + """Expand a tool row into one OpenAI tool message per result. + + All ``role: "tool"`` rows carry a ``tool_results`` array. Each entry expands to + one OpenAI tool message with ``tool_call_id`` and ``content``. + + Returns an empty list if ``tool_results`` is absent or not a list (non-tool rows). + """ + tool_results = row.get("tool_results") + if not isinstance(tool_results, list): + return [] + if not tool_results: + logger.warning( + "Row has empty tool_results list (conversation_id=%s, turn=%s)", + row.get("conversation_id"), + row.get("turn"), + ) + return [] + messages = [] + for i, result in enumerate(tool_results): + tool_call_id = result.get("tool_call_id") + content = result.get("content") + if tool_call_id is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'tool_call_id'" + ) + if content is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'content'" + ) + messages.append( + {"role": "tool", "tool_call_id": tool_call_id, "content": content} + ) + return messages + + +class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): + """Dataset for multi-turn conversations. + + Supports conversational AI benchmarking with turn sequencing and conversation history. + Validates that conversations have proper structure (alternating user/assistant roles) + and builds metadata for the scheduler to enforce turn ordering. + + Dataset format (JSONL): + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} + + Required columns: + - conversation_id: Unique identifier for each conversation + - turn: Turn number within conversation (1-indexed) + - role: Speaker role ("user" or "assistant") + - content: Message content + + Optional columns: + - system: System prompt associated with the conversation (typically set on the first user turn) + - model: Model name override + - max_new_tokens / max_completion_tokens: Max tokens for this turn (alias; mapped to max_completion_tokens) + + Attributes: + conversation_metadata: Metadata dict containing: + - samples: List of user turn metadata (index, conversation_id, turn, system) + - num_conversations: Total number of unique conversations + - max_turns_per_conv: Maximum turns in any conversation + """ + + COLUMN_NAMES = ["conversation_id", "turn", "role", "content"] + + def __init__(self, dataframe: pd.DataFrame, **kwargs): + """Initialize multi-turn dataset. + + Args: + dataframe: DataFrame with conversation data. + **kwargs: Additional arguments passed to Dataset.__init__. + + Raises: + ValueError: If conversation structure is invalid. + """ + super().__init__(dataframe, **kwargs) + assert self.dataframe is not None, "Dataframe must be initialized" + self._conv_groups = dict( + list(self.dataframe.groupby("conversation_id", sort=False)) + ) + self._validate_conversation_grouping() + self._validate_conversation_structure() + self._validate_turn_numbering() + self.conversation_metadata = self._build_metadata() + + def _validate_conversation_grouping(self) -> None: + """Validate that all rows for each conversation_id appear consecutively in file order. + + Raises: + InputValidationError: If rows for a conversation_id are interleaved with other conversations. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + seen: set[str] = set() + last_conv: str | None = None + for row in self.dataframe.to_dict(orient="records"): + conv_id = str(row["conversation_id"]) + if conv_id != last_conv: + if conv_id in seen: + raise InputValidationError( + f"Rows for conversation '{conv_id}' are not consecutive. " + "All rows for a conversation must appear together in the file." + ) + seen.add(conv_id) + last_conv = conv_id + + def _validate_conversation_structure(self): + """Validate conversations are well-formed. + + Accepts plain user/assistant alternation as well as tool sequences: + user → assistant → tool → [assistant → tool]* → assistant → user + + Raises: + ValueError: If any conversation has invalid role sequence. + """ + VALID_NEXT: dict[str, set[str]] = { + "start": {"user"}, + "user": {"assistant"}, + "assistant": {"tool", "user"}, + "tool": {"assistant", "user"}, + } + + for conv_id, group in self._conv_groups.items(): + sorted_group = group.sort_values("turn") + state = "start" + + for _, row in sorted_group.iterrows(): + role = row["role"] + + if role not in VALID_NEXT.get(state, set()): + raise ValueError( + f"Conversation {conv_id} has invalid role sequence at turn " + f"{row['turn']}: got '{role}' after state '{state}'" + ) + state = role + + def _validate_turn_numbering(self): + """Validate turn numbers are consecutive starting at 1. + + Raises: + ValueError: If turn numbers are not exactly 1, 2, 3, …, N. + """ + for conv_id, group in self._conv_groups.items(): + turns = sorted(group["turn"].tolist()) + expected = list(range(1, len(turns) + 1)) + if turns != expected: + raise ValueError( + f"Conversation {conv_id}: Turn numbers must be consecutive starting at 1, " + f"got {turns}" + ) + + def _build_metadata(self) -> dict[str, Any]: + """Build metadata for scheduler (maps sample index to conversation context). + + Pre-computes the complete message list for each client turn so that + conversation history does not need to be accumulated at runtime. + + Returns: + Metadata dict with samples list, num_conversations, max_turns_per_conv, + client_turns_per_conversation, and pre_built_messages_by_key. + """ + samples = [] + + # Count client turns (user + tool) per conversation for completion tracking + client_turns_per_conv = { + str(conv_id): int(group["role"].isin(["user", "tool"]).sum()) + for conv_id, group in self._conv_groups.items() + } + + # Map (conversation_id, turn) → complete message list ready to send to endpoint. + # Each entry is: [system (optional)] + all prior rows formatted as messages + # + the current client turn message. + # This includes assistant rows (tool dispatches or terminal responses) + # so no runtime injection is required. + pre_built_messages_by_key: dict[tuple, list[dict]] = {} + current_turn_messages_by_key: dict[tuple, list[dict]] = {} + system_prompts_by_conv: dict[str, str | None] = {} + + assert self.dataframe is not None, "Dataframe must be initialized" + for conv_id, group in self._conv_groups.items(): + sorted_group = group.sort_values("turn") + client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] + + # Extract system prompt from the first row that has it (typically turn 1) + system_content: str | None = None + for _, srow in sorted_group.iterrows(): + val = srow.get("system") + if val and isinstance(val, str): + system_content = val + break + system_prompts_by_conv[str(conv_id)] = system_content + + for _, row in client_rows.iterrows(): + t_n = int(row["turn"]) + + messages: list[dict] = [] + if system_content: + messages.append({"role": "system", "content": system_content}) + + # All dataset rows strictly before this client turn (includes + # assistant rows and prior tool results). + prior_rows = sorted_group[sorted_group["turn"] < t_n] + for _, prior_row in prior_rows.iterrows(): + msg: dict[str, Any] = {} + for key in ("role", "content", "tool_calls", "tool_results"): + val = prior_row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + msg[key] = val + if ( + msg.get("role") == "assistant" + and "tool_calls" in msg + and "content" not in msg + ): + msg["content"] = None + if msg.get("role"): + # Expand merged parallel tool results: a single row with + # tool_results: [{tool_call_id, content}, ...] expands into + # one OpenAI tool message per result entry. + expanded = _expand_tool_results(msg) + if expanded: + messages.extend(expanded) + else: + messages.append(msg) + + # Append the current client turn message. + # A merged parallel-tool row carries tool_results instead of a + # single tool_call_id/content pair; expand to one message per result. + current_turn_msgs: list[dict] = [] + expanded = _expand_tool_results(row) + if expanded: + current_turn_msgs = expanded + else: + cur: dict[str, Any] = {} + for key in ("role", "content"): + val = row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + cur[key] = val + current_turn_msgs = [cur] + messages.extend(current_turn_msgs) + + str_conv_id = str(conv_id) + pre_built_messages_by_key[(str_conv_id, t_n)] = messages + current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs + + samples.append( + { + "conversation_id": str_conv_id, + "turn": t_n, + } + ) + + return { + "samples": samples, + "num_conversations": len(self._conv_groups), + "max_turns_per_conv": max( + g["turn"].max() for g in self._conv_groups.values() + ), + "client_turns_per_conversation": client_turns_per_conv, + "pre_built_messages_by_key": pre_built_messages_by_key, + "current_turn_messages_by_key": current_turn_messages_by_key, + "system_prompts_by_conv": system_prompts_by_conv, + } + + def load( + self, + adapter=None, + api_type: APIType | None = None, + model_params: ModelParams | None = None, + force: bool = False, + ): + """Load dataset, apply adapter defaults, and pre-bake client-turn samples. + + Unlike single-turn datasets, multi-turn rows do not have a `prompt` column, + so ColumnFilter (which requires prompt) is skipped. AddStaticColumns entries + from the adapter are applied via AddDefaultColumns (fill-missing-only) so that + per-row dataset overrides are preserved. + + After transforms, only client turns (user + tool) are stored in self.data as + fully assembled sample dicts (with messages attached). + load_sample() and num_samples() are inherited from the base class. + """ + if not force and self.data is not None: + return + + df = self.dataframe + if df is None: + raise ValueError( + f"Cannot load dataset {self.__class__.__name__}: dataframe is None" + ) + + transforms = [] + if self.transforms is not None: + transforms.extend(self.transforms) + + if transforms: + df = apply_transforms(df, transforms) + + # Extract AddStaticColumns defaults from adapter transforms and apply as + # fill-missing-only (preserves per-row dataset values). + if api_type is not None and model_params is not None: + adapter_transforms = get_transforms_for_api_type(api_type, model_params) + defaults: dict[str, Any] = {} + for t in adapter_transforms: + if isinstance(t, AddStaticColumns): + defaults.update(t.data) + if defaults: + df = AddDefaultColumns(defaults)(df) + + all_rows = df.to_dict(orient="records") + + # Pre-bake: assemble one complete sample dict per client turn. + # NaN filtering replaces the GENERATION_PARAMS allowlist — any key whose + # value is float NaN was absent in the original dataset row. + pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}) + client_turn_samples: list[dict[str, Any]] = [] + + # Collect per-conversation defaults from the first user row so that + # fields like model/max_completion_tokens propagate to tool rows. + _PROPAGATED_KEYS = { + "model", + "max_completion_tokens", + "max_new_tokens", + "stream", + } + conv_defaults: dict[str, dict[str, Any]] = {} + for row in all_rows: + cid = row.get("conversation_id") + if cid not in conv_defaults and row.get("role") == "user": + conv_defaults[cid] = { + k: row[k] + for k in _PROPAGATED_KEYS + if k in row + and row[k] is not None + and not (isinstance(row[k], float) and pd.isna(row[k])) + } + + for row in all_rows: + if row.get("role") not in ("user", "tool"): + continue + + # Filter NaN values; keep all meaningful fields (extra keys are harmless + # since adapters consume only what they recognize). + sample: dict[str, Any] = { + k: v + for k, v in row.items() + if v is not None and not (isinstance(v, float) and pd.isna(v)) + } + # Strip dataset-internal fields that must not reach the endpoint. + sample.pop("tool_results", None) + sample.pop("tool_calls", None) + + # Fill missing propagated fields from the first user row of this conversation. + for k, v in conv_defaults.get(row.get("conversation_id"), {}).items(): + if k not in sample: + sample[k] = v + + # max_new_tokens → max_completion_tokens alias + if "max_completion_tokens" not in sample and "max_new_tokens" in sample: + sample["max_completion_tokens"] = sample.pop("max_new_tokens") + if "max_completion_tokens" not in sample: + sample["max_completion_tokens"] = 128 + if "stream" not in sample: + sample["stream"] = False + + # Attach pre-built message list (system + history + current turn). + key = (str(row["conversation_id"]), int(row["turn"])) + messages = pre_built.get(key, []) + sample["messages"] = messages + + client_turn_samples.append(sample) + + self.data = client_turn_samples diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index 79133796..a288da6d 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -127,6 +127,30 @@ def __call__(self, df: pd.DataFrame) -> pd.DataFrame: return df +class AddDefaultColumns(Transform): + """Add columns only where values are missing (NaN or absent). + + Unlike AddStaticColumns which unconditionally overwrites, this preserves + existing non-null values — dataset per-row overrides take precedence over + the supplied defaults. + """ + + def __init__(self, data: dict[str, Any]): + """Initialize the AddDefaultColumns transform.""" + self.data = data + + def __call__(self, df: pd.DataFrame) -> pd.DataFrame: + """Fill missing columns with defaults without overwriting existing values.""" + for key, value in self.data.items(): + if value is None: + continue + if key in df.columns: + df[key] = df[key].where(pd.notna(df[key]), value) + else: + df[key] = value + return df + + class Harmonize(RowProcessor): """Transform to convert a user prompt to an OpenAI Harmony-compatible format.""" diff --git a/src/inference_endpoint/endpoint_client/adapter_protocol.py b/src/inference_endpoint/endpoint_client/adapter_protocol.py index c0239d56..feb590a4 100644 --- a/src/inference_endpoint/endpoint_client/adapter_protocol.py +++ b/src/inference_endpoint/endpoint_client/adapter_protocol.py @@ -19,7 +19,7 @@ import re from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from inference_endpoint.core.types import Query, QueryResult @@ -93,24 +93,24 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: @classmethod @abstractmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: + def decode_sse_message(cls, json_bytes: bytes) -> Any: """ - Decode SSE message and extract content string. + Decode SSE message and return adapter-specific chunk object. Args: json_bytes: Raw JSON bytes from SSE stream Returns: - Content string from the SSE message + Adapter-specific chunk object passed to accumulator.add_chunk() """ raise NotImplementedError("decode_sse_message not implemented") @classmethod - def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[str]: + def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: """ - Parse SSE chunk and extract all content strings. + Parse SSE chunk and extract all chunk objects. - Extracts JSON documents from SSE stream and decodes them to content strings. + Extracts JSON documents from SSE stream and decodes them to chunk objects. Silently ignores non-content SSE messages (role, finish_reason, etc). Args: @@ -118,7 +118,7 @@ def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[str]: end_pos: End position in buffer to parse up to Returns: - List of content strings extracted from the SSE chunk + List of chunk objects extracted from the SSE chunk """ json_docs = cls.SSE_DATA_PATTERN.findall(buffer[:end_pos]) parsed_contents = [] diff --git a/src/inference_endpoint/endpoint_client/http.py b/src/inference_endpoint/endpoint_client/http.py index d9047301..1e67a023 100644 --- a/src/inference_endpoint/endpoint_client/http.py +++ b/src/inference_endpoint/endpoint_client/http.py @@ -792,10 +792,12 @@ class InFlightRequest: query_id: Correlates response back to original Query. http_bytes: Serialized HTTP request for socket.write(). is_streaming: Whether this is a streaming (SSE) request or not. + query_metadata: Internal metadata carried alongside the request. connection: PooledConnection assigned to this request (set once request is fired). """ query_id: str http_bytes: bytes is_streaming: bool + query_metadata: dict[str, object] = field(default_factory=dict) connection: PooledConnection = field(default=None, repr=False) # type: ignore[assignment] diff --git a/src/inference_endpoint/endpoint_client/worker.py b/src/inference_endpoint/endpoint_client/worker.py index 8e0e560e..8fb69fce 100644 --- a/src/inference_endpoint/endpoint_client/worker.py +++ b/src/inference_endpoint/endpoint_client/worker.py @@ -341,6 +341,7 @@ def _prepare_request(self, query: Query) -> InFlightRequest: query_id=query.id, http_bytes=http_bytes, is_streaming=is_streaming, + query_metadata=query.metadata, ) return req @@ -429,7 +430,9 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None: self._pool.release(conn) # Send final complete back to main rank - self._responses.send(accumulator.get_final_output()) + self._responses.send( + accumulator.get_final_output().with_metadata(req.query_metadata) + ) @profile async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: @@ -447,7 +450,7 @@ async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: result = self._adapter.decode_response(response_bytes, query_id) # Send result back to main rank - self._responses.send(result) + self._responses.send(result.with_metadata(req.query_metadata)) async def _handle_error(self, query_id: str, error: Exception | str) -> None: """Send error response for a query.""" diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py new file mode 100644 index 00000000..1b0834bb --- /dev/null +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversation state management for multi-turn benchmarking.""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ConversationState: + """Per-conversation state for multi-turn benchmarking. + + Attributes: + conversation_id: Unique identifier for this conversation. + message_history: Accumulated message list (populated only when + use_dataset_history=False; empty otherwise). + completed_turns: Turns with responses (success or failure) — observability only. + failed_turns: Turns that failed — observability only. + expected_client_turns: Expected total client turns (for completion detection). + """ + + conversation_id: str + message_history: list[dict[str, Any]] = field(default_factory=list) + completed_turns: int = 0 + failed_turns: int = 0 + expected_client_turns: int | None = None + + def is_complete(self) -> bool: + """Return True when all expected turns have a response.""" + if self.expected_client_turns is None: + return False + return self.completed_turns >= self.expected_client_turns + + +class ConversationManager: + """Manages per-conversation state for multi-turn benchmarking. + + All methods are synchronous. Turn sequencing is driven by MultiTurnStrategy + which calls on_sample_complete() → _issue_next_turn() directly. + + All states are pre-created by MultiTurnStrategy.execute() before any turns + are issued, so get_or_create() requires no locking. + """ + + def __init__(self): + """Initialize with empty state.""" + self._conversations: dict[str, ConversationState] = {} + + def get_state(self, conversation_id: str) -> ConversationState | None: + """Return existing state without creating (read-only access).""" + return self._conversations.get(conversation_id) + + def get_or_create( + self, + conversation_id: str, + expected_client_turns: int | None = None, + system_message: dict[str, Any] | None = None, + ) -> ConversationState: + """Return existing state or create a new one. + + Args: + conversation_id: Unique identifier for conversation. + expected_client_turns: Expected number of client turns. + system_message: System message to prepend to message_history + (only used when use_dataset_history=False and state is new). + + Returns: + ConversationState for this conversation. + """ + if conversation_id not in self._conversations: + initial_history: list[dict[str, Any]] = ( + [system_message] if system_message is not None else [] + ) + self._conversations[conversation_id] = ConversationState( + conversation_id=conversation_id, + expected_client_turns=expected_client_turns, + message_history=initial_history, + ) + return self._conversations[conversation_id] + + def _log_if_complete(self, state: ConversationState, conversation_id: str) -> None: + """Log completion status once all expected turns have a response.""" + if not state.is_complete(): + return + if state.failed_turns > 0: + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + else: + logger.debug( + f"Conversation {conversation_id} completed: " + f"{state.completed_turns}/{state.expected_client_turns} turns" + ) + + def mark_turn_complete( + self, + conversation_id: str, + response: str, + store_in_history: bool = False, + metadata: dict[str, Any] | None = None, + ) -> None: + """Record a successful response. + + Args: + conversation_id: Conversation ID. + response: Model output (appended to history when store_in_history=True). + store_in_history: When True, append response to message_history. + metadata: Optional response metadata; tool_calls are preserved in history + when present (only used when store_in_history=True). + + Raises: + KeyError: If conversation_id not found. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + if store_in_history: + tool_calls = metadata.get("tool_calls") if metadata else None + if response or tool_calls: + msg: dict[str, Any] = {"role": "assistant", "content": response or None} + if tool_calls: + msg["tool_calls"] = tool_calls + state.message_history.append(msg) + state.completed_turns += 1 + self._log_if_complete(state, conversation_id) + + def mark_turn_failed( + self, + conversation_id: str, + store_in_history: bool = False, + ) -> None: + """Record a failed response. + + Failed turns count toward completion so sequencing progresses under errors. + + Args: + conversation_id: Conversation ID. + store_in_history: When True, append error placeholder to message_history. + + Raises: + KeyError: If conversation_id not found. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + if store_in_history: + state.message_history.append( + {"role": "assistant", "content": "[ERROR: Turn failed or timed out]"} + ) + state.completed_turns += 1 + state.failed_turns += 1 + logger.warning(f"Turn failed for conversation {conversation_id}") + self._log_if_complete(state, conversation_id) diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py new file mode 100644 index 00000000..d3f432d7 --- /dev/null +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async multi-turn load strategy implementing the LoadStrategy protocol.""" + +import asyncio +import logging +from collections import defaultdict, deque +from collections.abc import Iterator +from typing import Any + +from ..config.schema import MultiTurnConfig +from ..core.types import QueryResult +from .conversation_manager import ConversationManager, ConversationState +from .strategy import PhaseIssuerProtocol + +logger = logging.getLogger(__name__) + +# Default turn timeout when no MultiTurnConfig is provided. +_DEFAULT_TURN_TIMEOUT_S = 300.0 + + +class MultiTurnStrategy: + """Event-driven multi-turn strategy. Completion of each turn triggers the next. + + execute() seeds the first N conversations (issues turn 1 for each), then + awaits _all_done. on_sample_complete() is called synchronously from the + receive coroutine for each response — it issues the next turn immediately + (zero event-loop iterations between response and next issuance), or starts + a new conversation when the current one finishes all turns. + + At most target_concurrency conversations are active simultaneously. When + target_concurrency is None, all conversations start at once. + + Integration with BenchmarkSession: + - execute(): seeds conversations, awaits completion + - on_query_complete(): no-op (required by LoadStrategy protocol) + - on_sample_complete(): routes completed QueryResult, issues next turn + + The response routing path: + 1. _issue_next_turn issues turn N via phase_issuer.issue(idx) → query_id + 2. _issue_next_turn stores conv_id in _inflight[query_id] + 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult + 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete + 5. on_sample_complete calls _issue_next_turn for turn N+1 (synchronously) + """ + + def __init__( + self, + conversation_manager: ConversationManager, + dataset_metadata: dict[str, Any], + multi_turn_config: MultiTurnConfig | None = None, + target_concurrency: int | None = None, + ): + """Initialize multi-turn strategy. + + Args: + conversation_manager: Manages conversation sequencing state. + dataset_metadata: Metadata from MultiTurnDataset (samples list). + multi_turn_config: Multi-turn conversation configuration. + target_concurrency: Maximum number of simultaneously active conversations. + None means all conversations run concurrently. + """ + self._conv_manager = conversation_manager + self._dataset_metadata = dataset_metadata + self._multi_turn_config = multi_turn_config + self._turn_timeout_s = ( + multi_turn_config.turn_timeout_s + if multi_turn_config is not None + else _DEFAULT_TURN_TIMEOUT_S + ) + self._target_concurrency = target_concurrency + self._store_in_history = ( + not multi_turn_config.use_dataset_history + if multi_turn_config is not None + else False + ) + + # Maps query_id -> conversation_id for routing completions. + self._inflight: dict[str, str] = {} + # Cached ConversationState refs for O(1) lookup in on_sample_complete. + self._conv_states: dict[str, ConversationState] = {} + + # Event-driven state — populated in execute(). + self._pending_convs: deque[tuple[str, list[tuple[int, int]]]] = deque() + self._active_iters: dict[str, Iterator[tuple[int, int]]] = {} + self._timeout_handles: dict[str, asyncio.TimerHandle] = {} + self._error: BaseException | None = None + self._all_done: asyncio.Event | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._phase_issuer: PhaseIssuerProtocol | None = None + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + """Drive multi-turn sample issuance. + + Args: + phase_issuer: Interface for issuing samples to the endpoint. + + Returns: + Total count of samples issued. + """ + self._phase_issuer = phase_issuer + self._loop = asyncio.get_running_loop() + self._all_done = asyncio.Event() + self._error = None + + conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) + for sample_index, sample_meta in enumerate(self._dataset_metadata["samples"]): + conv_id = sample_meta["conversation_id"] + conv_samples[conv_id].append((sample_index, sample_meta["turn"])) + + # Pre-create all conversation states before issuing any turns (no locking needed). + sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) + for conv_id, turns in conv_samples.items(): + sys_content = sys_prompts.get(conv_id) if self._store_in_history else None + system_message = ( + {"role": "system", "content": sys_content} + if sys_content is not None + else None + ) + state = self._conv_manager.get_or_create( + conv_id, + expected_client_turns=len(turns), + system_message=system_message, + ) + self._conv_states[conv_id] = state + + # Build pending queue (sorted turns per conversation). + for conv_id, turns in conv_samples.items(): + self._pending_convs.append((conv_id, sorted(turns, key=lambda x: x[1]))) + + n_to_start = ( + min(self._target_concurrency, len(self._pending_convs)) + if self._target_concurrency is not None and self._target_concurrency > 0 + else len(self._pending_convs) + ) + for _ in range(n_to_start): + self._start_conversation() + + if not self._active_iters and not self._inflight: + return phase_issuer.issued_count + + await self._all_done.wait() + + for handle in self._timeout_handles.values(): + handle.cancel() + self._timeout_handles.clear() + + if self._inflight: + logger.warning( + "%d query(ies) never received a response (session stop or transport failure): %s", + len(self._inflight), + list(self._inflight.keys()), + ) + self._inflight.clear() + + if self._error is not None: + raise self._error + return phase_issuer.issued_count + + def _start_conversation(self) -> None: + """Pop the next conversation from the pending queue and issue its first turn.""" + conv_id, turns = self._pending_convs.popleft() + self._active_iters[conv_id] = iter(turns) + self._issue_next_turn(conv_id) + + def _issue_next_turn(self, conv_id: str) -> None: + """Issue the next turn for conv_id, or mark the conversation done.""" + it = self._active_iters.get(conv_id) + if it is None: + return + + pair = next(it, None) + if pair is None: + del self._active_iters[conv_id] + self._fill_slot() + return + + idx, turn = pair + state = self._conv_states[conv_id] + + data_override: dict[str, Any] | None = None + current_turn_messages: list[dict[str, Any]] | None = None + if self._store_in_history: + current_turn_messages = self._dataset_metadata.get( + "current_turn_messages_by_key", {} + ).get((conv_id, turn)) + if current_turn_messages: + has_tool_msg = any( + m.get("role") == "tool" for m in current_turn_messages + ) + if has_tool_msg: + logger.warning( + "Live-history mode with tool messages uses dataset " + "tool_call_ids; real endpoint IDs will differ " + "(conv=%s, turn=%d)", + conv_id, + turn, + ) + live_messages = state.message_history.copy() + current_turn_messages + data_override = {"messages": live_messages} + + assert self._phase_issuer is not None + query_id = self._phase_issuer.issue(idx, data_override=data_override) + if query_id is None: + # Session stopping — signal done. + assert self._all_done is not None + self._all_done.set() + return + + self._inflight[query_id] = conv_id + + if self._store_in_history and current_turn_messages: + state.message_history.extend(current_turn_messages) + + assert self._loop is not None + handle = self._loop.call_later( + self._turn_timeout_s, self._handle_timeout, query_id, conv_id + ) + self._timeout_handles[query_id] = handle + + def _fill_slot(self) -> None: + """Start a new conversation from the pending queue, or signal all done.""" + if self._pending_convs: + self._start_conversation() + elif not self._active_iters: + assert self._all_done is not None + self._all_done.set() + + def _handle_timeout(self, query_id: str, conv_id: str) -> None: + """Called by the event loop when a turn response does not arrive in time.""" + if self._inflight.pop(query_id, None) is None: + return + self._timeout_handles.pop(query_id, None) + + logger.warning( + "Turn timed out for conversation %s (query=%s)", conv_id, query_id + ) + + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + it = self._active_iters.pop(conv_id, None) + if it is not None: + for _ in it: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + + self._fill_slot() + + def on_query_complete(self, query_id: str) -> None: + """No-op. Required by LoadStrategy protocol; called by BenchmarkSession.""" + pass + + def on_sample_complete(self, result: QueryResult) -> None: + """Route completed QueryResult to ConversationManager and issue next turn. + + Called synchronously from BenchmarkSession._handle_response(). Issues the + next turn immediately (zero event-loop delay) or starts a new conversation + when this one finishes all turns. + + Args: + result: Completed QueryResult from the endpoint. + """ + conv_id = self._inflight.pop(result.id, None) + if conv_id is None: + return + + handle = self._timeout_handles.pop(result.id, None) + if handle is not None: + handle.cancel() + + response_text = result.get_response_output_string() + + try: + if result.error is not None: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + else: + self._conv_manager.mark_turn_complete( + conv_id, + response_text, + store_in_history=self._store_in_history, + metadata=result.metadata, + ) + except KeyError: + logger.warning( + "on_sample_complete: conversation %s not found in manager (result=%s)", + conv_id, + result.id, + ) + return + + try: + self._issue_next_turn(conv_id) + except Exception as exc: + logger.error("Error issuing next turn for %s: %s", conv_id, exc) + self._error = exc + if self._all_done is not None: + self._all_done.set() diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 1c8ad992..2b1c39b5 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -26,9 +26,9 @@ import time import uuid from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..core.record import ( @@ -68,6 +68,7 @@ class PhaseConfig: runtime_settings: RuntimeSettings dataset: Dataset phase_type: PhaseType = PhaseType.PERFORMANCE + strategy: LoadStrategy | None = field(default=None, compare=False) # --------------------------------------------------------------------------- @@ -172,11 +173,19 @@ def __init__( self.inflight: int = 0 self.issued_count: int = 0 - def issue(self, sample_index: int) -> str | None: + def issue( + self, sample_index: int, data_override: dict[str, Any] | None = None + ) -> str | None: """Load data, build Query, publish ISSUED, send to endpoint. Returns query_id on success, None if session is stopping. + Args: + sample_index: Index into the dataset. + data_override: If provided, merged over the loaded sample data. + Keys in data_override take precedence. Used by MultiTurnStrategy + to substitute live-accumulated message history. + Note: load_sample() runs synchronously before the ISSUED timestamp. For accurate timing, datasets MUST be pre-loaded into memory. Disk-backed datasets will inflate timing and delay subsequent issues. @@ -185,14 +194,24 @@ def issue(self, sample_index: int) -> str | None: return None query_id = uuid.uuid4().hex data = self._dataset.load_sample(sample_index) + if data_override is not None: + data = {**data, **data_override} query = Query(id=query_id, data=data) self.uuid_to_index[query_id] = sample_index ts = time.monotonic_ns() prompt_data: PromptData if isinstance(data, dict): token_ids = data.get("input_tokens") or data.get("token_ids") + prompt_text = data.get("prompt") + if prompt_text is None and "messages" in data: + parts: list[str] = [ + m["content"] + for m in data["messages"] + if isinstance(m, dict) and m.get("content") + ] + prompt_text = "\n".join(parts) if parts else None prompt_data = PromptData( - text=data.get("prompt"), + text=prompt_text, token_ids=tuple(token_ids) if token_ids is not None else None, ) else: @@ -306,10 +325,13 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: phase_start = time.monotonic_ns() # Create per-phase state - sample_order = create_sample_order(phase.runtime_settings) - strategy = create_load_strategy( - phase.runtime_settings, self._loop, sample_order - ) + if phase.strategy is not None: + strategy = phase.strategy + else: + sample_order = create_sample_order(phase.runtime_settings) + strategy = create_load_strategy( + phase.runtime_settings, self._loop, sample_order + ) phase_issuer = PhaseIssuer( dataset=phase.dataset, issuer=self._issuer, diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index dd311f10..8ee13722 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -29,7 +29,7 @@ import logging from collections.abc import Callable, Iterator from time import monotonic_ns -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..config.schema import LoadPatternType @@ -47,8 +47,17 @@ class PhaseIssuerProtocol(Protocol): """Minimal interface that strategies see for issuing samples.""" - def issue(self, sample_index: int) -> str | None: - """Issue a sample. Returns query_id, or None if the session is stopping.""" + def issue( + self, sample_index: int, data_override: dict[str, Any] | None = None + ) -> str | None: + """Issue a sample. Returns query_id, or None if the session is stopping. + + Args: + sample_index: Index into the dataset. + data_override: If provided, use this as Query.data instead of + loading from the dataset. Used by MultiTurnStrategy for + live-history mode where the messages array is built at runtime. + """ ... issued_count: int @@ -297,5 +306,11 @@ def create_load_strategy( ) return ConcurrencyStrategy(lp.target_concurrency, sample_order) + case LoadPatternType.MULTI_TURN: + raise ValueError( + "MULTI_TURN load pattern requires a MultiTurnDataset — " + "use 'inference-endpoint benchmark from-config' with a multi-turn dataset" + ) + case _: raise ValueError(f"Unsupported load pattern type: {lp.type}") diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index 6cb23ed8..a01b7b44 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -15,15 +15,13 @@ """OpenAI SSE stream accumulator implementation.""" -import logging +from typing import Any from inference_endpoint.core.types import QueryResult, StreamChunk, TextModelOutput from inference_endpoint.endpoint_client.accumulator_protocol import ( SSEAccumulatorProtocol, ) -from inference_endpoint.openai.types import SSEDelta as OpenAISSEDelta - -logger = logging.getLogger(__name__) +from inference_endpoint.openai.types import SSEChoice class OpenAISSEAccumulator(SSEAccumulatorProtocol): @@ -32,15 +30,41 @@ class OpenAISSEAccumulator(SSEAccumulatorProtocol): def __init__(self, query_id: str, stream_all_chunks: bool): self.output_chunks: list[str] = [] self.reasoning_chunks: list[str] = [] + self._tool_calls: dict[int, dict[str, Any]] = {} + self._finish_reason: str | None = None self.first_chunk_sent = False self.query_id = query_id self.stream_all_chunks = stream_all_chunks - def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: - if not isinstance(delta, OpenAISSEDelta): + def add_chunk(self, choice: SSEChoice | None) -> StreamChunk | None: + if not isinstance(choice, SSEChoice): + return None + + if choice.finish_reason: + self._finish_reason = choice.finish_reason + + delta = choice.delta + if delta is None: return None + # Accumulate tool_calls partials (streamed as incremental JSON fragments) + if delta.tool_calls: + for partial in delta.tool_calls: + idx = partial.get("index", 0) + tc = self._tool_calls.setdefault( + idx, {"type": "function", "function": {"arguments": ""}} + ) + if partial.get("id"): + tc["id"] = partial["id"] + if partial.get("type"): + tc["type"] = partial["type"] + fn = partial.get("function") or {} + if fn.get("name"): + tc["function"]["name"] = fn["name"] + if fn.get("arguments"): + tc["function"]["arguments"] += fn["arguments"] + content = None if delta.content: self.output_chunks.append(delta.content) @@ -68,9 +92,6 @@ def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: def get_final_output(self) -> QueryResult: if self.reasoning_chunks: - # If there are reasoning chunks, then the first chunk received - # is the first reasoning chunk. The rest of the reasoning chunks, - # as well as the output chunks can be joined together. resp_reasoning: list[str] = [self.reasoning_chunks[0]] if len(self.reasoning_chunks) > 1: resp_reasoning.append("".join(self.reasoning_chunks[1:])) @@ -79,19 +100,26 @@ def get_final_output(self) -> QueryResult: reasoning=resp_reasoning, ) elif self.output_chunks: - # If there are only output chunks, the first chunk is used for - # TTFT calculations. The rest are joined together. resp_output: list[str] = [self.output_chunks[0]] if len(self.output_chunks) > 1: resp_output.append("".join(self.output_chunks[1:])) text_output = TextModelOutput(output=resp_output, reasoning=None) else: text_output = TextModelOutput(output=[], reasoning=None) + + metadata: dict[str, Any] = { + "first_chunk": not self.first_chunk_sent, + "final_chunk": True, + } + if self._finish_reason: + metadata["finish_reason"] = self._finish_reason + if self._tool_calls: + metadata["tool_calls"] = [ + self._tool_calls[i] for i in sorted(self._tool_calls) + ] + return QueryResult( id=self.query_id, response_output=text_output, - metadata={ - "first_chunk": not self.first_chunk_sent, - "final_chunk": True, - }, + metadata=metadata, ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 5834d6b0..a458688c 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -14,6 +14,7 @@ # limitations under the License. import time +from typing import Any import msgspec from inference_endpoint.core.types import Query, QueryResult, TextModelOutput @@ -36,7 +37,7 @@ Role6, ServiceTier, ) -from .types import SSEMessage +from .types import SSEChoice, SSEMessage class OpenAIAdapter(HttpRequestAdapter): @@ -75,10 +76,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return SSEChoice (delta + finish_reason).""" msg = msgspec.json.decode(json_bytes, type=SSEMessage) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -86,15 +89,21 @@ def decode_sse_message(cls, json_bytes: bytes) -> str: @classmethod def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: - """Convert a Query to an OpenAI request.""" - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] - if "system" in query.data: - messages.insert( - 0, {"role": Role3.system.value, "content": query.data["system"]} - ) + """Convert a Query to an OpenAI request. + + Supports both single-turn (prompt/system) and multi-turn (messages array) formats. + """ + if "messages" in query.data and isinstance(query.data["messages"], list): + messages = query.data["messages"] + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] + if "system" in query.data: + messages.insert( + 0, {"role": Role3.system.value, "content": query.data["system"]} + ) request = CreateChatCompletionRequest( model=ModelIdsShared(query.data.get("model", "no-model-name")), @@ -103,6 +112,7 @@ def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: stream=query.data.get("stream", False), max_completion_tokens=query.data.get("max_completion_tokens", 100), temperature=query.data.get("temperature", 0.7), + tools=query.data.get("tools"), ) return request @@ -119,9 +129,19 @@ def from_endpoint_response( if result_id is None: result_id = response.id + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason.value + if choice.message.tool_calls: + metadata["tool_calls"] = [ + tc.model_dump(mode="json") for tc in choice.message.tool_calls + ] + return QueryResult( id=result_id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput(output=choice.message.content), + metadata=metadata, ) @classmethod @@ -160,11 +180,11 @@ def decode_endpoint_response( response_dict = msgspec.json.decode(response_bytes) # Set default values for optional fields if missing - response_dict["choices"][0]["message"]["refusal"] = "None" + response_dict["choices"][0]["message"]["refusal"] = "" response_dict["choices"][0]["logprobs"] = {"content": [], "refusal": []} if ( "content" not in response_dict["choices"][0]["message"] or response_dict["choices"][0]["message"]["content"] is None ): - response_dict["choices"][0]["message"]["content"] = "None" + response_dict["choices"][0]["message"]["content"] = "" return CreateChatCompletionResponse(**response_dict, ignore_extra=True) diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index 6106e1bd..e512e22b 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -18,6 +18,7 @@ """ import time +from typing import Any import msgspec from inference_endpoint.config.schema import ModelParams, StreamingMode @@ -37,6 +38,7 @@ ChatCompletionResponse, ChatCompletionResponseMessage, ChatMessage, + SSEChoice, SSEMessage, ) @@ -45,6 +47,17 @@ # ============================================================================ +def _chat_message_from_dict(msg: dict) -> "ChatMessage": + """Build a ChatMessage from a dict, forwarding all supported fields.""" + return ChatMessage( + role=msg["role"], + content=msg.get("content"), + name=msg.get("name"), + tool_calls=msg.get("tool_calls"), + tool_call_id=msg.get("tool_call_id"), + ) + + class OpenAIMsgspecAdapter(HttpRequestAdapter): """OpenAI adapter using msgspec for serialization/deserialization.""" @@ -105,10 +118,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return the SSEChoice (delta + finish_reason).""" msg = cls._sse_decoder.decode(json_bytes) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -129,24 +144,31 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: Returns: msgspec.Struct ChatCompletionRequest """ - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [ - ChatMessage( - role="user", - content=query.data["prompt"], - name=query.data.get("name"), - ), - ] - if "system" in query.data: - messages.insert( - 0, + if "messages" in query.data and isinstance(query.data["messages"], list): + messages = [] + for message in query.data["messages"]: + if not isinstance(message, dict): + raise ValueError("messages entries must be dicts") + messages.append(_chat_message_from_dict(message)) + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [ ChatMessage( - role="system", - content=query.data["system"], + role="user", + content=query.data["prompt"], + name=query.data.get("name"), ), - ) + ] + if "system" in query.data: + messages.insert( + 0, + ChatMessage( + role="system", + content=query.data["system"], + ), + ) return ChatCompletionRequest( model=query.data.get("model", "no-model-name"), @@ -164,6 +186,7 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: logit_bias=query.data.get("logit_bias"), user=query.data.get("user"), chat_template=query.data.get("chat_template"), + tools=query.data.get("tools"), ) @classmethod @@ -184,9 +207,19 @@ def from_endpoint_response( if not response.choices: raise ValueError("Response must contain at least one choice") + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason + if choice.message.tool_calls: + metadata["tool_calls"] = choice.message.tool_calls + if choice.message.reasoning_content: + metadata["reasoning_content"] = choice.message.reasoning_content + return QueryResult( id=result_id or response.id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput(output=choice.message.content or ""), + metadata=metadata, ) @classmethod diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 036dd172..5296b2ee 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -46,6 +46,7 @@ class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc content: str = "" reasoning: str = "" + tool_calls: list[dict[str, Any]] | None = None class SSEChoice( @@ -75,12 +76,17 @@ class ChatMessage( ): # type: ignore[call-arg] """Chat message in OpenAI format. - content: str for text-only messages; list[dict] for multimodal (vision). + content: str for text-only messages; list[dict] for multimodal (vision); + None for tool-dispatching assistant messages. + tool_calls: list of tool call objects for assistant messages that invoke tools. + tool_call_id: correlates a tool result message to its tool call. """ role: str - content: ChatMessageContent + content: ChatMessageContent | None = None name: str | None = None + tool_calls: list[dict[str, Any]] | None = None + tool_call_id: str | None = None class ChatCompletionRequest( @@ -103,6 +109,7 @@ class ChatCompletionRequest( logit_bias: dict[str, float] | None = None user: str | None = None chat_template: str | None = None + tools: list[dict[str, Any]] | None = None class ChatCompletionResponseMessage( @@ -113,6 +120,8 @@ class ChatCompletionResponseMessage( role: str content: str | None refusal: str | None + tool_calls: list[dict[str, Any]] | None = None + reasoning_content: str | None = None class ChatCompletionChoice( diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py new file mode 100644 index 00000000..cfe8a68c --- /dev/null +++ b/tests/integration/test_multi_turn.py @@ -0,0 +1,684 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for multi-turn benchmarking end-to-end. + +Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work +correctly together against a real HTTP echo server. + +Tests cover: + 1. Dataset-history mode (use_dataset_history=True): pre-built messages are + issued as-is; each turn is issued sequentially per conversation. + 2. Live-history mode (use_dataset_history=False): messages are built at + runtime from ConversationManager.message_history; the injected messages + grow with each turn. + 3. Multiple concurrent conversations complete successfully. + 4. Turn ordering: turn N+1 is never issued before turn N completes. +""" + +import asyncio +import random +import time +from urllib.parse import urljoin + +import pandas as pd +import pytest +from inference_endpoint import metrics +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import ( + LoadPattern, + LoadPatternType, + MultiTurnConfig, +) +from inference_endpoint.core.record import EventRecord +from inference_endpoint.core.types import QueryResult +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset +from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient +from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy +from inference_endpoint.load_generator.session import ( + BenchmarkSession, + PhaseConfig, + PhaseType, +) +from inference_endpoint.testing.echo_server import EchoServer + + +class _NoOpPublisher: + def publish(self, event_record: EventRecord) -> None: + pass + + def flush(self) -> None: + pass + + +def _make_dataset(rows: list[dict]) -> MultiTurnDataset: + """Build a loaded MultiTurnDataset from a list of row dicts.""" + df = pd.DataFrame(rows) + ds = MultiTurnDataset(dataframe=df) + ds.load() + return ds + + +def _make_strategy( + ds: MultiTurnDataset, + use_dataset_history: bool = True, +) -> MultiTurnStrategy: + mt_cfg = MultiTurnConfig( + turn_timeout_s=10.0, + use_dataset_history=use_dataset_history, + ) + return MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + +async def _run_session( + server_url: str, + ds: MultiTurnDataset, + strategy: MultiTurnStrategy, + responses_out: dict, +) -> int: + """Wire up HTTPEndpointClient + BenchmarkSession and run one phase. + + Populates responses_out[query_id] = response_text for every completed turn. + Returns issued_count. + """ + loop = asyncio.get_running_loop() + + def on_complete(result: QueryResult) -> None: + strategy.on_sample_complete(result) + responses_out[result.id] = result.get_response_output_string() + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(server_url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=2, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig( + "perf", + rt, + ds, + PhaseType.PERFORMANCE, + strategy=strategy, + ) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + return result.perf_results[0].issued_count + finally: + await http_client.shutdown_async() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_server(): + server = EchoServer(port=0) + server.start() + try: + yield server + finally: + server.stop() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_single_conversation_all_turns_issued(echo_server): + """All turns of a single conversation are issued and completed.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Bye"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Two user turns (turns 1 and 3); turn 2 is assistant so not a client turn + assert count == 2 + assert len(responses) == 2 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multiple_conversations_all_issued(echo_server): + """Multiple conversations complete independently and concurrently.""" + rows = [] + for conv_idx in range(3): + conv_id = f"conv_{conv_idx}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2 {conv_idx}", + } + ) + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 3 conversations × 2 user turns each = 6 + assert count == 6 + assert len(responses) == 6 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_dataset_history_messages_present(echo_server): + """Dataset-history mode: each request contains the messages array from the dataset.""" + received_payloads: list[dict] = [] + + # Override get_response to capture the incoming request body. + # EchoServer._handle_echo_chat_completions_request parses it into + # CreateChatCompletionRequest — we capture the raw JSON at the HTTP layer + # by subclassing and overriding get_response (called with first user content). + # Instead, use a custom echo server that logs the full payload. + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "First question", + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "First answer", + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "user", + "content": "Second question", + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + # Both requests must include a "messages" array + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "messages" in payload + assert len(payload["messages"]) >= 1 + + # Turn 1 should have 1 user message; turn 3 should have 3 messages + # (system? no system here — user, assistant, user) + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_history_messages_grow_each_turn(echo_server): + """Live-history mode: messages array grows with each completed turn.""" + received_payloads: list[dict] = [] + + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Turn one"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Answer one", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Turn two"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=False) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + assert len(received_payloads) == 2 + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + # Turn 1: [user msg] = 1; Turn 3: [user, assistant, user] = 3 + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_turn_ordering_enforced_end_to_end(echo_server): + """Turn N+1 is issued after Turn N's response arrives, verified by timestamps.""" + complete_times: dict[str, float] = {} + + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "First"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Second"}, + ] + ds = _make_dataset(rows) + mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=True) + conv_manager = ConversationManager() + strategy = MultiTurnStrategy( + conversation_manager=conv_manager, + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + # Wrap on_sample_complete to record completion timestamps + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_on_sample_complete(result: QueryResult) -> None: + # Map query_id → sample_index via uuid_to_index (set after session runs) + complete_times[result.id] = time.monotonic() + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_on_sample_complete + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(echo_server.url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=1, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + try: + + def on_complete(result: QueryResult) -> None: + tracked_on_sample_complete(result) + responses[result.id] = result.get_response_output_string() + + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + finally: + await http_client.shutdown_async() + + assert result.perf_results[0].issued_count == 2 + + # Build query_id → sample_index from session result + uuid_to_index = result.perf_results[0].uuid_to_index + index_to_query = {v: k for k, v in uuid_to_index.items()} + + # Sample 0 = turn 1, sample 1 = turn 3 + q_turn1 = index_to_query[0] + q_turn3 = index_to_query[1] + + # Turn 3 must complete after turn 1 completes + assert complete_times[q_turn3] >= complete_times[q_turn1] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_tool_use_conversation_all_turns_issued(echo_server): + """Tool-use conversation: all client turns (user + tool) are issued and completed.""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "search result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ] + + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Find something", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Here is the result", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Thanks"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Client turns: turn 1 (user) + turn 3 (tool) + turn 5 (user) = 3 + assert count == 3 + assert len(responses) == 3 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_conversation_ending_with_tool_row(echo_server): + """Conversation ending with a tool row completes normally (matches agentic_coding dataset pattern).""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path": "out.py"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "file written"}] + tools = [ + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ] + + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Write a file", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Client turns: turn 1 (user) + turn 3 (tool) = 2 + assert count == 2 + assert len(responses) == 2 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_concurrent_conversations_stress(echo_server): + """12 conversations × 3 turns each complete with correct counts.""" + num_convs = 12 + turns_per_conv = 3 # 2 user turns + 1 assistant turn each + rows = [] + for i in range(num_convs): + conv_id = f"stress_conv_{i}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2-{i}", + } + ) + + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 12 conversations × 2 client turns each = 24 + expected_client_turns = num_convs * (turns_per_conv - 1) # 24 + assert count == expected_client_turns + assert len(responses) == expected_client_turns + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_tools_field_forwarded_to_endpoint(echo_server): + """The 'tools' array from the dataset reaches the endpoint in every request payload. + + TODO: Add a tool-call-aware server that returns dynamic tool_call_ids to + validate live-history mode with real tool_call_id round-tripping. + """ + received_payloads: list[dict] = [] + + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "hello"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ] + + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Search for hello", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + finally: + server.stop() diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index 73e3363f..64987b57 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -450,3 +450,123 @@ def test_explicit_adapter_override_at_construction_survives(self): client = HTTPClientConfig(api_type=APIType.OPENAI, adapter=OpenAIAdapter) assert client.adapter is OpenAIAdapter assert client.adapter is not OpenAIMsgspecAdapter + + +class TestMultiTurnValidation: + """Tests for multi-turn config validation and cross-validation.""" + + def _make_online_multi_turn(self, concurrency: int | None = 4, **ds_kwargs): + lp: dict = {"type": "multi_turn"} + if concurrency is not None: + lp["target_concurrency"] = concurrency + return { + "type": TestType.ONLINE, + "model_params": {"name": "M"}, + "endpoint_config": {"endpoints": ["http://x"]}, + "datasets": [{"path": "D", "multi_turn": {}, **ds_kwargs}], + "settings": {"load_pattern": lp}, + } + + @pytest.mark.unit + def test_multi_turn_valid_config(self): + config = BenchmarkConfig(**self._make_online_multi_turn(concurrency=16)) + from inference_endpoint.config.schema import LoadPatternType + + assert config.settings.load_pattern.type == LoadPatternType.MULTI_TURN + assert config.settings.load_pattern.target_concurrency == 16 + + @pytest.mark.unit + def test_multi_turn_requires_target_concurrency(self): + with pytest.raises(ValueError, match="Multi-turn requires --concurrency"): + BenchmarkConfig(**self._make_online_multi_turn(concurrency=None)) + + @pytest.mark.unit + def test_multi_turn_without_multi_turn_dataset_rejected(self): + with pytest.raises(ValueError, match="requires at least one dataset"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D"}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4} + }, + ) + + @pytest.mark.unit + def test_multi_turn_dataset_without_multi_turn_load_pattern_rejected(self): + with pytest.raises(ValueError, match="require load_pattern.type=multi_turn"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={"load_pattern": {"type": "poisson", "target_qps": 10}}, + ) + + +class TestMultiTurnTotalSamples: + """Tests for total_samples_to_issue() with multi_turn load pattern.""" + + @pytest.mark.unit + def test_multi_turn_uses_dataset_size_ignoring_duration(self): + from inference_endpoint.config.runtime_settings import RuntimeSettings + + config = BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4}, + "runtime": {"min_duration_ms": 600000}, + }, + ) + rt = RuntimeSettings.from_config(config, dataloader_num_samples=4316) + assert rt.total_samples_to_issue() == 4316 + + @pytest.mark.unit + def test_multi_turn_respects_min_sample_count(self): + import random + + from inference_endpoint import metrics + from inference_endpoint.config.runtime_settings import RuntimeSettings + from inference_endpoint.config.schema import LoadPattern, LoadPatternType + + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=5, + n_samples_to_issue=None, + min_sample_count=100, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 100 + + @pytest.mark.unit + def test_multi_turn_explicit_n_samples_takes_precedence(self): + import random + + from inference_endpoint import metrics + from inference_endpoint.config.runtime_settings import RuntimeSettings + from inference_endpoint.config.schema import LoadPattern, LoadPatternType + + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=4316, + n_samples_to_issue=200, + min_sample_count=1, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 200 diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index 52bdbe77..9c3bec15 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -891,3 +891,57 @@ def test_numeric_types_in_metadata(self): assert decoded.metadata["large_int"] == 9999999999999999 assert decoded.metadata["negative"] == -123.456 assert decoded.metadata["zero"] == 0 + + +@pytest.mark.unit +class TestQueryResultWithMetadata: + """Test QueryResult.with_metadata() method for metadata merging.""" + + def test_with_metadata_merge_behavior(self): + """Test that with_metadata adds new keys and overwrites existing ones.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "old_value", "key2": "keep_me"}, + ) + + updated = result.with_metadata({"key1": "new_value", "key3": "added"}) + + assert updated.metadata == { + "key1": "new_value", + "key2": "keep_me", + "key3": "added", + } + assert updated.id == "test" + assert updated.response_output == TextModelOutput(output="hello") + + def test_with_metadata_none_returns_self(self): + """Test that with_metadata(None) returns self unchanged.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "value"}, + ) + assert result.with_metadata(None) is result + + def test_with_metadata_empty_returns_self(self): + """Test that with_metadata({}) returns self unchanged.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "value"}, + ) + assert result.with_metadata({}) is result + + def test_query_metadata_field_roundtrips(self): + """Test that Query.metadata round-trips through msgspec encoding.""" + query = Query( + data={"prompt": "Hello"}, + metadata={"conversation_id": "conv-1", "turn": 2}, + ) + + encoded = msgspec.json.encode(query) + decoded = msgspec.json.decode(encoded, type=Query) + + assert decoded.metadata["conversation_id"] == "conv-1" + assert decoded.metadata["turn"] == 2 diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py new file mode 100644 index 00000000..62301940 --- /dev/null +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -0,0 +1,1425 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from collections.abc import Generator +from pathlib import Path + +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.dataset import DatasetFormat +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset + + +@pytest.fixture +def valid_multi_turn_jsonl() -> Generator[str, None, None]: + """Create valid multi-turn conversation JSONL data.""" + data = [ + { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "Hello, how are you?", + "system": "You are a helpful assistant", + }, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'm doing well, thank you!", + }, + { + "conversation_id": "conv_001", + "turn": 3, + "role": "user", + "content": "What can you help me with?", + }, + { + "conversation_id": "conv_002", + "turn": 1, + "role": "user", + "content": "What's the weather?", + }, + { + "conversation_id": "conv_002", + "turn": 2, + "role": "assistant", + "content": "I don't have access to real-time weather data.", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def invalid_role_sequence_jsonl() -> Generator[str, None, None]: + """Create JSONL with invalid role sequence (not alternating).""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "Hello"}, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "user", + "content": "Another user message", + }, # Invalid - consecutive user + { + "conversation_id": "conv_001", + "turn": 3, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def missing_fields_jsonl() -> Generator[str, None, None]: + """Create JSONL with missing required fields.""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user"}, # Missing content + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): + """Test loading valid multi-turn conversation data.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # data contains only client turns (3 user turns), not all rows + assert len(dataset.data) == 3 + + # Should have 3 user turns (samples) - only user turns are indexed + assert dataset.num_samples() == 3 + + +@pytest.mark.unit +def test_multi_turn_dataset_user_turn_indexing(valid_multi_turn_jsonl): + """Test that only client turns (user + tool) are stored as samples.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # data contains only client turns (fixture has only user turns) + assert dataset.num_samples() == 3 + + # Every sample in data is a client turn + for i in range(dataset.num_samples()): + assert dataset.load_sample(i)["role"] in ("user", "tool") + + +@pytest.mark.unit +def test_multi_turn_dataset_load_sample(valid_multi_turn_jsonl): + """Test load_sample returns correct user turns with dense indexing.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0 should be first user turn + sample_0 = dataset.load_sample(0) + assert sample_0["conversation_id"] == "conv_001" + assert sample_0["turn"] == 1 + assert sample_0["role"] == "user" + assert sample_0["content"] == "Hello, how are you?" + # System prompt is the first message in the messages array + assert sample_0["messages"][0]["role"] == "system" + assert sample_0["messages"][0]["content"] == "You are a helpful assistant" + + # Sample 1 should be second user turn (conv_001 turn 3) + sample_1 = dataset.load_sample(1) + assert sample_1["conversation_id"] == "conv_001" + assert sample_1["turn"] == 3 + assert sample_1["role"] == "user" + assert sample_1["content"] == "What can you help me with?" + + # Sample 2 should be third user turn (conv_002 turn 1) + sample_2 = dataset.load_sample(2) + assert sample_2["conversation_id"] == "conv_002" + assert sample_2["turn"] == 1 + assert sample_2["role"] == "user" + assert sample_2["content"] == "What's the weather?" + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_metadata(valid_multi_turn_jsonl): + """Test conversation metadata generation.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + metadata = dataset.conversation_metadata + + # Check metadata structure + assert "samples" in metadata + assert "num_conversations" in metadata + assert "max_turns_per_conv" in metadata + assert "client_turns_per_conversation" in metadata + + # Should have 3 client turn samples (fixture has only user turns, no tool turns) + assert len(metadata["samples"]) == 3 + + # Should have 2 conversations + assert metadata["num_conversations"] == 2 + + # Max turns per conversation should be 3 (conv_001 has 3 turns) + assert metadata["max_turns_per_conv"] == 3 + + # Check sample metadata structure + sample_meta = metadata["samples"][0] + assert "conversation_id" in sample_meta + assert "turn" in sample_meta + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_invalid_role_sequence( + invalid_role_sequence_jsonl, +): + """Test validation rejects invalid role sequences.""" + # Validation happens during load_from_file (in __init__), not during load() + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset.load_from_file( + invalid_role_sequence_jsonl, format=DatasetFormat.JSONL + ) + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_missing_fields(missing_fields_jsonl): + """Missing content field is preserved as None in the loaded sample.""" + dataset = MultiTurnDataset.load_from_file( + missing_fields_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + sample = dataset.load_sample(0) + # Missing content is no longer propagated to the sample dict + assert "content" not in sample + + +@pytest.mark.unit +def test_multi_turn_dataset_multiple_conversations(): + """Test dataset with multiple conversations of varying lengths.""" + data = [ + # Conversation 1: 3 turns (user-assistant-user, missing final assistant) + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg1b"}, + # Conversation 2: 4 turns (complete user-assistant alternation) + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + {"conversation_id": "c2", "turn": 3, "role": "user", "content": "msg3"}, + {"conversation_id": "c2", "turn": 4, "role": "assistant", "content": "resp3"}, + # Conversation 3: 2 turns (complete user-assistant) + {"conversation_id": "c3", "turn": 1, "role": "user", "content": "msg4"}, + {"conversation_id": "c3", "turn": 2, "role": "assistant", "content": "resp4"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # data contains only client turns: 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) + assert len(dataset.data) == 5 + assert dataset.num_samples() == 5 + + # Metadata checks + metadata = dataset.conversation_metadata + assert metadata["num_conversations"] == 3 + assert metadata["max_turns_per_conv"] == 4 # c2 has 4 turns + + # Verify user turns are correctly indexed + samples = [dataset.load_sample(i) for i in range(5)] + + # Check we got all the user turns + user_turns = [(s["conversation_id"], s["turn"]) for s in samples] + expected_turns = [("c1", 1), ("c1", 3), ("c2", 1), ("c2", 3), ("c3", 1)] + assert sorted(user_turns) == sorted(expected_turns) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): + """Test system prompt is included as the first message in the messages array. + + The system prompt is pre-baked into every client turn's message list so the + conversation manager no longer needs to track it separately. + """ + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # First sample: messages starts with system message + sample_0 = dataset.load_sample(0) + assert "messages" in sample_0 + msgs = sample_0["messages"] + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "You are a helpful assistant" + + # Second sample (same conversation, turn 3): system message still first + sample_1 = dataset.load_sample(1) + msgs_1 = sample_1["messages"] + assert msgs_1[0]["role"] == "system" + assert msgs_1[0]["content"] == "You are a helpful assistant" + + +@pytest.mark.unit +def test_multi_turn_dataset_single_turn_conversations(): + """Test conversations with only one turn.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Single turn"}, + # No assistant response + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "Another single", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 2 rows, 2 user turns + assert len(dataset.data) == 2 + assert dataset.num_samples() == 2 + + # Both samples should be user turns + assert dataset.load_sample(0)["role"] == "user" + assert dataset.load_sample(1)["role"] == "user" + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_empty_conversation(): + """Empty JSONL file raises ValueError (no columns to validate against).""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + temp_path = f.name + + try: + with pytest.raises(ValueError): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_grouping(): + """Test that properly grouped conversations load correctly.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "c1t3"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # data contains only client turns: 3 user turns (c1t1, c1t3, c2t1) + assert len(dataset.data) == 3 + assert dataset.num_samples() == 3 + + # Load samples to verify conversation grouping + samples = [dataset.load_sample(i) for i in range(3)] + + # Verify conversation IDs + conv_ids = [s["conversation_id"] for s in samples] + assert conv_ids == ["c1", "c1", "c2"] + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_interleaved_conversations_rejected(): + """Test that interleaved conversation rows raise InputValidationError.""" + from inference_endpoint.exceptions import InputValidationError + + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(InputValidationError, match="not consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +@pytest.mark.parametrize( + "rows", + [ + # assistant-first + [ + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + # consecutive assistants + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "assistant", "content": "C"}, + ], + # tool directly after user (tool-before-assistant) + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "tool", + "tool_results": [{"tool_call_id": "x", "content": "r"}], + }, + ], + # consecutive users + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + ], +) +def test_validation_rejects_invalid_role_sequence(rows): + """Invalid role sequences raise ValueError regardless of turn numbering.""" + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_multi_turn_dataset_additional_fields(): + """Test that additional fields (model, max_new_tokens, etc.) are preserved.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + "model": "gpt-4", + "max_new_tokens": 256, + "temperature": 0.7, + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + assert sample["model"] == "gpt-4" + assert sample["max_completion_tokens"] == 256 + assert sample["temperature"] == pytest.approx(0.7) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_openai_field_forwarding(): + """Test that OpenAI-specific fields are preserved and forwarded.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + # OpenAI fields that should be forwarded + "n": 3, + "name": "Alice", + "user": "user_12345", + "logit_bias": {"50256": -100}, + "chat_template": "custom_template", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # Verify OpenAI fields are present + assert sample.get("n") == 3 + assert sample.get("name") == "Alice" + assert sample.get("user") == "user_12345" + assert sample.get("logit_bias") == {"50256": -100} + assert sample.get("chat_template") == "custom_template" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_all_generation_params(): + """Test that dataset-supplied generation parameters are forwarded to the sample.""" + # Create dataset with a representative set of generation params + row_params = { + "model": "test-model", + "max_completion_tokens": 100, + "stream": True, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "seed": 42, + "repetition_penalty": 1.1, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "stop": ["END"], + "n": 2, + "logit_bias": {"100": 10}, + "name": "TestEntity", + "user": "test_user_001", + "chat_template": "test_template", + } + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Test", + **row_params, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # All non-NaN row fields must appear in the pre-baked sample + for param in row_params: + assert param in sample, f"Parameter '{param}' not forwarded to sample" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_non_contiguous_turns(): + """Turn numbers must be consecutive; gaps are rejected.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "a"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "b"}, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "c"}, + {"conversation_id": "c1", "turn": 6, "role": "assistant", "content": "d"}, + ] + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_validation_rejects_turns_not_starting_at_one(): + """Validation should reject conversations whose turns don't start at 1.""" + data = [ + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_accepts_valid_contiguous_turns(): + """Validation should accept contiguous turn sequences.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg2"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + assert dataset.num_samples() == 2 + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_turn_starting_at_zero(): + """Validation should reject conversations starting at turn 0.""" + data = [ + {"conversation_id": "c1", "turn": 0, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_duplicate_turn_numbers(): + """Duplicate turn numbers within a conversation are rejected.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + # c2 has duplicate turn 2 — second assistant row with same turn number + {"conversation_id": "c2", "turn": 2, "role": "user", "content": "dup"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_assistant_tc_role_literal(): + """role='assistant_tc' literal in dataset is rejected; only 'assistant' is valid.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Q"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant_tc", + "tool_calls": [ + { + "id": "c0", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "c0", "content": "r"}], + }, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "A"}, + ] + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +# ============================================================================ +# Tool sequence tests +# ============================================================================ + + +def _make_tool_sequence_df(): + """Return a DataFrame with a tool sequence embedded between user turns.""" + return pd.DataFrame( + [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "What is the weather?", + "system": "Be helpful", + }, + # assistant (with tool_calls): dispatches a tool call + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_c1_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + # tool result + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "call_c1_0", "content": '{"temp": 22}'} + ], + }, + # terminal assistant + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The weather is 22°C.", + }, + # second user turn + { + "conversation_id": "c1", + "turn": 5, + "role": "user", + "content": "Thanks!", + }, + ] + ) + + +@pytest.mark.unit +def test_validation_accepts_tool_sequence(): + """user → assistant → tool → assistant → user passes validation.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 3 # user(1), tool(3), user(5) are all client turns + + +@pytest.mark.unit +def test_validation_accepts_parallel_tool_calls(): + """Assistant with two tool_calls + merged tool_results row passes.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 2 # user(1), tool(3) are client turns + + +@pytest.mark.unit +def test_load_sample_merged_tool_row_has_no_content_key(): + """load_sample for a merged tool_results row must not emit content: NaN.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the merged tool row (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "content" not in s1 # must NOT emit NaN + assert "messages" in s1 + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages(): + """pre_built_messages_by_key contains complete message arrays for each client turn. + + Dataset: + turn 1: user ← client turn 1 + turn 2: asst_tc ← scripted (assistant with tool_calls) + turn 3: tool ← client turn 2 + turn 4: assistant ← terminal assistant + turn 5: user ← client turn 3 + + Expected pre_built_messages: + client turn 1 (t=1): [system, user(1)] + client turn 2 (t=3): [system, user(1), asst_tc(2), tool(3)] + client turn 3 (t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + """ + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 1 (user, t=1): [system, user(1)] + msgs_t1 = pbm[("c1", 1)] + assert len(msgs_t1) == 2 + assert msgs_t1[0] == {"role": "system", "content": "Be helpful"} + assert msgs_t1[1] == {"role": "user", "content": "What is the weather?"} + + # Client turn 2 (tool, t=3): [system, user(1), asst_tc(2), tool(3)] + msgs_t3 = pbm[("c1", 3)] + assert len(msgs_t3) == 4 + assert msgs_t3[0]["role"] == "system" + assert msgs_t3[1]["role"] == "user" + assert msgs_t3[2]["role"] == "assistant" + assert "tool_calls" in msgs_t3[2] + assert msgs_t3[3]["role"] == "tool" + assert msgs_t3[3]["content"] == '{"temp": 22}' + assert msgs_t3[3]["tool_call_id"] == "call_c1_0" + + # Client turn 3 (user, t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + msgs_t5 = pbm[("c1", 5)] + assert len(msgs_t5) == 6 + assert msgs_t5[4] == {"role": "assistant", "content": "The weather is 22°C."} + assert msgs_t5[5] == {"role": "user", "content": "Thanks!"} + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages_no_tools(): + """Plain user/assistant alternation produces correct pre_built_messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "C"}, + ] + ) + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Turn 1: just the user message (no system, no prior rows) + assert pbm[("c1", 1)] == [{"role": "user", "content": "A"}] + + # Turn 3: user(1) + assistant(2) + user(3) + msgs = pbm[("c1", 3)] + assert len(msgs) == 3 + assert msgs[0] == {"role": "user", "content": "A"} + assert msgs[1] == {"role": "assistant", "content": "B"} + assert msgs[2] == {"role": "user", "content": "C"} + + +@pytest.mark.unit +def test_load_sample_includes_messages(): + """load_sample returns messages with the complete message list.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) # user turn 1 + assert "messages" in s0 + msgs = s0["messages"] + assert msgs[0]["role"] == "system" + assert msgs[-1] == {"role": "user", "content": "What is the weather?"} + + s1 = ds.load_sample(1) # tool turn 3 + assert s1["role"] == "tool" + msgs_t3 = s1["messages"] + # system + user(1) + asst_tc(2) + tool(3) = 4 messages + assert len(msgs_t3) == 4 + assert msgs_t3[-1]["role"] == "tool" + + s2 = ds.load_sample(2) # user turn 5 + msgs_t5 = s2["messages"] + # system + user(1) + asst_tc(2) + tool(3) + asst(4) + user(5) = 6 messages + assert len(msgs_t5) == 6 + + +@pytest.mark.unit +def test_client_turns_include_tool_rows(): + """Tool rows are counted in num_samples() as client turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + # 5 rows total: user(1), assistant(2), tool(3), assistant(4), user(5) + # Client turns: user(1), tool(3), user(5) → 3 + assert ds.num_samples() == 3 + + +# ============================================================================ +# Pre-built messages content correctness +# ============================================================================ + + +@pytest.mark.unit +def test_messages_include_prior_assistant_response(valid_multi_turn_jsonl): + """The terminal assistant response before each user turn is included in messages.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0: turn 1 (first user) → just [system, user(1)] + s0 = dataset.load_sample(0) + msgs_0 = s0["messages"] + assert msgs_0[0]["role"] == "system" + assert msgs_0[-1]["role"] == "user" + + # Sample 1: turn 3 (second user) → [system, user(1), assistant(2), user(3)] + s1 = dataset.load_sample(1) + msgs_1 = s1["messages"] + assert len(msgs_1) == 4 + assert msgs_1[2] == {"role": "assistant", "content": "I'm doing well, thank you!"} + assert msgs_1[3]["role"] == "user" + + # Sample 2: turn 1 of conv_002 → no prior assistant row + s2 = dataset.load_sample(2) + msgs_2 = s2["messages"] + assert all(m["role"] != "assistant" for m in msgs_2) + + +@pytest.mark.unit +def test_messages_no_cross_conversation_bleed(): + """Messages for conv_001 must not appear in conv_002's messages array.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1 user"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2 user"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2 resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # c1: only its own user message + s_c1 = dataset.load_sample(0) + assert s_c1["messages"] == [{"role": "user", "content": "c1 user"}] + + # c2: only c2 messages (no c1 content) + s_c2 = dataset.load_sample(1) + contents = [m.get("content") for m in s_c2["messages"]] + assert "c1 user" not in contents + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_messages_with_tool_sequence_terminal_assistant(): + """Terminal assistant response (turn 4) appears in messages for user(5).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s2 = ds.load_sample(2) # user turn 5 + msgs = s2["messages"] + # The terminal assistant at turn 4 should be included + assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] + assert any(m["content"] == "The weather is 22°C." for m in assistant_msgs) + + +# ============================================================================ +# Tool-use flat dataset regression tests (BUG 1, BUG 2, BUG 3) +# ============================================================================ + + +@pytest.mark.unit +def test_prior_tool_row_expanded_with_tool_call_id(): + """Prior tool rows must expand to messages with tool_call_id and content (BUG 1).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 3 (user, t=5) has a prior tool row at t=3. + # msgs_t5[3] should be the expanded tool message with proper fields. + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["tool_call_id"] == "call_c1_0" + assert tool_msgs[0]["content"] == '{"temp": 22}' + + +@pytest.mark.unit +def test_prior_parallel_tool_results_expand_to_multiple_messages(): + """Prior turn with 2 parallel tool_results expands to 2 tool messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Ok"}, + ] + ) + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # user(5) sees prior rows: user(1), assistant(2), tool(3)x2, assistant(4) + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 2 + assert tool_msgs[0]["tool_call_id"] == "c_0" + assert tool_msgs[0]["content"] == "r1" + assert tool_msgs[1]["tool_call_id"] == "c_1" + assert tool_msgs[1]["content"] == "r2" + + +@pytest.mark.unit +def test_assistant_content_null_preserved_in_history(): + """Assistant messages with tool_calls and content:null include content key (BUG 2).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 2 (tool, t=3): prior includes assistant(2) with tool_calls + content: null + msgs_t3 = pbm[("c1", 3)] + asst_msg = msgs_t3[2] + assert asst_msg["role"] == "assistant" + assert "tool_calls" in asst_msg + assert "content" in asst_msg + assert asst_msg["content"] is None + + # Also verify in user(5)'s history + msgs_t5 = pbm[("c1", 5)] + asst_tc_msg = msgs_t5[2] + assert asst_tc_msg["role"] == "assistant" + assert "tool_calls" in asst_tc_msg + assert "content" in asst_tc_msg + assert asst_tc_msg["content"] is None + + +@pytest.mark.unit +def test_jsonl_round_trip_with_tools_field(): + """Load from JSONL tmpfile with tools field; verify tools survives to sample dict.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Run the test", + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc_0", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "tc_0", "content": "file1.py"}], + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The directory contains file1.py", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # user(1) has tools + s0 = dataset.load_sample(0) + assert "tools" in s0 + assert len(s0["tools"]) == 1 + assert s0["tools"][0]["function"]["name"] == "bash" + + # tool(3) also has tools + s1 = dataset.load_sample(1) + assert "tools" in s1 + assert s1["tools"][0]["function"]["name"] == "bash" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_current_turn_messages_by_key_parallel_tools(): + """current_turn_messages_by_key stores all expanded messages for a tool turn.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ctm = ds.conversation_metadata["current_turn_messages_by_key"] + + # user(1) current turn is 1 message + assert len(ctm[("c1", 1)]) == 1 + assert ctm[("c1", 1)][0] == {"role": "user", "content": "Go"} + + # tool(3) current turn has 2 expanded messages (parallel tool_results) + assert len(ctm[("c1", 3)]) == 2 + assert ctm[("c1", 3)][0]["tool_call_id"] == "c_0" + assert ctm[("c1", 3)][1]["tool_call_id"] == "c_1" + + +# ============================================================================ +# Fix 1: system_prompts_by_conv in metadata (live-history mode) +# ============================================================================ + + +@pytest.mark.unit +def test_metadata_contains_system_prompts_by_conv(): + """_build_metadata exposes system_prompts_by_conv keyed by conversation_id.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hi", + "system": "Be concise", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Ok"}, + # c2 has no system prompt + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hello"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + + spc = ds.conversation_metadata["system_prompts_by_conv"] + assert spc["c1"] == "Be concise" + assert spc["c2"] is None + + +@pytest.mark.unit +def test_metadata_system_prompts_multiple_convs(): + """Each conversation gets its own system prompt entry.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "A", + "system": "Sys1", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "C", + "system": "Sys2", + }, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "D"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + + spc = ds.conversation_metadata["system_prompts_by_conv"] + assert spc["c1"] == "Sys1" + assert spc["c2"] == "Sys2" + + +# ============================================================================ +# Fix 2: tool_results / tool_calls stripped from sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_tool_results_not_in_sample_dict(): + """tool_results must not appear in the pre-baked sample dict for tool turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the tool turn (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "tool_results" not in s1 + + +@pytest.mark.unit +def test_tool_calls_not_in_sample_dict(): + """tool_calls must not appear in sample dicts (only relevant on assistant rows).""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Go", + "tool_calls": [ + {"id": "bad", "type": "function", "function": {"name": "f"}} + ], + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Done"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) + assert "tool_calls" not in s0 + + +# ============================================================================ +# Fix 3: no dead current_turn_message / system_content fields in sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_no_dead_current_turn_message_field(): + """current_turn_message must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert ( + "current_turn_message" not in s + ), f"Sample {i} has dead field current_turn_message" + + +@pytest.mark.unit +def test_no_dead_system_content_field(): + """system_content must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert "system_content" not in s, f"Sample {i} has dead field system_content" diff --git a/tests/unit/dataset_manager/test_transforms.py b/tests/unit/dataset_manager/test_transforms.py index ab342204..5eca41b4 100644 --- a/tests/unit/dataset_manager/test_transforms.py +++ b/tests/unit/dataset_manager/test_transforms.py @@ -23,6 +23,7 @@ import pandas as pd import pytest from inference_endpoint.dataset_manager.transforms import ( + AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -824,4 +825,51 @@ def test_no_matching_columns(self): # Should not raise error or create prompt column assert "prompt" not in result.columns - assert "unrelated" in result.columns + + +class TestAddDefaultColumns: + """Unit tests for AddDefaultColumns transform.""" + + @pytest.mark.unit + def test_fills_missing_columns(self): + """New columns are added when absent.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddDefaultColumns({"b": 10, "c": "x"})(df) + assert list(result["b"]) == [10, 10] + assert list(result["c"]) == ["x", "x"] + + @pytest.mark.unit + def test_preserves_existing_non_null_values(self): + """Existing non-null values are not overwritten.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddDefaultColumns({"a": 99})(df) + assert list(result["a"]) == [1, 2] + + @pytest.mark.unit + def test_fills_nan_values_in_existing_column(self): + """NaN cells in an existing column are replaced with the default.""" + + df = pd.DataFrame({"a": [1.0, float("nan"), 3.0]}) + result = AddDefaultColumns({"a": 99})(df) + assert result["a"].tolist()[0] == 1.0 + assert result["a"].tolist()[1] == 99 + assert result["a"].tolist()[2] == 3.0 + + @pytest.mark.unit + def test_skips_none_default_values(self): + """A None default value is ignored; the column is not modified.""" + df = pd.DataFrame({"a": [1]}) + original_a = df["a"].copy() + result = AddDefaultColumns({"a": None, "b": None})(df) + assert list(result["a"]) == list(original_a) + assert "b" not in result.columns + + @pytest.mark.unit + def test_mixed_nan_and_real_values(self): + """Only NaN cells are filled; real values in the same column are preserved.""" + + df = pd.DataFrame({"temp": [0.9, float("nan"), 0.5]}) + result = AddDefaultColumns({"temp": 0.7})(df) + assert result["temp"].tolist()[0] == pytest.approx(0.9) + assert result["temp"].tolist()[1] == pytest.approx(0.7) + assert result["temp"].tolist()[2] == pytest.approx(0.5) diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py new file mode 100644 index 00000000..c389fb5f --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect + +import pytest +from inference_endpoint.load_generator.conversation_manager import ( + ConversationManager, + ConversationState, +) + + +@pytest.mark.unit +def test_conversation_state_initialization(): + """ConversationState initializes with correct defaults.""" + state = ConversationState(conversation_id="conv_001") + + assert state.conversation_id == "conv_001" + assert state.message_history == [] + assert state.completed_turns == 0 + assert state.failed_turns == 0 + assert state.expected_client_turns is None + + +@pytest.mark.unit +def test_conversation_state_is_complete_without_expected(): + """is_complete() returns False when expected_client_turns is None.""" + state = ConversationState(conversation_id="conv_001") + assert not state.is_complete() + state.completed_turns = 5 + assert not state.is_complete() + + +@pytest.mark.unit +def test_conversation_state_is_complete_with_expected(): + """is_complete() returns True once completed_turns >= expected.""" + state = ConversationState(conversation_id="conv_001", expected_client_turns=2) + assert not state.is_complete() + state.completed_turns = 1 + assert not state.is_complete() + state.completed_turns = 2 + assert state.is_complete() + + +@pytest.mark.unit +def test_create_is_synchronous(): + """get_or_create() must be a plain function, not a coroutine.""" + manager = ConversationManager() + result = manager.get_or_create("conv_001") + assert not inspect.iscoroutine(result), "get_or_create returned a coroutine" + assert isinstance(result, ConversationState) + + +@pytest.mark.unit +def test_conversation_manager_get_or_create(): + """get_or_create returns the same state for the same conversation_id.""" + manager = ConversationManager() + + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_001") + + assert state1 is state2 + assert state1.conversation_id == "conv_001" + + +@pytest.mark.unit +def test_conversation_manager_multiple_conversations(): + """Manager tracks multiple conversations independently.""" + manager = ConversationManager() + + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_002") + + assert state1 is not state2 + + manager.mark_turn_complete("conv_001", "Response to conv_001") + + assert state1.completed_turns == 1 + assert state2.completed_turns == 0 + + +@pytest.mark.unit +def test_conversation_manager_mark_turn_complete(): + """mark_turn_complete increments counter and appends history.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001") + + manager.mark_turn_complete("conv_001", "Assistant response") + + assert state.completed_turns == 1 + assert state.failed_turns == 0 + assert state.message_history == [] # store_in_history=False by default + + +@pytest.mark.unit +def test_conversation_manager_mark_turn_complete_stores_history(): + """mark_turn_complete appends to history when store_in_history=True.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001") + + manager.mark_turn_complete("conv_001", "Hello", store_in_history=True) + + assert state.message_history == [{"role": "assistant", "content": "Hello"}] + + +@pytest.mark.unit +def test_conversation_manager_mark_turn_failed(): + """mark_turn_failed increments both counters.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001", expected_client_turns=2) + + manager.mark_turn_failed("conv_001") + + assert state.completed_turns == 1 + assert state.failed_turns == 1 + + +@pytest.mark.unit +def test_conversation_completion_tracking(): + """is_complete() returns True after all expected turns receive responses.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001", expected_client_turns=2) + + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r1") + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r2") + assert state.is_complete() + + +@pytest.mark.unit +def test_conversation_completion_without_expected_turns(): + """Completion is never True when expected_client_turns is None.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001", expected_client_turns=None) + + manager.mark_turn_complete("conv_001", "r1") + + assert not state.is_complete() + + +@pytest.mark.unit +def test_conversation_completion_with_failures(): + """Conversations complete even when some turns fail.""" + manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=3) + + manager.mark_turn_complete("conv1", "Hi") + assert not state.is_complete() + + manager.mark_turn_failed("conv1") + assert not state.is_complete() + + manager.mark_turn_complete("conv1", "Bye") + assert state.is_complete() + assert state.failed_turns == 1 + assert state.completed_turns == 3 + + +@pytest.mark.unit +def test_all_turns_fail(): + """Conversation completes when all turns fail.""" + manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=2) + + manager.mark_turn_failed("conv1") + manager.mark_turn_failed("conv1") + + assert state.is_complete() + assert state.completed_turns == 2 + assert state.failed_turns == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_concurrent_access(): + """Concurrent pipeline tasks on independent conversations complete without errors.""" + manager = ConversationManager() + num_conversations = 10 + turns_per_conv = 5 + + for i in range(num_conversations): + manager.get_or_create(f"conv_{i:03d}", expected_client_turns=turns_per_conv) + + errors = [] + + async def process_conversation(conv_id: str): + try: + state = manager.get_state(conv_id) + assert state is not None + for _ in range(turns_per_conv): + manager.mark_turn_complete(conv_id, "response") + await asyncio.sleep(0.001) + except Exception as e: + errors.append(f"{conv_id} error: {e}") + + tasks = [ + asyncio.create_task(process_conversation(f"conv_{i:03d}")) + for i in range(num_conversations) + ] + await asyncio.gather(*tasks) + + assert not errors + for i in range(num_conversations): + state = manager._conversations[f"conv_{i:03d}"] + assert state.completed_turns == turns_per_conv + assert state.is_complete() diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py new file mode 100644 index 00000000..d3c9a22a --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -0,0 +1,567 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MultiTurnStrategy.""" + +import asyncio + +import pytest +from inference_endpoint.core.types import QueryResult, TextModelOutput +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy + + +class FakePhaseIssuer: + """Minimal PhaseIssuerProtocol stub.""" + + def __init__(self, stop_after: int | None = None): + self._count = 0 + self._stop_after = stop_after + self.issued: list[int] = [] + self.issued_count = 0 + + def issue(self, sample_index: int, data_override: dict | None = None) -> str | None: + if self._stop_after is not None and self._count >= self._stop_after: + return None + self._count += 1 + self.issued_count += 1 + query_id = f"q{sample_index:04d}" + self.issued.append(sample_index) + return query_id + + +def _make_dataset_metadata(conversations: dict[str, list[int]]) -> dict: + """Build dataset_metadata dict from {conv_id: [turn_numbers]} mapping.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + { + "conversation_id": conv_id, + "turn": turn, + "sample_index": sample_index, + } + ) + sample_index += 1 + return {"samples": samples} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_single_turn(): + """Single conversation, single turn — should issue exactly one sample.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def complete_turns(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response 1") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turns()) + count = await strategy.execute(issuer) + + assert count == 1 + assert issuer.issued == [0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_multi_turn(): + """Single conversation, 3 turns — turns must be issued sequentially.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3, 5]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + issued_order: list[str] = [] + original_issue = issuer.issue + + def tracked_issue(idx, data_override=None): + q = original_issue(idx, data_override=data_override) + if q: + issued_order.append(q) + return q + + issuer.issue = tracked_issue + + async def simulate_responses(): + await asyncio.sleep(0.01) + for turn_q, resp in [("q0000", "r1"), ("q0001", "r2"), ("q0002", "r3")]: + result = QueryResult( + id=turn_q, response_output=TextModelOutput(output=resp) + ) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 3 + assert issuer.issued == [0, 1, 2] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_multiple_conversations_concurrent(): + """Two conversations run concurrently, each with 2 turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3], "conv2": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def simulate_responses(): + await asyncio.sleep(0.02) + for q_prefix in range(4): + q = f"q{q_prefix:04d}" + result = QueryResult(id=q, response_output=TextModelOutput(output="resp")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 4 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_ordering_enforced(): + """Turn 2 must not be issued before Turn 1 completes.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + issue_timestamps: dict[int, float] = {} + complete_timestamps: dict[int, float] = {} + + class TimedIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue(self, idx: int, data_override: dict | None = None) -> str | None: + import time + + issue_timestamps[idx] = time.monotonic() + self.issued.append(idx) + self.issued_count += 1 + return f"q{idx:04d}" + + issuer = TimedIssuer() + + async def simulate_responses(): + import time + + await asyncio.sleep(0.02) + complete_timestamps[0] = time.monotonic() + result = QueryResult(id="q0000", response_output=TextModelOutput(output="r1")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.05) + complete_timestamps[1] = time.monotonic() + result = QueryResult(id="q0001", response_output=TextModelOutput(output="r2")) + strategy.on_sample_complete(result) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 2 + # Turn 2 (sample index 1) must be issued AFTER turn 1 completes + assert issue_timestamps[1] >= complete_timestamps[0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_timeout_triggers_failure(): + """A turn that never completes should timeout and abort remaining turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=None) + strategy._turn_timeout_s = 0.1 # Very short timeout for testing + issuer = FakePhaseIssuer() + + # Do NOT simulate any response — turn 1 will timeout + await strategy.execute(issuer) + + # Only turn 1 should be issued (turn 2 never gets to run) + assert issuer.issued_count == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_routes_to_manager(): + """on_sample_complete marks the turn complete in the ConversationManager.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + # Simulate issuer registering conv_id in _inflight + strategy._inflight["q0001"] = "conv1" + + result = QueryResult(id="q0001", response_output=TextModelOutput(output="hello")) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_turns == 1 + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_error_response_marks_turn_failed(): + """on_sample_complete marks failed when result.error is set.""" + from inference_endpoint.core.types import ErrorData + + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + strategy._inflight["q0001"] = "conv1" + + result = QueryResult( + id="q0001", + response_output=None, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.failed_turns == 1 + + +def _make_metadata_with_system( + conversations: dict[str, list[int]], + system_prompts: dict[str, str | None] | None = None, +) -> dict: + """Build metadata dict including system_prompts_by_conv.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + { + "conversation_id": conv_id, + "turn": turn, + "sample_index": sample_index, + } + ) + sample_index += 1 + return { + "samples": samples, + "system_prompts_by_conv": system_prompts or {}, + } + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_initializes_system_prompt(): + """In live-history mode, ConversationManager.message_history starts with system message.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Be helpful"}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history[0] must be the system message + assert len(state.message_history) >= 1 + assert state.message_history[0] == {"role": "system", "content": "Be helpful"} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_no_system_prompt_when_none(): + """In live-history mode, no system message is prepended when system_prompt is None.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": None}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # No system message should be in history + system_msgs = [m for m in state.message_history if m.get("role") == "system"] + assert len(system_msgs) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_dataset_history_mode_does_not_inject_system_prompt(): + """In dataset-history mode (use_dataset_history=True), system_message is not passed.""" + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Some system"}, + ) + # Default: use_dataset_history=True → _store_in_history=False + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history should be empty (dataset-history mode doesn't accumulate) + assert len(state.message_history) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_pipeline_error_propagated(): + """execute() re-raises when _issue_next_turn raises an exception.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + class ErrorIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue(self, idx: int, data_override: dict | None = None) -> str | None: + raise RuntimeError("simulated pipeline error") + + with pytest.raises(RuntimeError, match="simulated pipeline error"): + await strategy.execute(ErrorIssuer()) + + +@pytest.mark.unit +def test_mark_turn_complete_preserves_tool_calls(): + """mark_turn_complete stores tool_calls in history when metadata contains them.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 1 + msg = state.message_history[0] + assert msg["role"] == "assistant" + assert msg["content"] is None + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_with_response_and_tool_calls(): + """mark_turn_complete stores both content and tool_calls when both are present.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="Calling search...", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + msg = state.message_history[0] + assert msg["content"] == "Calling search..." + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_no_history_when_empty(): + """mark_turn_complete does not append when response is empty and no tool_calls.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + conv_manager.mark_turn_complete("conv1", response="", store_in_history=True) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_passes_metadata(): + """on_sample_complete forwards result.metadata (including tool_calls) to ConversationManager.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata_dict = _make_metadata_with_system({"conv1": [1]}) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata_dict, multi_turn_config=mt_cfg) + + conv_manager.get_or_create("conv1", expected_client_turns=1) + strategy._inflight["q0001"] = "conv1" + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": "{}"}, + } + ] + result = QueryResult( + id="q0001", + response_output=TextModelOutput(output=""), + metadata={"tool_calls": tool_calls}, + ) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_turns == 1 + assert len(state.message_history) == 1 + assert state.message_history[0]["tool_calls"] == tool_calls + assert state.message_history[0]["content"] is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_concurrency_limits_active_conversations(): + """target_concurrency=2 starts at most 2 conversation pipelines simultaneously. + + Uses 2-turn conversations so each pipeline has an await point between turns. + With 4 conversations and 2 workers, the 3rd and 4th conversations cannot start + until a worker finishes its current conversation. + """ + conv_manager = ConversationManager() + # 4 two-turn conversations; pipeline awaits turn-1 response before issuing turn-2 + metadata = _make_dataset_metadata( + {"conv1": [1, 2], "conv2": [1, 2], "conv3": [1, 2], "conv4": [1, 2]} + ) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=2) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + execute_task = asyncio.create_task(strategy.execute(issuer)) + + # Let both seed turns get issued before auto_respond fires + await asyncio.sleep(0.01) + + # Only 2 workers → exactly 2 turn-1 queries issued (conv3/conv4 not started yet) + assert issuer.issued_count == 2 + + await asyncio.wait_for(execute_task, timeout=5.0) + responder_task.cancel() + + assert issuer.issued_count == 8 # 4 conversations × 2 turns + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_slot_reuse(): + """With target_concurrency=1, worker completes conv1 before starting conv2. + + The single slot must process both turns of conv1 before conv2's turn 1 is issued. + """ + conv_manager = ConversationManager() + # 2 two-turn conversations; sample indices: conv1→[0,1], conv2→[2,3] + metadata = _make_dataset_metadata({"conv1": [1, 2], "conv2": [1, 2]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + await strategy.execute(issuer) + responder_task.cancel() + + # Single slot: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) + assert issuer.issued[:2] == [0, 1], "Conv1 turns should be issued before conv2" + assert issuer.issued[2:] == [2, 3], "Conv2 turns should follow conv1" diff --git a/tests/unit/openai/test_msgspec_adapter.py b/tests/unit/openai/test_msgspec_adapter.py new file mode 100644 index 00000000..8127d199 --- /dev/null +++ b/tests/unit/openai/test_msgspec_adapter.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIMsgspecAdapter with tool call fields.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_msgspec_adapter import ( + OpenAIMsgspecAdapter, + _chat_message_from_dict, +) +from inference_endpoint.openai.types import ChatMessage + + +@pytest.mark.unit +def test_chat_message_tool_calls_serialised(): + """tool_calls field is included in the JSON output when non-None.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ] + msg = ChatMessage(role="assistant", tool_calls=tool_calls) + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "assistant" + assert decoded["tool_calls"] == tool_calls + assert "content" not in decoded # omit_defaults=True, None omitted + + +@pytest.mark.unit +def test_chat_message_tool_call_id_serialised(): + """tool_call_id field is included in the JSON output when non-None.""" + msg = ChatMessage(role="tool", content="result", tool_call_id="call_0") + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "tool" + assert decoded["content"] == "result" + assert decoded["tool_call_id"] == "call_0" + + +@pytest.mark.unit +def test_to_endpoint_request_preserves_tool_calls(): + """to_endpoint_request forwards tool_calls in the messages array.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "test"}'}, + } + ] + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": None, "tool_calls": tool_calls}, + {"role": "tool", "content": "answer", "tool_call_id": "call_0"}, + {"role": "assistant", "content": "Done"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + }, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + msgs = payload["messages"] + # assistant tool-dispatch row + assert msgs[1]["role"] == "assistant" + assert msgs[1]["tool_calls"] == tool_calls + assert "content" not in msgs[1] + # tool result row + assert msgs[2]["role"] == "tool" + assert msgs[2]["tool_call_id"] == "call_0" + assert msgs[2]["content"] == "answer" + # terminal assistant row + assert msgs[3]["content"] == "Done" + + +@pytest.mark.unit +def test_backward_compat_plain_messages_unchanged(): + """Plain user/assistant messages encode identically to before the change.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + query = Query( + id="q2", + data={"model": "m", "messages": messages}, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + for i, msg in enumerate(payload["messages"]): + assert msg["role"] == messages[i]["role"] + assert msg["content"] == messages[i]["content"] + assert "tool_calls" not in msg + assert "tool_call_id" not in msg + + +@pytest.mark.unit +def test_chat_message_from_dict_all_fields(): + """_chat_message_from_dict forwards all four optional fields.""" + tool_calls = [ + {"id": "x", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ] + msg = _chat_message_from_dict( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + "tool_call_id": None, + } + ) + assert msg.role == "assistant" + assert msg.content is None + assert msg.tool_calls == tool_calls + assert msg.tool_call_id is None + + +@pytest.mark.unit +def test_chat_message_content_optional(): + """ChatMessage accepts content=None for tool-dispatching assistant turns.""" + msg = ChatMessage(role="assistant", tool_calls=[]) + assert msg.content is None diff --git a/tests/unit/openai/test_openai_adapter.py b/tests/unit/openai/test_openai_adapter.py new file mode 100644 index 00000000..506ec3fe --- /dev/null +++ b/tests/unit/openai/test_openai_adapter.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIAdapter tool serialization.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_adapter import OpenAIAdapter + +_TOOL_DEF = { + "type": "function", + "function": { + "name": "search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, +} + +_TOOL_CALLS = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } +] + + +@pytest.mark.unit +def test_tool_definitions_forwarded(): + """tools array in query.data is present in the encoded request.""" + messages = [ + {"role": "user", "content": "Find something"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + + +@pytest.mark.unit +def test_tool_use_messages_roundtrip(): + """Full tool-use message sequence encodes and decodes without data loss.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Find something"}, + {"role": "assistant", "content": None, "tool_calls": _TOOL_CALLS}, + {"role": "tool", "content": "search result", "tool_call_id": "call_1"}, + {"role": "assistant", "content": "Here is the answer"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + msgs = payload["messages"] + assert msgs[0]["role"] == "system" + assert msgs[1]["role"] == "user" + # assistant tool-dispatch: content is None (Pydantic model_dump includes None fields) + assert msgs[2]["role"] == "assistant" + assert msgs[2]["tool_calls"] == _TOOL_CALLS + assert msgs[2].get("content") is None + # tool result + assert msgs[3]["role"] == "tool" + assert msgs[3]["tool_call_id"] == "call_1" + assert msgs[3]["content"] == "search result" + # terminal assistant + assert msgs[4]["content"] == "Here is the answer" + + +@pytest.mark.unit +def test_encode_request_produces_valid_json_bytes(): + """encode_request returns bytes that msgspec can decode back.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q2", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + request = OpenAIAdapter.to_endpoint_request(query) + encoded = OpenAIAdapter.encode_request(request) + + assert isinstance(encoded, bytes) + decoded = msgspec.json.decode(encoded) + assert decoded["messages"][0]["role"] == "user" + + +@pytest.mark.unit +def test_no_tools_key_when_absent(): + """When query.data has no 'tools', the encoded payload has tools=None.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q3", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + # Pydantic model_dump includes None fields; tools must be None when not supplied + assert payload.get("tools") is None