From 132730212982c98fd0fd5bd4c07a39e66b5f3f02 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 23 Apr 2026 14:34:52 -0700 Subject: [PATCH 01/13] feat: add multi-turn dataset manager with flat JSONL support Add MultiTurnDataset, MultiTurnConfig schema, tool-calling types, Query.metadata transport field, adapter tools= kwarg, and multi-turn factory routing. --- docs/MULTI_TURN_QUICKSTART.md | 345 ++++++ examples/09_MultiTurn/README.md | 311 +++++ .../agentic_coding_benchmark.yaml | 33 + .../agentic_workflow_benchmark.yaml | 33 + .../customer_support_conversations.jsonl | 10 + .../09_MultiTurn/multi_turn_benchmark.yaml | 44 + .../multi_turn_with_concurrency.yaml | 44 + src/inference_endpoint/config/schema.py | 38 +- .../templates/concurrency_template.yaml | 2 +- .../templates/concurrency_template_full.yaml | 4 +- .../templates/offline_template_full.yaml | 4 +- .../config/templates/online_template.yaml | 2 +- .../templates/online_template_full.yaml | 4 +- src/inference_endpoint/core/types.py | 28 + .../dataset_manager/__init__.py | 2 + .../dataset_manager/factory.py | 9 +- .../dataset_manager/multi_turn_dataset.py | 428 +++++++ .../endpoint_client/adapter_protocol.py | 16 +- .../endpoint_client/http.py | 2 + .../endpoint_client/worker.py | 7 +- src/inference_endpoint/openai/accumulator.py | 58 +- .../openai/openai_adapter.py | 34 +- .../openai/openai_msgspec_adapter.py | 73 +- src/inference_endpoint/openai/types.py | 13 +- tests/unit/core/test_types.py | 54 + .../test_multi_turn_dataset.py | 1073 +++++++++++++++++ tests/unit/openai/test_msgspec_adapter.py | 146 +++ 27 files changed, 2749 insertions(+), 68 deletions(-) create mode 100644 docs/MULTI_TURN_QUICKSTART.md create mode 100644 examples/09_MultiTurn/README.md create mode 100644 examples/09_MultiTurn/agentic_coding_benchmark.yaml create mode 100644 examples/09_MultiTurn/agentic_workflow_benchmark.yaml create mode 100644 examples/09_MultiTurn/customer_support_conversations.jsonl create mode 100644 examples/09_MultiTurn/multi_turn_benchmark.yaml create mode 100644 examples/09_MultiTurn/multi_turn_with_concurrency.yaml create mode 100644 src/inference_endpoint/dataset_manager/multi_turn_dataset.py create mode 100644 tests/unit/dataset_manager/test_multi_turn_dataset.py create mode 100644 tests/unit/openai/test_msgspec_adapter.py diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md new file mode 100644 index 00000000..73ed6678 --- /dev/null +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -0,0 +1,345 @@ +# Multi-Turn Conversation Benchmarking - Quick Start Guide + +## πŸš€ Quick Start in 5 Minutes + +### 1. Prepare Your Dataset + +Create a JSONL file with your conversations. All rows for a given `conversation_id` must appear +**consecutively** in the file (no interleaving with other conversations): + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello!", "system": "You are a helpful assistant"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi! How can I help?"} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "What's 2+2?"} +{"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "2+2 equals 4."} +``` + +**Rules**: + +- Alternate between "user" and "assistant" roles +- Start with "user" role +- Sequential turn numbers (1, 2, 3, ...) +- Same `conversation_id` for all turns in a conversation +- All rows for the same `conversation_id` must be grouped together + +### 2. Create Configuration File + +Save as `multi_turn_config.yaml`: + +```yaml +name: "my-multi-turn-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: my_conversations + type: performance + path: path/to/your/conversations.jsonl + format: ".jsonl" + multi_turn: # ← Presence of this block enables multi-turn mode + mode: independent # ← Per-conv pipelines; no cross-conv turn barrier + turn_timeout_s: 300 # ← Max wait for prev turn + +settings: + load_pattern: + type: multi_turn # ← Use multi-turn scheduler + target_concurrency: 32 # ← OPTIONAL: limit concurrent requests + + client: + workers: 4 + +endpoint_config: + endpoints: + - "http://your-endpoint:8000" + api_type: openai + +report_dir: logs/my_multi_turn_benchmark +``` + +Results are written to `report_dir` (here: `logs/my_multi_turn_benchmark/`). + +### 3. Run Benchmark + +```bash +inference-endpoint benchmark from-config --config multi_turn_config.yaml +``` + +That's it! Your benchmark will now: + +- βœ… Enforce turn ordering (turn N+1 waits for turn N) +- βœ… Include conversation history in each request +- βœ… Track per-turn and per-conversation metrics +- βœ… Log all turns with conversation metadata + +--- + +## πŸ“Š Understanding Results + +After the benchmark completes, check the directory configured via `report_dir`: + +### Events Database + +The `events.db` SQLite database includes: + +- Standard fields: sample_uuid, event_type, timestamp_ns +- **New fields**: conversation_id, turn_number + +Query example: + +```sql +SELECT conversation_id, turn_number, event_type, timestamp_ns +FROM events +WHERE conversation_id = 'c1' +ORDER BY turn_number; +``` + +### Metrics + +Currently available: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each turn +- **Conversation tracking**: All events tagged with conversation_id + +_Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a future update._ + +--- + +## 🎯 Conversation Modes Explained + +### Independent Mode (Default) + +```yaml +mode: independent +``` + +**Behavior**: + +- Issues turn-1 of ALL conversations at t=0 +- Then sequences turns within each conversation independently +- Maximum parallelism and throughput + +**Use for**: Realistic production load where short conversations finish while long ones are still running. +For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. +Note: unlike the plain `ConcurrencyScheduler`, multi-turn + `target_concurrency: 1` still enforces +per-conversation turn ordering β€” turn N+1 waits for turn N even at concurrency 1. + +**Example timeline**: + +``` +t=0: conv1-turn1, conv2-turn1, conv3-turn1 (all at once) +t=0.5: conv1-turn2 (after conv1-turn1 completes) +t=0.7: conv2-turn2 (after conv2-turn1 completes) +t=0.8: conv1-turn3 (after conv1-turn2 completes) +... +``` + +--- + +## πŸŽ›οΈ Concurrency Control (NEW!) + +For benchmarks with **> 50 conversations**, use `target_concurrency` to prevent endpoint overload: + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Limit to 32 concurrent requests +``` + +**Why?** Without this, independent mode issues ALL turn-1s at once (could be 100+), overwhelming your endpoint. + +**Rule of thumb**: + +- Small (< 50 convs): No limit needed +- Medium (50-500 convs): `target_concurrency: 32` +- Large (500+ convs): `target_concurrency: 64` + +--- + +## πŸ”§ Common Configurations + +### Recommended: With Concurrency Control + +```yaml +multi_turn: + mode: independent + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Prevents overload + client: + workers: 8 + +datasets: + - samples: 100 +``` + +### High Throughput Testing + +```yaml +multi_turn: + mode: independent + turn_timeout_s: 600 + +settings: + client: + workers: 16 # More workers for parallel conversations +``` + +### Long Conversations + +```yaml +multi_turn: + mode: independent + turn_timeout_s: 1800 # 30 minutes for slow responses +``` + +--- + +## ❓ Troubleshooting + +### "Conversation has invalid role sequence" + +**Problem**: Your dataset doesn't alternate between user/assistant. + +**Fix**: Check your JSONL - must be: user, assistant, user, assistant, ... + +### "Rows for conversation X are not consecutive" + +**Problem**: Rows for the same `conversation_id` are interleaved with rows from other conversations. + +**Fix**: Sort your JSONL so all rows for each conversation appear together. + +### "Turn timed out waiting for prev turn" + +**Problem**: Previous turn took longer than `turn_timeout_s`. + +**Fixes**: + +1. Increase `turn_timeout_s` in config +2. Check if your endpoint is slow or unresponsive +3. Look for errors in the endpoint logs + +### Dataset not loading + +**Problem**: MultiTurnDataset not recognized. + +**Fix**: Ensure `format: ".jsonl"` is specified in config: + +```yaml +datasets: + - path: your_file.jsonl + format: ".jsonl" # ← Required for JSONL +``` + +--- + +## πŸ“ Example Datasets + +### Simple 2-Turn Conversation + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +``` + +### With System Prompt + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Who won?", "system": "You are a sports expert"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "The Lakers won."} +``` + +### Multiple Conversations + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +{"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hey"} +{"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "Hi there!"} +``` + +### With Model Override + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Summarize this", "model": "gpt-4"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Here's the summary..."} +``` + +--- + +## πŸ§ͺ Testing Your Setup + +### 1. Use the Example Dataset + +```bash +cd examples/multi_turn +inference-endpoint benchmark from-config --config multi_turn_benchmark.yaml +``` + +### 2. Check the Logs + +```bash +cat logs/multi_turn_test/benchmark.log +# Look for: "Turn X of conversation_id issued" +``` + +### 3. Verify Event Recording + +```bash +sqlite3 logs/multi_turn_test/events.db +sqlite> SELECT DISTINCT conversation_id FROM events; +# Should show your conversation IDs +``` + +--- + +## πŸ’‘ Tips & Best Practices + +### Dataset Design + +- **Keep conversations realistic**: 2-10 turns typical +- **Test edge cases**: 1-turn conversations, very long conversations +- **Include system prompts**: Helps model understand context + +### Performance + +- **Workers**: Set `workers` = number of concurrent conversations +- **Timeout**: Set `turn_timeout_s` = 2x your longest expected turn latency +- **Memory**: ~1KB per turn, plan accordingly for large datasets + +### Debugging + +- **Start small**: Test with 1-2 conversations first +- **Single conversation**: Use `mode: independent` with `target_concurrency: 1` +- **Check events.db**: Verify turn ordering in database + +--- + +## πŸ”— More Information + +- **Full Documentation**: See `examples/09_MultiTurn/README.md` +- **Architecture**: See `AGENTS.md` (Multi-Turn section) + +--- + +## βœ… Checklist + +Before running your first multi-turn benchmark: + +- [ ] Dataset follows format (alternating user/assistant roles) +- [ ] All rows for each conversation_id are grouped together +- [ ] Config has `multi_turn:` block in the dataset section +- [ ] Config has `load_pattern.type: multi_turn` +- [ ] Endpoint is running and reachable +- [ ] `format: ".jsonl"` specified for JSONL datasets +- [ ] Conversation IDs are unique per conversation +- [ ] Turn numbers are sequential (1, 2, 3, ...) + +Happy benchmarking! πŸš€ diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md new file mode 100644 index 00000000..0d3348c7 --- /dev/null +++ b/examples/09_MultiTurn/README.md @@ -0,0 +1,311 @@ +# Multi-Turn Conversation Benchmarking Examples + +This directory contains examples for benchmarking conversational AI workloads with multi-turn conversation support. + +## Overview + +Multi-turn conversation benchmarking enables testing realistic conversational AI scenarios where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing to simulate real-world multi-turn interactions. + +## Dataset Format + +Multi-turn datasets use JSONL format with the following structure: + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} +``` + +### Required Fields + +- `conversation_id`: Unique identifier for each conversation +- `turn`: Turn number within conversation (1-indexed) +- `role`: Speaker role ("user" or "assistant") +- `content`: Message content + +### Optional Fields + +- `system`: System prompt (typically only on first user turn) +- `model`: Model name override for this turn +- `max_new_tokens`: Maximum tokens to generate for this turn + +### Validation Rules + +1. All rows for a given `conversation_id` must appear **consecutively** in the file (no interleaving + with rows from other conversations). Turns within a conversation must be in order. + The flat-row format is intentional: it enables row-by-row streaming without loading entire + conversations into memory first. +2. Conversations must follow a valid role sequence: + - Plain chat: `user β†’ assistant β†’ user β†’ ...` + - Agentic: `user β†’ assistant (with tool_calls) β†’ tool β†’ [tool | assistant (with tool_calls)]* β†’ assistant β†’ user β†’ ...` +3. First turn must be "user" role +4. Turn numbers must be sequential (1, 2, 3, ...) +5. Each conversation must have at least one turn + +## Agentic (Tool-Sequence) Datasets + +For agentic workloads where the model dispatches tools, the dataset must include tool-call +metadata. The source format for these datasets is a **snapshot JSONL** β€” each line contains the +full conversation history at a particular checkpoint. The benchmarker requires **flat-row JSONL** +(one row per message), so a conversion step is needed first. + +### Source snapshot format + +Each line in the source file represents one snapshot of a conversation: + +```json +{ + "conversation_id": "sim_001", + "conversation_idx": 5, + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + {"role": "assistant", "tool_calls": [{"id": "...", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"ls\"}"}}]}, + {"role": "tool", "tool_call_id": "...", "content": "file1.txt\nfile2.txt"}, + {"role": "assistant", "content": "Done."} + ], + "tools": [...], + "metadata": {} +} +``` + +Multiple snapshots may exist per `conversation_id` (one per `conversation_idx`); only the +highest-indexed snapshot per conversation is used. + +### Converting to flat-row format + +The following commands convert each source snapshot file to the flat-row format required by the benchmarker. +Run from the repo root: + +```bash +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_coding_dataset.jsonl \ # input snapshot JSONL + examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ # output flat-row JSONL + --verify + +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_workflow_dataset.jsonl \ # input snapshot JSONL + examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ # output flat-row JSONL + --verify +``` + +The `--verify` flag cross-checks every client turn's message history against the source snapshot +and exits with code 1 if any mismatch is found. The script also: + +- Collapses consecutive `user` messages into one (keeps turn sequencing clean) +- Merges consecutive `tool` messages for the same assistant dispatch into a single row with a + `tool_results` list (so all parallel results are sent together in one API call) + +### Flat-row format after conversion + +The extra fields supported beyond plain user/assistant: + +| Row role | Extra fields | +| -------------------------------- | ------------------------------------------------------------------ | +| `assistant` with tool calls | `tool_calls: [{id, type, function: {name, arguments}}]` | +| `tool` single result | `tool_call_id: `, `content: ` | +| `tool` parallel results (merged) | `tool_results: [{tool_call_id, content}, ...]` | +| `user` or `tool` turns | `tools: [...]` (OpenAI tool definitions forwarded to the endpoint) | + +Example rows from a converted agentic dataset: + +```jsonl +{"conversation_id": "sim_001", "turn": 1, "role": "user", "content": "Fix the bug in foo.py", "system": "You are a coding agent.", "tools": [...]} +{"conversation_id": "sim_001", "turn": 2, "role": "assistant", "tool_calls": [{"id": "functions.bash:0", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"cat foo.py\"}"}}]} +{"conversation_id": "sim_001", "turn": 3, "role": "tool", "tool_call_id": "functions.bash:0", "content": "def foo():\n return 1/0", "tools": [...]} +{"conversation_id": "sim_001", "turn": 4, "role": "assistant", "content": "The bug is a ZeroDivisionError. Here is the fix: ..."} +``` + +### Running agentic benchmarks + +After converting the datasets, update the `path` field in the config files and run: + +```bash +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_coding_benchmark.yaml + +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_workflow_benchmark.yaml +``` + +--- + +## Configuration + +### Basic Configuration + +```yaml +datasets: + - name: customer_support + type: performance + path: examples/multi_turn/customer_support_conversations.jsonl + format: ".jsonl" + multi_turn: + mode: independent + turn_timeout_s: 300.0 + +settings: + load_pattern: + type: multi_turn +``` + +### Concurrency Control (Optional) + +The multi-turn scheduler supports **optional concurrency limiting** to control the maximum number of in-flight requests across all conversations: + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Limit to 32 concurrent requests +``` + +**Behavior**: + +- Without `target_concurrency`: Unlimited concurrency (all turn-1s issue at t=0 in INDEPENDENT mode) +- With `target_concurrency`: Limits total in-flight requests across all conversations +- Combines with turn sequencing: Turn N+1 still waits for turn N, AND waits for available slot + +**Use cases**: + +- 🎯 **Prevent endpoint overload**: Control request rate to busy endpoints +- 🎯 **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system +- 🎯 **Resource management**: Stay within port limits, memory constraints + +**Example**: 100 conversations with `target_concurrency: 32` + +``` +t=0: Issue first 32 turn-1s (concurrency limit reached) +t=0.5: Turn-1 completes β†’ issue next turn-1 (slot filled) +t=1.0: Turn-1 completes β†’ issue turn-2 of completed conv (slot filled) +... Maintains ~32 in-flight across all conversations +``` + +### Conversation Modes + +The default mode is `independent`. + +#### Independent Mode (Default) + +Issues turns for each conversation independently β€” no cross-conversation turn barrier. + +```yaml +multi_turn: + mode: independent + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 +``` + +**Use case**: Realistic production load where short conversations finish while long ones are +still running. Turn 1 of one conversation and turn 100 of another can be in-flight simultaneously. + +For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. + +### Turn Timeout + +Configure maximum wait time for previous turn completion: + +```yaml +multi_turn: + turn_timeout_s: 300.0 # 5 minutes +``` + +If a turn times out waiting for the previous turn, it will be skipped and logged as a warning. + +## Running Multi-Turn Benchmarks + +### Using Configuration File + +```bash +inference-endpoint benchmark from-config \ + --config examples/multi_turn/multi_turn_benchmark.yaml +``` + +### Viewing Results + +Multi-turn benchmarks produce both per-turn and per-conversation metrics: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn +- **Per-conversation metrics**: Total conversation latency, conversations per second + +Results are stored in the configured `report_dir` with conversation metadata included in the events database. + +## Example Datasets + +### customer_support_conversations.jsonl + +Simple customer support conversations demonstrating basic multi-turn interactions: + +- 3 conversations +- 2-4 turns per conversation +- Customer support agent system prompt + +## Architecture Notes + +### Key Components + +- **ConversationManager**: Tracks conversation state and message history +- **MultiTurnScheduler**: Enforces turn sequencing within conversations +- **ConversationSample**: Sample with conversation metadata +- **MultiTurnDataset**: Validates and structures multi-turn data + +### Turn Sequencing + +The system ensures that: + +1. Turn N+1 cannot be issued until turn N completes +2. Message history is included in subsequent requests +3. Concurrent conversations are supported (in independent mode) + +### Memory Considerations + +Each conversation maintains message history in memory. For large-scale benchmarks with long conversations: + +- Memory usage: ~1KB per turn (approximate) +- 1000 conversations Γ— 10 turns = ~10MB + +## Troubleshooting + +### "Conversation has invalid role sequence" + +**Cause**: Conversation doesn't follow a valid role sequence. + +**Fix**: For plain chat, ensure the dataset alternates between user and assistant: + +``` +user -> assistant -> user -> assistant -> ... +``` + +For agentic datasets, use the conversion script (`scripts/convert_agentic_snapshot.py`) to +produce a properly sequenced flat-row file. The valid agentic sequence is: + +``` +user -> assistant (tool_calls) -> tool -> [tool | assistant (tool_calls)]* -> assistant -> user -> ... +``` + +### "Turn timed out waiting for prev turn" + +**Cause**: Previous turn took longer than `turn_timeout_s` to complete. + +**Fixes**: + +- Increase `turn_timeout_s` in configuration +- Check endpoint performance +- Verify endpoint is responding + +### Single-turn benchmarks unaffected + +Multi-turn logic is only activated when a `multi_turn:` block is present in the dataset configuration. Existing single-turn benchmarks continue to work unchanged with zero performance overhead. + +## Future Enhancements + +Planned features: + +- [ ] Poisson conversation arrival mode implementation +- [ ] Per-conversation metrics in reporting +- [ ] Conversation-level latency percentiles +- [ ] Support for tool/function calls in conversations +- [ ] Dynamic conversation branching diff --git a/examples/09_MultiTurn/agentic_coding_benchmark.yaml b/examples/09_MultiTurn/agentic_coding_benchmark.yaml new file mode 100644 index 00000000..5a1036a7 --- /dev/null +++ b/examples/09_MultiTurn/agentic_coding_benchmark.yaml @@ -0,0 +1,33 @@ +name: "agentic-coding-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 1024 + +datasets: + - name: agentic_coding + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl --verify + path: examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl + format: ".jsonl" + multi_turn: + mode: independent + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 4096 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_coding diff --git a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml new file mode 100644 index 00000000..e8885465 --- /dev/null +++ b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml @@ -0,0 +1,33 @@ +name: "agentic-workflow-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 512 + +datasets: + - name: agentic_workflow + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl --verify + path: examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl + format: ".jsonl" + multi_turn: + mode: independent + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 96 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_workflow diff --git a/examples/09_MultiTurn/customer_support_conversations.jsonl b/examples/09_MultiTurn/customer_support_conversations.jsonl new file mode 100644 index 00000000..ac19e907 --- /dev/null +++ b/examples/09_MultiTurn/customer_support_conversations.jsonl @@ -0,0 +1,10 @@ +{"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "I need help resetting my password", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_001", "turn": 2, "role": "assistant", "content": "I'd be happy to help you reset your password. Can you provide your email address?"} +{"conversation_id": "conv_001", "turn": 3, "role": "user", "content": "It's user@example.com"} +{"conversation_id": "conv_001", "turn": 4, "role": "assistant", "content": "Thank you. I've sent a password reset link to user@example.com. Please check your inbox and follow the instructions."} +{"conversation_id": "conv_002", "turn": 1, "role": "user", "content": "What are your business hours?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_002", "turn": 2, "role": "assistant", "content": "We're open Monday-Friday, 9 AM to 5 PM EST. How can I assist you today?"} +{"conversation_id": "conv_002", "turn": 3, "role": "user", "content": "Do you offer weekend support?"} +{"conversation_id": "conv_002", "turn": 4, "role": "assistant", "content": "For urgent issues, we offer limited support on weekends from 10 AM to 2 PM EST. For non-urgent matters, please contact us during our regular business hours."} +{"conversation_id": "conv_003", "turn": 1, "role": "user", "content": "Can I cancel my subscription?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_003", "turn": 2, "role": "assistant", "content": "Yes, you can cancel your subscription at any time. Would you like me to guide you through the cancellation process?"} diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml new file mode 100644 index 00000000..9ed6c9f1 --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -0,0 +1,44 @@ +name: "multi-turn-customer-support" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + format: ".jsonl" + samples: 10 + multi_turn: + mode: independent + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + # target_concurrency: 32 # Optional: limit concurrent requests across all conversations + + client: + warmup_connections: 0 + +metrics: + collect: + - throughput + - latency + - ttft + - tpot + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_test diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml new file mode 100644 index 00000000..491e6b4b --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -0,0 +1,44 @@ +name: "multi-turn-with-concurrency-control" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + format: ".jsonl" + samples: 10 + multi_turn: + mode: independent # All conv turn-1 start together + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + target_concurrency: 32 # ← NEW: Limit to 32 concurrent requests + + client: + warmup_connections: 0 + +metrics: + collect: + - throughput + - latency + - ttft + - tpot + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_with_concurrency diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 6a1884b4..5b122c46 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -60,10 +60,17 @@ class LoadPatternType(str, Enum): MAX_THROUGHPUT = "max_throughput" # Offline: all queries at t=0 POISSON = "poisson" # Online: fixed QPS with Poisson distribution CONCURRENCY = "concurrency" # Online: fixed concurrent requests + MULTI_TURN = "multi_turn" # Multi-turn conversations with turn sequencing BURST = "burst" # Burst pattern (TODO) STEP = "step" # Step pattern (TODO) +class ConversationMode(str, Enum): + """Multi-turn conversation scheduling modes.""" + + INDEPENDENT = "independent" # Per-conv pipelines; no cross-conv turn barrier + + class OSLDistributionType(str, Enum): """Output Sequence Length distribution types.""" @@ -230,6 +237,26 @@ def get_ruleset_instance(self) -> BenchmarkSuiteRuleset: return get_ruleset(self.ruleset) +class MultiTurnConfig(BaseModel): + """Multi-turn conversation configuration. + + Configuration for benchmarking conversational AI workloads with turn sequencing. + Enables testing multi-turn conversations where each turn depends on previous responses. + Presence of this block in the dataset config enables multi-turn mode. + + Attributes: + mode: Conversation scheduling strategy (currently only independent). + turn_timeout_s: Maximum seconds to wait for previous turn completion. + use_dataset_history: If True, use pre-built message history from dataset. + """ + + model_config = {"extra": "forbid"} + + mode: ConversationMode = ConversationMode.INDEPENDENT + turn_timeout_s: float = 300.0 + use_dataset_history: bool = True + + class Dataset(BaseModel): """Dataset configuration. @@ -260,6 +287,9 @@ class Dataset(BaseModel): accuracy_config: AccuracyConfig | None = Field( None, description="Accuracy evaluation settings" ) + multi_turn: MultiTurnConfig | None = Field( + None, description="Multi-turn conversation configuration" + ) @model_validator(mode="after") def _auto_derive_name(self) -> Self: @@ -584,9 +614,13 @@ def _resolve_and_validate(self) -> Self: f"Offline benchmarks must use 'max_throughput', got '{lp.type}'" ) elif effective_mode == TestType.ONLINE: - if lp.type not in (LoadPatternType.POISSON, LoadPatternType.CONCURRENCY): + if lp.type not in ( + LoadPatternType.POISSON, + LoadPatternType.CONCURRENCY, + LoadPatternType.MULTI_TURN, + ): raise ValueError( - "Online mode requires --load-pattern (poisson or concurrency)" + "Online mode requires --load-pattern (poisson, concurrency, or multi_turn)" ) return self diff --git a/src/inference_endpoint/config/templates/concurrency_template.yaml b/src/inference_endpoint/config/templates/concurrency_template.yaml index 7b560ed7..c44295b4 100644 --- a/src/inference_endpoint/config/templates/concurrency_template.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_concurrency: 32 # Concurrent requests endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 3a8e004f..2c0c24ae 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: 32 # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index faabffde..72182198 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/online_template.yaml b/src/inference_endpoint/config/templates/online_template.yaml index d33c1fd5..a56dc9b0 100644 --- a/src/inference_endpoint/config/templates/online_template.yaml +++ b/src/inference_endpoint/config/templates/online_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index e9b7a673..b36c41a2 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index accd2ca8..dc7f4faf 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -226,6 +226,7 @@ class Query( Attributes: id: Unique identifier for this query (auto-generated UUID). data: Request payload as a dictionary (typically contains prompt, model, etc.). + metadata: Internal metadata that round-trips through transport (e.g., conversation_id). headers: HTTP headers to include in the request (e.g., authorization). created_at: Timestamp when query was created (seconds since epoch). @@ -249,6 +250,7 @@ class Query( id: str = msgspec.field(default_factory=lambda: str(uuid.uuid4())) data: dict[str, Any] = msgspec.field(default_factory=dict) + metadata: dict[str, Any] = msgspec.field(default_factory=dict) headers: dict[str, str] = msgspec.field(default_factory=dict) created_at: float = msgspec.field(default_factory=time.time) @@ -331,6 +333,32 @@ def get_response_output_string(self) -> str: else: return "" + def with_metadata( + self, additional_metadata: dict[str, Any] | None + ) -> "QueryResult": + """Return a new QueryResult with merged metadata. + + Args: + additional_metadata: Metadata to merge into existing metadata. + Values in additional_metadata override existing keys. + + Returns: + New QueryResult with merged metadata (existing + additional). + If additional_metadata is None or empty, returns self unchanged. + """ + if not additional_metadata: + return self + + merged = dict(self.metadata) + merged.update(additional_metadata) + + return QueryResult( + id=self.id, + response_output=self.response_output, + metadata=merged, + error=self.error, + ) + class StreamChunk( msgspec.Struct, diff --git a/src/inference_endpoint/dataset_manager/__init__.py b/src/inference_endpoint/dataset_manager/__init__.py index 4bb6c575..12938f8e 100644 --- a/src/inference_endpoint/dataset_manager/__init__.py +++ b/src/inference_endpoint/dataset_manager/__init__.py @@ -21,6 +21,7 @@ from .dataset import Dataset, EmptyDataset from .factory import DataLoaderFactory +from .multi_turn_dataset import MultiTurnDataset from .predefined.aime25 import AIME25 from .predefined.cnndailymail import CNNDailyMail from .predefined.gpqa import GPQA @@ -58,4 +59,5 @@ "CNNDailyMail", "RandomDataset", "ShopifyProductCatalogue", + "MultiTurnDataset", ] diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index 6ed1674a..8c1226c6 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -24,6 +24,7 @@ from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat +from .multi_turn_dataset import MultiTurnDataset from .transforms import ColumnRemap, MakeAdapterCompatible, Transform logger = logging.getLogger(__name__) @@ -95,18 +96,24 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data if file_format is not None: format_enum = DatasetFormat(file_format) + dataset_id = None + if config.multi_turn is not None: + dataset_id = MultiTurnDataset.DATASET_ID + transforms: list[Transform] = [] if remap is not None: # Parser convention is {target: source} (e.g. {prompt: article}). # ColumnRemap expects {source: target} β€” flip it. flipped = {src: dst for dst, src in remap.items()} transforms.append(ColumnRemap(flipped)) # type: ignore[arg-type] - transforms.append(MakeAdapterCompatible()) + if dataset_id != MultiTurnDataset.DATASET_ID: + transforms.append(MakeAdapterCompatible()) assert dataset_path is not None return Dataset.load_from_file( Path(dataset_path), transforms=transforms, format=format_enum, + dataset_id=dataset_id, num_repeats=num_repeats, ) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py new file mode 100644 index 00000000..bff8c011 --- /dev/null +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -0,0 +1,428 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-turn conversation dataset for conversational AI benchmarking.""" + +from typing import Any + +import pandas as pd + +from ..config.schema import APIType, ModelParams, StreamingMode +from ..exceptions import InputValidationError +from .dataset import Dataset +from .transforms import apply_transforms + +# Known generation parameter fields to forward from dataset to API requests. +# Aligned with OpenAI API specification and openai_msgspec_adapter.py implementation. +# These parameters work in both single-turn and multi-turn modes. +GENERATION_PARAMS = { + "model", + "max_new_tokens", + "max_completion_tokens", + "stream", + "temperature", + "top_p", + "top_k", + "seed", + "repetition_penalty", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "logit_bias", # Token probability adjustments + "name", # Entity name for role (NOT model name, e.g., 'Bob' for tracking) + "user", # End-user identifier for monitoring/abuse detection + "chat_template", # Custom chat formatting template + "tools", # OpenAI tool definitions (list[dict]) for tool-calling models +} + + +def _model_param_defaults(model_params: ModelParams | None) -> dict[str, Any]: + """Build per-request defaults for multi-turn rows from model params. + + Multi-turn datasets use `content` and conversation metadata rather than the + single-turn `prompt` field expected by adapter dataset transforms. Applying + those transforms would drop the conversation schema before load_sample() can + construct the messages array. Instead, we inject the request defaults here. + """ + if model_params is None: + return {} + + return { + "model": model_params.name, + "stream": model_params.streaming == StreamingMode.ON, + "max_completion_tokens": model_params.max_new_tokens, + "temperature": model_params.temperature, + "top_p": model_params.top_p, + "top_k": model_params.top_k, + "repetition_penalty": model_params.repetition_penalty, + } + + +def _expand_tool_results(row: dict) -> list[dict]: + """Expand a tool row into one OpenAI tool message per result. + + All ``role: "tool"`` rows carry a ``tool_results`` array. Each entry expands to + one OpenAI tool message with ``tool_call_id`` and ``content``. + + Returns an empty list if ``tool_results`` is absent or not a list (non-tool rows). + """ + tool_results = row.get("tool_results") + if not isinstance(tool_results, list): + return [] + return [ + { + "role": "tool", + "tool_call_id": result.get("tool_call_id"), + "content": result.get("content"), + } + for result in tool_results + ] + + +class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): + """Dataset for multi-turn conversations. + + Supports conversational AI benchmarking with turn sequencing and conversation history. + Validates that conversations have proper structure (alternating user/assistant roles) + and builds metadata for the scheduler to enforce turn ordering. + + Dataset format (JSONL): + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} + + Required columns: + - conversation_id: Unique identifier for each conversation + - turn: Turn number within conversation (1-indexed) + - role: Speaker role ("user" or "assistant") + - content: Message content + + Optional columns: + - system: System prompt associated with the conversation (typically set on the first user turn) + - model: Model name override + - max_new_tokens: Max tokens for this turn + + Attributes: + conversation_metadata: Metadata dict containing: + - samples: List of user turn metadata (index, conversation_id, turn, system) + - num_conversations: Total number of unique conversations + - max_turns_per_conv: Maximum turns in any conversation + """ + + COLUMN_NAMES = ["conversation_id", "turn", "role", "content"] + + def __init__(self, dataframe: pd.DataFrame, **kwargs): + """Initialize multi-turn dataset. + + Args: + dataframe: DataFrame with conversation data. + **kwargs: Additional arguments passed to Dataset.__init__. + + Raises: + ValueError: If conversation structure is invalid. + """ + super().__init__(dataframe, **kwargs) + self._validate_conversation_grouping() + self._validate_conversation_structure() + self._validate_turn_numbering() + self.conversation_metadata = self._build_metadata() + self._client_turn_indices: list[int] | None = None + + def _validate_conversation_grouping(self) -> None: + """Validate that all rows for each conversation_id appear consecutively in file order. + + Raises: + InputValidationError: If rows for a conversation_id are interleaved with other conversations. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + seen: set[str] = set() + last_conv: str | None = None + for row in self.dataframe.to_dict(orient="records"): + conv_id = str(row["conversation_id"]) + if conv_id != last_conv: + if conv_id in seen: + raise InputValidationError( + f"Rows for conversation '{conv_id}' are not consecutive. " + "All rows for a conversation must appear together in the file." + ) + seen.add(conv_id) + last_conv = conv_id + + def _validate_conversation_structure(self): + """Validate conversations are well-formed. + + Accepts plain user/assistant alternation as well as tool sequences: + user β†’ assistant β†’ tool β†’ [assistant β†’ tool]* β†’ assistant β†’ user + + Raises: + ValueError: If any conversation has invalid role sequence. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + + # Valid state transitions (flat 4-state machine β€” no assistant_tc node, + # no toolβ†’tool; converter always merges consecutive tool rows into tool_results) + VALID_NEXT: dict[str, set[str]] = { + "start": {"user"}, + "user": {"assistant"}, + "assistant": {"tool", "user"}, + "tool": {"assistant", "user"}, + } + + for conv_id, group in self.dataframe.groupby("conversation_id"): + sorted_group = group.sort_values("turn") + state = "start" + + for _, row in sorted_group.iterrows(): + role = row["role"] + + if role not in VALID_NEXT.get(state, set()): + raise ValueError( + f"Conversation {conv_id} has invalid role sequence at turn " + f"{row['turn']}: got '{role}' after state '{state}'" + ) + state = role + + def _validate_turn_numbering(self): + """Validate turn numbers are consecutive starting at 1. + + Raises: + ValueError: If turn numbers are not exactly 1, 2, 3, …, N. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + + for conv_id, group in self.dataframe.groupby("conversation_id"): + turns = sorted(group["turn"].tolist()) + expected = list(range(1, len(turns) + 1)) + if turns != expected: + raise ValueError( + f"Conversation {conv_id}: Turn numbers must be consecutive starting at 1, " + f"got {turns}" + ) + + def _build_metadata(self) -> dict[str, Any]: + """Build metadata for scheduler (maps sample index to conversation context). + + Pre-computes the complete message list for each client turn so that + conversation history does not need to be accumulated at runtime. + + Returns: + Metadata dict with samples list, num_conversations, max_turns_per_conv, + client_turns_per_conversation, and pre_built_messages_by_key. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + samples = [] + client_turns_df = self.dataframe[self.dataframe["role"].isin(["user", "tool"])] + + # Count client turns (user + tool) per conversation for completion tracking + client_turns_per_conv = ( + client_turns_df.groupby("conversation_id").size().to_dict() + ) + + # Map (conversation_id, turn) β†’ complete message list ready to send to endpoint. + # Each entry is: [system (optional)] + all prior rows formatted as messages + # + the current client turn message. + # This includes assistant rows (tool dispatches or terminal responses) + # so no runtime injection is required. + pre_built_messages_by_key: dict[tuple, list[dict]] = {} + + for conv_id, group in self.dataframe.groupby("conversation_id"): + sorted_group = group.sort_values("turn") + client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] + + # Extract system prompt from the first row that has it (typically turn 1) + system_content: str | None = None + for _, srow in sorted_group.iterrows(): + val = srow.get("system") + if val and isinstance(val, str): + system_content = val + break + + for idx, row in client_rows.iterrows(): + t_n = int(row["turn"]) + + messages: list[dict] = [] + if system_content: + messages.append({"role": "system", "content": system_content}) + + # All dataset rows strictly before this client turn (includes + # assistant rows and prior tool results). + prior_rows = sorted_group[sorted_group["turn"] < t_n] + for _, prior_row in prior_rows.iterrows(): + msg: dict[str, Any] = {} + for key in ("role", "content", "tool_calls"): + val = prior_row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + msg[key] = val + if msg.get("role"): + # Expand merged parallel tool results: a single row with + # tool_results: [{tool_call_id, content}, ...] expands into + # one OpenAI tool message per result entry. + expanded = _expand_tool_results(msg) + if expanded: + messages.extend(expanded) + else: + messages.append(msg) + + # Append the current client turn message. + # A merged parallel-tool row carries tool_results instead of a + # single tool_call_id/content pair; expand to one message per result. + expanded = _expand_tool_results(row) + if expanded: + messages.extend(expanded) + else: + cur: dict[str, Any] = {} + for key in ("role", "content"): + val = row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + cur[key] = val + messages.append(cur) + + pre_built_messages_by_key[(conv_id, t_n)] = messages + + samples.append( + { + "index": idx, + "conversation_id": conv_id, + "turn": t_n, + } + ) + + return { + "samples": samples, + "num_conversations": self.dataframe["conversation_id"].nunique(), + "max_turns_per_conv": self.dataframe.groupby("conversation_id")["turn"] + .max() + .max(), + "client_turns_per_conversation": client_turns_per_conv, + "pre_built_messages_by_key": pre_built_messages_by_key, + } + + def load( + self, + adapter=None, + api_type: APIType | None = None, + model_params: ModelParams | None = None, + force: bool = False, + ): + """Load dataset and build a dense user-turn index. + + Multi-turn benchmarks only issue user turns. Assistant turns remain in the + backing data so the conversation structure can still be validated. + + Unlike single-turn datasets, multi-turn rows do not have a `prompt` + column, so adapter dataset transforms are intentionally skipped here. + They would apply a single-turn ColumnFilter and strip the conversation + fields required by load_sample(). Request defaults from model_params are + merged directly into the conversation rows instead. + """ + if not force and self.data is not None: + self._client_turn_indices = [ + index + for index, row in enumerate(self.data) + if row["role"] in ("user", "tool") + ] + return + + df = self.dataframe + if df is None: + raise ValueError( + f"Cannot load dataset {self.__class__.__name__}: dataframe is None" + ) + + transforms = [] + if self.transforms is not None: + transforms.extend(self.transforms) + + if transforms: + df = apply_transforms(df, transforms) + + defaults = _model_param_defaults(model_params) + for key, value in defaults.items(): + if value is None: + continue + if key in df.columns: + df[key] = df[key].where(pd.notna(df[key]), value) + else: + df[key] = value + + self.data = df.to_dict(orient="records") + assert self.data is not None, "Failed to convert DataFrame to records" + + self._client_turn_indices = [ + index + for index, row in enumerate(self.data) + if row["role"] in ("user", "tool") + ] + + def load_sample(self, index: int) -> dict[str, Any]: + """Load the Nth client turn (user or tool) as a benchmark sample.""" + assert self.data is not None, "Dataset not loaded. Call load() first." + assert ( + self._client_turn_indices is not None + ), "Dataset not loaded. Call load() first." + row = self.data[self._client_turn_indices[index]] + + content_val = row.get("content") + sample: dict[str, Any] = { + "conversation_id": row["conversation_id"], + "turn": row["turn"], + "role": row["role"], + } + if content_val is not None and not ( + isinstance(content_val, float) and pd.isna(content_val) + ): + sample["content"] = content_val + + for param in GENERATION_PARAMS: + if param in row: + value = row[param] + # Skip pandas NaN/None values + if value is not None and ( + not isinstance(value, float) or not pd.isna(value) + ): + sample[param] = value + + # Set defaults for critical params if not present + if "max_new_tokens" not in sample and "max_completion_tokens" not in sample: + sample["max_new_tokens"] = 128 + if "stream" not in sample: + sample["stream"] = False + + # Attach pre-built message list (system + history + current turn). + key = (row["conversation_id"], int(row["turn"])) + pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}).get( + key, [] + ) + sample["pre_built_messages"] = pre_built + + # Fields for use_dataset_history=False path (live history accumulation). + sample["current_turn_message"] = pre_built[-1] if pre_built else {} + first = pre_built[0] if pre_built else {} + sample["system_content"] = ( + first.get("content") if first.get("role") == "system" else None + ) + + return sample + + def num_samples(self) -> int: + assert ( + self._client_turn_indices is not None + ), "Dataset not loaded. Call load() first." + return len(self._client_turn_indices) diff --git a/src/inference_endpoint/endpoint_client/adapter_protocol.py b/src/inference_endpoint/endpoint_client/adapter_protocol.py index c0239d56..feb590a4 100644 --- a/src/inference_endpoint/endpoint_client/adapter_protocol.py +++ b/src/inference_endpoint/endpoint_client/adapter_protocol.py @@ -19,7 +19,7 @@ import re from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from inference_endpoint.core.types import Query, QueryResult @@ -93,24 +93,24 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: @classmethod @abstractmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: + def decode_sse_message(cls, json_bytes: bytes) -> Any: """ - Decode SSE message and extract content string. + Decode SSE message and return adapter-specific chunk object. Args: json_bytes: Raw JSON bytes from SSE stream Returns: - Content string from the SSE message + Adapter-specific chunk object passed to accumulator.add_chunk() """ raise NotImplementedError("decode_sse_message not implemented") @classmethod - def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[str]: + def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: """ - Parse SSE chunk and extract all content strings. + Parse SSE chunk and extract all chunk objects. - Extracts JSON documents from SSE stream and decodes them to content strings. + Extracts JSON documents from SSE stream and decodes them to chunk objects. Silently ignores non-content SSE messages (role, finish_reason, etc). Args: @@ -118,7 +118,7 @@ def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[str]: end_pos: End position in buffer to parse up to Returns: - List of content strings extracted from the SSE chunk + List of chunk objects extracted from the SSE chunk """ json_docs = cls.SSE_DATA_PATTERN.findall(buffer[:end_pos]) parsed_contents = [] diff --git a/src/inference_endpoint/endpoint_client/http.py b/src/inference_endpoint/endpoint_client/http.py index d9047301..1e67a023 100644 --- a/src/inference_endpoint/endpoint_client/http.py +++ b/src/inference_endpoint/endpoint_client/http.py @@ -792,10 +792,12 @@ class InFlightRequest: query_id: Correlates response back to original Query. http_bytes: Serialized HTTP request for socket.write(). is_streaming: Whether this is a streaming (SSE) request or not. + query_metadata: Internal metadata carried alongside the request. connection: PooledConnection assigned to this request (set once request is fired). """ query_id: str http_bytes: bytes is_streaming: bool + query_metadata: dict[str, object] = field(default_factory=dict) connection: PooledConnection = field(default=None, repr=False) # type: ignore[assignment] diff --git a/src/inference_endpoint/endpoint_client/worker.py b/src/inference_endpoint/endpoint_client/worker.py index 8e0e560e..8fb69fce 100644 --- a/src/inference_endpoint/endpoint_client/worker.py +++ b/src/inference_endpoint/endpoint_client/worker.py @@ -341,6 +341,7 @@ def _prepare_request(self, query: Query) -> InFlightRequest: query_id=query.id, http_bytes=http_bytes, is_streaming=is_streaming, + query_metadata=query.metadata, ) return req @@ -429,7 +430,9 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None: self._pool.release(conn) # Send final complete back to main rank - self._responses.send(accumulator.get_final_output()) + self._responses.send( + accumulator.get_final_output().with_metadata(req.query_metadata) + ) @profile async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: @@ -447,7 +450,7 @@ async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: result = self._adapter.decode_response(response_bytes, query_id) # Send result back to main rank - self._responses.send(result) + self._responses.send(result.with_metadata(req.query_metadata)) async def _handle_error(self, query_id: str, error: Exception | str) -> None: """Send error response for a query.""" diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index 6cb23ed8..a01b7b44 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -15,15 +15,13 @@ """OpenAI SSE stream accumulator implementation.""" -import logging +from typing import Any from inference_endpoint.core.types import QueryResult, StreamChunk, TextModelOutput from inference_endpoint.endpoint_client.accumulator_protocol import ( SSEAccumulatorProtocol, ) -from inference_endpoint.openai.types import SSEDelta as OpenAISSEDelta - -logger = logging.getLogger(__name__) +from inference_endpoint.openai.types import SSEChoice class OpenAISSEAccumulator(SSEAccumulatorProtocol): @@ -32,15 +30,41 @@ class OpenAISSEAccumulator(SSEAccumulatorProtocol): def __init__(self, query_id: str, stream_all_chunks: bool): self.output_chunks: list[str] = [] self.reasoning_chunks: list[str] = [] + self._tool_calls: dict[int, dict[str, Any]] = {} + self._finish_reason: str | None = None self.first_chunk_sent = False self.query_id = query_id self.stream_all_chunks = stream_all_chunks - def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: - if not isinstance(delta, OpenAISSEDelta): + def add_chunk(self, choice: SSEChoice | None) -> StreamChunk | None: + if not isinstance(choice, SSEChoice): + return None + + if choice.finish_reason: + self._finish_reason = choice.finish_reason + + delta = choice.delta + if delta is None: return None + # Accumulate tool_calls partials (streamed as incremental JSON fragments) + if delta.tool_calls: + for partial in delta.tool_calls: + idx = partial.get("index", 0) + tc = self._tool_calls.setdefault( + idx, {"type": "function", "function": {"arguments": ""}} + ) + if partial.get("id"): + tc["id"] = partial["id"] + if partial.get("type"): + tc["type"] = partial["type"] + fn = partial.get("function") or {} + if fn.get("name"): + tc["function"]["name"] = fn["name"] + if fn.get("arguments"): + tc["function"]["arguments"] += fn["arguments"] + content = None if delta.content: self.output_chunks.append(delta.content) @@ -68,9 +92,6 @@ def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: def get_final_output(self) -> QueryResult: if self.reasoning_chunks: - # If there are reasoning chunks, then the first chunk received - # is the first reasoning chunk. The rest of the reasoning chunks, - # as well as the output chunks can be joined together. resp_reasoning: list[str] = [self.reasoning_chunks[0]] if len(self.reasoning_chunks) > 1: resp_reasoning.append("".join(self.reasoning_chunks[1:])) @@ -79,19 +100,26 @@ def get_final_output(self) -> QueryResult: reasoning=resp_reasoning, ) elif self.output_chunks: - # If there are only output chunks, the first chunk is used for - # TTFT calculations. The rest are joined together. resp_output: list[str] = [self.output_chunks[0]] if len(self.output_chunks) > 1: resp_output.append("".join(self.output_chunks[1:])) text_output = TextModelOutput(output=resp_output, reasoning=None) else: text_output = TextModelOutput(output=[], reasoning=None) + + metadata: dict[str, Any] = { + "first_chunk": not self.first_chunk_sent, + "final_chunk": True, + } + if self._finish_reason: + metadata["finish_reason"] = self._finish_reason + if self._tool_calls: + metadata["tool_calls"] = [ + self._tool_calls[i] for i in sorted(self._tool_calls) + ] + return QueryResult( id=self.query_id, response_output=text_output, - metadata={ - "first_chunk": not self.first_chunk_sent, - "final_chunk": True, - }, + metadata=metadata, ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 5834d6b0..4830c682 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -36,7 +36,7 @@ Role6, ServiceTier, ) -from .types import SSEMessage +from .types import SSEChoice, SSEMessage class OpenAIAdapter(HttpRequestAdapter): @@ -75,10 +75,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return SSEChoice (delta + finish_reason).""" msg = msgspec.json.decode(json_bytes, type=SSEMessage) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -86,15 +88,21 @@ def decode_sse_message(cls, json_bytes: bytes) -> str: @classmethod def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: - """Convert a Query to an OpenAI request.""" - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] - if "system" in query.data: - messages.insert( - 0, {"role": Role3.system.value, "content": query.data["system"]} - ) + """Convert a Query to an OpenAI request. + + Supports both single-turn (prompt/system) and multi-turn (messages array) formats. + """ + if "messages" in query.data and isinstance(query.data["messages"], list): + messages = query.data["messages"] + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] + if "system" in query.data: + messages.insert( + 0, {"role": Role3.system.value, "content": query.data["system"]} + ) request = CreateChatCompletionRequest( model=ModelIdsShared(query.data.get("model", "no-model-name")), diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index 6106e1bd..e8f15ce6 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -18,6 +18,7 @@ """ import time +from typing import Any import msgspec from inference_endpoint.config.schema import ModelParams, StreamingMode @@ -37,6 +38,7 @@ ChatCompletionResponse, ChatCompletionResponseMessage, ChatMessage, + SSEChoice, SSEMessage, ) @@ -45,6 +47,17 @@ # ============================================================================ +def _chat_message_from_dict(msg: dict) -> "ChatMessage": + """Build a ChatMessage from a dict, forwarding all supported fields.""" + return ChatMessage( + role=msg["role"], + content=msg.get("content"), + name=msg.get("name"), + tool_calls=msg.get("tool_calls"), + tool_call_id=msg.get("tool_call_id"), + ) + + class OpenAIMsgspecAdapter(HttpRequestAdapter): """OpenAI adapter using msgspec for serialization/deserialization.""" @@ -105,10 +118,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return the SSEChoice (delta + finish_reason).""" msg = cls._sse_decoder.decode(json_bytes) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -129,24 +144,31 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: Returns: msgspec.Struct ChatCompletionRequest """ - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [ - ChatMessage( - role="user", - content=query.data["prompt"], - name=query.data.get("name"), - ), - ] - if "system" in query.data: - messages.insert( - 0, + if "messages" in query.data and isinstance(query.data["messages"], list): + messages = [] + for message in query.data["messages"]: + if not isinstance(message, dict): + raise ValueError("messages entries must be dicts") + messages.append(_chat_message_from_dict(message)) + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [ ChatMessage( - role="system", - content=query.data["system"], + role="user", + content=query.data["prompt"], + name=query.data.get("name"), ), - ) + ] + if "system" in query.data: + messages.insert( + 0, + ChatMessage( + role="system", + content=query.data["system"], + ), + ) return ChatCompletionRequest( model=query.data.get("model", "no-model-name"), @@ -164,6 +186,7 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: logit_bias=query.data.get("logit_bias"), user=query.data.get("user"), chat_template=query.data.get("chat_template"), + tools=query.data.get("tools"), ) @classmethod @@ -184,9 +207,19 @@ def from_endpoint_response( if not response.choices: raise ValueError("Response must contain at least one choice") + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason + if choice.message.tool_calls: + metadata["tool_calls"] = choice.message.tool_calls + if choice.message.reasoning_content: + metadata["reasoning_content"] = choice.message.reasoning_content + return QueryResult( id=result_id or response.id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput(output=choice.message.content or ""), + metadata=metadata if metadata else None, ) @classmethod diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 036dd172..5296b2ee 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -46,6 +46,7 @@ class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc content: str = "" reasoning: str = "" + tool_calls: list[dict[str, Any]] | None = None class SSEChoice( @@ -75,12 +76,17 @@ class ChatMessage( ): # type: ignore[call-arg] """Chat message in OpenAI format. - content: str for text-only messages; list[dict] for multimodal (vision). + content: str for text-only messages; list[dict] for multimodal (vision); + None for tool-dispatching assistant messages. + tool_calls: list of tool call objects for assistant messages that invoke tools. + tool_call_id: correlates a tool result message to its tool call. """ role: str - content: ChatMessageContent + content: ChatMessageContent | None = None name: str | None = None + tool_calls: list[dict[str, Any]] | None = None + tool_call_id: str | None = None class ChatCompletionRequest( @@ -103,6 +109,7 @@ class ChatCompletionRequest( logit_bias: dict[str, float] | None = None user: str | None = None chat_template: str | None = None + tools: list[dict[str, Any]] | None = None class ChatCompletionResponseMessage( @@ -113,6 +120,8 @@ class ChatCompletionResponseMessage( role: str content: str | None refusal: str | None + tool_calls: list[dict[str, Any]] | None = None + reasoning_content: str | None = None class ChatCompletionChoice( diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index 52bdbe77..9c3bec15 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -891,3 +891,57 @@ def test_numeric_types_in_metadata(self): assert decoded.metadata["large_int"] == 9999999999999999 assert decoded.metadata["negative"] == -123.456 assert decoded.metadata["zero"] == 0 + + +@pytest.mark.unit +class TestQueryResultWithMetadata: + """Test QueryResult.with_metadata() method for metadata merging.""" + + def test_with_metadata_merge_behavior(self): + """Test that with_metadata adds new keys and overwrites existing ones.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "old_value", "key2": "keep_me"}, + ) + + updated = result.with_metadata({"key1": "new_value", "key3": "added"}) + + assert updated.metadata == { + "key1": "new_value", + "key2": "keep_me", + "key3": "added", + } + assert updated.id == "test" + assert updated.response_output == TextModelOutput(output="hello") + + def test_with_metadata_none_returns_self(self): + """Test that with_metadata(None) returns self unchanged.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "value"}, + ) + assert result.with_metadata(None) is result + + def test_with_metadata_empty_returns_self(self): + """Test that with_metadata({}) returns self unchanged.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "value"}, + ) + assert result.with_metadata({}) is result + + def test_query_metadata_field_roundtrips(self): + """Test that Query.metadata round-trips through msgspec encoding.""" + query = Query( + data={"prompt": "Hello"}, + metadata={"conversation_id": "conv-1", "turn": 2}, + ) + + encoded = msgspec.json.encode(query) + decoded = msgspec.json.decode(encoded, type=Query) + + assert decoded.metadata["conversation_id"] == "conv-1" + assert decoded.metadata["turn"] == 2 diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py new file mode 100644 index 00000000..09ccd224 --- /dev/null +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -0,0 +1,1073 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from collections.abc import Generator +from pathlib import Path + +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.dataset import DatasetFormat +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset + + +@pytest.fixture +def valid_multi_turn_jsonl() -> Generator[str, None, None]: + """Create valid multi-turn conversation JSONL data.""" + data = [ + { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "Hello, how are you?", + "system": "You are a helpful assistant", + }, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'm doing well, thank you!", + }, + { + "conversation_id": "conv_001", + "turn": 3, + "role": "user", + "content": "What can you help me with?", + }, + { + "conversation_id": "conv_002", + "turn": 1, + "role": "user", + "content": "What's the weather?", + }, + { + "conversation_id": "conv_002", + "turn": 2, + "role": "assistant", + "content": "I don't have access to real-time weather data.", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def invalid_role_sequence_jsonl() -> Generator[str, None, None]: + """Create JSONL with invalid role sequence (not alternating).""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "Hello"}, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "user", + "content": "Another user message", + }, # Invalid - consecutive user + { + "conversation_id": "conv_001", + "turn": 3, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def missing_fields_jsonl() -> Generator[str, None, None]: + """Create JSONL with missing required fields.""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user"}, # Missing content + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): + """Test loading valid multi-turn conversation data.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Should have 5 rows total (3 for conv_001, 2 for conv_002) + assert len(dataset.data) == 5 + + # Should have 3 user turns (samples) - only user turns are indexed + assert dataset.num_samples() == 3 + + +@pytest.mark.unit +def test_multi_turn_dataset_user_turn_indexing(valid_multi_turn_jsonl): + """Test that only client turns (user + tool) are indexed as samples.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Verify client turn indices are correct (fixture has only user turns) + assert len(dataset._client_turn_indices) == 3 + + # Check that indices point to client turns + for idx in dataset._client_turn_indices: + assert dataset.data[idx]["role"] in ("user", "tool") + + +@pytest.mark.unit +def test_multi_turn_dataset_load_sample(valid_multi_turn_jsonl): + """Test load_sample returns correct user turns with dense indexing.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0 should be first user turn + sample_0 = dataset.load_sample(0) + assert sample_0["conversation_id"] == "conv_001" + assert sample_0["turn"] == 1 + assert sample_0["role"] == "user" + assert sample_0["content"] == "Hello, how are you?" + # System prompt is in pre_built_messages, not as a separate field + assert sample_0["pre_built_messages"][0]["role"] == "system" + assert sample_0["pre_built_messages"][0]["content"] == "You are a helpful assistant" + + # Sample 1 should be second user turn (conv_001 turn 3) + sample_1 = dataset.load_sample(1) + assert sample_1["conversation_id"] == "conv_001" + assert sample_1["turn"] == 3 + assert sample_1["role"] == "user" + assert sample_1["content"] == "What can you help me with?" + + # Sample 2 should be third user turn (conv_002 turn 1) + sample_2 = dataset.load_sample(2) + assert sample_2["conversation_id"] == "conv_002" + assert sample_2["turn"] == 1 + assert sample_2["role"] == "user" + assert sample_2["content"] == "What's the weather?" + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_metadata(valid_multi_turn_jsonl): + """Test conversation metadata generation.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + metadata = dataset.conversation_metadata + + # Check metadata structure + assert "samples" in metadata + assert "num_conversations" in metadata + assert "max_turns_per_conv" in metadata + assert "client_turns_per_conversation" in metadata + + # Should have 3 client turn samples (fixture has only user turns, no tool turns) + assert len(metadata["samples"]) == 3 + + # Should have 2 conversations + assert metadata["num_conversations"] == 2 + + # Max turns per conversation should be 3 (conv_001 has 3 turns) + assert metadata["max_turns_per_conv"] == 3 + + # Check sample metadata structure + sample_meta = metadata["samples"][0] + assert "index" in sample_meta + assert "conversation_id" in sample_meta + assert "turn" in sample_meta + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_invalid_role_sequence( + invalid_role_sequence_jsonl, +): + """Test validation rejects invalid role sequences.""" + # Validation happens during load_from_file (in __init__), not during load() + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset.load_from_file( + invalid_role_sequence_jsonl, format=DatasetFormat.JSONL + ) + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_missing_fields(missing_fields_jsonl): + """Missing content field is preserved as None in the loaded sample.""" + dataset = MultiTurnDataset.load_from_file( + missing_fields_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + sample = dataset.load_sample(0) + # Missing content is no longer propagated to the sample dict + assert "content" not in sample + + +@pytest.mark.unit +def test_multi_turn_dataset_multiple_conversations(): + """Test dataset with multiple conversations of varying lengths.""" + data = [ + # Conversation 1: 3 turns (user-assistant-user, missing final assistant) + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg1b"}, + # Conversation 2: 4 turns (complete user-assistant alternation) + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + {"conversation_id": "c2", "turn": 3, "role": "user", "content": "msg3"}, + {"conversation_id": "c2", "turn": 4, "role": "assistant", "content": "resp3"}, + # Conversation 3: 2 turns (complete user-assistant) + {"conversation_id": "c3", "turn": 1, "role": "user", "content": "msg4"}, + {"conversation_id": "c3", "turn": 2, "role": "assistant", "content": "resp4"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 9 total rows, 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) + assert len(dataset.data) == 9 + assert dataset.num_samples() == 5 + + # Metadata checks + metadata = dataset.conversation_metadata + assert metadata["num_conversations"] == 3 + assert metadata["max_turns_per_conv"] == 4 # c2 has 4 turns + + # Verify user turns are correctly indexed + samples = [dataset.load_sample(i) for i in range(5)] + + # Check we got all the user turns + user_turns = [(s["conversation_id"], s["turn"]) for s in samples] + expected_turns = [("c1", 1), ("c1", 3), ("c2", 1), ("c2", 3), ("c3", 1)] + assert sorted(user_turns) == sorted(expected_turns) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): + """Test system prompt is included as the first message in pre_built_messages. + + The system prompt is pre-baked into every client turn's message list so the + conversation manager no longer needs to track it separately. + """ + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # First sample: pre_built_messages starts with system message + sample_0 = dataset.load_sample(0) + assert "pre_built_messages" in sample_0 + msgs = sample_0["pre_built_messages"] + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "You are a helpful assistant" + + # Second sample (same conversation, turn 3): system message still first + sample_1 = dataset.load_sample(1) + msgs_1 = sample_1["pre_built_messages"] + assert msgs_1[0]["role"] == "system" + assert msgs_1[0]["content"] == "You are a helpful assistant" + + +@pytest.mark.unit +def test_multi_turn_dataset_single_turn_conversations(): + """Test conversations with only one turn.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Single turn"}, + # No assistant response + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "Another single", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 2 rows, 2 user turns + assert len(dataset.data) == 2 + assert dataset.num_samples() == 2 + + # Both samples should be user turns + assert dataset.load_sample(0)["role"] == "user" + assert dataset.load_sample(1)["role"] == "user" + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_empty_conversation(): + """Empty JSONL file raises ValueError (no columns to validate against).""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + temp_path = f.name + + try: + with pytest.raises(ValueError): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_grouping(): + """Test that properly grouped conversations load correctly.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "c1t3"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 5 total rows, 3 user turns (c1t1, c1t3, c2t1) + assert len(dataset.data) == 5 + assert dataset.num_samples() == 3 + + # Load samples to verify conversation grouping + samples = [dataset.load_sample(i) for i in range(3)] + + # Verify conversation IDs + conv_ids = [s["conversation_id"] for s in samples] + assert conv_ids == ["c1", "c1", "c2"] + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_interleaved_conversations_rejected(): + """Test that interleaved conversation rows raise InputValidationError.""" + from inference_endpoint.exceptions import InputValidationError + + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(InputValidationError, match="not consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +@pytest.mark.parametrize( + "rows", + [ + # assistant-first + [ + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + # consecutive assistants + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "assistant", "content": "C"}, + ], + # tool directly after user (tool-before-assistant) + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "tool", + "tool_results": [{"tool_call_id": "x", "content": "r"}], + }, + ], + # consecutive users + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + ], +) +def test_validation_rejects_invalid_role_sequence(rows): + """Invalid role sequences raise ValueError regardless of turn numbering.""" + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_multi_turn_dataset_additional_fields(): + """Test that additional fields (model, max_new_tokens, etc.) are preserved.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + "model": "gpt-4", + "max_new_tokens": 256, + "temperature": 0.7, + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + # Fields may or may not be present depending on how dataframe handles them + # Just check they're accessible if present + if "model" in sample: + assert sample["model"] == "gpt-4" + if "max_new_tokens" in sample: + assert sample["max_new_tokens"] == 256 + if "temperature" in sample: + assert sample["temperature"] == pytest.approx(0.7) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_openai_field_forwarding(): + """Test that OpenAI-specific fields are preserved and forwarded.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + # OpenAI fields that should be forwarded + "n": 3, + "name": "Alice", + "user": "user_12345", + "logit_bias": {"50256": -100}, + "chat_template": "custom_template", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # Verify OpenAI fields are present + assert sample.get("n") == 3 + assert sample.get("name") == "Alice" + assert sample.get("user") == "user_12345" + assert sample.get("logit_bias") == {"50256": -100} + assert sample.get("chat_template") == "custom_template" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_all_generation_params(): + """Test that all generation parameters in GENERATION_PARAMS are forwarded.""" + from inference_endpoint.dataset_manager.multi_turn_dataset import GENERATION_PARAMS + + # Create dataset with all possible generation params + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Test", + # Include all params from GENERATION_PARAMS + "model": "test-model", + "max_new_tokens": 100, + "max_completion_tokens": 100, + "stream": True, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "seed": 42, + "repetition_penalty": 1.1, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "stop": ["END"], + "n": 2, + "logit_bias": {"100": 10}, + "name": "TestEntity", + "user": "test_user_001", + "chat_template": "test_template", + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # Verify all GENERATION_PARAMS fields are forwarded + # (excluding conversational fields like conversation_id, turn, role, content, system) + for param in GENERATION_PARAMS: + if param in data[0]: + assert ( + param in sample + ), f"Generation parameter '{param}' not forwarded to sample" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_non_contiguous_turns(): + """Turn numbers must be consecutive; gaps are rejected.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "a"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "b"}, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "c"}, + {"conversation_id": "c1", "turn": 6, "role": "assistant", "content": "d"}, + ] + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_validation_rejects_turns_not_starting_at_one(): + """Validation should reject conversations whose turns don't start at 1.""" + data = [ + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_accepts_valid_contiguous_turns(): + """Validation should accept contiguous turn sequences.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg2"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + assert dataset.num_samples() == 2 + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_turn_starting_at_zero(): + """Validation should reject conversations starting at turn 0.""" + data = [ + {"conversation_id": "c1", "turn": 0, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_duplicate_turn_numbers(): + """Duplicate turn numbers within a conversation are rejected.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + # c2 has duplicate turn 2 β€” second assistant row with same turn number + {"conversation_id": "c2", "turn": 2, "role": "user", "content": "dup"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_assistant_tc_role_literal(): + """role='assistant_tc' literal in dataset is rejected; only 'assistant' is valid.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Q"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant_tc", + "tool_calls": [ + { + "id": "c0", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "c0", "content": "r"}], + }, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "A"}, + ] + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +# ============================================================================ +# Tool sequence tests +# ============================================================================ + + +def _make_tool_sequence_df(): + """Return a DataFrame with a tool sequence embedded between user turns.""" + return pd.DataFrame( + [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "What is the weather?", + "system": "Be helpful", + }, + # assistant (with tool_calls): dispatches a tool call + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_c1_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + # tool result + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "call_c1_0", "content": '{"temp": 22}'} + ], + }, + # terminal assistant + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The weather is 22Β°C.", + }, + # second user turn + { + "conversation_id": "c1", + "turn": 5, + "role": "user", + "content": "Thanks!", + }, + ] + ) + + +@pytest.mark.unit +def test_validation_accepts_tool_sequence(): + """user β†’ assistant β†’ tool β†’ assistant β†’ user passes validation.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 3 # user(1), tool(3), user(5) are all client turns + + +@pytest.mark.unit +def test_validation_accepts_parallel_tool_calls(): + """Assistant with two tool_calls + merged tool_results row passes.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 2 # user(1), tool(3) are client turns + + +@pytest.mark.unit +def test_load_sample_merged_tool_row_has_no_content_key(): + """load_sample for a merged tool_results row must not emit content: NaN.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the merged tool row (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "content" not in s1 # must NOT emit NaN + assert "pre_built_messages" in s1 + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages(): + """pre_built_messages_by_key contains complete message arrays for each client turn. + + Dataset: + turn 1: user ← client turn 1 + turn 2: asst_tc ← scripted (assistant with tool_calls) + turn 3: tool ← client turn 2 + turn 4: assistant ← terminal assistant + turn 5: user ← client turn 3 + + Expected pre_built_messages: + client turn 1 (t=1): [system, user(1)] + client turn 2 (t=3): [system, user(1), asst_tc(2), tool(3)] + client turn 3 (t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + """ + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 1 (user, t=1): [system, user(1)] + msgs_t1 = pbm[("c1", 1)] + assert len(msgs_t1) == 2 + assert msgs_t1[0] == {"role": "system", "content": "Be helpful"} + assert msgs_t1[1] == {"role": "user", "content": "What is the weather?"} + + # Client turn 2 (tool, t=3): [system, user(1), asst_tc(2), tool(3)] + msgs_t3 = pbm[("c1", 3)] + assert len(msgs_t3) == 4 + assert msgs_t3[0]["role"] == "system" + assert msgs_t3[1]["role"] == "user" + assert msgs_t3[2]["role"] == "assistant" + assert "tool_calls" in msgs_t3[2] + assert msgs_t3[3]["role"] == "tool" + assert msgs_t3[3]["content"] == '{"temp": 22}' + assert msgs_t3[3]["tool_call_id"] == "call_c1_0" + + # Client turn 3 (user, t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + msgs_t5 = pbm[("c1", 5)] + assert len(msgs_t5) == 6 + assert msgs_t5[4] == {"role": "assistant", "content": "The weather is 22Β°C."} + assert msgs_t5[5] == {"role": "user", "content": "Thanks!"} + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages_no_tools(): + """Plain user/assistant alternation produces correct pre_built_messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "C"}, + ] + ) + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Turn 1: just the user message (no system, no prior rows) + assert pbm[("c1", 1)] == [{"role": "user", "content": "A"}] + + # Turn 3: user(1) + assistant(2) + user(3) + msgs = pbm[("c1", 3)] + assert len(msgs) == 3 + assert msgs[0] == {"role": "user", "content": "A"} + assert msgs[1] == {"role": "assistant", "content": "B"} + assert msgs[2] == {"role": "user", "content": "C"} + + +@pytest.mark.unit +def test_load_sample_includes_pre_built_messages(): + """load_sample returns pre_built_messages with the complete message list.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) # user turn 1 + assert "pre_built_messages" in s0 + msgs = s0["pre_built_messages"] + assert msgs[0]["role"] == "system" + assert msgs[-1] == {"role": "user", "content": "What is the weather?"} + + s1 = ds.load_sample(1) # tool turn 3 + assert s1["role"] == "tool" + msgs_t3 = s1["pre_built_messages"] + # system + user(1) + asst_tc(2) + tool(3) = 4 messages + assert len(msgs_t3) == 4 + assert msgs_t3[-1]["role"] == "tool" + + s2 = ds.load_sample(2) # user turn 5 + msgs_t5 = s2["pre_built_messages"] + # system + user(1) + asst_tc(2) + tool(3) + asst(4) + user(5) = 6 messages + assert len(msgs_t5) == 6 + + +@pytest.mark.unit +def test_client_turns_include_tool_rows(): + """Tool rows are counted in num_samples() as client turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + # 5 rows total: user(1), assistant(2), tool(3), assistant(4), user(5) + # Client turns: user(1), tool(3), user(5) β†’ 3 + assert ds.num_samples() == 3 + + +# ============================================================================ +# Pre-built messages content correctness +# ============================================================================ + + +@pytest.mark.unit +def test_pre_built_messages_include_prior_assistant_response(valid_multi_turn_jsonl): + """The terminal assistant response before each user turn is included in pre_built_messages.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0: turn 1 (first user) β†’ just [system, user(1)] + s0 = dataset.load_sample(0) + msgs_0 = s0["pre_built_messages"] + assert msgs_0[0]["role"] == "system" + assert msgs_0[-1]["role"] == "user" + + # Sample 1: turn 3 (second user) β†’ [system, user(1), assistant(2), user(3)] + s1 = dataset.load_sample(1) + msgs_1 = s1["pre_built_messages"] + assert len(msgs_1) == 4 + assert msgs_1[2] == {"role": "assistant", "content": "I'm doing well, thank you!"} + assert msgs_1[3]["role"] == "user" + + # Sample 2: turn 1 of conv_002 β†’ no prior assistant row + s2 = dataset.load_sample(2) + msgs_2 = s2["pre_built_messages"] + assert all(m["role"] != "assistant" for m in msgs_2) + + +@pytest.mark.unit +def test_pre_built_messages_no_cross_conversation_bleed(): + """Messages for conv_001 must not appear in conv_002's pre_built_messages.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1 user"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2 user"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2 resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # c1: only its own user message + s_c1 = dataset.load_sample(0) + assert s_c1["pre_built_messages"] == [{"role": "user", "content": "c1 user"}] + + # c2: only c2 messages (no c1 content) + s_c2 = dataset.load_sample(1) + contents = [m.get("content") for m in s_c2["pre_built_messages"]] + assert "c1 user" not in contents + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_pre_built_messages_with_tool_sequence_terminal_assistant(): + """Terminal assistant response (turn 4) appears in pre_built_messages for user(5).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s2 = ds.load_sample(2) # user turn 5 + msgs = s2["pre_built_messages"] + # The terminal assistant at turn 4 should be included + assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] + assert any(m["content"] == "The weather is 22Β°C." for m in assistant_msgs) diff --git a/tests/unit/openai/test_msgspec_adapter.py b/tests/unit/openai/test_msgspec_adapter.py new file mode 100644 index 00000000..8127d199 --- /dev/null +++ b/tests/unit/openai/test_msgspec_adapter.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIMsgspecAdapter with tool call fields.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_msgspec_adapter import ( + OpenAIMsgspecAdapter, + _chat_message_from_dict, +) +from inference_endpoint.openai.types import ChatMessage + + +@pytest.mark.unit +def test_chat_message_tool_calls_serialised(): + """tool_calls field is included in the JSON output when non-None.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ] + msg = ChatMessage(role="assistant", tool_calls=tool_calls) + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "assistant" + assert decoded["tool_calls"] == tool_calls + assert "content" not in decoded # omit_defaults=True, None omitted + + +@pytest.mark.unit +def test_chat_message_tool_call_id_serialised(): + """tool_call_id field is included in the JSON output when non-None.""" + msg = ChatMessage(role="tool", content="result", tool_call_id="call_0") + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "tool" + assert decoded["content"] == "result" + assert decoded["tool_call_id"] == "call_0" + + +@pytest.mark.unit +def test_to_endpoint_request_preserves_tool_calls(): + """to_endpoint_request forwards tool_calls in the messages array.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "test"}'}, + } + ] + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": None, "tool_calls": tool_calls}, + {"role": "tool", "content": "answer", "tool_call_id": "call_0"}, + {"role": "assistant", "content": "Done"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + }, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + msgs = payload["messages"] + # assistant tool-dispatch row + assert msgs[1]["role"] == "assistant" + assert msgs[1]["tool_calls"] == tool_calls + assert "content" not in msgs[1] + # tool result row + assert msgs[2]["role"] == "tool" + assert msgs[2]["tool_call_id"] == "call_0" + assert msgs[2]["content"] == "answer" + # terminal assistant row + assert msgs[3]["content"] == "Done" + + +@pytest.mark.unit +def test_backward_compat_plain_messages_unchanged(): + """Plain user/assistant messages encode identically to before the change.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + query = Query( + id="q2", + data={"model": "m", "messages": messages}, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + for i, msg in enumerate(payload["messages"]): + assert msg["role"] == messages[i]["role"] + assert msg["content"] == messages[i]["content"] + assert "tool_calls" not in msg + assert "tool_call_id" not in msg + + +@pytest.mark.unit +def test_chat_message_from_dict_all_fields(): + """_chat_message_from_dict forwards all four optional fields.""" + tool_calls = [ + {"id": "x", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ] + msg = _chat_message_from_dict( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + "tool_call_id": None, + } + ) + assert msg.role == "assistant" + assert msg.content is None + assert msg.tool_calls == tool_calls + assert msg.tool_call_id is None + + +@pytest.mark.unit +def test_chat_message_content_optional(): + """ChatMessage accepts content=None for tool-dispatching assistant turns.""" + msg = ChatMessage(role="assistant", tool_calls=[]) + assert msg.content is None From ba1cce8aa7b150f83b15762b274c52f3b545ffd5 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 23 Apr 2026 15:01:35 -0700 Subject: [PATCH 02/13] feat: add ConversationManager and MultiTurnStrategy Add per-conversation asyncio.Event sequencing (ConversationManager), async turn pipeline (MultiTurnStrategy), and benchmark execution wiring (execute.py, session.py PhaseIssuer data_override). --- .../commands/benchmark/execute.py | 44 +- .../dataset_manager/__init__.py | 2 + .../dataset_manager/multi_turn_dataset.py | 198 +++----- .../dataset_manager/transforms.py | 24 + .../load_generator/conversation_manager.py | 356 +++++++++++++++ .../load_generator/multi_turn_strategy.py | 229 ++++++++++ .../load_generator/session.py | 30 +- .../load_generator/strategy.py | 21 +- tests/integration/test_multi_turn.py | 425 ++++++++++++++++++ .../test_multi_turn_dataset.py | 142 +++--- tests/unit/dataset_manager/test_transforms.py | 50 ++- .../test_multi_turn_conversation_manager.py | 396 ++++++++++++++++ .../test_multi_turn_strategy.py | 279 ++++++++++++ 13 files changed, 1972 insertions(+), 224 deletions(-) create mode 100644 src/inference_endpoint/load_generator/conversation_manager.py create mode 100644 src/inference_endpoint/load_generator/multi_turn_strategy.py create mode 100644 tests/integration/test_multi_turn.py create mode 100644 tests/unit/load_generator/test_multi_turn_conversation_manager.py create mode 100644 tests/unit/load_generator/test_multi_turn_strategy.py diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 73c3427f..e107e611 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -72,6 +72,7 @@ from inference_endpoint.core.types import QueryResult from inference_endpoint.dataset_manager.dataset import Dataset from inference_endpoint.dataset_manager.factory import DataLoaderFactory +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset from inference_endpoint.endpoint_client.cpu_affinity import AffinityPlan, pin_loadgen from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer @@ -82,6 +83,8 @@ InputValidationError, SetupError, ) +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy from inference_endpoint.load_generator.session import ( BenchmarkSession, PhaseConfig, @@ -343,14 +346,21 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo ) -def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: +def _build_phases( + ctx: BenchmarkContext, + perf_strategy: MultiTurnStrategy | None = None, +) -> list[PhaseConfig]: """Build the phase list from BenchmarkContext.""" phases: list[PhaseConfig] = [] # Performance phase phases.append( PhaseConfig( - "performance", ctx.rt_settings, ctx.dataloader, PhaseType.PERFORMANCE + "performance", + ctx.rt_settings, + ctx.dataloader, + PhaseType.PERFORMANCE, + strategy=perf_strategy, ) ) @@ -513,16 +523,42 @@ async def _run_benchmark_async( launcher.kill_all() raise SetupError(f"Failed to connect to endpoint: {e}") from e + # Build multi-turn strategy if the performance dataset is a MultiTurnDataset. + multi_turn_strategy: MultiTurnStrategy | None = None + if isinstance(ctx.dataloader, MultiTurnDataset): + mt_cfg = None + if ctx.config.datasets: + perf_ds_cfg = next( + ( + d + for d in ctx.config.datasets + if d.type == DatasetType.PERFORMANCE + ), + None, + ) + if perf_ds_cfg is not None: + mt_cfg = perf_ds_cfg.multi_turn + multi_turn_strategy = MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ctx.dataloader.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + def _on_sample_complete(result: QueryResult) -> None: + if multi_turn_strategy is not None: + multi_turn_strategy.on_sample_complete(result) + collector.on_complete_hook(result) + # Create session session = BenchmarkSession( issuer=issuer, event_publisher=publisher, loop=loop, - on_sample_complete=collector.on_complete_hook, + on_sample_complete=_on_sample_complete, session_id=session_id, ) - phases = _build_phases(ctx) + phases = _build_phases(ctx, perf_strategy=multi_turn_strategy) report: Report | None = None loop.add_signal_handler(signal.SIGINT, session.stop) diff --git a/src/inference_endpoint/dataset_manager/__init__.py b/src/inference_endpoint/dataset_manager/__init__.py index 12938f8e..403b8730 100644 --- a/src/inference_endpoint/dataset_manager/__init__.py +++ b/src/inference_endpoint/dataset_manager/__init__.py @@ -30,6 +30,7 @@ from .predefined.random import RandomDataset from .predefined.shopify_product_catalogue import ShopifyProductCatalogue from .transforms import ( + AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -46,6 +47,7 @@ "DataLoaderFactory", "ColumnFilter", "ColumnRemap", + "AddDefaultColumns", "AddStaticColumns", "UserPromptFormatter", "FusedRowProcessor", diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index bff8c011..f26c3d3e 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -19,56 +19,15 @@ import pandas as pd -from ..config.schema import APIType, ModelParams, StreamingMode +from ..config.schema import APIType, ModelParams from ..exceptions import InputValidationError from .dataset import Dataset -from .transforms import apply_transforms - -# Known generation parameter fields to forward from dataset to API requests. -# Aligned with OpenAI API specification and openai_msgspec_adapter.py implementation. -# These parameters work in both single-turn and multi-turn modes. -GENERATION_PARAMS = { - "model", - "max_new_tokens", - "max_completion_tokens", - "stream", - "temperature", - "top_p", - "top_k", - "seed", - "repetition_penalty", - "frequency_penalty", - "presence_penalty", - "stop", - "n", - "logit_bias", # Token probability adjustments - "name", # Entity name for role (NOT model name, e.g., 'Bob' for tracking) - "user", # End-user identifier for monitoring/abuse detection - "chat_template", # Custom chat formatting template - "tools", # OpenAI tool definitions (list[dict]) for tool-calling models -} - - -def _model_param_defaults(model_params: ModelParams | None) -> dict[str, Any]: - """Build per-request defaults for multi-turn rows from model params. - - Multi-turn datasets use `content` and conversation metadata rather than the - single-turn `prompt` field expected by adapter dataset transforms. Applying - those transforms would drop the conversation schema before load_sample() can - construct the messages array. Instead, we inject the request defaults here. - """ - if model_params is None: - return {} - - return { - "model": model_params.name, - "stream": model_params.streaming == StreamingMode.ON, - "max_completion_tokens": model_params.max_new_tokens, - "temperature": model_params.temperature, - "top_p": model_params.top_p, - "top_k": model_params.top_k, - "repetition_penalty": model_params.repetition_penalty, - } +from .transforms import ( + AddDefaultColumns, + AddStaticColumns, + apply_transforms, + get_transforms_for_api_type, +) def _expand_tool_results(row: dict) -> list[dict]: @@ -113,7 +72,7 @@ class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): Optional columns: - system: System prompt associated with the conversation (typically set on the first user turn) - model: Model name override - - max_new_tokens: Max tokens for this turn + - max_new_tokens / max_completion_tokens: Max tokens for this turn (alias; mapped to max_completion_tokens) Attributes: conversation_metadata: Metadata dict containing: @@ -139,7 +98,6 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): self._validate_conversation_structure() self._validate_turn_numbering() self.conversation_metadata = self._build_metadata() - self._client_turn_indices: list[int] | None = None def _validate_conversation_grouping(self) -> None: """Validate that all rows for each conversation_id appear consecutively in file order. @@ -321,23 +279,18 @@ def load( model_params: ModelParams | None = None, force: bool = False, ): - """Load dataset and build a dense user-turn index. + """Load dataset, apply adapter defaults, and pre-bake client-turn samples. - Multi-turn benchmarks only issue user turns. Assistant turns remain in the - backing data so the conversation structure can still be validated. + Unlike single-turn datasets, multi-turn rows do not have a `prompt` column, + so ColumnFilter (which requires prompt) is skipped. AddStaticColumns entries + from the adapter are applied via AddDefaultColumns (fill-missing-only) so that + per-row dataset overrides are preserved. - Unlike single-turn datasets, multi-turn rows do not have a `prompt` - column, so adapter dataset transforms are intentionally skipped here. - They would apply a single-turn ColumnFilter and strip the conversation - fields required by load_sample(). Request defaults from model_params are - merged directly into the conversation rows instead. + After transforms, only client turns (user + tool) are stored in self.data as + fully assembled sample dicts (with messages, current_turn_message, system_content + attached). load_sample() and num_samples() are inherited from the base class. """ if not force and self.data is not None: - self._client_turn_indices = [ - index - for index, row in enumerate(self.data) - if row["role"] in ("user", "tool") - ] return df = self.dataframe @@ -353,76 +306,57 @@ def load( if transforms: df = apply_transforms(df, transforms) - defaults = _model_param_defaults(model_params) - for key, value in defaults.items(): - if value is None: + # Extract AddStaticColumns defaults from adapter transforms and apply as + # fill-missing-only (preserves per-row dataset values). + if api_type is not None and model_params is not None: + adapter_transforms = get_transforms_for_api_type(api_type, model_params) + defaults: dict[str, Any] = {} + for t in adapter_transforms: + if isinstance(t, AddStaticColumns): + defaults.update(t.data) + if defaults: + df = AddDefaultColumns(defaults)(df) + + all_rows = df.to_dict(orient="records") + + # Pre-bake: assemble one complete sample dict per client turn. + # NaN filtering replaces the GENERATION_PARAMS allowlist β€” any key whose + # value is float NaN was absent in the original dataset row. + pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}) + client_turn_samples: list[dict[str, Any]] = [] + + for row in all_rows: + if row.get("role") not in ("user", "tool"): continue - if key in df.columns: - df[key] = df[key].where(pd.notna(df[key]), value) - else: - df[key] = value - - self.data = df.to_dict(orient="records") - assert self.data is not None, "Failed to convert DataFrame to records" - - self._client_turn_indices = [ - index - for index, row in enumerate(self.data) - if row["role"] in ("user", "tool") - ] - - def load_sample(self, index: int) -> dict[str, Any]: - """Load the Nth client turn (user or tool) as a benchmark sample.""" - assert self.data is not None, "Dataset not loaded. Call load() first." - assert ( - self._client_turn_indices is not None - ), "Dataset not loaded. Call load() first." - row = self.data[self._client_turn_indices[index]] - - content_val = row.get("content") - sample: dict[str, Any] = { - "conversation_id": row["conversation_id"], - "turn": row["turn"], - "role": row["role"], - } - if content_val is not None and not ( - isinstance(content_val, float) and pd.isna(content_val) - ): - sample["content"] = content_val - - for param in GENERATION_PARAMS: - if param in row: - value = row[param] - # Skip pandas NaN/None values - if value is not None and ( - not isinstance(value, float) or not pd.isna(value) - ): - sample[param] = value - - # Set defaults for critical params if not present - if "max_new_tokens" not in sample and "max_completion_tokens" not in sample: - sample["max_new_tokens"] = 128 - if "stream" not in sample: - sample["stream"] = False - - # Attach pre-built message list (system + history + current turn). - key = (row["conversation_id"], int(row["turn"])) - pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}).get( - key, [] - ) - sample["pre_built_messages"] = pre_built - # Fields for use_dataset_history=False path (live history accumulation). - sample["current_turn_message"] = pre_built[-1] if pre_built else {} - first = pre_built[0] if pre_built else {} - sample["system_content"] = ( - first.get("content") if first.get("role") == "system" else None - ) + # Filter NaN values; keep all meaningful fields (extra keys are harmless + # since adapters consume only what they recognize). + sample: dict[str, Any] = { + k: v + for k, v in row.items() + if v is not None and not (isinstance(v, float) and pd.isna(v)) + } + + # max_new_tokens β†’ max_completion_tokens alias + if "max_completion_tokens" not in sample and "max_new_tokens" in sample: + sample["max_completion_tokens"] = sample.pop("max_new_tokens") + if "max_completion_tokens" not in sample: + sample["max_completion_tokens"] = 128 + if "stream" not in sample: + sample["stream"] = False + + # Attach pre-built message list (system + history + current turn). + key = (row["conversation_id"], int(row["turn"])) + messages = pre_built.get(key, []) + sample["messages"] = messages + + # Fields for use_dataset_history=False path (live history accumulation). + sample["current_turn_message"] = messages[-1] if messages else {} + first = messages[0] if messages else {} + sample["system_content"] = ( + first.get("content") if first.get("role") == "system" else None + ) - return sample + client_turn_samples.append(sample) - def num_samples(self) -> int: - assert ( - self._client_turn_indices is not None - ), "Dataset not loaded. Call load() first." - return len(self._client_turn_indices) + self.data = client_turn_samples diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index 79133796..a288da6d 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -127,6 +127,30 @@ def __call__(self, df: pd.DataFrame) -> pd.DataFrame: return df +class AddDefaultColumns(Transform): + """Add columns only where values are missing (NaN or absent). + + Unlike AddStaticColumns which unconditionally overwrites, this preserves + existing non-null values β€” dataset per-row overrides take precedence over + the supplied defaults. + """ + + def __init__(self, data: dict[str, Any]): + """Initialize the AddDefaultColumns transform.""" + self.data = data + + def __call__(self, df: pd.DataFrame) -> pd.DataFrame: + """Fill missing columns with defaults without overwriting existing values.""" + for key, value in self.data.items(): + if value is None: + continue + if key in df.columns: + df[key] = df[key].where(pd.notna(df[key]), value) + else: + df[key] = value + return df + + class Harmonize(RowProcessor): """Transform to convert a user prompt to an OpenAI Harmony-compatible format.""" diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py new file mode 100644 index 00000000..86741711 --- /dev/null +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async conversation state management for multi-turn benchmarking.""" + +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ConversationState: + """Tracks conversation sequencing for multi-turn benchmarking. + + Maintains turn counters and asyncio conditions so the strategy can enforce + sequential turn ordering within a conversation. Message history is NOT stored + here β€” it is pre-computed in MultiTurnDataset and served via load_sample(). + + Attributes: + conversation_id: Unique identifier for this conversation. + current_turn: Last completed turn number (0 = not started). + pending_client_turn: Turn number of in-flight client turn (None if idle). + expected_client_turns: Expected number of client turns (for completion tracking). + issued_client_turns: Count of client turns issued. + completed_client_turns: Count of client turns with responses. + failed_client_turns: Count of client turns that failed (error/timeout). + message_history: Accumulated message list (only populated when + use_dataset_history=False; empty otherwise). + condition: Per-conversation asyncio.Condition for turn-ready and turn-issued waits. + Scoped to this conversation so that state changes only wake the single + pipeline task waiting on this conversation, not all pipeline tasks. + """ + + conversation_id: str + current_turn: int = 0 + pending_client_turn: int | None = None + expected_client_turns: int | None = None + issued_client_turns: int = 0 + completed_client_turns: int = 0 + failed_client_turns: int = 0 + message_history: list[dict[str, Any]] = field(default_factory=list) + condition: asyncio.Condition = field(default_factory=asyncio.Condition) + + def add_client_turn(self, turn: int, message: dict[str, Any] | None = None): + """Record that a client turn has been issued (updates sequencing counters). + + Args: + turn: Turn number for this client message. + message: Message dict to append to message_history (only used when + use_dataset_history=False). + """ + self.pending_client_turn = turn + self.issued_client_turns += 1 + if message is not None: + self.message_history.append(message) + + def add_assistant_turn(self, content: str | None = None): + """Record assistant response and mark turn complete (success). + + Args: + content: Response content to append to message_history. Only + used when use_dataset_history=False; None means no history + update (pre-built messages path). + """ + if content is not None: + self.message_history.append({"role": "assistant", "content": content}) + if self.pending_client_turn is not None: + self.current_turn = self.pending_client_turn + 1 + self.pending_client_turn = None + self.completed_client_turns += 1 + elif self.is_complete(): + pass + else: + logger.warning( + f"Received assistant response for {self.conversation_id} " + f"with no pending client turn (duplicate or out-of-order response)" + ) + self.current_turn = self.current_turn + 1 if self.current_turn > 0 else 1 + self.completed_client_turns += 1 + + if self.is_complete(): + if self.failed_client_turns > 0: + logger.info( + f"Conversation {self.conversation_id} completed with failures: " + f"{self.completed_client_turns - self.failed_client_turns}/" + f"{self.expected_client_turns} successful, " + f"{self.failed_client_turns} failed" + ) + else: + logger.debug( + f"Conversation {self.conversation_id} completed: " + f"{self.completed_client_turns}/{self.expected_client_turns} turns" + ) + + def mark_turn_failed(self, store_in_history: bool = False): + """Mark turn as failed (error/timeout) - still counts as completed for sequencing.""" + if self.pending_client_turn is not None: + self.current_turn = self.pending_client_turn + 1 + self.pending_client_turn = None + self.completed_client_turns += 1 + self.failed_client_turns += 1 + + if store_in_history: + self.message_history.append( + { + "role": "assistant", + "content": "[ERROR: Turn failed or timed out]", + } + ) + + logger.warning( + f"Turn {self.current_turn - 1} failed for conversation {self.conversation_id}" + ) + else: + logger.warning( + f"Attempted to mark failed turn for {self.conversation_id} " + f"with no pending client turn" + ) + + if self.is_complete(): + logger.info( + f"Conversation {self.conversation_id} completed with failures: " + f"{self.completed_client_turns - self.failed_client_turns}/" + f"{self.expected_client_turns} successful, " + f"{self.failed_client_turns} failed" + ) + + def is_complete(self) -> bool: + """Check if conversation is complete (all turns issued and responses received).""" + if self.expected_client_turns is None: + return False + return self.completed_client_turns >= self.expected_client_turns + + def is_ready_for_turn(self) -> bool: + """Check if the previous turn has completed and the next may be issued.""" + return ( + self.pending_client_turn is None + and self.issued_client_turns == self.completed_client_turns + and self.issued_client_turns > 0 + ) + + +class ConversationManager: + """Manages conversation sequencing for multi-turn benchmarking. + + Async manager that tracks multiple conversations and enforces turn ordering. + Conversations are identified by unique IDs. Message history is NOT maintained here + β€” it is pre-computed in MultiTurnDataset and passed directly to each request. + + The manager ensures that: + - Turn N+1 cannot be issued until turn N completes + - Concurrent access to conversation state is async-safe + + Each ConversationState carries its own asyncio.Condition so that state changes + (turn issued / turn complete) only wake the single pipeline task waiting + on that conversation, not all pipeline tasks across all conversations. + All conversation states are pre-created by the strategy before pipeline + tasks start, so wait_for_turn_issued never races against get_or_create. + """ + + def __init__(self): + """Initialize conversation manager with empty state.""" + self._conversations: dict[str, ConversationState] = {} + self._lock = asyncio.Lock() + + def get_state(self, conversation_id: str) -> ConversationState | None: + """Get conversation state without creating (for read-only access).""" + return self._conversations.get(conversation_id) + + async def get_or_create( + self, + conversation_id: str, + expected_client_turns: int | None = None, + system_message: dict[str, Any] | None = None, + ) -> ConversationState: + """Get existing or create new conversation state. + + Args: + conversation_id: Unique identifier for conversation. + expected_client_turns: Expected number of client turns (for completion tracking). + system_message: System message dict to pre-populate message_history with. + Only used when use_dataset_history=False and conversation is new. + + Returns: + ConversationState for this conversation. + """ + async with self._lock: + if conversation_id not in self._conversations: + initial_history: list[dict[str, Any]] = ( + [system_message] if system_message is not None else [] + ) + state = ConversationState( + conversation_id=conversation_id, + current_turn=0, + pending_client_turn=None, + expected_client_turns=expected_client_turns, + issued_client_turns=0, + completed_client_turns=0, + failed_client_turns=0, + message_history=initial_history, + ) + self._conversations[conversation_id] = state + return self._conversations[conversation_id] + + async def wait_for_turn_ready( + self, conversation_id: str, turn: int, timeout: float | None = None + ) -> bool: + """Block until conversation is ready for this turn. + + Uses the per-conversation asyncio.Condition so only this conversation's pipeline + task is woken on state changes, not all pipeline tasks. + + Args: + conversation_id: Conversation to wait for. + turn: Turn number to wait for (unused in readiness check; kept for + call-site compatibility). + timeout: Maximum seconds to wait (None = infinite). + + Returns: + True if ready, False if timeout. + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + logger.error(f"Conversation {conversation_id} not found in manager") + raise KeyError(f"Conversation {conversation_id} not initialized") + + async with state.condition: + if timeout is None: + await state.condition.wait_for(state.is_ready_for_turn) + return True + try: + async with asyncio.timeout(timeout): + await state.condition.wait_for(state.is_ready_for_turn) + return True + except TimeoutError: + return state.is_ready_for_turn() + + async def wait_for_turn_issued( + self, + conversation_id: str, + min_issued: int, + timeout: float | None = None, + ) -> bool: + """Block until at least min_issued client turns have been issued. + + Args: + conversation_id: Conversation to wait for. + min_issued: Minimum number of issued turns to wait for. + timeout: Maximum seconds to wait (None = infinite). + + Returns: + True if condition met, False if timeout. + + Raises: + KeyError: If conversation_id not found (programming error β€” state must be + pre-created by the strategy before pipeline tasks are spawned). + """ + state = self._conversations[conversation_id] + predicate = lambda: state.issued_client_turns >= min_issued # noqa: E731 + async with state.condition: + if timeout is None: + await state.condition.wait_for(predicate) + return True + try: + async with asyncio.timeout(timeout): + await state.condition.wait_for(predicate) + return True + except TimeoutError: + return state.issued_client_turns >= min_issued + + async def mark_turn_issued( + self, + conversation_id: str, + turn: int, + message: dict[str, Any] | None = None, + ): + """Mark that a client turn has been issued (updates sequencing counters). + + Args: + conversation_id: Conversation ID. + turn: Turn number being issued. + message: Message dict to append to history (used when + use_dataset_history=False). + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + async with state.condition: + state.add_client_turn(turn, message) + state.condition.notify_all() + + async def mark_turn_complete( + self, + conversation_id: str, + response: str, + store_in_history: bool = False, + ): + """Mark that assistant response has arrived. + + Args: + conversation_id: Conversation ID. + response: Model output (stored in history when store_in_history=True). + store_in_history: When True, append response to message_history. + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + async with state.condition: + state.add_assistant_turn(response if store_in_history else None) + state.condition.notify_all() + + async def mark_turn_failed( + self, conversation_id: str, store_in_history: bool = False + ): + """Mark that assistant response failed (error/timeout). + + Failed turns still count toward conversation completion to ensure + turn sequencing progresses even under errors. + + Args: + conversation_id: Conversation ID. + store_in_history: When True, append error placeholder to message_history. + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + async with state.condition: + state.mark_turn_failed(store_in_history=store_in_history) + state.condition.notify_all() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py new file mode 100644 index 00000000..48f5b45f --- /dev/null +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async multi-turn load strategy implementing the LoadStrategy protocol.""" + +import asyncio +import logging +from collections import defaultdict +from typing import Any + +from ..config.schema import MultiTurnConfig +from ..core.types import QueryResult +from .conversation_manager import ConversationManager +from .strategy import PhaseIssuerProtocol + +logger = logging.getLogger(__name__) + +# Default turn timeout when no MultiTurnConfig is provided. +_DEFAULT_TURN_TIMEOUT_S = 300.0 + + +class MultiTurnStrategy: + """Async multi-turn strategy. Spawns per-conversation asyncio.Tasks. + + Each conversation runs as an independent asyncio.Task that enforces + sequential turn ordering: turn N+1 cannot be issued until turn N completes. + Conversations run concurrently β€” no cross-conversation synchronization. + + Optional target_concurrency limits total in-flight requests across all + conversations using asyncio.Semaphore. + + Integration with BenchmarkSession: + - execute(): spawns conversation tasks, awaits all to complete + - on_query_complete(): releases semaphore slot (concurrency control only) + - on_sample_complete(): routes completed QueryResult to ConversationManager + + The response routing path: + 1. _conv_pipeline issues turn N via phase_issuer.issue(idx) β†’ query_id + 2. _conv_pipeline stores (conv_id, turn) in _inflight[query_id] + 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult + 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete + 5. mark_turn_complete notifies the pipeline task waiting on wait_for_turn_ready + 6. _conv_pipeline proceeds to issue turn N+1 + """ + + def __init__( + self, + conversation_manager: ConversationManager, + dataset_metadata: dict[str, Any], + multi_turn_config: MultiTurnConfig | None = None, + target_concurrency: int | None = None, + ): + """Initialize multi-turn strategy. + + Args: + conversation_manager: Manages conversation sequencing state. + dataset_metadata: Metadata from MultiTurnDataset (samples list). + multi_turn_config: Multi-turn conversation configuration. + target_concurrency: Optional maximum concurrent in-flight requests. + """ + self._conv_manager = conversation_manager + self._dataset_metadata = dataset_metadata + self._multi_turn_config = multi_turn_config + self._turn_timeout_s = ( + multi_turn_config.turn_timeout_s + if multi_turn_config is not None + else _DEFAULT_TURN_TIMEOUT_S + ) + self._target_concurrency = target_concurrency + self._sem: asyncio.Semaphore | None = ( + asyncio.Semaphore(target_concurrency) + if target_concurrency is not None and target_concurrency > 0 + else None + ) + self._store_in_history = ( + not multi_turn_config.use_dataset_history + if multi_turn_config is not None + else False + ) + + # Maps query_id -> conversation_id for routing completions. + # Populated by _conv_pipeline after issue() returns query_id. + self._inflight: dict[str, str] = {} + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + """Drive multi-turn sample issuance. + + Args: + phase_issuer: Interface for issuing samples to the endpoint. + + Returns: + Total count of samples issued. + """ + conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) + for sample_index, sample_meta in enumerate(self._dataset_metadata["samples"]): + conv_id = sample_meta["conversation_id"] + conv_samples[conv_id].append((sample_index, sample_meta["turn"])) + + # Pre-create all conversation states before spawning tasks. + for conv_id, turns in conv_samples.items(): + await self._conv_manager.get_or_create( + conv_id, expected_client_turns=len(turns) + ) + + tasks = [ + asyncio.create_task( + self._conv_pipeline(conv_id, turns, phase_issuer), + name=f"mt-pipeline-{conv_id}", + ) + for conv_id, turns in conv_samples.items() + ] + + await asyncio.gather(*tasks, return_exceptions=True) + return phase_issuer.issued_count + + async def _conv_pipeline( + self, + conv_id: str, + turns: list[tuple[int, int]], + phase_issuer: PhaseIssuerProtocol, + ) -> None: + """Process all turns for a single conversation sequentially. + + For each turn after the first, waits for the previous turn to complete + (via wait_for_turn_ready) before issuing the next. This enforces strict + sequential ordering: turn N+1 is not issued until turn N's response arrives. + """ + sorted_turns = sorted(turns, key=lambda x: x[1]) + + for i, (idx, turn) in enumerate(sorted_turns): + if i > 0: + # Wait for the previous turn to complete before issuing the next. + ready = await self._conv_manager.wait_for_turn_ready( + conv_id, turn, timeout=self._turn_timeout_s + ) + if not ready: + logger.warning( + f"Turn {turn} of {conv_id} timed out waiting for previous turn" + ) + await self._conv_manager.mark_turn_failed(conv_id) + break + + # Acquire concurrency slot before issuing + if self._sem is not None: + await self._sem.acquire() + + # For live-history mode: build messages from accumulated history + current turn, + # and pass as data_override so the pre-built messages from the dataset are replaced. + data_override: dict[str, Any] | None = None + current_turn_message: dict[str, Any] | None = None + if self._store_in_history: + pre_built = self._dataset_metadata.get( + "pre_built_messages_by_key", {} + ).get((conv_id, turn), []) + current_turn_message = pre_built[-1] if pre_built else None + state = self._conv_manager.get_state(conv_id) + if state is not None and current_turn_message is not None: + live_messages = state.message_history.copy() + [ + current_turn_message + ] + data_override = {"messages": live_messages} + + query_id = phase_issuer.issue(idx, data_override=data_override) + if query_id is None: + # Session stopping β€” release slot and exit + if self._sem is not None: + self._sem.release() + break + + # Register this query_id -> conv_id mapping for response routing. + self._inflight[query_id] = conv_id + + # Mark the turn as issued so wait_for_turn_ready can gate the next turn. + await self._conv_manager.mark_turn_issued( + conv_id, turn, message=current_turn_message + ) + + def on_query_complete(self, query_id: str) -> None: + """Called by BenchmarkSession when a QueryResult arrives. + + Releases the concurrency semaphore slot. Response routing is done + via on_sample_complete (which receives the full QueryResult). + + Args: + query_id: ID of the completed query. + """ + if self._sem is not None: + self._sem.release() + + def on_sample_complete(self, result: QueryResult) -> None: + """Route completed QueryResult to ConversationManager. + + Called by execute.py on_sample_complete hook after each response. + Looks up the conversation_id from _inflight and calls mark_turn_complete. + + Args: + result: Completed QueryResult from the endpoint. + """ + query_id = result.id + conv_id = self._inflight.pop(query_id, None) + if conv_id is None: + return + + response_text = result.get_response_output_string() + + if result.error is not None: + asyncio.ensure_future( + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + ) + else: + asyncio.ensure_future( + self._conv_manager.mark_turn_complete( + conv_id, response_text, store_in_history=self._store_in_history + ) + ) diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 1c8ad992..8ae0e74f 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -26,9 +26,9 @@ import time import uuid from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..core.record import ( @@ -60,7 +60,7 @@ class PhaseType(str, Enum): WARMUP = "warmup" -@dataclass(frozen=True, slots=True) +@dataclass(frozen=True) class PhaseConfig: """Configuration for a single benchmark phase.""" @@ -68,6 +68,7 @@ class PhaseConfig: runtime_settings: RuntimeSettings dataset: Dataset phase_type: PhaseType = PhaseType.PERFORMANCE + strategy: LoadStrategy | None = field(default=None, compare=False) # --------------------------------------------------------------------------- @@ -172,11 +173,19 @@ def __init__( self.inflight: int = 0 self.issued_count: int = 0 - def issue(self, sample_index: int) -> str | None: + def issue( + self, sample_index: int, data_override: dict[str, Any] | None = None + ) -> str | None: """Load data, build Query, publish ISSUED, send to endpoint. Returns query_id on success, None if session is stopping. + Args: + sample_index: Index into the dataset. + data_override: If provided, merged over the loaded sample data. + Keys in data_override take precedence. Used by MultiTurnStrategy + to substitute live-accumulated message history. + Note: load_sample() runs synchronously before the ISSUED timestamp. For accurate timing, datasets MUST be pre-loaded into memory. Disk-backed datasets will inflate timing and delay subsequent issues. @@ -185,6 +194,8 @@ def issue(self, sample_index: int) -> str | None: return None query_id = uuid.uuid4().hex data = self._dataset.load_sample(sample_index) + if data_override is not None: + data = {**data, **data_override} query = Query(id=query_id, data=data) self.uuid_to_index[query_id] = sample_index ts = time.monotonic_ns() @@ -306,10 +317,13 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: phase_start = time.monotonic_ns() # Create per-phase state - sample_order = create_sample_order(phase.runtime_settings) - strategy = create_load_strategy( - phase.runtime_settings, self._loop, sample_order - ) + if phase.strategy is not None: + strategy = phase.strategy + else: + sample_order = create_sample_order(phase.runtime_settings) + strategy = create_load_strategy( + phase.runtime_settings, self._loop, sample_order + ) phase_issuer = PhaseIssuer( dataset=phase.dataset, issuer=self._issuer, diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index dd311f10..8ee13722 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -29,7 +29,7 @@ import logging from collections.abc import Callable, Iterator from time import monotonic_ns -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..config.schema import LoadPatternType @@ -47,8 +47,17 @@ class PhaseIssuerProtocol(Protocol): """Minimal interface that strategies see for issuing samples.""" - def issue(self, sample_index: int) -> str | None: - """Issue a sample. Returns query_id, or None if the session is stopping.""" + def issue( + self, sample_index: int, data_override: dict[str, Any] | None = None + ) -> str | None: + """Issue a sample. Returns query_id, or None if the session is stopping. + + Args: + sample_index: Index into the dataset. + data_override: If provided, use this as Query.data instead of + loading from the dataset. Used by MultiTurnStrategy for + live-history mode where the messages array is built at runtime. + """ ... issued_count: int @@ -297,5 +306,11 @@ def create_load_strategy( ) return ConcurrencyStrategy(lp.target_concurrency, sample_order) + case LoadPatternType.MULTI_TURN: + raise ValueError( + "MULTI_TURN load pattern requires a MultiTurnDataset β€” " + "use 'inference-endpoint benchmark from-config' with a multi-turn dataset" + ) + case _: raise ValueError(f"Unsupported load pattern type: {lp.type}") diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py new file mode 100644 index 00000000..5a4d128d --- /dev/null +++ b/tests/integration/test_multi_turn.py @@ -0,0 +1,425 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for multi-turn benchmarking end-to-end. + +Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work +correctly together against a real HTTP echo server. + +Tests cover: + 1. Dataset-history mode (use_dataset_history=True): pre-built messages are + issued as-is; each turn is issued sequentially per conversation. + 2. Live-history mode (use_dataset_history=False): messages are built at + runtime from ConversationManager.message_history; the injected messages + grow with each turn. + 3. Multiple concurrent conversations complete successfully. + 4. Turn ordering: turn N+1 is never issued before turn N completes. +""" + +import asyncio +import random +import time +from urllib.parse import urljoin + +import pandas as pd +import pytest +from inference_endpoint import metrics +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import ( + LoadPattern, + LoadPatternType, + MultiTurnConfig, +) +from inference_endpoint.core.record import EventRecord +from inference_endpoint.core.types import QueryResult +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset +from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient +from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy +from inference_endpoint.load_generator.session import ( + BenchmarkSession, + PhaseConfig, + PhaseType, +) +from inference_endpoint.testing.echo_server import EchoServer + + +class _NoOpPublisher: + def publish(self, event_record: EventRecord) -> None: + pass + + def flush(self) -> None: + pass + + +def _make_dataset(rows: list[dict]) -> MultiTurnDataset: + """Build a loaded MultiTurnDataset from a list of row dicts.""" + df = pd.DataFrame(rows) + ds = MultiTurnDataset(dataframe=df) + ds.load() + return ds + + +def _make_strategy( + ds: MultiTurnDataset, + use_dataset_history: bool = True, +) -> MultiTurnStrategy: + mt_cfg = MultiTurnConfig( + turn_timeout_s=10.0, + use_dataset_history=use_dataset_history, + ) + return MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + +async def _run_session( + server_url: str, + ds: MultiTurnDataset, + strategy: MultiTurnStrategy, + responses_out: dict, +) -> int: + """Wire up HTTPEndpointClient + BenchmarkSession and run one phase. + + Populates responses_out[query_id] = response_text for every completed turn. + Returns issued_count. + """ + loop = asyncio.get_running_loop() + + def on_complete(result: QueryResult) -> None: + strategy.on_sample_complete(result) + responses_out[result.id] = result.get_response_output_string() + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(server_url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=2, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig( + "perf", + rt, + ds, + PhaseType.PERFORMANCE, + strategy=strategy, + ) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + return result.perf_results[0].issued_count + finally: + await http_client.shutdown_async() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_server(): + server = EchoServer(port=0) + server.start() + try: + yield server + finally: + server.stop() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_single_conversation_all_turns_issued(echo_server): + """All turns of a single conversation are issued and completed.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Bye"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Two user turns (turns 1 and 3); turn 2 is assistant so not a client turn + assert count == 2 + assert len(responses) == 2 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multiple_conversations_all_issued(echo_server): + """Multiple conversations complete independently and concurrently.""" + rows = [] + for conv_idx in range(3): + conv_id = f"conv_{conv_idx}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2 {conv_idx}", + } + ) + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 3 conversations Γ— 2 user turns each = 6 + assert count == 6 + assert len(responses) == 6 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_dataset_history_messages_present(echo_server): + """Dataset-history mode: each request contains the messages array from the dataset.""" + received_payloads: list[dict] = [] + + # Override get_response to capture the incoming request body. + # EchoServer._handle_echo_chat_completions_request parses it into + # CreateChatCompletionRequest β€” we capture the raw JSON at the HTTP layer + # by subclassing and overriding get_response (called with first user content). + # Instead, use a custom echo server that logs the full payload. + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "First question", + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "First answer", + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "user", + "content": "Second question", + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + # Both requests must include a "messages" array + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "messages" in payload + assert len(payload["messages"]) >= 1 + + # Turn 1 should have 1 user message; turn 3 should have 3 messages + # (system? no system here β€” user, assistant, user) + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_history_messages_grow_each_turn(echo_server): + """Live-history mode: messages array grows with each completed turn.""" + received_payloads: list[dict] = [] + + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Turn one"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Answer one", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Turn two"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=False) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + assert len(received_payloads) == 2 + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + # Turn 1: [user msg] = 1; Turn 3: [user, assistant, user] = 3 + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_turn_ordering_enforced_end_to_end(echo_server): + """Turn N+1 is issued after Turn N's response arrives, verified by timestamps.""" + complete_times: dict[str, float] = {} + + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "First"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Second"}, + ] + ds = _make_dataset(rows) + mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=True) + conv_manager = ConversationManager() + strategy = MultiTurnStrategy( + conversation_manager=conv_manager, + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + # Wrap on_sample_complete to record completion timestamps + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_on_sample_complete(result: QueryResult) -> None: + # Map query_id β†’ sample_index via uuid_to_index (set after session runs) + complete_times[result.id] = time.monotonic() + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_on_sample_complete + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(echo_server.url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=1, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + try: + + def on_complete(result: QueryResult) -> None: + tracked_on_sample_complete(result) + responses[result.id] = result.get_response_output_string() + + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + finally: + await http_client.shutdown_async() + + assert result.perf_results[0].issued_count == 2 + + # Build query_id β†’ sample_index from session result + uuid_to_index = result.perf_results[0].uuid_to_index + index_to_query = {v: k for k, v in uuid_to_index.items()} + + # Sample 0 = turn 1, sample 1 = turn 3 + q_turn1 = index_to_query[0] + q_turn3 = index_to_query[1] + + # Turn 3 must complete after turn 1 completes + assert complete_times[q_turn3] >= complete_times[q_turn1] diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index 09ccd224..a42fc1f3 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -128,8 +128,8 @@ def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): ) dataset.load() - # Should have 5 rows total (3 for conv_001, 2 for conv_002) - assert len(dataset.data) == 5 + # data contains only client turns (3 user turns), not all rows + assert len(dataset.data) == 3 # Should have 3 user turns (samples) - only user turns are indexed assert dataset.num_samples() == 3 @@ -137,18 +137,18 @@ def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): @pytest.mark.unit def test_multi_turn_dataset_user_turn_indexing(valid_multi_turn_jsonl): - """Test that only client turns (user + tool) are indexed as samples.""" + """Test that only client turns (user + tool) are stored as samples.""" dataset = MultiTurnDataset.load_from_file( valid_multi_turn_jsonl, format=DatasetFormat.JSONL ) dataset.load() - # Verify client turn indices are correct (fixture has only user turns) - assert len(dataset._client_turn_indices) == 3 + # data contains only client turns (fixture has only user turns) + assert dataset.num_samples() == 3 - # Check that indices point to client turns - for idx in dataset._client_turn_indices: - assert dataset.data[idx]["role"] in ("user", "tool") + # Every sample in data is a client turn + for i in range(dataset.num_samples()): + assert dataset.load_sample(i)["role"] in ("user", "tool") @pytest.mark.unit @@ -165,9 +165,9 @@ def test_multi_turn_dataset_load_sample(valid_multi_turn_jsonl): assert sample_0["turn"] == 1 assert sample_0["role"] == "user" assert sample_0["content"] == "Hello, how are you?" - # System prompt is in pre_built_messages, not as a separate field - assert sample_0["pre_built_messages"][0]["role"] == "system" - assert sample_0["pre_built_messages"][0]["content"] == "You are a helpful assistant" + # System prompt is the first message in the messages array + assert sample_0["messages"][0]["role"] == "system" + assert sample_0["messages"][0]["content"] == "You are a helpful assistant" # Sample 1 should be second user turn (conv_001 turn 3) sample_1 = dataset.load_sample(1) @@ -268,8 +268,8 @@ def test_multi_turn_dataset_multiple_conversations(): dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) dataset.load() - # 9 total rows, 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) - assert len(dataset.data) == 9 + # data contains only client turns: 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) + assert len(dataset.data) == 5 assert dataset.num_samples() == 5 # Metadata checks @@ -291,7 +291,7 @@ def test_multi_turn_dataset_multiple_conversations(): @pytest.mark.unit def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): - """Test system prompt is included as the first message in pre_built_messages. + """Test system prompt is included as the first message in the messages array. The system prompt is pre-baked into every client turn's message list so the conversation manager no longer needs to track it separately. @@ -301,16 +301,16 @@ def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): ) dataset.load() - # First sample: pre_built_messages starts with system message + # First sample: messages starts with system message sample_0 = dataset.load_sample(0) - assert "pre_built_messages" in sample_0 - msgs = sample_0["pre_built_messages"] + assert "messages" in sample_0 + msgs = sample_0["messages"] assert msgs[0]["role"] == "system" assert msgs[0]["content"] == "You are a helpful assistant" # Second sample (same conversation, turn 3): system message still first sample_1 = dataset.load_sample(1) - msgs_1 = sample_1["pre_built_messages"] + msgs_1 = sample_1["messages"] assert msgs_1[0]["role"] == "system" assert msgs_1[0]["content"] == "You are a helpful assistant" @@ -383,8 +383,8 @@ def test_multi_turn_dataset_conversation_grouping(): dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) dataset.load() - # 5 total rows, 3 user turns (c1t1, c1t3, c2t1) - assert len(dataset.data) == 5 + # data contains only client turns: 3 user turns (c1t1, c1t3, c2t1) + assert len(dataset.data) == 3 assert dataset.num_samples() == 3 # Load samples to verify conversation grouping @@ -485,14 +485,9 @@ def test_multi_turn_dataset_additional_fields(): dataset.load() sample = dataset.load_sample(0) - # Fields may or may not be present depending on how dataframe handles them - # Just check they're accessible if present - if "model" in sample: - assert sample["model"] == "gpt-4" - if "max_new_tokens" in sample: - assert sample["max_new_tokens"] == 256 - if "temperature" in sample: - assert sample["temperature"] == pytest.approx(0.7) + assert sample["model"] == "gpt-4" + assert sample["max_completion_tokens"] == 256 + assert sample["temperature"] == pytest.approx(0.7) finally: Path(temp_path).unlink() @@ -540,34 +535,33 @@ def test_multi_turn_dataset_openai_field_forwarding(): @pytest.mark.unit def test_multi_turn_dataset_all_generation_params(): - """Test that all generation parameters in GENERATION_PARAMS are forwarded.""" - from inference_endpoint.dataset_manager.multi_turn_dataset import GENERATION_PARAMS - - # Create dataset with all possible generation params + """Test that dataset-supplied generation parameters are forwarded to the sample.""" + # Create dataset with a representative set of generation params + row_params = { + "model": "test-model", + "max_completion_tokens": 100, + "stream": True, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "seed": 42, + "repetition_penalty": 1.1, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "stop": ["END"], + "n": 2, + "logit_bias": {"100": 10}, + "name": "TestEntity", + "user": "test_user_001", + "chat_template": "test_template", + } data = [ { "conversation_id": "c1", "turn": 1, "role": "user", "content": "Test", - # Include all params from GENERATION_PARAMS - "model": "test-model", - "max_new_tokens": 100, - "max_completion_tokens": 100, - "stream": True, - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "seed": 42, - "repetition_penalty": 1.1, - "frequency_penalty": 0.5, - "presence_penalty": 0.3, - "stop": ["END"], - "n": 2, - "logit_bias": {"100": 10}, - "name": "TestEntity", - "user": "test_user_001", - "chat_template": "test_template", + **row_params, }, { "conversation_id": "c1", @@ -588,13 +582,9 @@ def test_multi_turn_dataset_all_generation_params(): sample = dataset.load_sample(0) - # Verify all GENERATION_PARAMS fields are forwarded - # (excluding conversational fields like conversation_id, turn, role, content, system) - for param in GENERATION_PARAMS: - if param in data[0]: - assert ( - param in sample - ), f"Generation parameter '{param}' not forwarded to sample" + # All non-NaN row fields must appear in the pre-baked sample + for param in row_params: + assert param in sample, f"Parameter '{param}' not forwarded to sample" finally: Path(temp_path).unlink() @@ -888,7 +878,7 @@ def test_load_sample_merged_tool_row_has_no_content_key(): s1 = ds.load_sample(1) assert s1["role"] == "tool" assert "content" not in s1 # must NOT emit NaN - assert "pre_built_messages" in s1 + assert "messages" in s1 @pytest.mark.unit @@ -961,27 +951,27 @@ def test_build_metadata_pre_built_messages_no_tools(): @pytest.mark.unit -def test_load_sample_includes_pre_built_messages(): - """load_sample returns pre_built_messages with the complete message list.""" +def test_load_sample_includes_messages(): + """load_sample returns messages with the complete message list.""" df = _make_tool_sequence_df() ds = MultiTurnDataset(df) ds.load() s0 = ds.load_sample(0) # user turn 1 - assert "pre_built_messages" in s0 - msgs = s0["pre_built_messages"] + assert "messages" in s0 + msgs = s0["messages"] assert msgs[0]["role"] == "system" assert msgs[-1] == {"role": "user", "content": "What is the weather?"} s1 = ds.load_sample(1) # tool turn 3 assert s1["role"] == "tool" - msgs_t3 = s1["pre_built_messages"] + msgs_t3 = s1["messages"] # system + user(1) + asst_tc(2) + tool(3) = 4 messages assert len(msgs_t3) == 4 assert msgs_t3[-1]["role"] == "tool" s2 = ds.load_sample(2) # user turn 5 - msgs_t5 = s2["pre_built_messages"] + msgs_t5 = s2["messages"] # system + user(1) + asst_tc(2) + tool(3) + asst(4) + user(5) = 6 messages assert len(msgs_t5) == 6 @@ -1003,8 +993,8 @@ def test_client_turns_include_tool_rows(): @pytest.mark.unit -def test_pre_built_messages_include_prior_assistant_response(valid_multi_turn_jsonl): - """The terminal assistant response before each user turn is included in pre_built_messages.""" +def test_messages_include_prior_assistant_response(valid_multi_turn_jsonl): + """The terminal assistant response before each user turn is included in messages.""" dataset = MultiTurnDataset.load_from_file( valid_multi_turn_jsonl, format=DatasetFormat.JSONL ) @@ -1012,26 +1002,26 @@ def test_pre_built_messages_include_prior_assistant_response(valid_multi_turn_js # Sample 0: turn 1 (first user) β†’ just [system, user(1)] s0 = dataset.load_sample(0) - msgs_0 = s0["pre_built_messages"] + msgs_0 = s0["messages"] assert msgs_0[0]["role"] == "system" assert msgs_0[-1]["role"] == "user" # Sample 1: turn 3 (second user) β†’ [system, user(1), assistant(2), user(3)] s1 = dataset.load_sample(1) - msgs_1 = s1["pre_built_messages"] + msgs_1 = s1["messages"] assert len(msgs_1) == 4 assert msgs_1[2] == {"role": "assistant", "content": "I'm doing well, thank you!"} assert msgs_1[3]["role"] == "user" # Sample 2: turn 1 of conv_002 β†’ no prior assistant row s2 = dataset.load_sample(2) - msgs_2 = s2["pre_built_messages"] + msgs_2 = s2["messages"] assert all(m["role"] != "assistant" for m in msgs_2) @pytest.mark.unit -def test_pre_built_messages_no_cross_conversation_bleed(): - """Messages for conv_001 must not appear in conv_002's pre_built_messages.""" +def test_messages_no_cross_conversation_bleed(): + """Messages for conv_001 must not appear in conv_002's messages array.""" data = [ {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1 user"}, {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2 user"}, @@ -1049,25 +1039,25 @@ def test_pre_built_messages_no_cross_conversation_bleed(): # c1: only its own user message s_c1 = dataset.load_sample(0) - assert s_c1["pre_built_messages"] == [{"role": "user", "content": "c1 user"}] + assert s_c1["messages"] == [{"role": "user", "content": "c1 user"}] # c2: only c2 messages (no c1 content) s_c2 = dataset.load_sample(1) - contents = [m.get("content") for m in s_c2["pre_built_messages"]] + contents = [m.get("content") for m in s_c2["messages"]] assert "c1 user" not in contents finally: Path(temp_path).unlink() @pytest.mark.unit -def test_pre_built_messages_with_tool_sequence_terminal_assistant(): - """Terminal assistant response (turn 4) appears in pre_built_messages for user(5).""" +def test_messages_with_tool_sequence_terminal_assistant(): + """Terminal assistant response (turn 4) appears in messages for user(5).""" df = _make_tool_sequence_df() ds = MultiTurnDataset(df) ds.load() s2 = ds.load_sample(2) # user turn 5 - msgs = s2["pre_built_messages"] + msgs = s2["messages"] # The terminal assistant at turn 4 should be included assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] assert any(m["content"] == "The weather is 22Β°C." for m in assistant_msgs) diff --git a/tests/unit/dataset_manager/test_transforms.py b/tests/unit/dataset_manager/test_transforms.py index ab342204..5eca41b4 100644 --- a/tests/unit/dataset_manager/test_transforms.py +++ b/tests/unit/dataset_manager/test_transforms.py @@ -23,6 +23,7 @@ import pandas as pd import pytest from inference_endpoint.dataset_manager.transforms import ( + AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -824,4 +825,51 @@ def test_no_matching_columns(self): # Should not raise error or create prompt column assert "prompt" not in result.columns - assert "unrelated" in result.columns + + +class TestAddDefaultColumns: + """Unit tests for AddDefaultColumns transform.""" + + @pytest.mark.unit + def test_fills_missing_columns(self): + """New columns are added when absent.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddDefaultColumns({"b": 10, "c": "x"})(df) + assert list(result["b"]) == [10, 10] + assert list(result["c"]) == ["x", "x"] + + @pytest.mark.unit + def test_preserves_existing_non_null_values(self): + """Existing non-null values are not overwritten.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddDefaultColumns({"a": 99})(df) + assert list(result["a"]) == [1, 2] + + @pytest.mark.unit + def test_fills_nan_values_in_existing_column(self): + """NaN cells in an existing column are replaced with the default.""" + + df = pd.DataFrame({"a": [1.0, float("nan"), 3.0]}) + result = AddDefaultColumns({"a": 99})(df) + assert result["a"].tolist()[0] == 1.0 + assert result["a"].tolist()[1] == 99 + assert result["a"].tolist()[2] == 3.0 + + @pytest.mark.unit + def test_skips_none_default_values(self): + """A None default value is ignored; the column is not modified.""" + df = pd.DataFrame({"a": [1]}) + original_a = df["a"].copy() + result = AddDefaultColumns({"a": None, "b": None})(df) + assert list(result["a"]) == list(original_a) + assert "b" not in result.columns + + @pytest.mark.unit + def test_mixed_nan_and_real_values(self): + """Only NaN cells are filled; real values in the same column are preserved.""" + + df = pd.DataFrame({"temp": [0.9, float("nan"), 0.5]}) + result = AddDefaultColumns({"temp": 0.7})(df) + assert result["temp"].tolist()[0] == pytest.approx(0.9) + assert result["temp"].tolist()[1] == pytest.approx(0.7) + assert result["temp"].tolist()[2] == pytest.approx(0.5) diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py new file mode 100644 index 00000000..62602626 --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -0,0 +1,396 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging + +import pytest +from inference_endpoint.load_generator.conversation_manager import ( + ConversationManager, + ConversationState, +) + + +@pytest.mark.unit +def test_conversation_state_initialization(): + """Test ConversationState initializes with correct default values.""" + state = ConversationState(conversation_id="conv_001") + + assert state.conversation_id == "conv_001" + assert state.current_turn == 0 + assert state.pending_client_turn is None + + +@pytest.mark.unit +def test_conversation_state_add_client_turn(): + """Test adding a client turn updates sequencing state.""" + state = ConversationState(conversation_id="conv_001") + + state.add_client_turn(1) + + assert state.pending_client_turn == 1 + assert state.issued_client_turns == 1 + assert state.current_turn == 0 # Not incremented until assistant response + + +@pytest.mark.unit +def test_conversation_state_add_assistant_turn(): + """Test adding assistant turn completes turn cycle.""" + state = ConversationState(conversation_id="conv_001") + + state.add_client_turn(1) + state.add_assistant_turn() + + assert state.current_turn == 2 + assert state.pending_client_turn is None + assert state.completed_client_turns == 1 + + +@pytest.mark.unit +def test_conversation_state_late_response_after_complete_is_silently_ignored(caplog): + """Late response for a conversation that already completed is silently dropped.""" + state = ConversationState(conversation_id="conv_001", expected_client_turns=1) + + state.add_client_turn(1) + state.add_assistant_turn() + assert state.is_complete() + + completed_before = state.completed_client_turns + current_turn_before = state.current_turn + + with caplog.at_level(logging.WARNING): + state.add_assistant_turn() + + assert state.completed_client_turns == completed_before + assert state.current_turn == current_turn_before + assert "no pending client turn" not in caplog.text + + +@pytest.mark.unit +def test_conversation_state_is_ready_for_turn(): + """Test turn readiness checks using completion counts.""" + state = ConversationState(conversation_id="conv_001") + + assert not state.is_ready_for_turn() + + state.add_client_turn(1) + assert not state.is_ready_for_turn() + + state.add_assistant_turn() + assert state.is_ready_for_turn() + + state.add_client_turn(2) + assert not state.is_ready_for_turn() + + state.add_assistant_turn() + assert state.is_ready_for_turn() + + +@pytest.mark.unit +def test_conversation_state_multi_turn_sequence(): + """Test multi-turn conversation flow updates current_turn correctly.""" + state = ConversationState(conversation_id="conv_001") + + state.add_client_turn(1) + state.add_assistant_turn() + assert state.current_turn == 2 + + state.add_client_turn(3) + state.add_assistant_turn() + assert state.current_turn == 4 + + state.add_client_turn(5) + state.add_assistant_turn() + assert state.current_turn == 6 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_get_or_create(): + """Test get_or_create returns same state for same conversation_id.""" + manager = ConversationManager() + + state1 = await manager.get_or_create("conv_001") + state2 = await manager.get_or_create("conv_001") + + assert state1 is state2 + assert state1.conversation_id == "conv_001" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_multiple_conversations(): + """Test manager can track multiple conversations independently.""" + manager = ConversationManager() + + state1 = await manager.get_or_create("conv_001") + state2 = await manager.get_or_create("conv_002") + + assert state1 is not state2 + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "Response to conv_001") + + assert state1.current_turn == 2 + assert state2.current_turn == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_mark_turn_issued(): + """Test mark_turn_issued updates sequencing state.""" + manager = ConversationManager() + state = await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + + assert state.pending_client_turn == 1 + assert state.issued_client_turns == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_mark_turn_complete(): + """Test mark_turn_complete updates sequencing state.""" + manager = ConversationManager() + state = await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "Assistant response") + + assert state.current_turn == 2 + assert state.pending_client_turn is None + assert state.completed_client_turns == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_immediate(): + """Test wait_for_turn_ready returns immediately when previous turn is complete.""" + manager = ConversationManager() + await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "First response") + + result = await manager.wait_for_turn_ready("conv_001", 9, timeout=1.0) + + assert result is True + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_blocking(): + """Test wait_for_turn_ready blocks until previous turn completes.""" + manager = ConversationManager() + await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + + ready_flag = [] + + async def waiter(): + result = await manager.wait_for_turn_ready("conv_001", 3, timeout=2.0) + if result: + ready_flag.append(True) + + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.05) + assert not ready_flag + + await manager.mark_turn_complete("conv_001", "Assistant response") + await asyncio.sleep(0.05) + await waiter_task + + assert ready_flag == [True] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_timeout(): + """Test wait_for_turn_ready respects timeout.""" + manager = ConversationManager() + await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + + result = await manager.wait_for_turn_ready("conv_001", 3, timeout=0.1) + + assert result is False + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_completion_tracking(): + """Test conversation completion detection.""" + manager = ConversationManager() + + state = await manager.get_or_create("conv_001", expected_client_turns=2) + + assert not state.is_complete() + + await manager.mark_turn_issued("conv_001", 1) + assert not state.is_complete() + + await manager.mark_turn_complete("conv_001", "response 1") + assert not state.is_complete() + + await manager.mark_turn_issued("conv_001", 3) + await manager.mark_turn_complete("conv_001", "response 2") + + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_completion_without_expected_turns(): + """Test that completion tracking works when expected_client_turns is None.""" + manager = ConversationManager() + + state = await manager.get_or_create("conv_001", expected_client_turns=None) + + assert not state.is_complete() + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "response 1") + + assert not state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_completion_with_failures(): + """Test that conversations complete even when turns fail.""" + manager = ConversationManager() + state = await manager.get_or_create("conv1", expected_client_turns=3) + + await manager.mark_turn_issued("conv1", 1) + await manager.mark_turn_complete("conv1", "Hi there") + assert state.completed_client_turns == 1 + assert not state.is_complete() + + await manager.mark_turn_issued("conv1", 2) + await manager.mark_turn_failed("conv1") + assert state.completed_client_turns == 2 + assert state.failed_client_turns == 1 + assert not state.is_complete() + + await manager.mark_turn_issued("conv1", 3) + await manager.mark_turn_complete("conv1", "Bye!") + assert state.completed_client_turns == 3 + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_mark_turn_failed_with_no_pending(): + """Test that marking failed turn without pending turn logs warning.""" + manager = ConversationManager() + state = await manager.get_or_create("conv1", expected_client_turns=1) + + await manager.mark_turn_failed("conv1") + + assert state.completed_client_turns == 0 + assert state.failed_client_turns == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_all_turns_fail(): + """Test conversation completion when all turns fail.""" + manager = ConversationManager() + state = await manager.get_or_create("conv1", expected_client_turns=2) + + await manager.mark_turn_issued("conv1", 1) + await manager.mark_turn_failed("conv1") + + await manager.mark_turn_issued("conv1", 2) + await manager.mark_turn_failed("conv1") + + assert state.is_complete() + assert state.completed_client_turns == 2 + assert state.failed_client_turns == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_concurrent_access(): + """Test async concurrent access to multiple conversations.""" + manager = ConversationManager() + num_conversations = 10 + user_turns_per_conv = 5 + + for i in range(num_conversations): + await manager.get_or_create(f"conv_{i:03d}") + + errors = [] + + async def process_conversation(conv_id: str): + try: + for user_turn_idx in range(user_turns_per_conv): + turn = user_turn_idx * 2 + 1 + + if user_turn_idx > 0: + ready = await manager.wait_for_turn_ready( + conv_id, turn, timeout=5.0 + ) + if not ready: + errors.append(f"{conv_id} turn {turn} timeout") + return + + await manager.mark_turn_issued(conv_id, turn) + await asyncio.sleep(0.001) + await manager.mark_turn_complete(conv_id, f"Response {turn}") + except Exception as e: + errors.append(f"{conv_id} error: {e}") + + tasks = [ + asyncio.create_task(process_conversation(f"conv_{i:03d}")) + for i in range(num_conversations) + ] + await asyncio.gather(*tasks) + + assert not errors, f"Errors occurred: {errors}" + + for i in range(num_conversations): + conv_id = f"conv_{i:03d}" + state = manager._conversations[conv_id] + assert state.current_turn == user_turns_per_conv * 2 + assert state.completed_client_turns == user_turns_per_conv + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_reliably_wakes_on_completion(): + """Test completion wakeups do not depend on timing windows.""" + + async def run_one_iteration(): + mgr = ConversationManager() + await mgr.get_or_create("conv_001") + await mgr.mark_turn_issued("conv_001", 1) + + ready: list[bool] = [] + + async def waiter(m: ConversationManager, r: list) -> None: + r.append(await m.wait_for_turn_ready("conv_001", 3, timeout=0.5)) + + waiter_task = asyncio.create_task(waiter(mgr, ready)) + await asyncio.sleep(0.005) + await mgr.mark_turn_complete("conv_001", "Assistant response") + await asyncio.wait_for(waiter_task, timeout=0.5) + assert ready == [True] + + for _ in range(10): + await run_one_iteration() diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py new file mode 100644 index 00000000..55c51994 --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MultiTurnStrategy.""" + +import asyncio + +import pytest +from inference_endpoint.core.types import QueryResult, TextModelOutput +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy + + +class FakePhaseIssuer: + """Minimal PhaseIssuerProtocol stub.""" + + def __init__(self, stop_after: int | None = None): + self._count = 0 + self._stop_after = stop_after + self.issued: list[int] = [] + self.issued_count = 0 + + def issue(self, sample_index: int, data_override: dict | None = None) -> str | None: + if self._stop_after is not None and self._count >= self._stop_after: + return None + self._count += 1 + self.issued_count += 1 + query_id = f"q{sample_index:04d}" + self.issued.append(sample_index) + return query_id + + +def _make_dataset_metadata(conversations: dict[str, list[int]]) -> dict: + """Build dataset_metadata dict from {conv_id: [turn_numbers]} mapping.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + { + "conversation_id": conv_id, + "turn": turn, + "sample_index": sample_index, + } + ) + sample_index += 1 + return {"samples": samples} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_single_turn(): + """Single conversation, single turn β€” should issue exactly one sample.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + # Simulate response completion (turn 1 is issued, then completes) + async def complete_turns(): + # Wait a tick for the strategy to issue the first turn + await asyncio.sleep(0.01) + # Mark turn 1 complete + state = conv_manager.get_state("conv1") + if state: + await conv_manager.mark_turn_complete("conv1", "response 1") + + asyncio.create_task(complete_turns()) + count = await strategy.execute(issuer) + + assert count == 1 + assert issuer.issued == [0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_multi_turn(): + """Single conversation, 3 turns β€” turns must be issued sequentially.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3, 5]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + issued_order: list[str] = [] + original_issue = issuer.issue + + def tracked_issue(idx, data_override=None): + q = original_issue(idx, data_override=data_override) + if q: + issued_order.append(q) + return q + + issuer.issue = tracked_issue + + async def simulate_responses(): + await asyncio.sleep(0.01) + for turn_q, resp in [("q0000", "r1"), ("q0001", "r2"), ("q0002", "r3")]: + # Signal turn complete via on_sample_complete + result = QueryResult( + id=turn_q, response_output=TextModelOutput(output=resp) + ) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 3 + assert issuer.issued == [0, 1, 2] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_multiple_conversations_concurrent(): + """Two conversations run concurrently, each with 2 turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3], "conv2": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def simulate_responses(): + await asyncio.sleep(0.02) + # Complete all turns for both conversations + for q_prefix in range(4): + q = f"q{q_prefix:04d}" + result = QueryResult(id=q, response_output=TextModelOutput(output="resp")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 4 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_ordering_enforced(): + """Turn 2 must not be issued before Turn 1 completes.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + issue_timestamps: dict[int, float] = {} + complete_timestamps: dict[int, float] = {} + + class TimedIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue(self, idx: int, data_override: dict | None = None) -> str | None: + import time + + issue_timestamps[idx] = time.monotonic() + self.issued.append(idx) + self.issued_count += 1 + return f"q{idx:04d}" + + issuer = TimedIssuer() + + async def simulate_responses(): + import time + + await asyncio.sleep(0.02) + # Complete turn 1 (sample 0) after a delay + complete_timestamps[0] = time.monotonic() + result = QueryResult(id="q0000", response_output=TextModelOutput(output="r1")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.05) + # Complete turn 2 (sample 1) + complete_timestamps[1] = time.monotonic() + result = QueryResult(id="q0001", response_output=TextModelOutput(output="r2")) + strategy.on_sample_complete(result) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 2 + # Turn 2 (sample index 1) must be issued AFTER turn 1 completes + assert issue_timestamps[1] >= complete_timestamps[0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_timeout_triggers_failure(): + """A turn that never completes should timeout and abort remaining turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=None) + strategy._turn_timeout_s = 0.1 # Very short timeout for testing + issuer = FakePhaseIssuer() + + # Do NOT simulate any response β€” turn 1 will timeout + await strategy.execute(issuer) + + # Only turn 1 should be issued (turn 2 never gets to run) + assert issuer.issued_count == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_query_complete_releases_semaphore(): + """on_query_complete releases the concurrency semaphore.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) + assert strategy._sem is not None + + # Acquire the semaphore manually + await strategy._sem.acquire() + assert strategy._sem._value == 0 # type: ignore[attr-defined] + + strategy.on_query_complete("some-query") + assert strategy._sem._value == 1 # type: ignore[attr-defined] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_routes_to_manager(): + """on_sample_complete marks the turn complete in the ConversationManager.""" + conv_manager = ConversationManager() + await conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + # Simulate issuer registering conv_id in _inflight + strategy._inflight["q0001"] = "conv1" + # Pre-issue a turn so the state has pending_client_turn + await conv_manager.mark_turn_issued("conv1", 1) + + result = QueryResult(id="q0001", response_output=TextModelOutput(output="hello")) + strategy.on_sample_complete(result) + + # Allow the ensure_future coroutine to run + await asyncio.sleep(0.01) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_client_turns == 1 + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_error_response_marks_turn_failed(): + """on_sample_complete marks failed when result.error is set.""" + from inference_endpoint.core.types import ErrorData + + conv_manager = ConversationManager() + await conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + strategy._inflight["q0001"] = "conv1" + await conv_manager.mark_turn_issued("conv1", 1) + + result = QueryResult( + id="q0001", + response_output=None, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.failed_client_turns == 1 From 057600b03e63f0056e50a34f4de529aa4bddd821 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 24 Apr 2026 10:32:39 -0700 Subject: [PATCH 03/13] test: add multi-turn unit and integration tests Add unit tests for MultiTurnDataset, ConversationManager, and MultiTurnStrategy; add integration tests including tool-use scenarios and large-concurrency stress tests. --- tests/integration/test_multi_turn.py | 309 ++++++++++++++++++++++++++- 1 file changed, 308 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 5a4d128d..ca17a236 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -16,7 +16,8 @@ """Integration tests for multi-turn benchmarking end-to-end. Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work -correctly together against a real HTTP echo server. +correctly together against a real HTTP echo server (echo tests) and a live +model endpoint (live tests at port 8868). Tests cover: 1. Dataset-history mode (use_dataset_history=True): pre-built messages are @@ -26,12 +27,16 @@ grow with each turn. 3. Multiple concurrent conversations complete successfully. 4. Turn ordering: turn N+1 is never issued before turn N completes. + 5. Live concurrency: parametrized target_concurrency levels against a real + model endpoint verify all turns complete regardless of throttle setting. """ import asyncio +import json import random import time from urllib.parse import urljoin +from urllib.request import urlopen import pandas as pd import pytest @@ -423,3 +428,305 @@ def on_complete(result: QueryResult) -> None: # Turn 3 must complete after turn 1 completes assert complete_times[q_turn3] >= complete_times[q_turn1] + + +# --------------------------------------------------------------------------- +# Live endpoint fixtures and helpers +# --------------------------------------------------------------------------- + +_LIVE_ENDPOINT = "http://localhost:8868" + + +def _query_model_name(endpoint: str) -> str: + """Return the first model name from the endpoint, or skip if unreachable.""" + try: + with urlopen(f"{endpoint}/v1/models", timeout=5.0) as resp: + data = json.loads(resp.read()) + return data["data"][0]["id"] + except Exception as e: + pytest.skip(f"Live endpoint {endpoint} not reachable: {e}") + return "" + + +def _make_live_rows( + model: str, n_conversations: int = 20, n_user_turns: int = 3 +) -> list[dict]: + """Build a multi-conversation dataset rows list. + + Each conversation has n_user_turns user turns interleaved with scripted + assistant placeholders (needed to satisfy the turn-structure validator but + never sent to the endpoint). The resulting dataset produces + n_conversations Γ— n_user_turns client-turn samples. + """ + rows = [] + _user_prompts = [ + "Reply with exactly one word: the number {n} in English.", + "Add one to the previous number. Reply with only that word.", + "Add one more. Reply with only that word.", + ] + for i in range(n_conversations): + conv_id = f"live_conv_{i:03d}" + turn = 1 + for j in range(n_user_turns): + prompt = _user_prompts[j % len(_user_prompts)].format(n=i + 1) + rows.append( + { + "conversation_id": conv_id, + "turn": turn, + "role": "user", + "content": prompt, + "model": model, + "max_completion_tokens": 10, + } + ) + turn += 1 + if j < n_user_turns - 1: + rows.append( + { + "conversation_id": conv_id, + "turn": turn, + "role": "assistant", + "content": "placeholder", + } + ) + turn += 1 + return rows + + +async def _run_live_session( + model: str, + n_conversations: int, + n_user_turns: int, + target_concurrency: int | None, + timeout_s: float = 300.0, +) -> tuple[int, dict[str, str]]: + """Run a live multi-turn session against the endpoint at _LIVE_ENDPOINT. + + Returns (issued_count, {query_id: response_text}). + """ + rows = _make_live_rows(model, n_conversations, n_user_turns) + ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) + ds.load() + + mt_cfg = MultiTurnConfig( + turn_timeout_s=60.0, + use_dataset_history=True, + ) + strategy = MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + target_concurrency=target_concurrency, + ) + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + def on_complete(result: QueryResult) -> None: + strategy.on_sample_complete(result) + responses[result.id] = result.get_response_output_string() + + http_config = HTTPClientConfig( + endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], + warmup_connections=0, + num_workers=4, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=int(timeout_s * 1000), + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=timeout_s) + return result.perf_results[0].issued_count, responses + finally: + await http_client.shutdown_async() + + +# --------------------------------------------------------------------------- +# Live concurrency tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "target_concurrency", + [ + pytest.param(1, id="concurrency_1"), + pytest.param(4, id="concurrency_4"), + pytest.param(None, id="concurrency_unlimited"), + ], +) +async def test_live_concurrency(target_concurrency): + """All turns of 20 concurrent conversations complete for each concurrency level. + + Uses the live model endpoint at port 8868. Each conversation has 3 user + turns (60 total requests). Verifies that every turn receives a non-empty + response regardless of the concurrency throttle applied by target_concurrency. + """ + model = _query_model_name(_LIVE_ENDPOINT) + n_conversations = 20 + n_user_turns = 3 + expected_turns = n_conversations * n_user_turns # 60 total requests + + issued, responses = await _run_live_session( + model=model, + n_conversations=n_conversations, + n_user_turns=n_user_turns, + target_concurrency=target_concurrency, + timeout_s=300.0, + ) + + assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" + assert ( + len(responses) == expected_turns + ), f"Expected {expected_turns} responses, got {len(responses)}" + for qid, text in responses.items(): + assert text.strip(), f"Query {qid} returned empty response" + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_turn_ordering_multi_conversation(): + """Turn N+1 of each conversation is always issued after turn N completes. + + Runs 10 conversations with 3 turns each concurrently (30 total requests). + Records per-query completion timestamps and asserts that within every + conversation each successive turn completes no earlier than the previous. + """ + model = _query_model_name(_LIVE_ENDPOINT) + n_conversations = 10 + n_user_turns = 3 + rows = _make_live_rows(model, n_conversations, n_user_turns) + + ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) + ds.load() + + conv_manager = ConversationManager() + mt_cfg = MultiTurnConfig(turn_timeout_s=60.0, use_dataset_history=True) + strategy = MultiTurnStrategy( + conversation_manager=conv_manager, + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + complete_times: dict[str, float] = {} + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_complete(result: QueryResult) -> None: + complete_times[result.id] = time.monotonic() + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_complete + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + http_config = HTTPClientConfig( + endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], + warmup_connections=0, + num_workers=4, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + + def on_complete(result: QueryResult) -> None: + tracked_complete(result) + responses[result.id] = result.get_response_output_string() + + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=300_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=300.0) + finally: + await http_client.shutdown_async() + + expected_total = n_conversations * n_user_turns + assert result.perf_results[0].issued_count == expected_total + + # Build index β†’ query_id map and verify per-conversation ordering. + # Samples are grouped by conversation, turns sorted ascending within each: + # conv_0_t1, conv_0_t2, conv_0_t3, conv_1_t1, ... + uuid_to_index = result.perf_results[0].uuid_to_index + index_to_query = {v: k for k, v in uuid_to_index.items()} + + for conv_i in range(n_conversations): + base = conv_i * n_user_turns + for turn_j in range(n_user_turns - 1): + q_cur = index_to_query[base + turn_j] + q_next = index_to_query[base + turn_j + 1] + assert complete_times[q_cur] <= complete_times[q_next], ( + f"conv {conv_i}: turn {turn_j + 2} completed before turn {turn_j + 1} " + f"(t{turn_j + 1}={complete_times[q_cur]:.4f}, " + f"t{turn_j + 2}={complete_times[q_next]:.4f})" + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_large_concurrency(): + """All turns complete correctly under a large concurrency limit (>=512). + + Uses 200 conversations Γ— 3 turns = 600 total requests with + target_concurrency=512. The semaphore allows up to 512 simultaneous + in-flight requests, so the first wave of 200 first-turns is issued + without throttling, and subsequent turns queue naturally. Verifies + that all 600 turns complete and return non-empty responses, confirming + the semaphore implementation handles large values without deadlock or + starvation. + """ + model = _query_model_name(_LIVE_ENDPOINT) + n_conversations = 200 + n_user_turns = 3 + expected_turns = n_conversations * n_user_turns # 600 total requests + + issued, responses = await _run_live_session( + model=model, + n_conversations=n_conversations, + n_user_turns=n_user_turns, + target_concurrency=512, + timeout_s=300.0, + ) + + assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" + assert ( + len(responses) == expected_turns + ), f"Expected {expected_turns} responses, got {len(responses)}" + for qid, text in responses.items(): + assert text.strip(), f"Query {qid} returned empty response" From 109434d7c97a088d219941f3ef617f0880e26438 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 24 Apr 2026 11:20:42 -0700 Subject: [PATCH 04/13] feat: wire multi-turn into benchmark execution pipeline Consolidate multi-turn dataset with single-turn transform pipeline, fix prior-row extraction, live-history mode, system prompt injection, tool_calls preservation, and asyncio.Event-based sequencing. --- .../09_MultiTurn/multi_turn_benchmark.yaml | 9 +- .../multi_turn_with_concurrency.yaml | 7 - .../commands/benchmark/execute.py | 1 + .../config/runtime_settings.py | 11 + src/inference_endpoint/config/schema.py | 19 + .../dataset_manager/multi_turn_dataset.py | 59 ++- .../load_generator/conversation_manager.py | 348 +++---------- .../load_generator/multi_turn_strategy.py | 96 ++-- .../openai/openai_adapter.py | 1 + tests/integration/test_multi_turn.py | 474 +++++++----------- tests/unit/config/test_schema.py | 120 +++++ .../test_multi_turn_dataset.py | 363 ++++++++++++++ .../test_multi_turn_conversation_manager.py | 386 +++++--------- .../test_multi_turn_strategy.py | 125 ++++- tests/unit/openai/test_openai_adapter.py | 147 ++++++ 15 files changed, 1282 insertions(+), 884 deletions(-) create mode 100644 tests/unit/openai/test_openai_adapter.py diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml index 9ed6c9f1..da4773e0 100644 --- a/examples/09_MultiTurn/multi_turn_benchmark.yaml +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -24,18 +24,11 @@ settings: load_pattern: type: multi_turn - # target_concurrency: 32 # Optional: limit concurrent requests across all conversations + target_concurrency: 32 client: warmup_connections: 0 -metrics: - collect: - - throughput - - latency - - ttft - - tpot - endpoint_config: endpoints: - "http://localhost:8868" diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml index 491e6b4b..ba5362e3 100644 --- a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -29,13 +29,6 @@ settings: client: warmup_connections: 0 -metrics: - collect: - - throughput - - latency - - ttft - - tpot - endpoint_config: endpoints: - "http://localhost:8868" diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index e107e611..30411af0 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -542,6 +542,7 @@ async def _run_benchmark_async( conversation_manager=ConversationManager(), dataset_metadata=ctx.dataloader.conversation_metadata, multi_turn_config=mt_cfg, + target_concurrency=ctx.config.settings.load_pattern.target_concurrency, ) def _on_sample_complete(result: QueryResult) -> None: diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index fb349a02..a3fb3106 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -194,6 +194,17 @@ def total_samples_to_issue( ) return self.n_samples_to_issue + # Multi-turn must issue exactly all client turns β€” QPS-based formulas are meaningless. + if ( + self.load_pattern is not None + and self.load_pattern.type.value == "multi_turn" + ): + result = max(self.min_sample_count, self.n_samples_from_dataset) + logger.debug( + f"Sample count: {result} (multi-turn: issuing all {self.n_samples_from_dataset} client turns)" + ) + return result + # If min_duration is 0, use all dataset samples (new CLI default behavior) if self.min_duration_ms == 0: result = max(self.min_sample_count, self.n_samples_from_dataset) diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 5b122c46..f34f82d5 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -419,6 +419,12 @@ def _validate_completeness(self) -> Self: raise ValueError( "Concurrency requires --concurrency (e.g., --concurrency 10)" ) + if self.type == LoadPatternType.MULTI_TURN and ( + not self.target_concurrency or self.target_concurrency <= 0 + ): + raise ValueError( + "Multi-turn requires --concurrency (e.g., --concurrency 96)" + ) return self @@ -623,6 +629,19 @@ def _resolve_and_validate(self) -> Self: "Online mode requires --load-pattern (poisson, concurrency, or multi_turn)" ) + # Cross-validate load_pattern.type=multi_turn ↔ dataset.multi_turn config + has_multi_turn_dataset = any( + d.multi_turn is not None for d in (self.datasets or []) + ) + if lp.type == LoadPatternType.MULTI_TURN and not has_multi_turn_dataset: + raise ValueError( + "load_pattern.type=multi_turn requires at least one dataset with multi_turn config" + ) + if has_multi_turn_dataset and lp.type != LoadPatternType.MULTI_TURN: + raise ValueError( + f"Datasets with multi_turn config require load_pattern.type=multi_turn, got '{lp.type}'" + ) + return self @model_validator(mode="after") diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index f26c3d3e..574619c8 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -195,6 +195,8 @@ def _build_metadata(self) -> dict[str, Any]: # This includes assistant rows (tool dispatches or terminal responses) # so no runtime injection is required. pre_built_messages_by_key: dict[tuple, list[dict]] = {} + current_turn_messages_by_key: dict[tuple, list[dict]] = {} + system_prompts_by_conv: dict[str, str | None] = {} for conv_id, group in self.dataframe.groupby("conversation_id"): sorted_group = group.sort_values("turn") @@ -207,6 +209,7 @@ def _build_metadata(self) -> dict[str, Any]: if val and isinstance(val, str): system_content = val break + system_prompts_by_conv[str(conv_id)] = system_content for idx, row in client_rows.iterrows(): t_n = int(row["turn"]) @@ -220,12 +223,18 @@ def _build_metadata(self) -> dict[str, Any]: prior_rows = sorted_group[sorted_group["turn"] < t_n] for _, prior_row in prior_rows.iterrows(): msg: dict[str, Any] = {} - for key in ("role", "content", "tool_calls"): + for key in ("role", "content", "tool_calls", "tool_results"): val = prior_row.get(key) if val is not None and not ( isinstance(val, float) and pd.isna(val) ): msg[key] = val + if ( + msg.get("role") == "assistant" + and "tool_calls" in msg + and "content" not in msg + ): + msg["content"] = None if msg.get("role"): # Expand merged parallel tool results: a single row with # tool_results: [{tool_call_id, content}, ...] expands into @@ -239,9 +248,10 @@ def _build_metadata(self) -> dict[str, Any]: # Append the current client turn message. # A merged parallel-tool row carries tool_results instead of a # single tool_call_id/content pair; expand to one message per result. + current_turn_msgs: list[dict] = [] expanded = _expand_tool_results(row) if expanded: - messages.extend(expanded) + current_turn_msgs = expanded else: cur: dict[str, Any] = {} for key in ("role", "content"): @@ -250,9 +260,11 @@ def _build_metadata(self) -> dict[str, Any]: isinstance(val, float) and pd.isna(val) ): cur[key] = val - messages.append(cur) + current_turn_msgs = [cur] + messages.extend(current_turn_msgs) pre_built_messages_by_key[(conv_id, t_n)] = messages + current_turn_messages_by_key[(conv_id, t_n)] = current_turn_msgs samples.append( { @@ -270,6 +282,8 @@ def _build_metadata(self) -> dict[str, Any]: .max(), "client_turns_per_conversation": client_turns_per_conv, "pre_built_messages_by_key": pre_built_messages_by_key, + "current_turn_messages_by_key": current_turn_messages_by_key, + "system_prompts_by_conv": system_prompts_by_conv, } def load( @@ -287,8 +301,8 @@ def load( per-row dataset overrides are preserved. After transforms, only client turns (user + tool) are stored in self.data as - fully assembled sample dicts (with messages, current_turn_message, system_content - attached). load_sample() and num_samples() are inherited from the base class. + fully assembled sample dicts (with messages attached). + load_sample() and num_samples() are inherited from the base class. """ if not force and self.data is not None: return @@ -325,6 +339,26 @@ def load( pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}) client_turn_samples: list[dict[str, Any]] = [] + # Collect per-conversation defaults from the first user row so that + # fields like model/max_completion_tokens propagate to tool rows. + _PROPAGATED_KEYS = { + "model", + "max_completion_tokens", + "max_new_tokens", + "stream", + } + conv_defaults: dict[str, dict[str, Any]] = {} + for row in all_rows: + cid = row.get("conversation_id") + if cid not in conv_defaults and row.get("role") == "user": + conv_defaults[cid] = { + k: row[k] + for k in _PROPAGATED_KEYS + if k in row + and row[k] is not None + and not (isinstance(row[k], float) and pd.isna(row[k])) + } + for row in all_rows: if row.get("role") not in ("user", "tool"): continue @@ -336,6 +370,14 @@ def load( for k, v in row.items() if v is not None and not (isinstance(v, float) and pd.isna(v)) } + # Strip dataset-internal fields that must not reach the endpoint. + sample.pop("tool_results", None) + sample.pop("tool_calls", None) + + # Fill missing propagated fields from the first user row of this conversation. + for k, v in conv_defaults.get(row.get("conversation_id"), {}).items(): + if k not in sample: + sample[k] = v # max_new_tokens β†’ max_completion_tokens alias if "max_completion_tokens" not in sample and "max_new_tokens" in sample: @@ -350,13 +392,6 @@ def load( messages = pre_built.get(key, []) sample["messages"] = messages - # Fields for use_dataset_history=False path (live history accumulation). - sample["current_turn_message"] = messages[-1] if messages else {} - first = messages[0] if messages else {} - sample["system_content"] = ( - first.get("content") if first.get("role") == "system" else None - ) - client_turn_samples.append(sample) self.data = client_turn_samples diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 86741711..ba9a02ea 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Async conversation state management for multi-turn benchmarking.""" +"""Conversation state management for multi-turn benchmarking.""" import asyncio import logging @@ -25,332 +25,150 @@ @dataclass class ConversationState: - """Tracks conversation sequencing for multi-turn benchmarking. + """Per-conversation state for multi-turn benchmarking. - Maintains turn counters and asyncio conditions so the strategy can enforce - sequential turn ordering within a conversation. Message history is NOT stored - here β€” it is pre-computed in MultiTurnDataset and served via load_sample(). + The pipeline task awaits ``turn_done`` between turns; ``mark_turn_complete`` + and ``mark_turn_failed`` set it synchronously from ``on_sample_complete``. Attributes: conversation_id: Unique identifier for this conversation. - current_turn: Last completed turn number (0 = not started). - pending_client_turn: Turn number of in-flight client turn (None if idle). - expected_client_turns: Expected number of client turns (for completion tracking). - issued_client_turns: Count of client turns issued. - completed_client_turns: Count of client turns with responses. - failed_client_turns: Count of client turns that failed (error/timeout). - message_history: Accumulated message list (only populated when + turn_done: Event set when a response arrives. Pipeline waits, then clears + it before issuing the next turn. + message_history: Accumulated message list (populated only when use_dataset_history=False; empty otherwise). - condition: Per-conversation asyncio.Condition for turn-ready and turn-issued waits. - Scoped to this conversation so that state changes only wake the single - pipeline task waiting on this conversation, not all pipeline tasks. + completed_turns: Turns with responses (success or failure) β€” observability only. + failed_turns: Turns that failed β€” observability only. + expected_client_turns: Expected total client turns (for completion detection). """ conversation_id: str - current_turn: int = 0 - pending_client_turn: int | None = None - expected_client_turns: int | None = None - issued_client_turns: int = 0 - completed_client_turns: int = 0 - failed_client_turns: int = 0 + turn_done: asyncio.Event = field(default_factory=asyncio.Event) message_history: list[dict[str, Any]] = field(default_factory=list) - condition: asyncio.Condition = field(default_factory=asyncio.Condition) - - def add_client_turn(self, turn: int, message: dict[str, Any] | None = None): - """Record that a client turn has been issued (updates sequencing counters). - - Args: - turn: Turn number for this client message. - message: Message dict to append to message_history (only used when - use_dataset_history=False). - """ - self.pending_client_turn = turn - self.issued_client_turns += 1 - if message is not None: - self.message_history.append(message) - - def add_assistant_turn(self, content: str | None = None): - """Record assistant response and mark turn complete (success). - - Args: - content: Response content to append to message_history. Only - used when use_dataset_history=False; None means no history - update (pre-built messages path). - """ - if content is not None: - self.message_history.append({"role": "assistant", "content": content}) - if self.pending_client_turn is not None: - self.current_turn = self.pending_client_turn + 1 - self.pending_client_turn = None - self.completed_client_turns += 1 - elif self.is_complete(): - pass - else: - logger.warning( - f"Received assistant response for {self.conversation_id} " - f"with no pending client turn (duplicate or out-of-order response)" - ) - self.current_turn = self.current_turn + 1 if self.current_turn > 0 else 1 - self.completed_client_turns += 1 - - if self.is_complete(): - if self.failed_client_turns > 0: - logger.info( - f"Conversation {self.conversation_id} completed with failures: " - f"{self.completed_client_turns - self.failed_client_turns}/" - f"{self.expected_client_turns} successful, " - f"{self.failed_client_turns} failed" - ) - else: - logger.debug( - f"Conversation {self.conversation_id} completed: " - f"{self.completed_client_turns}/{self.expected_client_turns} turns" - ) - - def mark_turn_failed(self, store_in_history: bool = False): - """Mark turn as failed (error/timeout) - still counts as completed for sequencing.""" - if self.pending_client_turn is not None: - self.current_turn = self.pending_client_turn + 1 - self.pending_client_turn = None - self.completed_client_turns += 1 - self.failed_client_turns += 1 - - if store_in_history: - self.message_history.append( - { - "role": "assistant", - "content": "[ERROR: Turn failed or timed out]", - } - ) - - logger.warning( - f"Turn {self.current_turn - 1} failed for conversation {self.conversation_id}" - ) - else: - logger.warning( - f"Attempted to mark failed turn for {self.conversation_id} " - f"with no pending client turn" - ) - - if self.is_complete(): - logger.info( - f"Conversation {self.conversation_id} completed with failures: " - f"{self.completed_client_turns - self.failed_client_turns}/" - f"{self.expected_client_turns} successful, " - f"{self.failed_client_turns} failed" - ) + completed_turns: int = 0 + failed_turns: int = 0 + expected_client_turns: int | None = None def is_complete(self) -> bool: - """Check if conversation is complete (all turns issued and responses received).""" + """Return True when all expected turns have a response.""" if self.expected_client_turns is None: return False - return self.completed_client_turns >= self.expected_client_turns - - def is_ready_for_turn(self) -> bool: - """Check if the previous turn has completed and the next may be issued.""" - return ( - self.pending_client_turn is None - and self.issued_client_turns == self.completed_client_turns - and self.issued_client_turns > 0 - ) + return self.completed_turns >= self.expected_client_turns class ConversationManager: - """Manages conversation sequencing for multi-turn benchmarking. - - Async manager that tracks multiple conversations and enforces turn ordering. - Conversations are identified by unique IDs. Message history is NOT maintained here - β€” it is pre-computed in MultiTurnDataset and passed directly to each request. + """Manages per-conversation state for multi-turn benchmarking. - The manager ensures that: - - Turn N+1 cannot be issued until turn N completes - - Concurrent access to conversation state is async-safe + All methods are synchronous. The pipeline task uses ``ConversationState.turn_done`` + directly for turn-done notification β€” no locks or condition variables needed. - Each ConversationState carries its own asyncio.Condition so that state changes - (turn issued / turn complete) only wake the single pipeline task waiting - on that conversation, not all pipeline tasks across all conversations. - All conversation states are pre-created by the strategy before pipeline - tasks start, so wait_for_turn_issued never races against get_or_create. + All states are pre-created by ``MultiTurnStrategy.execute()`` before any pipeline + task starts, so ``get_or_create()`` requires no locking. """ def __init__(self): - """Initialize conversation manager with empty state.""" + """Initialize with empty state.""" self._conversations: dict[str, ConversationState] = {} - self._lock = asyncio.Lock() def get_state(self, conversation_id: str) -> ConversationState | None: - """Get conversation state without creating (for read-only access).""" + """Return existing state without creating (read-only access).""" return self._conversations.get(conversation_id) - async def get_or_create( + def get_or_create( self, conversation_id: str, expected_client_turns: int | None = None, system_message: dict[str, Any] | None = None, ) -> ConversationState: - """Get existing or create new conversation state. + """Return existing state or create a new one. Args: conversation_id: Unique identifier for conversation. - expected_client_turns: Expected number of client turns (for completion tracking). - system_message: System message dict to pre-populate message_history with. - Only used when use_dataset_history=False and conversation is new. + expected_client_turns: Expected number of client turns. + system_message: System message to prepend to message_history + (only used when use_dataset_history=False and state is new). Returns: ConversationState for this conversation. """ - async with self._lock: - if conversation_id not in self._conversations: - initial_history: list[dict[str, Any]] = ( - [system_message] if system_message is not None else [] - ) - state = ConversationState( - conversation_id=conversation_id, - current_turn=0, - pending_client_turn=None, - expected_client_turns=expected_client_turns, - issued_client_turns=0, - completed_client_turns=0, - failed_client_turns=0, - message_history=initial_history, - ) - self._conversations[conversation_id] = state - return self._conversations[conversation_id] - - async def wait_for_turn_ready( - self, conversation_id: str, turn: int, timeout: float | None = None - ) -> bool: - """Block until conversation is ready for this turn. - - Uses the per-conversation asyncio.Condition so only this conversation's pipeline - task is woken on state changes, not all pipeline tasks. - - Args: - conversation_id: Conversation to wait for. - turn: Turn number to wait for (unused in readiness check; kept for - call-site compatibility). - timeout: Maximum seconds to wait (None = infinite). - - Returns: - True if ready, False if timeout. - - Raises: - KeyError: If conversation_id not found in manager. - """ - state = self._conversations.get(conversation_id) - if state is None: - logger.error(f"Conversation {conversation_id} not found in manager") - raise KeyError(f"Conversation {conversation_id} not initialized") - - async with state.condition: - if timeout is None: - await state.condition.wait_for(state.is_ready_for_turn) - return True - try: - async with asyncio.timeout(timeout): - await state.condition.wait_for(state.is_ready_for_turn) - return True - except TimeoutError: - return state.is_ready_for_turn() - - async def wait_for_turn_issued( - self, - conversation_id: str, - min_issued: int, - timeout: float | None = None, - ) -> bool: - """Block until at least min_issued client turns have been issued. - - Args: - conversation_id: Conversation to wait for. - min_issued: Minimum number of issued turns to wait for. - timeout: Maximum seconds to wait (None = infinite). - - Returns: - True if condition met, False if timeout. - - Raises: - KeyError: If conversation_id not found (programming error β€” state must be - pre-created by the strategy before pipeline tasks are spawned). - """ - state = self._conversations[conversation_id] - predicate = lambda: state.issued_client_turns >= min_issued # noqa: E731 - async with state.condition: - if timeout is None: - await state.condition.wait_for(predicate) - return True - try: - async with asyncio.timeout(timeout): - await state.condition.wait_for(predicate) - return True - except TimeoutError: - return state.issued_client_turns >= min_issued - - async def mark_turn_issued( - self, - conversation_id: str, - turn: int, - message: dict[str, Any] | None = None, - ): - """Mark that a client turn has been issued (updates sequencing counters). - - Args: - conversation_id: Conversation ID. - turn: Turn number being issued. - message: Message dict to append to history (used when - use_dataset_history=False). - - Raises: - KeyError: If conversation_id not found in manager. - """ - state = self._conversations.get(conversation_id) - if state is None: - raise KeyError(f"Conversation {conversation_id} not initialized") - async with state.condition: - state.add_client_turn(turn, message) - state.condition.notify_all() + if conversation_id not in self._conversations: + initial_history: list[dict[str, Any]] = ( + [system_message] if system_message is not None else [] + ) + self._conversations[conversation_id] = ConversationState( + conversation_id=conversation_id, + expected_client_turns=expected_client_turns, + message_history=initial_history, + ) + return self._conversations[conversation_id] - async def mark_turn_complete( + def mark_turn_complete( self, conversation_id: str, response: str, store_in_history: bool = False, - ): - """Mark that assistant response has arrived. + ) -> None: + """Record a successful response and wake the pipeline task. Args: conversation_id: Conversation ID. - response: Model output (stored in history when store_in_history=True). + response: Model output (appended to history when store_in_history=True). store_in_history: When True, append response to message_history. Raises: - KeyError: If conversation_id not found in manager. + KeyError: If conversation_id not found. """ state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - async with state.condition: - state.add_assistant_turn(response if store_in_history else None) - state.condition.notify_all() + if store_in_history and response: + state.message_history.append({"role": "assistant", "content": response}) + state.completed_turns += 1 + if state.is_complete(): + if state.failed_turns > 0: + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + else: + logger.debug( + f"Conversation {conversation_id} completed: " + f"{state.completed_turns}/{state.expected_client_turns} turns" + ) + state.turn_done.set() - async def mark_turn_failed( - self, conversation_id: str, store_in_history: bool = False - ): - """Mark that assistant response failed (error/timeout). + def mark_turn_failed( + self, + conversation_id: str, + store_in_history: bool = False, + ) -> None: + """Record a failed response and wake the pipeline task. - Failed turns still count toward conversation completion to ensure - turn sequencing progresses even under errors. + Failed turns count toward completion so sequencing progresses under errors. Args: conversation_id: Conversation ID. store_in_history: When True, append error placeholder to message_history. Raises: - KeyError: If conversation_id not found in manager. + KeyError: If conversation_id not found. """ state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - async with state.condition: - state.mark_turn_failed(store_in_history=store_in_history) - state.condition.notify_all() + if store_in_history: + state.message_history.append( + {"role": "assistant", "content": "[ERROR: Turn failed or timed out]"} + ) + state.completed_turns += 1 + state.failed_turns += 1 + logger.warning(f"Turn failed for conversation {conversation_id}") + if state.is_complete(): + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + state.turn_done.set() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 48f5b45f..cfd418bd 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -22,7 +22,7 @@ from ..config.schema import MultiTurnConfig from ..core.types import QueryResult -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ConversationState from .strategy import PhaseIssuerProtocol logger = logging.getLogger(__name__) @@ -48,11 +48,12 @@ class MultiTurnStrategy: The response routing path: 1. _conv_pipeline issues turn N via phase_issuer.issue(idx) β†’ query_id - 2. _conv_pipeline stores (conv_id, turn) in _inflight[query_id] + 2. _conv_pipeline stores conv_id in _inflight[query_id] 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete - 5. mark_turn_complete notifies the pipeline task waiting on wait_for_turn_ready - 6. _conv_pipeline proceeds to issue turn N+1 + 5. mark_turn_complete sets state.turn_done synchronously + 6. _conv_pipeline's await asyncio.wait_for(state.turn_done.wait()) returns + 7. Pipeline clears the event and issues turn N+1 """ def __init__( @@ -91,8 +92,9 @@ def __init__( ) # Maps query_id -> conversation_id for routing completions. - # Populated by _conv_pipeline after issue() returns query_id. self._inflight: dict[str, str] = {} + # Cached ConversationState refs for O(1) lookup in on_sample_complete. + self._conv_states: dict[str, ConversationState] = {} async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: """Drive multi-turn sample issuance. @@ -108,11 +110,21 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: conv_id = sample_meta["conversation_id"] conv_samples[conv_id].append((sample_index, sample_meta["turn"])) - # Pre-create all conversation states before spawning tasks. + # Pre-create all conversation states before spawning tasks (no locking needed). + sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) for conv_id, turns in conv_samples.items(): - await self._conv_manager.get_or_create( - conv_id, expected_client_turns=len(turns) + sys_content = sys_prompts.get(conv_id) if self._store_in_history else None + system_message = ( + {"role": "system", "content": sys_content} + if sys_content is not None + else None ) + state = self._conv_manager.get_or_create( + conv_id, + expected_client_turns=len(turns), + system_message=system_message, + ) + self._conv_states[conv_id] = state tasks = [ asyncio.create_task( @@ -133,59 +145,53 @@ async def _conv_pipeline( ) -> None: """Process all turns for a single conversation sequentially. - For each turn after the first, waits for the previous turn to complete - (via wait_for_turn_ready) before issuing the next. This enforces strict - sequential ordering: turn N+1 is not issued until turn N's response arrives. + For each turn after the first, waits for state.turn_done before issuing + the next. This enforces strict sequential ordering within the conversation. """ + state = self._conv_states[conv_id] sorted_turns = sorted(turns, key=lambda x: x[1]) for i, (idx, turn) in enumerate(sorted_turns): if i > 0: - # Wait for the previous turn to complete before issuing the next. - ready = await self._conv_manager.wait_for_turn_ready( - conv_id, turn, timeout=self._turn_timeout_s - ) - if not ready: + try: + await asyncio.wait_for( + state.turn_done.wait(), timeout=self._turn_timeout_s + ) + except TimeoutError: logger.warning( f"Turn {turn} of {conv_id} timed out waiting for previous turn" ) - await self._conv_manager.mark_turn_failed(conv_id) + state.failed_turns += 1 break + state.turn_done.clear() - # Acquire concurrency slot before issuing + # Acquire concurrency slot before issuing. if self._sem is not None: await self._sem.acquire() - # For live-history mode: build messages from accumulated history + current turn, - # and pass as data_override so the pre-built messages from the dataset are replaced. + # Live-history mode: build messages from accumulated history + current turn. data_override: dict[str, Any] | None = None - current_turn_message: dict[str, Any] | None = None + current_turn_messages: list[dict[str, Any]] | None = None if self._store_in_history: - pre_built = self._dataset_metadata.get( - "pre_built_messages_by_key", {} - ).get((conv_id, turn), []) - current_turn_message = pre_built[-1] if pre_built else None - state = self._conv_manager.get_state(conv_id) - if state is not None and current_turn_message is not None: - live_messages = state.message_history.copy() + [ - current_turn_message - ] + current_turn_messages = self._dataset_metadata.get( + "current_turn_messages_by_key", {} + ).get((conv_id, turn)) + if current_turn_messages: + live_messages = state.message_history.copy() + current_turn_messages data_override = {"messages": live_messages} query_id = phase_issuer.issue(idx, data_override=data_override) if query_id is None: - # Session stopping β€” release slot and exit + # Session stopping β€” release slot and exit. if self._sem is not None: self._sem.release() break - # Register this query_id -> conv_id mapping for response routing. self._inflight[query_id] = conv_id - # Mark the turn as issued so wait_for_turn_ready can gate the next turn. - await self._conv_manager.mark_turn_issued( - conv_id, turn, message=current_turn_message - ) + # Append current-turn messages to history so the next turn sees them. + if self._store_in_history and current_turn_messages: + state.message_history.extend(current_turn_messages) def on_query_complete(self, query_id: str) -> None: """Called by BenchmarkSession when a QueryResult arrives. @@ -203,27 +209,23 @@ def on_sample_complete(self, result: QueryResult) -> None: """Route completed QueryResult to ConversationManager. Called by execute.py on_sample_complete hook after each response. - Looks up the conversation_id from _inflight and calls mark_turn_complete. + Event.set() is synchronous β€” the pipeline task is woken immediately + without needing asyncio.ensure_future. Args: result: Completed QueryResult from the endpoint. """ - query_id = result.id - conv_id = self._inflight.pop(query_id, None) + conv_id = self._inflight.pop(result.id, None) if conv_id is None: return response_text = result.get_response_output_string() if result.error is not None: - asyncio.ensure_future( - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history ) else: - asyncio.ensure_future( - self._conv_manager.mark_turn_complete( - conv_id, response_text, store_in_history=self._store_in_history - ) + self._conv_manager.mark_turn_complete( + conv_id, response_text, store_in_history=self._store_in_history ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 4830c682..9c6f6ebd 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -111,6 +111,7 @@ def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: stream=query.data.get("stream", False), max_completion_tokens=query.data.get("max_completion_tokens", 100), temperature=query.data.get("temperature", 0.7), + tools=query.data.get("tools"), ) return request diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index ca17a236..87351700 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -16,8 +16,7 @@ """Integration tests for multi-turn benchmarking end-to-end. Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work -correctly together against a real HTTP echo server (echo tests) and a live -model endpoint (live tests at port 8868). +correctly together against a real HTTP echo server. Tests cover: 1. Dataset-history mode (use_dataset_history=True): pre-built messages are @@ -27,16 +26,12 @@ grow with each turn. 3. Multiple concurrent conversations complete successfully. 4. Turn ordering: turn N+1 is never issued before turn N completes. - 5. Live concurrency: parametrized target_concurrency levels against a real - model endpoint verify all turns complete regardless of throttle setting. """ import asyncio -import json import random import time from urllib.parse import urljoin -from urllib.request import urlopen import pandas as pd import pytest @@ -430,303 +425,210 @@ def on_complete(result: QueryResult) -> None: assert complete_times[q_turn3] >= complete_times[q_turn1] -# --------------------------------------------------------------------------- -# Live endpoint fixtures and helpers -# --------------------------------------------------------------------------- - -_LIVE_ENDPOINT = "http://localhost:8868" - - -def _query_model_name(endpoint: str) -> str: - """Return the first model name from the endpoint, or skip if unreachable.""" - try: - with urlopen(f"{endpoint}/v1/models", timeout=5.0) as resp: - data = json.loads(resp.read()) - return data["data"][0]["id"] - except Exception as e: - pytest.skip(f"Live endpoint {endpoint} not reachable: {e}") - return "" - - -def _make_live_rows( - model: str, n_conversations: int = 20, n_user_turns: int = 3 -) -> list[dict]: - """Build a multi-conversation dataset rows list. - - Each conversation has n_user_turns user turns interleaved with scripted - assistant placeholders (needed to satisfy the turn-structure validator but - never sent to the endpoint). The resulting dataset produces - n_conversations Γ— n_user_turns client-turn samples. - """ - rows = [] - _user_prompts = [ - "Reply with exactly one word: the number {n} in English.", - "Add one to the previous number. Reply with only that word.", - "Add one more. Reply with only that word.", +@pytest.mark.integration +@pytest.mark.asyncio +async def test_tool_use_conversation_all_turns_issued(echo_server): + """Tool-use conversation: all client turns (user + tool) are issued and completed.""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "search result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } ] - for i in range(n_conversations): - conv_id = f"live_conv_{i:03d}" - turn = 1 - for j in range(n_user_turns): - prompt = _user_prompts[j % len(_user_prompts)].format(n=i + 1) - rows.append( - { - "conversation_id": conv_id, - "turn": turn, - "role": "user", - "content": prompt, - "model": model, - "max_completion_tokens": 10, - } - ) - turn += 1 - if j < n_user_turns - 1: - rows.append( - { - "conversation_id": conv_id, - "turn": turn, - "role": "assistant", - "content": "placeholder", - } - ) - turn += 1 - return rows - - -async def _run_live_session( - model: str, - n_conversations: int, - n_user_turns: int, - target_concurrency: int | None, - timeout_s: float = 300.0, -) -> tuple[int, dict[str, str]]: - """Run a live multi-turn session against the endpoint at _LIVE_ENDPOINT. - - Returns (issued_count, {query_id: response_text}). - """ - rows = _make_live_rows(model, n_conversations, n_user_turns) - ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) - ds.load() - - mt_cfg = MultiTurnConfig( - turn_timeout_s=60.0, - use_dataset_history=True, - ) - strategy = MultiTurnStrategy( - conversation_manager=ConversationManager(), - dataset_metadata=ds.conversation_metadata, - multi_turn_config=mt_cfg, - target_concurrency=target_concurrency, - ) - - loop = asyncio.get_running_loop() - responses: dict[str, str] = {} - - def on_complete(result: QueryResult) -> None: - strategy.on_sample_complete(result) - responses[result.id] = result.get_response_output_string() - - http_config = HTTPClientConfig( - endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], - warmup_connections=0, - num_workers=4, - ) - http_client = await HTTPEndpointClient.create(http_config, loop) - issuer = HttpClientSampleIssuer(http_client) - try: - session = BenchmarkSession( - issuer=issuer, - event_publisher=_NoOpPublisher(), - loop=loop, - on_sample_complete=on_complete, - ) - rt = RuntimeSettings( - metrics.Throughput(1000), - [metrics.Throughput(1000)], - min_duration_ms=0, - max_duration_ms=int(timeout_s * 1000), - n_samples_from_dataset=ds.num_samples(), - n_samples_to_issue=ds.num_samples(), - min_sample_count=1, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) - result = await asyncio.wait_for(session.run([phase]), timeout=timeout_s) - return result.perf_results[0].issued_count, responses - finally: - await http_client.shutdown_async() + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Find something", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Here is the result", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Thanks"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + count = await _run_session(echo_server.url, ds, strategy, responses) -# --------------------------------------------------------------------------- -# Live concurrency tests -# --------------------------------------------------------------------------- + # Client turns: turn 1 (user) + turn 3 (tool) + turn 5 (user) = 3 + assert count == 3 + assert len(responses) == 3 @pytest.mark.integration @pytest.mark.asyncio -@pytest.mark.parametrize( - "target_concurrency", - [ - pytest.param(1, id="concurrency_1"), - pytest.param(4, id="concurrency_4"), - pytest.param(None, id="concurrency_unlimited"), - ], -) -async def test_live_concurrency(target_concurrency): - """All turns of 20 concurrent conversations complete for each concurrency level. +async def test_conversation_ending_with_tool_row(echo_server): + """Conversation ending with a tool row completes normally (matches agentic_coding dataset pattern).""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path": "out.py"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "file written"}] + tools = [ + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ] - Uses the live model endpoint at port 8868. Each conversation has 3 user - turns (60 total requests). Verifies that every turn receives a non-empty - response regardless of the concurrency throttle applied by target_concurrency. - """ - model = _query_model_name(_LIVE_ENDPOINT) - n_conversations = 20 - n_user_turns = 3 - expected_turns = n_conversations * n_user_turns # 60 total requests - - issued, responses = await _run_live_session( - model=model, - n_conversations=n_conversations, - n_user_turns=n_user_turns, - target_concurrency=target_concurrency, - timeout_s=300.0, - ) + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Write a file", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} - assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" - assert ( - len(responses) == expected_turns - ), f"Expected {expected_turns} responses, got {len(responses)}" - for qid, text in responses.items(): - assert text.strip(), f"Query {qid} returned empty response" + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Client turns: turn 1 (user) + turn 3 (tool) = 2 + assert count == 2 + assert len(responses) == 2 @pytest.mark.integration @pytest.mark.asyncio -async def test_live_turn_ordering_multi_conversation(): - """Turn N+1 of each conversation is always issued after turn N completes. - - Runs 10 conversations with 3 turns each concurrently (30 total requests). - Records per-query completion timestamps and asserts that within every - conversation each successive turn completes no earlier than the previous. - """ - model = _query_model_name(_LIVE_ENDPOINT) - n_conversations = 10 - n_user_turns = 3 - rows = _make_live_rows(model, n_conversations, n_user_turns) - - ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) - ds.load() - - conv_manager = ConversationManager() - mt_cfg = MultiTurnConfig(turn_timeout_s=60.0, use_dataset_history=True) - strategy = MultiTurnStrategy( - conversation_manager=conv_manager, - dataset_metadata=ds.conversation_metadata, - multi_turn_config=mt_cfg, - ) - - complete_times: dict[str, float] = {} - orig_on_sample_complete = strategy.on_sample_complete - - def tracked_complete(result: QueryResult) -> None: - complete_times[result.id] = time.monotonic() - orig_on_sample_complete(result) - - strategy.on_sample_complete = tracked_complete - - loop = asyncio.get_running_loop() - responses: dict[str, str] = {} +async def test_tools_field_forwarded_to_endpoint(echo_server): + """The 'tools' array from the dataset reaches the endpoint in every request payload.""" + received_payloads: list[dict] = [] - http_config = HTTPClientConfig( - endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], - warmup_connections=0, - num_workers=4, - ) - http_client = await HTTPEndpointClient.create(http_config, loop) - issuer = HttpClientSampleIssuer(http_client) + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + server = CapturingEchoServer(port=0) + server.start() try: + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "hello"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ] - def on_complete(result: QueryResult) -> None: - tracked_complete(result) - responses[result.id] = result.get_response_output_string() - - session = BenchmarkSession( - issuer=issuer, - event_publisher=_NoOpPublisher(), - loop=loop, - on_sample_complete=on_complete, - ) - rt = RuntimeSettings( - metrics.Throughput(1000), - [metrics.Throughput(1000)], - min_duration_ms=0, - max_duration_ms=300_000, - n_samples_from_dataset=ds.num_samples(), - n_samples_to_issue=ds.num_samples(), - min_sample_count=1, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) - result = await asyncio.wait_for(session.run([phase]), timeout=300.0) - finally: - await http_client.shutdown_async() - - expected_total = n_conversations * n_user_turns - assert result.perf_results[0].issued_count == expected_total - - # Build index β†’ query_id map and verify per-conversation ordering. - # Samples are grouped by conversation, turns sorted ascending within each: - # conv_0_t1, conv_0_t2, conv_0_t3, conv_1_t1, ... - uuid_to_index = result.perf_results[0].uuid_to_index - index_to_query = {v: k for k, v in uuid_to_index.items()} - - for conv_i in range(n_conversations): - base = conv_i * n_user_turns - for turn_j in range(n_user_turns - 1): - q_cur = index_to_query[base + turn_j] - q_next = index_to_query[base + turn_j + 1] - assert complete_times[q_cur] <= complete_times[q_next], ( - f"conv {conv_i}: turn {turn_j + 2} completed before turn {turn_j + 1} " - f"(t{turn_j + 1}={complete_times[q_cur]:.4f}, " - f"t{turn_j + 2}={complete_times[q_next]:.4f})" - ) - + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Search for hello", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} -@pytest.mark.integration -@pytest.mark.asyncio -async def test_live_large_concurrency(): - """All turns complete correctly under a large concurrency limit (>=512). - - Uses 200 conversations Γ— 3 turns = 600 total requests with - target_concurrency=512. The semaphore allows up to 512 simultaneous - in-flight requests, so the first wave of 200 first-turns is issued - without throttling, and subsequent turns queue naturally. Verifies - that all 600 turns complete and return non-empty responses, confirming - the semaphore implementation handles large values without deadlock or - starvation. - """ - model = _query_model_name(_LIVE_ENDPOINT) - n_conversations = 200 - n_user_turns = 3 - expected_turns = n_conversations * n_user_turns # 600 total requests - - issued, responses = await _run_live_session( - model=model, - n_conversations=n_conversations, - n_user_turns=n_user_turns, - target_concurrency=512, - timeout_s=300.0, - ) + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 - assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" - assert ( - len(responses) == expected_turns - ), f"Expected {expected_turns} responses, got {len(responses)}" - for qid, text in responses.items(): - assert text.strip(), f"Query {qid} returned empty response" + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + finally: + server.stop() diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index 73e3363f..64987b57 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -450,3 +450,123 @@ def test_explicit_adapter_override_at_construction_survives(self): client = HTTPClientConfig(api_type=APIType.OPENAI, adapter=OpenAIAdapter) assert client.adapter is OpenAIAdapter assert client.adapter is not OpenAIMsgspecAdapter + + +class TestMultiTurnValidation: + """Tests for multi-turn config validation and cross-validation.""" + + def _make_online_multi_turn(self, concurrency: int | None = 4, **ds_kwargs): + lp: dict = {"type": "multi_turn"} + if concurrency is not None: + lp["target_concurrency"] = concurrency + return { + "type": TestType.ONLINE, + "model_params": {"name": "M"}, + "endpoint_config": {"endpoints": ["http://x"]}, + "datasets": [{"path": "D", "multi_turn": {}, **ds_kwargs}], + "settings": {"load_pattern": lp}, + } + + @pytest.mark.unit + def test_multi_turn_valid_config(self): + config = BenchmarkConfig(**self._make_online_multi_turn(concurrency=16)) + from inference_endpoint.config.schema import LoadPatternType + + assert config.settings.load_pattern.type == LoadPatternType.MULTI_TURN + assert config.settings.load_pattern.target_concurrency == 16 + + @pytest.mark.unit + def test_multi_turn_requires_target_concurrency(self): + with pytest.raises(ValueError, match="Multi-turn requires --concurrency"): + BenchmarkConfig(**self._make_online_multi_turn(concurrency=None)) + + @pytest.mark.unit + def test_multi_turn_without_multi_turn_dataset_rejected(self): + with pytest.raises(ValueError, match="requires at least one dataset"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D"}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4} + }, + ) + + @pytest.mark.unit + def test_multi_turn_dataset_without_multi_turn_load_pattern_rejected(self): + with pytest.raises(ValueError, match="require load_pattern.type=multi_turn"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={"load_pattern": {"type": "poisson", "target_qps": 10}}, + ) + + +class TestMultiTurnTotalSamples: + """Tests for total_samples_to_issue() with multi_turn load pattern.""" + + @pytest.mark.unit + def test_multi_turn_uses_dataset_size_ignoring_duration(self): + from inference_endpoint.config.runtime_settings import RuntimeSettings + + config = BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4}, + "runtime": {"min_duration_ms": 600000}, + }, + ) + rt = RuntimeSettings.from_config(config, dataloader_num_samples=4316) + assert rt.total_samples_to_issue() == 4316 + + @pytest.mark.unit + def test_multi_turn_respects_min_sample_count(self): + import random + + from inference_endpoint import metrics + from inference_endpoint.config.runtime_settings import RuntimeSettings + from inference_endpoint.config.schema import LoadPattern, LoadPatternType + + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=5, + n_samples_to_issue=None, + min_sample_count=100, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 100 + + @pytest.mark.unit + def test_multi_turn_explicit_n_samples_takes_precedence(self): + import random + + from inference_endpoint import metrics + from inference_endpoint.config.runtime_settings import RuntimeSettings + from inference_endpoint.config.schema import LoadPattern, LoadPatternType + + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=4316, + n_samples_to_issue=200, + min_sample_count=1, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 200 diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index a42fc1f3..a93a3a08 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -1061,3 +1061,366 @@ def test_messages_with_tool_sequence_terminal_assistant(): # The terminal assistant at turn 4 should be included assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] assert any(m["content"] == "The weather is 22Β°C." for m in assistant_msgs) + + +# ============================================================================ +# Tool-use flat dataset regression tests (BUG 1, BUG 2, BUG 3) +# ============================================================================ + + +@pytest.mark.unit +def test_prior_tool_row_expanded_with_tool_call_id(): + """Prior tool rows must expand to messages with tool_call_id and content (BUG 1).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 3 (user, t=5) has a prior tool row at t=3. + # msgs_t5[3] should be the expanded tool message with proper fields. + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["tool_call_id"] == "call_c1_0" + assert tool_msgs[0]["content"] == '{"temp": 22}' + + +@pytest.mark.unit +def test_prior_parallel_tool_results_expand_to_multiple_messages(): + """Prior turn with 2 parallel tool_results expands to 2 tool messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Ok"}, + ] + ) + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # user(5) sees prior rows: user(1), assistant(2), tool(3)x2, assistant(4) + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 2 + assert tool_msgs[0]["tool_call_id"] == "c_0" + assert tool_msgs[0]["content"] == "r1" + assert tool_msgs[1]["tool_call_id"] == "c_1" + assert tool_msgs[1]["content"] == "r2" + + +@pytest.mark.unit +def test_assistant_content_null_preserved_in_history(): + """Assistant messages with tool_calls and content:null include content key (BUG 2).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 2 (tool, t=3): prior includes assistant(2) with tool_calls + content: null + msgs_t3 = pbm[("c1", 3)] + asst_msg = msgs_t3[2] + assert asst_msg["role"] == "assistant" + assert "tool_calls" in asst_msg + assert "content" in asst_msg + assert asst_msg["content"] is None + + # Also verify in user(5)'s history + msgs_t5 = pbm[("c1", 5)] + asst_tc_msg = msgs_t5[2] + assert asst_tc_msg["role"] == "assistant" + assert "tool_calls" in asst_tc_msg + assert "content" in asst_tc_msg + assert asst_tc_msg["content"] is None + + +@pytest.mark.unit +def test_jsonl_round_trip_with_tools_field(): + """Load from JSONL tmpfile with tools field; verify tools survives to sample dict.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Run the test", + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc_0", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "tc_0", "content": "file1.py"}], + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The directory contains file1.py", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # user(1) has tools + s0 = dataset.load_sample(0) + assert "tools" in s0 + assert len(s0["tools"]) == 1 + assert s0["tools"][0]["function"]["name"] == "bash" + + # tool(3) also has tools + s1 = dataset.load_sample(1) + assert "tools" in s1 + assert s1["tools"][0]["function"]["name"] == "bash" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_current_turn_messages_by_key_parallel_tools(): + """current_turn_messages_by_key stores all expanded messages for a tool turn.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ctm = ds.conversation_metadata["current_turn_messages_by_key"] + + # user(1) current turn is 1 message + assert len(ctm[("c1", 1)]) == 1 + assert ctm[("c1", 1)][0] == {"role": "user", "content": "Go"} + + # tool(3) current turn has 2 expanded messages (parallel tool_results) + assert len(ctm[("c1", 3)]) == 2 + assert ctm[("c1", 3)][0]["tool_call_id"] == "c_0" + assert ctm[("c1", 3)][1]["tool_call_id"] == "c_1" + + +# ============================================================================ +# Fix 1: system_prompts_by_conv in metadata (live-history mode) +# ============================================================================ + + +@pytest.mark.unit +def test_metadata_contains_system_prompts_by_conv(): + """_build_metadata exposes system_prompts_by_conv keyed by conversation_id.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hi", + "system": "Be concise", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Ok"}, + # c2 has no system prompt + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hello"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + + spc = ds.conversation_metadata["system_prompts_by_conv"] + assert spc["c1"] == "Be concise" + assert spc["c2"] is None + + +@pytest.mark.unit +def test_metadata_system_prompts_multiple_convs(): + """Each conversation gets its own system prompt entry.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "A", + "system": "Sys1", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "C", + "system": "Sys2", + }, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "D"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + + spc = ds.conversation_metadata["system_prompts_by_conv"] + assert spc["c1"] == "Sys1" + assert spc["c2"] == "Sys2" + + +# ============================================================================ +# Fix 2: tool_results / tool_calls stripped from sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_tool_results_not_in_sample_dict(): + """tool_results must not appear in the pre-baked sample dict for tool turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the tool turn (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "tool_results" not in s1 + + +@pytest.mark.unit +def test_tool_calls_not_in_sample_dict(): + """tool_calls must not appear in sample dicts (only relevant on assistant rows).""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Go", + "tool_calls": [ + {"id": "bad", "type": "function", "function": {"name": "f"}} + ], + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Done"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) + assert "tool_calls" not in s0 + + +# ============================================================================ +# Fix 3: no dead current_turn_message / system_content fields in sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_no_dead_current_turn_message_field(): + """current_turn_message must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert ( + "current_turn_message" not in s + ), f"Sample {i} has dead field current_turn_message" + + +@pytest.mark.unit +def test_no_dead_system_content_field(): + """system_content must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert "system_content" not in s, f"Sample {i} has dead field system_content" diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py index 62602626..331e6709 100644 --- a/tests/unit/load_generator/test_multi_turn_conversation_manager.py +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -14,7 +14,7 @@ # limitations under the License. import asyncio -import logging +import inspect import pytest from inference_endpoint.load_generator.conversation_manager import ( @@ -25,334 +25,253 @@ @pytest.mark.unit def test_conversation_state_initialization(): - """Test ConversationState initializes with correct default values.""" + """ConversationState initializes with correct defaults.""" state = ConversationState(conversation_id="conv_001") assert state.conversation_id == "conv_001" - assert state.current_turn == 0 - assert state.pending_client_turn is None + assert not state.turn_done.is_set() + assert state.message_history == [] + assert state.completed_turns == 0 + assert state.failed_turns == 0 + assert state.expected_client_turns is None @pytest.mark.unit -def test_conversation_state_add_client_turn(): - """Test adding a client turn updates sequencing state.""" +def test_conversation_state_is_complete_without_expected(): + """is_complete() returns False when expected_client_turns is None.""" state = ConversationState(conversation_id="conv_001") - - state.add_client_turn(1) - - assert state.pending_client_turn == 1 - assert state.issued_client_turns == 1 - assert state.current_turn == 0 # Not incremented until assistant response - - -@pytest.mark.unit -def test_conversation_state_add_assistant_turn(): - """Test adding assistant turn completes turn cycle.""" - state = ConversationState(conversation_id="conv_001") - - state.add_client_turn(1) - state.add_assistant_turn() - - assert state.current_turn == 2 - assert state.pending_client_turn is None - assert state.completed_client_turns == 1 + assert not state.is_complete() + state.completed_turns = 5 + assert not state.is_complete() @pytest.mark.unit -def test_conversation_state_late_response_after_complete_is_silently_ignored(caplog): - """Late response for a conversation that already completed is silently dropped.""" - state = ConversationState(conversation_id="conv_001", expected_client_turns=1) - - state.add_client_turn(1) - state.add_assistant_turn() +def test_conversation_state_is_complete_with_expected(): + """is_complete() returns True once completed_turns >= expected.""" + state = ConversationState(conversation_id="conv_001", expected_client_turns=2) + assert not state.is_complete() + state.completed_turns = 1 + assert not state.is_complete() + state.completed_turns = 2 assert state.is_complete() - completed_before = state.completed_client_turns - current_turn_before = state.current_turn - - with caplog.at_level(logging.WARNING): - state.add_assistant_turn() - - assert state.completed_client_turns == completed_before - assert state.current_turn == current_turn_before - assert "no pending client turn" not in caplog.text - @pytest.mark.unit -def test_conversation_state_is_ready_for_turn(): - """Test turn readiness checks using completion counts.""" - state = ConversationState(conversation_id="conv_001") - - assert not state.is_ready_for_turn() - - state.add_client_turn(1) - assert not state.is_ready_for_turn() - - state.add_assistant_turn() - assert state.is_ready_for_turn() - - state.add_client_turn(2) - assert not state.is_ready_for_turn() - - state.add_assistant_turn() - assert state.is_ready_for_turn() - - -@pytest.mark.unit -def test_conversation_state_multi_turn_sequence(): - """Test multi-turn conversation flow updates current_turn correctly.""" - state = ConversationState(conversation_id="conv_001") - - state.add_client_turn(1) - state.add_assistant_turn() - assert state.current_turn == 2 - - state.add_client_turn(3) - state.add_assistant_turn() - assert state.current_turn == 4 - - state.add_client_turn(5) - state.add_assistant_turn() - assert state.current_turn == 6 +def test_create_is_synchronous(): + """get_or_create() must be a plain function, not a coroutine.""" + manager = ConversationManager() + result = manager.get_or_create("conv_001") + assert not inspect.iscoroutine(result), "get_or_create returned a coroutine" + assert isinstance(result, ConversationState) @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_get_or_create(): - """Test get_or_create returns same state for same conversation_id.""" +def test_conversation_manager_get_or_create(): + """get_or_create returns the same state for the same conversation_id.""" manager = ConversationManager() - state1 = await manager.get_or_create("conv_001") - state2 = await manager.get_or_create("conv_001") + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_001") assert state1 is state2 assert state1.conversation_id == "conv_001" @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_multiple_conversations(): - """Test manager can track multiple conversations independently.""" +def test_conversation_manager_multiple_conversations(): + """Manager tracks multiple conversations independently.""" manager = ConversationManager() - state1 = await manager.get_or_create("conv_001") - state2 = await manager.get_or_create("conv_002") + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_002") assert state1 is not state2 - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "Response to conv_001") + manager.mark_turn_complete("conv_001", "Response to conv_001") - assert state1.current_turn == 2 - assert state2.current_turn == 0 + assert state1.completed_turns == 1 + assert state2.completed_turns == 0 @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_mark_turn_issued(): - """Test mark_turn_issued updates sequencing state.""" +def test_conversation_manager_mark_turn_complete(): + """mark_turn_complete increments counter, appends history, sets event.""" manager = ConversationManager() - state = await manager.get_or_create("conv_001") + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv_001", 1) + manager.mark_turn_complete("conv_001", "Assistant response") - assert state.pending_client_turn == 1 - assert state.issued_client_turns == 1 + assert state.completed_turns == 1 + assert state.failed_turns == 0 + assert state.turn_done.is_set() + assert state.message_history == [] # store_in_history=False by default @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_mark_turn_complete(): - """Test mark_turn_complete updates sequencing state.""" +def test_conversation_manager_mark_turn_complete_stores_history(): + """mark_turn_complete appends to history when store_in_history=True.""" manager = ConversationManager() - state = await manager.get_or_create("conv_001") + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "Assistant response") + manager.mark_turn_complete("conv_001", "Hello", store_in_history=True) - assert state.current_turn == 2 - assert state.pending_client_turn is None - assert state.completed_client_turns == 1 + assert state.message_history == [{"role": "assistant", "content": "Hello"}] @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_immediate(): - """Test wait_for_turn_ready returns immediately when previous turn is complete.""" +def test_conversation_manager_mark_turn_failed(): + """mark_turn_failed increments both counters and sets event.""" manager = ConversationManager() - await manager.get_or_create("conv_001") - - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "First response") + state = manager.get_or_create("conv_001", expected_client_turns=2) - result = await manager.wait_for_turn_ready("conv_001", 9, timeout=1.0) + manager.mark_turn_failed("conv_001") - assert result is True + assert state.completed_turns == 1 + assert state.failed_turns == 1 + assert state.turn_done.is_set() @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_blocking(): - """Test wait_for_turn_ready blocks until previous turn completes.""" +def test_conversation_completion_tracking(): + """is_complete() returns True after all expected turns receive responses.""" manager = ConversationManager() - await manager.get_or_create("conv_001") - - await manager.mark_turn_issued("conv_001", 1) - - ready_flag = [] - - async def waiter(): - result = await manager.wait_for_turn_ready("conv_001", 3, timeout=2.0) - if result: - ready_flag.append(True) + state = manager.get_or_create("conv_001", expected_client_turns=2) - waiter_task = asyncio.create_task(waiter()) - await asyncio.sleep(0.05) - assert not ready_flag - - await manager.mark_turn_complete("conv_001", "Assistant response") - await asyncio.sleep(0.05) - await waiter_task - - assert ready_flag == [True] + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r1") + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r2") + assert state.is_complete() @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_timeout(): - """Test wait_for_turn_ready respects timeout.""" +def test_conversation_completion_without_expected_turns(): + """Completion is never True when expected_client_turns is None.""" manager = ConversationManager() - await manager.get_or_create("conv_001") + state = manager.get_or_create("conv_001", expected_client_turns=None) - await manager.mark_turn_issued("conv_001", 1) + manager.mark_turn_complete("conv_001", "r1") - result = await manager.wait_for_turn_ready("conv_001", 3, timeout=0.1) - - assert result is False + assert not state.is_complete() @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_completion_tracking(): - """Test conversation completion detection.""" +def test_conversation_completion_with_failures(): + """Conversations complete even when some turns fail.""" manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=3) - state = await manager.get_or_create("conv_001", expected_client_turns=2) - + manager.mark_turn_complete("conv1", "Hi") assert not state.is_complete() - await manager.mark_turn_issued("conv_001", 1) + manager.mark_turn_failed("conv1") assert not state.is_complete() - await manager.mark_turn_complete("conv_001", "response 1") - assert not state.is_complete() - - await manager.mark_turn_issued("conv_001", 3) - await manager.mark_turn_complete("conv_001", "response 2") - + manager.mark_turn_complete("conv1", "Bye") assert state.is_complete() + assert state.failed_turns == 1 + assert state.completed_turns == 3 @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_completion_without_expected_turns(): - """Test that completion tracking works when expected_client_turns is None.""" +def test_all_turns_fail(): + """Conversation completes when all turns fail.""" manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=2) - state = await manager.get_or_create("conv_001", expected_client_turns=None) - - assert not state.is_complete() - - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "response 1") + manager.mark_turn_failed("conv1") + manager.mark_turn_failed("conv1") - assert not state.is_complete() + assert state.is_complete() + assert state.completed_turns == 2 + assert state.failed_turns == 2 @pytest.mark.unit @pytest.mark.asyncio -async def test_conversation_completion_with_failures(): - """Test that conversations complete even when turns fail.""" +async def test_event_set_wakes_waiter(): + """mark_turn_complete sets turn_done so a blocked await returns.""" manager = ConversationManager() - state = await manager.get_or_create("conv1", expected_client_turns=3) + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv1", 1) - await manager.mark_turn_complete("conv1", "Hi there") - assert state.completed_client_turns == 1 - assert not state.is_complete() + woke_up: list[bool] = [] - await manager.mark_turn_issued("conv1", 2) - await manager.mark_turn_failed("conv1") - assert state.completed_client_turns == 2 - assert state.failed_client_turns == 1 - assert not state.is_complete() + async def waiter(): + await state.turn_done.wait() + woke_up.append(True) - await manager.mark_turn_issued("conv1", 3) - await manager.mark_turn_complete("conv1", "Bye!") - assert state.completed_client_turns == 3 - assert state.is_complete() + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.01) + assert not woke_up + + manager.mark_turn_complete("conv_001", "response") + await asyncio.sleep(0.01) + await task + + assert woke_up == [True] @pytest.mark.unit @pytest.mark.asyncio -async def test_mark_turn_failed_with_no_pending(): - """Test that marking failed turn without pending turn logs warning.""" +async def test_failed_sets_event(): + """mark_turn_failed sets turn_done so the pipeline can unblock.""" manager = ConversationManager() - state = await manager.get_or_create("conv1", expected_client_turns=1) + state = manager.get_or_create("conv_001") + + woke_up: list[bool] = [] - await manager.mark_turn_failed("conv1") + async def waiter(): + await state.turn_done.wait() + woke_up.append(True) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.01) + + manager.mark_turn_failed("conv_001") + await asyncio.sleep(0.01) + await task - assert state.completed_client_turns == 0 - assert state.failed_client_turns == 0 + assert woke_up == [True] @pytest.mark.unit @pytest.mark.asyncio -async def test_all_turns_fail(): - """Test conversation completion when all turns fail.""" +async def test_event_clear_resets_for_next_turn(): + """Clearing turn_done after wait() properly gates the next turn.""" manager = ConversationManager() - state = await manager.get_or_create("conv1", expected_client_turns=2) + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv1", 1) - await manager.mark_turn_failed("conv1") + # First turn: set then clear + manager.mark_turn_complete("conv_001", "r1") + await state.turn_done.wait() + state.turn_done.clear() + assert not state.turn_done.is_set() - await manager.mark_turn_issued("conv1", 2) - await manager.mark_turn_failed("conv1") - - assert state.is_complete() - assert state.completed_client_turns == 2 - assert state.failed_client_turns == 2 + # Second turn: set again + manager.mark_turn_complete("conv_001", "r2") + assert state.turn_done.is_set() @pytest.mark.unit @pytest.mark.asyncio async def test_conversation_manager_concurrent_access(): - """Test async concurrent access to multiple conversations.""" + """Concurrent pipeline tasks on independent conversations complete without errors.""" manager = ConversationManager() num_conversations = 10 - user_turns_per_conv = 5 + turns_per_conv = 5 for i in range(num_conversations): - await manager.get_or_create(f"conv_{i:03d}") + manager.get_or_create(f"conv_{i:03d}", expected_client_turns=turns_per_conv) errors = [] async def process_conversation(conv_id: str): try: - for user_turn_idx in range(user_turns_per_conv): - turn = user_turn_idx * 2 + 1 - - if user_turn_idx > 0: - ready = await manager.wait_for_turn_ready( - conv_id, turn, timeout=5.0 - ) - if not ready: - errors.append(f"{conv_id} turn {turn} timeout") - return - - await manager.mark_turn_issued(conv_id, turn) + state = manager.get_state(conv_id) + assert state is not None + for _ in range(turns_per_conv): + manager.mark_turn_complete(conv_id, "response") await asyncio.sleep(0.001) - await manager.mark_turn_complete(conv_id, f"Response {turn}") except Exception as e: errors.append(f"{conv_id} error: {e}") @@ -362,35 +281,8 @@ async def process_conversation(conv_id: str): ] await asyncio.gather(*tasks) - assert not errors, f"Errors occurred: {errors}" - + assert not errors for i in range(num_conversations): - conv_id = f"conv_{i:03d}" - state = manager._conversations[conv_id] - assert state.current_turn == user_turns_per_conv * 2 - assert state.completed_client_turns == user_turns_per_conv - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_reliably_wakes_on_completion(): - """Test completion wakeups do not depend on timing windows.""" - - async def run_one_iteration(): - mgr = ConversationManager() - await mgr.get_or_create("conv_001") - await mgr.mark_turn_issued("conv_001", 1) - - ready: list[bool] = [] - - async def waiter(m: ConversationManager, r: list) -> None: - r.append(await m.wait_for_turn_ready("conv_001", 3, timeout=0.5)) - - waiter_task = asyncio.create_task(waiter(mgr, ready)) - await asyncio.sleep(0.005) - await mgr.mark_turn_complete("conv_001", "Assistant response") - await asyncio.wait_for(waiter_task, timeout=0.5) - assert ready == [True] - - for _ in range(10): - await run_one_iteration() + state = manager._conversations[f"conv_{i:03d}"] + assert state.completed_turns == turns_per_conv + assert state.is_complete() diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 55c51994..0edbb34f 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -75,7 +75,7 @@ async def complete_turns(): # Mark turn 1 complete state = conv_manager.get_state("conv1") if state: - await conv_manager.mark_turn_complete("conv1", "response 1") + conv_manager.mark_turn_complete("conv1", "response 1") asyncio.create_task(complete_turns()) count = await strategy.execute(issuer) @@ -231,24 +231,20 @@ async def test_on_query_complete_releases_semaphore(): async def test_on_sample_complete_routes_to_manager(): """on_sample_complete marks the turn complete in the ConversationManager.""" conv_manager = ConversationManager() - await conv_manager.get_or_create("conv1", expected_client_turns=1) + conv_manager.get_or_create("conv1", expected_client_turns=1) metadata = _make_dataset_metadata({"conv1": [1]}) strategy = MultiTurnStrategy(conv_manager, metadata) # Simulate issuer registering conv_id in _inflight strategy._inflight["q0001"] = "conv1" - # Pre-issue a turn so the state has pending_client_turn - await conv_manager.mark_turn_issued("conv1", 1) result = QueryResult(id="q0001", response_output=TextModelOutput(output="hello")) strategy.on_sample_complete(result) - # Allow the ensure_future coroutine to run - await asyncio.sleep(0.01) - state = conv_manager.get_state("conv1") assert state is not None - assert state.completed_client_turns == 1 + assert state.completed_turns == 1 + assert state.turn_done.is_set() assert state.is_complete() @@ -259,12 +255,11 @@ async def test_error_response_marks_turn_failed(): from inference_endpoint.core.types import ErrorData conv_manager = ConversationManager() - await conv_manager.get_or_create("conv1", expected_client_turns=1) + conv_manager.get_or_create("conv1", expected_client_turns=1) metadata = _make_dataset_metadata({"conv1": [1]}) strategy = MultiTurnStrategy(conv_manager, metadata) strategy._inflight["q0001"] = "conv1" - await conv_manager.mark_turn_issued("conv1", 1) result = QueryResult( id="q0001", @@ -272,8 +267,114 @@ async def test_error_response_marks_turn_failed(): error=ErrorData(error_type="timeout", error_message="timed out"), ) strategy.on_sample_complete(result) - await asyncio.sleep(0.01) state = conv_manager.get_state("conv1") assert state is not None - assert state.failed_client_turns == 1 + assert state.failed_turns == 1 + + +def _make_metadata_with_system( + conversations: dict[str, list[int]], + system_prompts: dict[str, str | None] | None = None, +) -> dict: + """Build metadata dict including system_prompts_by_conv.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + { + "conversation_id": conv_id, + "turn": turn, + "sample_index": sample_index, + } + ) + sample_index += 1 + return { + "samples": samples, + "system_prompts_by_conv": system_prompts or {}, + } + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_initializes_system_prompt(): + """In live-history mode, ConversationManager.message_history starts with system message.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Be helpful"}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + await conv_manager.mark_turn_complete("conv1", "response") + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history[0] must be the system message + assert len(state.message_history) >= 1 + assert state.message_history[0] == {"role": "system", "content": "Be helpful"} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_no_system_prompt_when_none(): + """In live-history mode, no system message is prepended when system_prompt is None.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": None}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + await conv_manager.mark_turn_complete("conv1", "response") + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # No system message should be in history + system_msgs = [m for m in state.message_history if m.get("role") == "system"] + assert len(system_msgs) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_dataset_history_mode_does_not_inject_system_prompt(): + """In dataset-history mode (use_dataset_history=True), system_message is not passed.""" + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Some system"}, + ) + # Default: use_dataset_history=True β†’ _store_in_history=False + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + await conv_manager.mark_turn_complete("conv1", "response") + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history should be empty (dataset-history mode doesn't accumulate) + assert len(state.message_history) == 0 diff --git a/tests/unit/openai/test_openai_adapter.py b/tests/unit/openai/test_openai_adapter.py new file mode 100644 index 00000000..506ec3fe --- /dev/null +++ b/tests/unit/openai/test_openai_adapter.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIAdapter tool serialization.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_adapter import OpenAIAdapter + +_TOOL_DEF = { + "type": "function", + "function": { + "name": "search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, +} + +_TOOL_CALLS = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } +] + + +@pytest.mark.unit +def test_tool_definitions_forwarded(): + """tools array in query.data is present in the encoded request.""" + messages = [ + {"role": "user", "content": "Find something"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + + +@pytest.mark.unit +def test_tool_use_messages_roundtrip(): + """Full tool-use message sequence encodes and decodes without data loss.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Find something"}, + {"role": "assistant", "content": None, "tool_calls": _TOOL_CALLS}, + {"role": "tool", "content": "search result", "tool_call_id": "call_1"}, + {"role": "assistant", "content": "Here is the answer"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + msgs = payload["messages"] + assert msgs[0]["role"] == "system" + assert msgs[1]["role"] == "user" + # assistant tool-dispatch: content is None (Pydantic model_dump includes None fields) + assert msgs[2]["role"] == "assistant" + assert msgs[2]["tool_calls"] == _TOOL_CALLS + assert msgs[2].get("content") is None + # tool result + assert msgs[3]["role"] == "tool" + assert msgs[3]["tool_call_id"] == "call_1" + assert msgs[3]["content"] == "search result" + # terminal assistant + assert msgs[4]["content"] == "Here is the answer" + + +@pytest.mark.unit +def test_encode_request_produces_valid_json_bytes(): + """encode_request returns bytes that msgspec can decode back.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q2", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + request = OpenAIAdapter.to_endpoint_request(query) + encoded = OpenAIAdapter.encode_request(request) + + assert isinstance(encoded, bytes) + decoded = msgspec.json.decode(encoded) + assert decoded["messages"][0]["role"] == "user" + + +@pytest.mark.unit +def test_no_tools_key_when_absent(): + """When query.data has no 'tools', the encoded payload has tools=None.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q3", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + # Pydantic model_dump includes None fields; tools must be None when not supplied + assert payload.get("tools") is None From d481c3c3fc5189afafbd8fcad9c1e59643d732d5 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 24 Apr 2026 20:58:00 -0700 Subject: [PATCH 05/13] docs: add multi-turn quickstart, examples, and conversion scripts Add MULTI_TURN_QUICKSTART.md, examples/09_MultiTurn/ configs and sample data, scripts/convert_agentic_snapshot.py, and README clarifications including conversion script output destination. --- docs/MULTI_TURN_QUICKSTART.md | 96 +++-- examples/09_MultiTurn/README.md | 35 +- .../agentic_coding_benchmark.yaml | 2 +- .../agentic_workflow_benchmark.yaml | 2 +- examples/09_MultiTurn/datasets/.gitkeep | 0 .../09_MultiTurn/multi_turn_benchmark.yaml | 1 - .../multi_turn_with_concurrency.yaml | 1 - scripts/convert_agentic_snapshot.py | 356 ++++++++++++++++++ scripts/validate_jsonl_schema.py | 126 +++++++ .../load_generator/conversation_manager.py | 51 +-- .../load_generator/multi_turn_strategy.py | 5 +- .../test_multi_turn_strategy.py | 111 +++++- 12 files changed, 698 insertions(+), 88 deletions(-) create mode 100644 examples/09_MultiTurn/datasets/.gitkeep create mode 100644 scripts/convert_agentic_snapshot.py create mode 100644 scripts/validate_jsonl_schema.py diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 73ed6678..99b35aa5 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -1,6 +1,6 @@ # Multi-Turn Conversation Benchmarking - Quick Start Guide -## πŸš€ Quick Start in 5 Minutes +## Quick Start in 5 Minutes ### 1. Prepare Your Dataset @@ -40,7 +40,6 @@ datasets: - name: my_conversations type: performance path: path/to/your/conversations.jsonl - format: ".jsonl" multi_turn: # ← Presence of this block enables multi-turn mode mode: independent # ← Per-conv pipelines; no cross-conv turn barrier turn_timeout_s: 300 # ← Max wait for prev turn @@ -48,7 +47,7 @@ datasets: settings: load_pattern: type: multi_turn # ← Use multi-turn scheduler - target_concurrency: 32 # ← OPTIONAL: limit concurrent requests + target_concurrency: 32 # ← Required: max concurrent requests client: workers: 4 @@ -78,24 +77,26 @@ That's it! Your benchmark will now: --- -## πŸ“Š Understanding Results +## Understanding Results After the benchmark completes, check the directory configured via `report_dir`: -### Events Database +### Events Log -The `events.db` SQLite database includes: +The `events.jsonl` file contains one JSON record per line: -- Standard fields: sample_uuid, event_type, timestamp_ns -- **New fields**: conversation_id, turn_number +- Standard fields: `sample_uuid`, `event_type`, `timestamp_ns` +- **New fields**: `conversation_id`, `turn_number` -Query example: +Query examples: -```sql -SELECT conversation_id, turn_number, event_type, timestamp_ns -FROM events -WHERE conversation_id = 'c1' -ORDER BY turn_number; +```bash +# All events for a specific conversation +grep '"conversation_id": "c1"' logs/my_multi_turn_benchmark/events.jsonl + +# With jq for structured output +jq 'select(.conversation_id == "c1") | {conversation_id, turn_number, event_type, timestamp_ns}' \ + logs/my_multi_turn_benchmark/events.jsonl ``` ### Metrics @@ -109,7 +110,7 @@ _Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a f --- -## 🎯 Conversation Modes Explained +## Conversation Modes Explained ### Independent Mode (Default) @@ -140,9 +141,9 @@ t=0.8: conv1-turn3 (after conv1-turn2 completes) --- -## πŸŽ›οΈ Concurrency Control (NEW!) +## Concurrency Control -For benchmarks with **> 50 conversations**, use `target_concurrency` to prevent endpoint overload: +`target_concurrency` is **required** for the `multi_turn` load pattern. It limits the maximum number of in-flight requests across all conversations and prevents endpoint overload when many conversations run simultaneously. ```yaml settings: @@ -151,17 +152,15 @@ settings: target_concurrency: 32 # ← Limit to 32 concurrent requests ``` -**Why?** Without this, independent mode issues ALL turn-1s at once (could be 100+), overwhelming your endpoint. - -**Rule of thumb**: +**Sizing guide**: -- Small (< 50 convs): No limit needed -- Medium (50-500 convs): `target_concurrency: 32` -- Large (500+ convs): `target_concurrency: 64` +- Small (< 50 convs): `target_concurrency: 32` +- Medium (50-500 convs): `target_concurrency: 64` +- Large (500+ convs): `target_concurrency: 96` or higher --- -## πŸ”§ Common Configurations +## Common Configurations ### Recommended: With Concurrency Control @@ -188,6 +187,9 @@ multi_turn: turn_timeout_s: 600 settings: + load_pattern: + type: multi_turn + target_concurrency: 96 client: workers: 16 # More workers for parallel conversations ``` @@ -198,17 +200,27 @@ settings: multi_turn: mode: independent turn_timeout_s: 1800 # 30 minutes for slow responses + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 ``` --- -## ❓ Troubleshooting +## Troubleshooting ### "Conversation has invalid role sequence" -**Problem**: Your dataset doesn't alternate between user/assistant. +**Problem**: Your dataset doesn't follow a valid role sequence. + +**Fix**: Check your JSONL. Valid sequences: + +- Plain chat: `user β†’ assistant β†’ user β†’ assistant β†’ ...` +- Agentic (tool-use): `user β†’ assistant β†’ tool β†’ assistant β†’ tool β†’ ... β†’ user` -**Fix**: Check your JSONL - must be: user, assistant, user, assistant, ... +Conversations may also end with a `tool` row (the model's response to the final tool call is the benchmark target). ### "Rows for conversation X are not consecutive" @@ -230,17 +242,19 @@ multi_turn: **Problem**: MultiTurnDataset not recognized. -**Fix**: Ensure `format: ".jsonl"` is specified in config: +**Fix**: Ensure `multi_turn:` block is present in the dataset config. The file format +is auto-detected from the `.jsonl` extension β€” no `format` field is needed: ```yaml datasets: - path: your_file.jsonl - format: ".jsonl" # ← Required for JSONL + multi_turn: + mode: independent ``` --- -## πŸ“ Example Datasets +## Example Datasets ### Simple 2-Turn Conversation @@ -274,12 +288,12 @@ datasets: --- -## πŸ§ͺ Testing Your Setup +## Testing Your Setup ### 1. Use the Example Dataset ```bash -cd examples/multi_turn +cd examples/09_MultiTurn inference-endpoint benchmark from-config --config multi_turn_benchmark.yaml ``` @@ -293,14 +307,14 @@ cat logs/multi_turn_test/benchmark.log ### 3. Verify Event Recording ```bash -sqlite3 logs/multi_turn_test/events.db -sqlite> SELECT DISTINCT conversation_id FROM events; +# List all unique conversation IDs in the events log +jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u # Should show your conversation IDs ``` --- -## πŸ’‘ Tips & Best Practices +## Tips & Best Practices ### Dataset Design @@ -318,28 +332,28 @@ sqlite> SELECT DISTINCT conversation_id FROM events; - **Start small**: Test with 1-2 conversations first - **Single conversation**: Use `mode: independent` with `target_concurrency: 1` -- **Check events.db**: Verify turn ordering in database +- **Check events.jsonl**: Verify turn ordering with `jq` --- -## πŸ”— More Information +## More Information - **Full Documentation**: See `examples/09_MultiTurn/README.md` - **Architecture**: See `AGENTS.md` (Multi-Turn section) --- -## βœ… Checklist +## Checklist Before running your first multi-turn benchmark: -- [ ] Dataset follows format (alternating user/assistant roles) +- [ ] Dataset follows format (user/assistant alternation, or agentic userβ†’assistantβ†’tool sequences) - [ ] All rows for each conversation_id are grouped together - [ ] Config has `multi_turn:` block in the dataset section - [ ] Config has `load_pattern.type: multi_turn` - [ ] Endpoint is running and reachable -- [ ] `format: ".jsonl"` specified for JSONL datasets +- [ ] File uses `.jsonl` extension (format is auto-detected) - [ ] Conversation IDs are unique per conversation - [ ] Turn numbers are sequential (1, 2, 3, ...) -Happy benchmarking! πŸš€ +Happy benchmarking! diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index 0d3348c7..e7f9505a 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -78,17 +78,21 @@ The following commands convert each source snapshot file to the flat-row format Run from the repo root: ```bash +# First argument: input snapshot JSONL; second argument: output flat-row JSONL python scripts/convert_agentic_snapshot.py \ - /path/to/agentic_coding_dataset.jsonl \ # input snapshot JSONL - examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ # output flat-row JSONL + /path/to/agentic_coding_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ --verify python scripts/convert_agentic_snapshot.py \ - /path/to/agentic_workflow_dataset.jsonl \ # input snapshot JSONL - examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ # output flat-row JSONL + /path/to/agentic_workflow_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ --verify ``` +The `datasets/` directory under `examples/09_MultiTurn/` is a placeholder; run the conversion +commands above to populate it before benchmarking. + The `--verify` flag cross-checks every client turn's message history against the source snapshot and exits with code 1 if any mismatch is found. The script also: @@ -138,8 +142,7 @@ inference-endpoint benchmark from-config \ datasets: - name: customer_support type: performance - path: examples/multi_turn/customer_support_conversations.jsonl - format: ".jsonl" + path: examples/09_MultiTurn/customer_support_conversations.jsonl multi_turn: mode: independent turn_timeout_s: 300.0 @@ -147,11 +150,12 @@ datasets: settings: load_pattern: type: multi_turn + target_concurrency: 32 # ← Required for multi_turn load pattern ``` -### Concurrency Control (Optional) +### Concurrency Control -The multi-turn scheduler supports **optional concurrency limiting** to control the maximum number of in-flight requests across all conversations: +The `target_concurrency` field is **required** for the `multi_turn` load pattern and controls the maximum number of in-flight requests across all conversations: ```yaml settings: @@ -162,15 +166,14 @@ settings: **Behavior**: -- Without `target_concurrency`: Unlimited concurrency (all turn-1s issue at t=0 in INDEPENDENT mode) - With `target_concurrency`: Limits total in-flight requests across all conversations - Combines with turn sequencing: Turn N+1 still waits for turn N, AND waits for available slot **Use cases**: -- 🎯 **Prevent endpoint overload**: Control request rate to busy endpoints -- 🎯 **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system -- 🎯 **Resource management**: Stay within port limits, memory constraints +- **Prevent endpoint overload**: Control request rate to busy endpoints +- **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system +- **Resource management**: Stay within port limits, memory constraints **Example**: 100 conversations with `target_concurrency: 32` @@ -221,7 +224,7 @@ If a turn times out waiting for the previous turn, it will be skipped and logged ```bash inference-endpoint benchmark from-config \ - --config examples/multi_turn/multi_turn_benchmark.yaml + --config examples/09_MultiTurn/multi_turn_benchmark.yaml ``` ### Viewing Results @@ -231,7 +234,7 @@ Multi-turn benchmarks produce both per-turn and per-conversation metrics: - **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn - **Per-conversation metrics**: Total conversation latency, conversations per second -Results are stored in the configured `report_dir` with conversation metadata included in the events database. +Results are stored in the configured `report_dir` with conversation metadata included in the events log (`events.jsonl`). ## Example Datasets @@ -248,8 +251,7 @@ Simple customer support conversations demonstrating basic multi-turn interaction ### Key Components - **ConversationManager**: Tracks conversation state and message history -- **MultiTurnScheduler**: Enforces turn sequencing within conversations -- **ConversationSample**: Sample with conversation metadata +- **MultiTurnStrategy**: Enforces turn sequencing within conversations - **MultiTurnDataset**: Validates and structures multi-turn data ### Turn Sequencing @@ -307,5 +309,4 @@ Planned features: - [ ] Poisson conversation arrival mode implementation - [ ] Per-conversation metrics in reporting - [ ] Conversation-level latency percentiles -- [ ] Support for tool/function calls in conversations - [ ] Dynamic conversation branching diff --git a/examples/09_MultiTurn/agentic_coding_benchmark.yaml b/examples/09_MultiTurn/agentic_coding_benchmark.yaml index 5a1036a7..f3abc3cf 100644 --- a/examples/09_MultiTurn/agentic_coding_benchmark.yaml +++ b/examples/09_MultiTurn/agentic_coding_benchmark.yaml @@ -10,8 +10,8 @@ datasets: - name: agentic_coding type: performance # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. path: examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl - format: ".jsonl" multi_turn: mode: independent turn_timeout_s: 600.0 diff --git a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml index e8885465..239e9374 100644 --- a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml +++ b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml @@ -10,8 +10,8 @@ datasets: - name: agentic_workflow type: performance # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. path: examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl - format: ".jsonl" multi_turn: mode: independent turn_timeout_s: 600.0 diff --git a/examples/09_MultiTurn/datasets/.gitkeep b/examples/09_MultiTurn/datasets/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml index da4773e0..36066aa3 100644 --- a/examples/09_MultiTurn/multi_turn_benchmark.yaml +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -11,7 +11,6 @@ datasets: - name: customer_support_conversations type: performance path: examples/09_MultiTurn/customer_support_conversations.jsonl - format: ".jsonl" samples: 10 multi_turn: mode: independent diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml index ba5362e3..e1d5f37c 100644 --- a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -11,7 +11,6 @@ datasets: - name: customer_support_conversations type: performance path: examples/09_MultiTurn/customer_support_conversations.jsonl - format: ".jsonl" samples: 10 multi_turn: mode: independent # All conv turn-1 start together diff --git a/scripts/convert_agentic_snapshot.py b/scripts/convert_agentic_snapshot.py new file mode 100644 index 00000000..fe217b9b --- /dev/null +++ b/scripts/convert_agentic_snapshot.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert agentic snapshot datasets to the flat-row JSONL format expected by MultiTurnDataset. + +Each snapshot record contains the full conversation history up to a checkpoint: + {"conversation_id": "sim_000001", "conversation_idx": 0, + "messages": [{"role": "system", ...}, ...], "tools": [...], "metadata": {}} + +For each conversation only the final snapshot (highest conversation_idx) is used. +Its messages array is expanded into individual flat rows, one per message. + +Usage: + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl --verify +""" + +import argparse +import json +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Helpers shared between convert() and verify() +# --------------------------------------------------------------------------- + + +def _load_final_snapshots(input_path: Path) -> dict[str, dict]: + """Return {conv_id: record} keeping only the highest conversation_idx per conv.""" + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + return final + + +def _apply_collapses(non_system: list[dict]) -> list[tuple[dict, int]]: + """Apply user-collapse and tool-merge passes, tracking the last source index each + output row covers. + + Returns list of (output_msg, last_source_idx) pairs where last_source_idx is the + 0-based index within non_system of the final source message folded into this row. + """ + # Pass 1: collapse consecutive user messages + collapsed: list[tuple[dict, int]] = [] # (msg, last_source_idx) + for src_idx, msg in enumerate(non_system): + if collapsed and collapsed[-1][0]["role"] == "user" and msg["role"] == "user": + prev_msg, _ = collapsed[-1] + prev_text = prev_msg.get("content") or "" + cur_text = msg.get("content") or "" + collapsed[-1] = ( + {**prev_msg, "content": f"{prev_text}\n\n{cur_text}".strip()}, + src_idx, + ) + else: + collapsed.append((msg, src_idx)) + + # Pass 2: merge consecutive tool messages + # Input messages are raw snapshot wire-format (tool_call_id + content on each msg). + # On merge, upgrade the first message to a tool_results list so the output always + # uses the tool_results array form regardless of how many results there are. + merged: list[tuple[dict, int]] = [] + for msg, last_src in collapsed: + if merged and merged[-1][0]["role"] == "tool" and msg["role"] == "tool": + prev_msg, _ = merged[-1] + tool_results = prev_msg.get("tool_results") + if tool_results is None: + tool_results = [ + { + "tool_call_id": prev_msg.get("tool_call_id"), + "content": prev_msg.get("content"), + } + ] + prev_msg = {"role": "tool", "tool_results": tool_results} + tool_results.append( + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ) + merged[-1] = (prev_msg, last_src) + else: + merged.append((msg, last_src)) + + return merged + + +def _normalize_msg(msg: dict) -> dict: + """Drop None values for comparison.""" + return {k: v for k, v in msg.items() if v is not None} + + +def _expand_row_to_wire_msgs(row: dict) -> list[dict]: + """Expand a single flat row into one or more OpenAI wire-format messages. + + Handles two tool row forms: + - Output flat rows: tool_results array (always used after conversion) + - Raw snapshot messages passed through verify(): tool_call_id + content directly + """ + if isinstance(row.get("tool_results"), list): + return [ + { + "role": "tool", + "tool_call_id": r.get("tool_call_id"), + "content": r.get("content"), + } + for r in row["tool_results"] + ] + msg: dict = {"role": row["role"], "content": row.get("content")} + if row.get("tool_calls"): + msg["tool_calls"] = row["tool_calls"] + if row.get("tool_call_id"): + msg["tool_call_id"] = row["tool_call_id"] + return [msg] + + +def verify(input_path: Path, output_path: Path) -> bool: + """Cross-check every client-turn's pre_built_messages against the source snapshot. + + For each output client turn, reconstruct the pre_built_messages that + MultiTurnDataset would build from the flat rows and compare it against the + ground-truth messages built directly from the source snapshot up to the same + point (accounting for user-collapse and tool-merge). + + Returns: + True if all checks pass, False if any mismatch found. + """ + final = _load_final_snapshots(input_path) + + # Load converted rows grouped by conversation_id + conv_rows: dict[str, list[dict]] = {} + with output_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + row = json.loads(line) + cid = row["conversation_id"] + conv_rows.setdefault(cid, []).append(row) + for cid in conv_rows: + conv_rows[cid].sort(key=lambda r: r["turn"]) + + errors: list[str] = [] + total_checked = 0 + + for conv_id in sorted(final): + record = final[conv_id] + system_content: str | None = None + non_system: list[dict] = [] + for msg in record["messages"]: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Re-apply the same collapses the converter applies, tracking source coverage + processed = _apply_collapses(non_system) # [(output_msg, last_source_idx), ...] + flat_rows = conv_rows.get(conv_id, []) + + if len(processed) != len(flat_rows): + errors.append( + f"{conv_id}: expected {len(processed)} flat rows after collapses, " + f"got {len(flat_rows)} in output" + ) + continue + + client_turn_pairs = [ + (out_pos, flat_row) + for out_pos, (flat_row, _) in enumerate( + zip(flat_rows, processed, strict=True) + ) + if flat_row["role"] in ("user", "tool") + ] + + for ct_idx, (out_pos, flat_row) in enumerate(client_turn_pairs): + # Ground truth: apply the same collapses the converter applies, then + # build the message list from the processed (collapsed/merged) rows up to + # and including this client turn. This correctly reflects what the + # converter produces β€” consecutive user/tool merges mean history is + # shorter than the raw source but content-equivalent. + expected: list[dict] = [] + if system_content: + expected.append({"role": "system", "content": system_content}) + for proc_msg, _ in processed[: out_pos + 1]: + expected.extend(_expand_row_to_wire_msgs(proc_msg)) + + # Reconstructed output: system + expand all flat rows up to this turn + got: list[dict] = [] + if system_content: + got.append({"role": "system", "content": system_content}) + for row in flat_rows[: out_pos + 1]: + got.extend(_expand_row_to_wire_msgs(row)) + + exp_norm = [_normalize_msg(m) for m in expected] + got_norm = [_normalize_msg(m) for m in got] + + if exp_norm != got_norm: + errors.append( + f"{conv_id} client-turn {ct_idx + 1} (flat turn {flat_row['turn']}):\n" + f" expected {len(exp_norm)} msgs, got {len(got_norm)}\n" + f" EXPECTED: {json.dumps(exp_norm, ensure_ascii=False)[:400]}\n" + f" GOT: {json.dumps(got_norm, ensure_ascii=False)[:400]}" + ) + total_checked += 1 + + if errors: + print( + f"FAIL: {len(errors)} mismatches out of {total_checked} client turns checked.", + file=sys.stderr, + ) + for err in errors[:20]: + print(err, file=sys.stderr) + if len(errors) > 20: + print(f" ... and {len(errors) - 20} more", file=sys.stderr) + return False + + print( + f"OK: all {total_checked} client turns verified against source.", + file=sys.stderr, + ) + return True + + +def convert(input_path: Path, output_path: Path) -> None: + # Group records by conversation_id, keep only the final snapshot per conversation. + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + + print(f"Found {len(final)} conversations in {input_path.name}", file=sys.stderr) + + rows_written = 0 + with output_path.open("w") as out: + for conv_id, record in sorted(final.items()): + messages = record["messages"] + tools = record.get("tools") or [] + + # Extract system message (always first if present). + system_content: str | None = None + non_system: list[dict] = [] + for msg in messages: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Apply the same user-collapse and tool-merge passes used by verify(). + # _apply_collapses returns [(msg, last_source_idx), ...]; strip the indices. + non_system = [msg for msg, _ in _apply_collapses(non_system)] + + first_user_seen = False + for position, msg in enumerate(non_system): + role = msg["role"] + turn = position + 1 # 1-indexed + + row: dict = {"conversation_id": conv_id, "turn": turn, "role": role} + + # System prompt on the first user row only. + if role == "user" and not first_user_seen: + if system_content is not None: + row["system"] = system_content + first_user_seen = True + + # tool_calls for assistant messages that dispatch tools. + if msg.get("tool_calls"): + row["tool_calls"] = msg["tool_calls"] + + if role == "tool": + # All tool rows use tool_results array (single results have one entry). + if msg.get("tool_results"): + row["tool_results"] = msg["tool_results"] + else: + row["tool_results"] = [ + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ] + else: + # content field (may be None for tool-dispatching assistant messages) + row["content"] = msg.get("content") + + # Attach tool definitions to client-turn rows only (user + tool). + # This avoids duplicating the large tools array on every assistant row + # while still making them available via load_sample(). + if role in ("user", "tool") and tools: + row["tools"] = tools + + out.write(json.dumps(row, ensure_ascii=False) + "\n") + rows_written += 1 + + print(f"Wrote {rows_written} rows to {output_path}", file=sys.stderr) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert agentic snapshot JSONL to MultiTurnDataset flat-row JSONL." + ) + parser.add_argument("input", type=Path, help="Input snapshot JSONL file") + parser.add_argument("output", type=Path, help="Output flat-row JSONL file") + parser.add_argument( + "--verify", + action="store_true", + help=( + "After converting, cross-check every client-turn's pre_built_messages " + "against the source snapshot. Exits with code 1 if any mismatch found." + ), + ) + args = parser.parse_args() + + if not args.input.exists(): + print(f"Error: input file not found: {args.input}", file=sys.stderr) + sys.exit(1) + + convert(args.input, args.output) + + if args.verify: + ok = verify(args.input, args.output) + if not ok: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py new file mode 100644 index 00000000..d2bb7177 --- /dev/null +++ b/scripts/validate_jsonl_schema.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validate multi-turn JSONL dataset files against multi_turn_dataset_schema.json. + +Checks each row's structure against the JSON schema (field types, required fields, +tool_results shape, etc.). Does NOT check cross-row invariants such as turn +numbering or role sequences β€” those are enforced by MultiTurnDataset at load time. + +Usage: + python scripts/validate_jsonl_schema.py FILE [FILE ...] + python scripts/validate_jsonl_schema.py /model/agentic_coding_flat.jsonl /model/agentic_workflow_flat.jsonl +""" + +import argparse +import json +import sys +from pathlib import Path + +try: + import jsonschema +except ImportError: + print( + "Error: jsonschema not installed. Run: pip install jsonschema", file=sys.stderr + ) + sys.exit(1) + + +def validate_file(path: Path, schema: dict, max_errors: int = 50) -> int: + """Validate every row in a JSONL file against the schema. + + Returns the number of validation errors found. + """ + errors: list[str] = [] + validator = jsonschema.Draft7Validator(schema) + + with path.open() as fh: + for lineno, line in enumerate(fh, 1): + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except json.JSONDecodeError as e: + errors.append(f" line {lineno}: JSON parse error: {e}") + if len(errors) >= max_errors: + break + continue + + conv_id = row.get("conversation_id", "") + turn = row.get("turn", "?") + role = row.get("role", "?") + + row_errors = list(validator.iter_errors(row)) + for err in row_errors: + path_str = " -> ".join(str(p) for p in err.absolute_path) or "(root)" + errors.append( + f" line {lineno} [{conv_id} turn={turn} role={role}] " + f"@ {path_str}: {err.message}" + ) + + if len(errors) >= max_errors: + errors.append(f" ... stopping after {max_errors} errors") + break + + if errors: + print(f"FAIL {path.name}: {len(errors)} error(s)") + for msg in errors: + print(msg) + else: + print(f"OK {path.name}") + + return len(errors) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Validate multi-turn JSONL files against multi_turn_dataset_schema.json." + ) + parser.add_argument("files", nargs="+", type=Path, help="JSONL files to validate") + parser.add_argument( + "--schema", + type=Path, + default=Path(__file__).parent.parent / "multi_turn_dataset_schema.json", + help="Path to the JSON schema file (default: multi_turn_dataset_schema.json)", + ) + parser.add_argument( + "--max-errors", + type=int, + default=50, + help="Stop reporting after this many errors per file (default: 50)", + ) + args = parser.parse_args() + + if not args.schema.exists(): + print(f"Error: schema not found: {args.schema}", file=sys.stderr) + sys.exit(1) + + schema = json.load(args.schema.open()) + + total_errors = 0 + for path in args.files: + if not path.exists(): + print(f"Error: file not found: {path}", file=sys.stderr) + total_errors += 1 + continue + total_errors += validate_file(path, schema, max_errors=args.max_errors) + + sys.exit(1 if total_errors > 0 else 0) + + +if __name__ == "__main__": + main() diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index ba9a02ea..56bb8278 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -101,11 +101,29 @@ def get_or_create( ) return self._conversations[conversation_id] + def _log_if_complete(self, state: ConversationState, conversation_id: str) -> None: + """Log completion status once all expected turns have a response.""" + if not state.is_complete(): + return + if state.failed_turns > 0: + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + else: + logger.debug( + f"Conversation {conversation_id} completed: " + f"{state.completed_turns}/{state.expected_client_turns} turns" + ) + def mark_turn_complete( self, conversation_id: str, response: str, store_in_history: bool = False, + metadata: dict[str, Any] | None = None, ) -> None: """Record a successful response and wake the pipeline task. @@ -113,6 +131,8 @@ def mark_turn_complete( conversation_id: Conversation ID. response: Model output (appended to history when store_in_history=True). store_in_history: When True, append response to message_history. + metadata: Optional response metadata; tool_calls are preserved in history + when present (only used when store_in_history=True). Raises: KeyError: If conversation_id not found. @@ -120,22 +140,15 @@ def mark_turn_complete( state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - if store_in_history and response: - state.message_history.append({"role": "assistant", "content": response}) + if store_in_history: + tool_calls = metadata.get("tool_calls") if metadata else None + if response or tool_calls: + msg: dict[str, Any] = {"role": "assistant", "content": response or None} + if tool_calls: + msg["tool_calls"] = tool_calls + state.message_history.append(msg) state.completed_turns += 1 - if state.is_complete(): - if state.failed_turns > 0: - logger.info( - f"Conversation {conversation_id} completed with failures: " - f"{state.completed_turns - state.failed_turns}/" - f"{state.expected_client_turns} successful, " - f"{state.failed_turns} failed" - ) - else: - logger.debug( - f"Conversation {conversation_id} completed: " - f"{state.completed_turns}/{state.expected_client_turns} turns" - ) + self._log_if_complete(state, conversation_id) state.turn_done.set() def mark_turn_failed( @@ -164,11 +177,5 @@ def mark_turn_failed( state.completed_turns += 1 state.failed_turns += 1 logger.warning(f"Turn failed for conversation {conversation_id}") - if state.is_complete(): - logger.info( - f"Conversation {conversation_id} completed with failures: " - f"{state.completed_turns - state.failed_turns}/" - f"{state.expected_client_turns} successful, " - f"{state.failed_turns} failed" - ) + self._log_if_complete(state, conversation_id) state.turn_done.set() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index cfd418bd..0f3ba6e8 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -227,5 +227,8 @@ def on_sample_complete(self, result: QueryResult) -> None: ) else: self._conv_manager.mark_turn_complete( - conv_id, response_text, store_in_history=self._store_in_history + conv_id, + response_text, + store_in_history=self._store_in_history, + metadata=result.metadata, ) diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 0edbb34f..1eecc75d 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -313,7 +313,7 @@ async def test_live_history_initializes_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - await conv_manager.mark_turn_complete("conv1", "response") + conv_manager.mark_turn_complete("conv1", "response") asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -342,7 +342,7 @@ async def test_live_history_no_system_prompt_when_none(): async def complete_turn(): await asyncio.sleep(0.01) - await conv_manager.mark_turn_complete("conv1", "response") + conv_manager.mark_turn_complete("conv1", "response") asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -369,7 +369,7 @@ async def test_dataset_history_mode_does_not_inject_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - await conv_manager.mark_turn_complete("conv1", "response") + conv_manager.mark_turn_complete("conv1", "response") asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -378,3 +378,108 @@ async def complete_turn(): assert state is not None # message_history should be empty (dataset-history mode doesn't accumulate) assert len(state.message_history) == 0 + + +@pytest.mark.unit +def test_mark_turn_complete_preserves_tool_calls(): + """mark_turn_complete stores tool_calls in history when metadata contains them.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 1 + msg = state.message_history[0] + assert msg["role"] == "assistant" + assert msg["content"] is None + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_with_response_and_tool_calls(): + """mark_turn_complete stores both content and tool_calls when both are present.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="Calling search...", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + msg = state.message_history[0] + assert msg["content"] == "Calling search..." + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_no_history_when_empty(): + """mark_turn_complete does not append when response is empty and no tool_calls.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + conv_manager.mark_turn_complete("conv1", response="", store_in_history=True) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_passes_metadata(): + """on_sample_complete forwards result.metadata (including tool_calls) to ConversationManager.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata_dict = _make_metadata_with_system({"conv1": [1]}) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata_dict, multi_turn_config=mt_cfg) + + conv_manager.get_or_create("conv1", expected_client_turns=1) + strategy._inflight["q0001"] = "conv1" + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": "{}"}, + } + ] + result = QueryResult( + id="q0001", + response_output=TextModelOutput(output=""), + metadata={"tool_calls": tool_calls}, + ) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_turns == 1 + assert len(state.message_history) == 1 + assert state.message_history[0]["tool_calls"] == tool_calls + assert state.message_history[0]["content"] is None From aca5431e35f1afadbe4e58e8b5c6904c62958b66 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Sat, 25 Apr 2026 15:24:50 -0700 Subject: [PATCH 06/13] fix: replace hardcoded /model/ path in validate_jsonl_schema.py docstring Co-Authored-By: Claude Sonnet 4.6 --- scripts/validate_jsonl_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py index d2bb7177..c2d25eca 100644 --- a/scripts/validate_jsonl_schema.py +++ b/scripts/validate_jsonl_schema.py @@ -22,7 +22,7 @@ Usage: python scripts/validate_jsonl_schema.py FILE [FILE ...] - python scripts/validate_jsonl_schema.py /model/agentic_coding_flat.jsonl /model/agentic_workflow_flat.jsonl + python scripts/validate_jsonl_schema.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl """ import argparse From 0a7ad3786e61be9f7e3a55fe525654e5b4008ac3 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Sat, 25 Apr 2026 15:32:13 -0700 Subject: [PATCH 07/13] chore: move multi_turn_dataset_schema.json into scripts/ and update default path Co-Authored-By: Claude Sonnet 4.6 --- scripts/multi_turn_dataset_schema.json | 557 +++++++++++++++++++++++++ scripts/validate_jsonl_schema.py | 8 +- 2 files changed, 561 insertions(+), 4 deletions(-) create mode 100644 scripts/multi_turn_dataset_schema.json diff --git a/scripts/multi_turn_dataset_schema.json b/scripts/multi_turn_dataset_schema.json new file mode 100644 index 00000000..b1b7ca13 --- /dev/null +++ b/scripts/multi_turn_dataset_schema.json @@ -0,0 +1,557 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Multi-Turn Conversation Dataset Schema", + "description": "JSON schema describing the structure and requirements for multi-turn conversation datasets in the MLPerf Inference Endpoint Benchmarking System", + "version": "1.0.0", + + "definitions": { + "basicMessageTypes": { + "title": "Basic Message Types", + "description": "Plain conversational messages without tool calls", + "oneOf": [ + { + "title": "User Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation (1-indexed)" + }, + "role": { + "const": "user", + "description": "Message role - user initiates turns" + }, + "content": { + "type": "string", + "description": "Message content from the user" + }, + "system": { + "type": "string", + "description": "System prompt (typically only on first user turn)" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "additionalProperties": true + }, + { + "title": "Assistant Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Message role - assistant responds to user" + }, + "content": { + "type": "string", + "description": "Message content from the assistant" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "not": { "required": ["tool_calls"] }, + "additionalProperties": true + } + ] + }, + + "toolCallMessage": { + "title": "Assistant Message with Tool Calls", + "description": "Assistant message that dispatches one or more tool calls. Role must be 'assistant' with a non-empty tool_calls array (OpenAI wire format).", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Role for a tool-dispatching assistant message." + }, + "content": { + "type": ["string", "null"], + "description": "Optional textual prefix alongside tool dispatch (e.g., 'I will investigate this with bash'). Typically null/absent." + }, + "tool_calls": { + "type": "array", + "minItems": 1, + "description": "List of tool calls dispatched by the assistant", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for this tool call (e.g., 'functions.bash:0')" + }, + "type": { + "type": "string", + "const": "function", + "description": "Tool type (currently only 'function' supported)" + }, + "function": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the tool/function to invoke" + }, + "arguments": { + "type": "string", + "description": "JSON string containing function arguments" + } + }, + "required": ["name", "arguments"] + } + }, + "required": ["id", "type", "function"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_calls"], + "additionalProperties": true + }, + + "toolMessage": { + "title": "Tool Result Message", + "description": "Tool execution results as a list. Single results have one entry; parallel results have multiple entries.", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "tool", + "description": "Tool result role" + }, + "tool_results": { + "type": "array", + "minItems": 1, + "description": "List of tool execution results. Single tool calls have one entry; parallel tool calls have multiple entries.", + "items": { + "type": "object", + "properties": { + "tool_call_id": { + "type": "string", + "description": "ID of the tool call this result corresponds to" + }, + "content": { + "type": "string", + "description": "Output/result content from the tool execution" + } + }, + "required": ["tool_call_id", "content"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_results"], + "additionalProperties": true + }, + + "generationParameters": { + "title": "Generation Parameters", + "description": "Optional parameters controlling the model's behavior for generation", + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "Model name override for this turn" + }, + "max_new_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate" + }, + "max_completion_tokens": { + "type": "integer", + "minimum": 1, + "description": "OpenAI API compatible max tokens parameter" + }, + "stream": { + "type": "boolean", + "description": "Whether to use streaming for this turn" + }, + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 2, + "description": "Sampling temperature (0 = deterministic, higher = more random)" + }, + "top_p": { + "type": "number", + "minimum": 0, + "maximum": 1, + "description": "Nucleus sampling parameter" + }, + "top_k": { + "type": "integer", + "minimum": 1, + "description": "Top-k sampling parameter" + }, + "seed": { + "type": "integer", + "description": "Random seed for reproducibility" + }, + "repetition_penalty": { + "type": "number", + "minimum": 0, + "description": "Penalty for repeating tokens" + }, + "frequency_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Frequency penalty for tokens" + }, + "presence_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Presence penalty for tokens" + }, + "stop": { + "oneOf": [ + { "type": "string" }, + { "type": "array", "items": { "type": "string" } } + ], + "description": "Stop sequences for generation" + }, + "n": { + "type": "integer", + "minimum": 1, + "description": "Number of completions to generate" + }, + "logit_bias": { + "type": "object", + "description": "Token probability adjustments (token_id -> bias)" + }, + "name": { + "type": "string", + "description": "Entity name for role tracking (e.g., 'Bob')" + }, + "user": { + "type": "string", + "description": "End-user identifier for monitoring/abuse detection" + }, + "chat_template": { + "type": "string", + "description": "Custom chat formatting template" + }, + "tools": { + "type": "array", + "description": "OpenAI tool definitions for tool-calling models", + "items": { "type": "object" } + } + } + } + }, + + "type": "object", + "oneOf": [ + { + "title": "Plain Conversation Row", + "description": "A single row representing a plain user or assistant message", + "allOf": [ + { "$ref": "#/definitions/basicMessageTypes" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Call Row", + "description": "A single row representing an assistant dispatch of tool calls", + "allOf": [ + { "$ref": "#/definitions/toolCallMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Result Row", + "description": "A single row representing one or more tool results", + "allOf": [ + { "$ref": "#/definitions/toolMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + } + ], + + "examples": [ + { + "title": "Basic user message", + "data": { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "I need help resetting my password", + "system": "You are a helpful customer support agent" + } + }, + { + "title": "Assistant response", + "data": { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'd be happy to help. Can you provide your email address?" + } + }, + { + "title": "Assistant with tool calls (converter/OpenAI wire format)", + "data": { + "conversation_id": "sim_001", + "turn": 2, + "role": "assistant", + "tool_calls": [ + { + "id": "functions.bash:0", + "type": "function", + "function": { + "name": "bash", + "arguments": "{\"cmd\": \"cat foo.py\"}" + } + } + ] + } + }, + { + "title": "Tool result from execution", + "data": { + "conversation_id": "sim_001", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "def foo():\n return 1/0" + } + ] + } + }, + { + "title": "Merged parallel tool results", + "data": { + "conversation_id": "sim_002", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "file1.txt" + }, + { + "tool_call_id": "functions.bash:1", + "content": "file2.txt" + } + ] + } + }, + { + "title": "User turn with generation parameters", + "data": { + "conversation_id": "conv_002", + "turn": 5, + "role": "user", + "content": "What's the best way to optimize this code?", + "temperature": 0.7, + "max_new_tokens": 256, + "top_p": 0.9 + } + } + ], + + "documentation": { + "overview": "Multi-turn conversation datasets enable benchmarking of realistic conversational AI workloads where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing.", + + "requiredFields": [ + { + "field": "conversation_id", + "type": "string", + "description": "Unique identifier for each conversation. All rows belonging to the same conversation must share the same conversation_id." + }, + { + "field": "turn", + "type": "integer", + "description": "Turn number within conversation (1-indexed). Must be consecutive starting at 1 (i.e., 1, 2, 3, …, N with no gaps or duplicates)." + }, + { + "field": "role", + "type": "string", + "enum": ["user", "assistant", "tool"], + "description": "Speaker role. 'user' or 'tool' are client-initiated turns. 'assistant' is the server response β€” either a terminal reply or a tool dispatch (with tool_calls field)." + }, + { + "field": "content", + "type": "string", + "description": "Message content. Required for 'user' role and plain 'assistant' rows. For tool-dispatching assistant rows, content may be omitted (null/absent). For 'tool' rows using tool_results (merged parallel results), top-level content is absent β€” results are in the tool_results array instead." + } + ], + + "optionalFields": [ + { + "field": "system", + "type": "string", + "description": "System prompt (typically only on first user turn). Applied to all messages in the conversation." + }, + { + "field": "tool_calls", + "type": "array", + "description": "Tool calls dispatched by assistant (for tool-dispatching 'assistant' rows). Each element has {id, type, function: {name, arguments}}." + }, + { + "field": "tool_results", + "type": "array", + "description": "Tool execution results (required for all 'tool' role rows). Each element has {tool_call_id, content}. Single results have one entry; parallel results have multiple entries." + }, + { + "field": "model", + "type": "string", + "description": "Model name override for this turn." + }, + { + "field": "max_new_tokens", + "type": "integer", + "description": "Maximum tokens to generate for this turn." + }, + { + "field": "temperature", + "type": "number", + "description": "Sampling temperature (0 to 2)." + }, + { + "field": "top_p", + "type": "number", + "description": "Nucleus sampling parameter (0 to 1)." + }, + { + "field": "tools", + "type": "array", + "description": "OpenAI tool definitions forwarded to the endpoint for tool-calling models. The converter attaches this only to client-turn rows (user and tool) to avoid duplicating the large array on every assistant row. Hand-authored datasets typically place it on the first user turn." + } + ], + + "validRoleSequences": [ + { + "name": "Plain conversation", + "sequence": "user β†’ assistant β†’ user β†’ assistant β†’ ...", + "description": "Standard alternating conversation without tool use." + }, + { + "name": "Agentic with tools", + "sequence": "user β†’ assistant β†’ tool β†’ [assistant β†’ tool]* β†’ assistant β†’ user", + "description": "Agent dispatches tools (assistant with tool_calls), executes them, and returns results before final response. 'tool β†’ user' is also valid when no terminal assistant response is needed before the next user turn." + } + ], + "stateMachine": { + "description": "Complete valid-next-state table from _validate_conversation_structure()", + "transitions": { + "start": ["user"], + "user": ["assistant"], + "assistant": ["tool", "user"], + "tool": ["assistant", "user"] + } + }, + + "validationRules": [ + { + "rule": "Turn numbers must be consecutive starting at 1", + "violation": "Turn sequence is not exactly 1, 2, 3, …, N (missing, duplicate, or out-of-range turns)" + }, + { + "rule": "Role sequences must follow the state machine", + "violation": "Invalid transition (e.g., 'user' directly followed by 'user', consecutive 'assistant' rows). Note: 'tool β†’ user' IS a valid transition. The state machine also implicitly enforces that the first row must be a user turn." + } + ], + "notValidated": [ + "tool_results[*].tool_call_id pairing: the validator does NOT verify that tool_call_id values inside tool_results items reference a prior assistant tool_calls entry. Correct pairing is the dataset author's or converter's responsibility." + ], + + "dataTypes": [ + { + "name": "Basic Message", + "roles": ["user", "assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "content"], + "optional": ["system", "...generation parameters"] + } + }, + { + "name": "Tool Call Dispatch", + "roles": ["assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_calls"], + "optional": ["content", "...generation parameters"] + } + }, + { + "name": "Tool Result", + "roles": ["tool"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_results"], + "optional": ["...generation parameters"], + "note": "Expands to one OpenAI tool message per result entry at the wire layer" + } + } + ], + + "conversionFromSnapshot": { + "description": "Agentic datasets are often stored as full-conversation snapshots. Use scripts/convert_agentic_snapshot.py to convert.", + "sourceFormat": "Each JSONL line is a complete conversation snapshot with 'messages' array", + "targetFormat": "Each JSONL line is a single message row with conversation metadata", + "process": [ + "Extract conversation_id, conversation_idx, and messages array", + "Use highest-indexed snapshot per conversation_id", + "Collapse consecutive user messages into a single user row (newline-joined content)", + "Emit all tool result rows as tool_results arrays (single results have one entry; consecutive tool results from parallel dispatch are merged into one row with multiple entries)", + "Flatten messages into individual rows, numbering turns sequentially from 1", + "Attach system prompt to first user row only", + "Attach tools array to client-turn rows (user and tool) only β€” not to assistant rows", + "Tool-dispatching assistant messages are written as role 'assistant' with tool_calls (OpenAI wire format)" + ] + }, + + "performanceNotes": [ + "Pre-built messages: The system pre-computes complete message lists during dataset load() for efficient turn serving", + "Memory efficiency: ~1KB per turn average; 1000 conversations Γ— 10 turns = ~10MB", + "Hot path: Only client turns (user/tool) are issued; assistant turns remain in backing store for history" + ], + + "commonErrors": [ + { + "error": "Invalid role sequence", + "cause": "Violates state machine (e.g., userβ†’user, userβ†’tool, consecutive assistant rows)", + "fix": "Verify alternation or use conversion script for agentic data" + }, + { + "error": "Turn numbers not consecutive", + "cause": "Turn sequence has gaps, duplicates, or doesn't start at 1", + "fix": "Ensure turns are numbered 1, 2, 3, …, N with no missing or duplicate values" + }, + { + "error": "Turn timeout", + "cause": "Previous turn took too long to complete", + "fix": "Increase turn_timeout_s in configuration or check endpoint performance" + } + ] + } +} diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py index c2d25eca..1be81dd2 100644 --- a/scripts/validate_jsonl_schema.py +++ b/scripts/validate_jsonl_schema.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Validate multi-turn JSONL dataset files against multi_turn_dataset_schema.json. +"""Validate multi-turn JSONL dataset files against scripts/multi_turn_dataset_schema.json. Checks each row's structure against the JSON schema (field types, required fields, tool_results shape, etc.). Does NOT check cross-row invariants such as turn @@ -88,14 +88,14 @@ def validate_file(path: Path, schema: dict, max_errors: int = 50) -> int: def main() -> None: parser = argparse.ArgumentParser( - description="Validate multi-turn JSONL files against multi_turn_dataset_schema.json." + description="Validate multi-turn JSONL files against scripts/multi_turn_dataset_schema.json." ) parser.add_argument("files", nargs="+", type=Path, help="JSONL files to validate") parser.add_argument( "--schema", type=Path, - default=Path(__file__).parent.parent / "multi_turn_dataset_schema.json", - help="Path to the JSON schema file (default: multi_turn_dataset_schema.json)", + default=Path(__file__).parent / "multi_turn_dataset_schema.json", + help="Path to the JSON schema file (default: scripts/multi_turn_dataset_schema.json)", ) parser.add_argument( "--max-errors", From 1140361127ebc74d5c28deea4feb6143b5800bd4 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Tue, 28 Apr 2026 10:37:22 -0700 Subject: [PATCH 08/13] fix: address PR #285 review comments for multi-turn implementation Fix 15 review issues across severity levels: - HIGH: metadata=None crash in msgspec adapter, silent exception swallowing in gather - MEDIUM: timeout state consistency, conv_id canonicalization, PromptData fallback, conv_id guard - LOW: enum comparison, frozen config, empty tool_results warning, adapter metadata extraction, groupby deduplication, live-history tool warning, asyncio.Event docs, test TODO Co-Authored-By: Claude Opus 4.6 --- .../config/runtime_settings.py | 3 +- src/inference_endpoint/config/schema.py | 2 +- .../dataset_manager/multi_turn_dataset.py | 50 +++++++++++-------- .../load_generator/conversation_manager.py | 1 + .../load_generator/multi_turn_strategy.py | 32 +++++++++++- .../load_generator/session.py | 9 +++- .../openai/openai_adapter.py | 13 ++++- .../openai/openai_msgspec_adapter.py | 2 +- tests/integration/test_multi_turn.py | 6 ++- 9 files changed, 88 insertions(+), 30 deletions(-) diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index a3fb3106..eac1aa47 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -32,6 +32,7 @@ from typing import TYPE_CHECKING from .. import metrics +from .schema import LoadPatternType logger = logging.getLogger(__name__) @@ -197,7 +198,7 @@ def total_samples_to_issue( # Multi-turn must issue exactly all client turns β€” QPS-based formulas are meaningless. if ( self.load_pattern is not None - and self.load_pattern.type.value == "multi_turn" + and self.load_pattern.type == LoadPatternType.MULTI_TURN ): result = max(self.min_sample_count, self.n_samples_from_dataset) logger.debug( diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index f34f82d5..2846268c 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -250,7 +250,7 @@ class MultiTurnConfig(BaseModel): use_dataset_history: If True, use pre-built message history from dataset. """ - model_config = {"extra": "forbid"} + model_config = ConfigDict(extra="forbid", frozen=True) mode: ConversationMode = ConversationMode.INDEPENDENT turn_timeout_s: float = 300.0 diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 574619c8..c75f285d 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -15,6 +15,7 @@ """Multi-turn conversation dataset for conversational AI benchmarking.""" +import logging from typing import Any import pandas as pd @@ -29,6 +30,8 @@ get_transforms_for_api_type, ) +logger = logging.getLogger(__name__) + def _expand_tool_results(row: dict) -> list[dict]: """Expand a tool row into one OpenAI tool message per result. @@ -41,6 +44,13 @@ def _expand_tool_results(row: dict) -> list[dict]: tool_results = row.get("tool_results") if not isinstance(tool_results, list): return [] + if not tool_results: + logger.warning( + "Row has empty tool_results list (conversation_id=%s, turn=%s)", + row.get("conversation_id"), + row.get("turn"), + ) + return [] return [ { "role": "tool", @@ -94,6 +104,8 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): ValueError: If conversation structure is invalid. """ super().__init__(dataframe, **kwargs) + assert self.dataframe is not None, "Dataframe must be initialized" + self._conv_groups = dict(list(self.dataframe.groupby("conversation_id"))) self._validate_conversation_grouping() self._validate_conversation_structure() self._validate_turn_numbering() @@ -128,10 +140,6 @@ def _validate_conversation_structure(self): Raises: ValueError: If any conversation has invalid role sequence. """ - assert self.dataframe is not None, "Dataframe must be initialized" - - # Valid state transitions (flat 4-state machine β€” no assistant_tc node, - # no toolβ†’tool; converter always merges consecutive tool rows into tool_results) VALID_NEXT: dict[str, set[str]] = { "start": {"user"}, "user": {"assistant"}, @@ -139,7 +147,7 @@ def _validate_conversation_structure(self): "tool": {"assistant", "user"}, } - for conv_id, group in self.dataframe.groupby("conversation_id"): + for conv_id, group in self._conv_groups.items(): sorted_group = group.sort_values("turn") state = "start" @@ -159,9 +167,7 @@ def _validate_turn_numbering(self): Raises: ValueError: If turn numbers are not exactly 1, 2, 3, …, N. """ - assert self.dataframe is not None, "Dataframe must be initialized" - - for conv_id, group in self.dataframe.groupby("conversation_id"): + for conv_id, group in self._conv_groups.items(): turns = sorted(group["turn"].tolist()) expected = list(range(1, len(turns) + 1)) if turns != expected: @@ -180,14 +186,13 @@ def _build_metadata(self) -> dict[str, Any]: Metadata dict with samples list, num_conversations, max_turns_per_conv, client_turns_per_conversation, and pre_built_messages_by_key. """ - assert self.dataframe is not None, "Dataframe must be initialized" samples = [] - client_turns_df = self.dataframe[self.dataframe["role"].isin(["user", "tool"])] # Count client turns (user + tool) per conversation for completion tracking - client_turns_per_conv = ( - client_turns_df.groupby("conversation_id").size().to_dict() - ) + client_turns_per_conv = { + str(conv_id): int(group["role"].isin(["user", "tool"]).sum()) + for conv_id, group in self._conv_groups.items() + } # Map (conversation_id, turn) β†’ complete message list ready to send to endpoint. # Each entry is: [system (optional)] + all prior rows formatted as messages @@ -198,7 +203,7 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_messages_by_key: dict[tuple, list[dict]] = {} system_prompts_by_conv: dict[str, str | None] = {} - for conv_id, group in self.dataframe.groupby("conversation_id"): + for conv_id, group in self._conv_groups.items(): sorted_group = group.sort_values("turn") client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] @@ -263,23 +268,24 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_msgs = [cur] messages.extend(current_turn_msgs) - pre_built_messages_by_key[(conv_id, t_n)] = messages - current_turn_messages_by_key[(conv_id, t_n)] = current_turn_msgs + str_conv_id = str(conv_id) + pre_built_messages_by_key[(str_conv_id, t_n)] = messages + current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs samples.append( { "index": idx, - "conversation_id": conv_id, + "conversation_id": str_conv_id, "turn": t_n, } ) return { "samples": samples, - "num_conversations": self.dataframe["conversation_id"].nunique(), - "max_turns_per_conv": self.dataframe.groupby("conversation_id")["turn"] - .max() - .max(), + "num_conversations": len(self._conv_groups), + "max_turns_per_conv": max( + g["turn"].max() for g in self._conv_groups.values() + ), "client_turns_per_conversation": client_turns_per_conv, "pre_built_messages_by_key": pre_built_messages_by_key, "current_turn_messages_by_key": current_turn_messages_by_key, @@ -388,7 +394,7 @@ def load( sample["stream"] = False # Attach pre-built message list (system + history + current turn). - key = (row["conversation_id"], int(row["turn"])) + key = (str(row["conversation_id"]), int(row["turn"])) messages = pre_built.get(key, []) sample["messages"] = messages diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 56bb8278..30276d5e 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -42,6 +42,7 @@ class ConversationState: """ conversation_id: str + # Python 3.12+: asyncio.Event no longer requires a running loop at construction. turn_done: asyncio.Event = field(default_factory=asyncio.Event) message_history: list[dict[str, Any]] = field(default_factory=list) completed_turns: int = 0 diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 0f3ba6e8..69697129 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -134,7 +134,12 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: for conv_id, turns in conv_samples.items() ] - await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*tasks, return_exceptions=True) + errors = [r for r in results if isinstance(r, BaseException)] + for err in errors: + logger.error(f"Conversation pipeline failed: {err}") + if errors: + raise errors[0] return phase_issuer.issued_count async def _conv_pipeline( @@ -150,6 +155,7 @@ async def _conv_pipeline( """ state = self._conv_states[conv_id] sorted_turns = sorted(turns, key=lambda x: x[1]) + last_query_id: str | None = None for i, (idx, turn) in enumerate(sorted_turns): if i > 0: @@ -161,7 +167,13 @@ async def _conv_pipeline( logger.warning( f"Turn {turn} of {conv_id} timed out waiting for previous turn" ) - state.failed_turns += 1 + if last_query_id is not None: + self._inflight.pop(last_query_id, None) + remaining = len(sorted_turns) - i + for _ in range(remaining): + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) break state.turn_done.clear() @@ -177,6 +189,17 @@ async def _conv_pipeline( "current_turn_messages_by_key", {} ).get((conv_id, turn)) if current_turn_messages: + has_tool_msg = any( + m.get("role") == "tool" for m in current_turn_messages + ) + if has_tool_msg: + logger.warning( + "Live-history mode with tool messages uses dataset " + "tool_call_ids; real endpoint IDs will differ " + "(conv=%s, turn=%d)", + conv_id, + turn, + ) live_messages = state.message_history.copy() + current_turn_messages data_override = {"messages": live_messages} @@ -188,6 +211,7 @@ async def _conv_pipeline( break self._inflight[query_id] = conv_id + last_query_id = query_id # Append current-turn messages to history so the next turn sees them. if self._store_in_history and current_turn_messages: @@ -219,6 +243,10 @@ def on_sample_complete(self, result: QueryResult) -> None: if conv_id is None: return + if self._conv_manager.get_state(conv_id) is None: + logger.warning(f"on_sample_complete: unknown conversation {conv_id}") + return + response_text = result.get_response_output_string() if result.error is not None: diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 8ae0e74f..d4b4c93a 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -202,8 +202,15 @@ def issue( prompt_data: PromptData if isinstance(data, dict): token_ids = data.get("input_tokens") or data.get("token_ids") + prompt_text = data.get("prompt") + if prompt_text is None and "messages" in data: + prompt_text = " ".join( + m.get("content", "") + for m in data["messages"] + if isinstance(m, dict) and m.get("content") + ) prompt_data = PromptData( - text=data.get("prompt"), + text=prompt_text, token_ids=tuple(token_ids) if token_ids is not None else None, ) else: diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 9c6f6ebd..85f208ca 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -14,6 +14,7 @@ # limitations under the License. import time +from typing import Any import msgspec from inference_endpoint.core.types import Query, QueryResult, TextModelOutput @@ -128,9 +129,19 @@ def from_endpoint_response( if result_id is None: result_id = response.id + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason.value + if choice.message.tool_calls: + metadata["tool_calls"] = [ + tc.model_dump(mode="json") for tc in choice.message.tool_calls + ] + return QueryResult( id=result_id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput(output=choice.message.content), + metadata=metadata, ) @classmethod diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index e8f15ce6..e512e22b 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -219,7 +219,7 @@ def from_endpoint_response( return QueryResult( id=result_id or response.id, response_output=TextModelOutput(output=choice.message.content or ""), - metadata=metadata if metadata else None, + metadata=metadata, ) @classmethod diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 87351700..8ea3666f 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -557,7 +557,11 @@ async def test_conversation_ending_with_tool_row(echo_server): @pytest.mark.integration @pytest.mark.asyncio async def test_tools_field_forwarded_to_endpoint(echo_server): - """The 'tools' array from the dataset reaches the endpoint in every request payload.""" + """The 'tools' array from the dataset reaches the endpoint in every request payload. + + TODO: Add a tool-call-aware server that returns dynamic tool_call_ids to + validate live-history mode with real tool_call_id round-tripping. + """ received_payloads: list[dict] = [] class CapturingEchoServer(EchoServer): From 3b9dd1e304a98a5ca83d642df7d70b99ba808be9 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Tue, 28 Apr 2026 13:18:31 -0700 Subject: [PATCH 09/13] fix: improve multi-turn PromptData text and add concurrent stress test Use newline separators (instead of spaces) when flattening messages to text for ISL estimation, and add a 12-conversation concurrent stress test. Co-Authored-By: Claude Sonnet 4.6 --- .../load_generator/session.py | 7 +-- tests/integration/test_multi_turn.py | 46 +++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index d4b4c93a..b602ca15 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -204,11 +204,12 @@ def issue( token_ids = data.get("input_tokens") or data.get("token_ids") prompt_text = data.get("prompt") if prompt_text is None and "messages" in data: - prompt_text = " ".join( - m.get("content", "") + parts: list[str] = [ + m["content"] for m in data["messages"] if isinstance(m, dict) and m.get("content") - ) + ] + prompt_text = "\n".join(parts) if parts else None prompt_data = PromptData( text=prompt_text, token_ids=tuple(token_ids) if token_ids is not None else None, diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 8ea3666f..cfe8a68c 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -554,6 +554,52 @@ async def test_conversation_ending_with_tool_row(echo_server): assert len(responses) == 2 +@pytest.mark.integration +@pytest.mark.asyncio +async def test_concurrent_conversations_stress(echo_server): + """12 conversations Γ— 3 turns each complete with correct counts.""" + num_convs = 12 + turns_per_conv = 3 # 2 user turns + 1 assistant turn each + rows = [] + for i in range(num_convs): + conv_id = f"stress_conv_{i}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2-{i}", + } + ) + + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 12 conversations Γ— 2 client turns each = 24 + expected_client_turns = num_convs * (turns_per_conv - 1) # 24 + assert count == expected_client_turns + assert len(responses) == expected_client_turns + + @pytest.mark.integration @pytest.mark.asyncio async def test_tools_field_forwarded_to_endpoint(echo_server): From 8ab45a1fcbb59f4239c4d5323ddc9f7d191d72a4 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 11:09:32 -0700 Subject: [PATCH 10/13] refactor: replace semaphore with worker-pool concurrency in MultiTurnStrategy target_concurrency now limits active conversations (not in-flight requests). N worker tasks pull from asyncio.Queue, each processing one full conversation before taking the next. Also adds slots=True back to PhaseConfig and sort=False to groupby for file-order preservation. Co-Authored-By: Claude Sonnet 4.6 --- .../dataset_manager/multi_turn_dataset.py | 4 +- .../load_generator/multi_turn_strategy.py | 81 ++++++++-------- .../load_generator/session.py | 2 +- .../test_multi_turn_strategy.py | 96 +++++++++++++++---- 4 files changed, 126 insertions(+), 57 deletions(-) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index c75f285d..cabfac79 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -105,7 +105,9 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): """ super().__init__(dataframe, **kwargs) assert self.dataframe is not None, "Dataframe must be initialized" - self._conv_groups = dict(list(self.dataframe.groupby("conversation_id"))) + self._conv_groups = dict( + list(self.dataframe.groupby("conversation_id", sort=False)) + ) self._validate_conversation_grouping() self._validate_conversation_structure() self._validate_turn_numbering() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 69697129..1ce3780e 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -32,18 +32,16 @@ class MultiTurnStrategy: - """Async multi-turn strategy. Spawns per-conversation asyncio.Tasks. + """Async multi-turn strategy. Uses a worker-pool to limit active conversations. - Each conversation runs as an independent asyncio.Task that enforces - sequential turn ordering: turn N+1 cannot be issued until turn N completes. - Conversations run concurrently β€” no cross-conversation synchronization. - - Optional target_concurrency limits total in-flight requests across all - conversations using asyncio.Semaphore. + N worker tasks pull from a queue of conversations. Each worker processes all + turns of one conversation before moving to the next, so at most N conversations + are active simultaneously. When target_concurrency is None, all conversations + run concurrently (one worker per conversation). Integration with BenchmarkSession: - - execute(): spawns conversation tasks, awaits all to complete - - on_query_complete(): releases semaphore slot (concurrency control only) + - execute(): populates queue, spawns workers, awaits all to complete + - on_query_complete(): no-op (required by LoadStrategy protocol) - on_sample_complete(): routes completed QueryResult to ConversationManager The response routing path: @@ -69,7 +67,8 @@ def __init__( conversation_manager: Manages conversation sequencing state. dataset_metadata: Metadata from MultiTurnDataset (samples list). multi_turn_config: Multi-turn conversation configuration. - target_concurrency: Optional maximum concurrent in-flight requests. + target_concurrency: Maximum number of simultaneously active conversations. + None means all conversations run concurrently. """ self._conv_manager = conversation_manager self._dataset_metadata = dataset_metadata @@ -80,11 +79,6 @@ def __init__( else _DEFAULT_TURN_TIMEOUT_S ) self._target_concurrency = target_concurrency - self._sem: asyncio.Semaphore | None = ( - asyncio.Semaphore(target_concurrency) - if target_concurrency is not None and target_concurrency > 0 - else None - ) self._store_in_history = ( not multi_turn_config.use_dataset_history if multi_turn_config is not None @@ -110,7 +104,7 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: conv_id = sample_meta["conversation_id"] conv_samples[conv_id].append((sample_index, sample_meta["turn"])) - # Pre-create all conversation states before spawning tasks (no locking needed). + # Pre-create all conversation states before spawning workers (no locking needed). sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) for conv_id, turns in conv_samples.items(): sys_content = sys_prompts.get(conv_id) if self._store_in_history else None @@ -126,15 +120,27 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: ) self._conv_states[conv_id] = state - tasks = [ + # Build queue of (conv_id, turns) pairs for workers to pull from. + conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]] = asyncio.Queue() + for conv_id, turns in conv_samples.items(): + await conv_queue.put((conv_id, turns)) + + n_conversations = len(conv_samples) + n_workers = ( + min(self._target_concurrency, n_conversations) + if self._target_concurrency is not None and self._target_concurrency > 0 + else n_conversations + ) + + worker_tasks = [ asyncio.create_task( - self._conv_pipeline(conv_id, turns, phase_issuer), - name=f"mt-pipeline-{conv_id}", + self._worker(conv_queue, phase_issuer), + name=f"mt-worker-{i}", ) - for conv_id, turns in conv_samples.items() + for i in range(n_workers) ] - results = await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*worker_tasks, return_exceptions=True) errors = [r for r in results if isinstance(r, BaseException)] for err in errors: logger.error(f"Conversation pipeline failed: {err}") @@ -142,6 +148,19 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: raise errors[0] return phase_issuer.issued_count + async def _worker( + self, + conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]], + phase_issuer: PhaseIssuerProtocol, + ) -> None: + """Pull conversations from queue and process each one fully before taking the next.""" + while True: + try: + conv_id, turns = conv_queue.get_nowait() + except asyncio.QueueEmpty: + break + await self._conv_pipeline(conv_id, turns, phase_issuer) + async def _conv_pipeline( self, conv_id: str, @@ -177,10 +196,6 @@ async def _conv_pipeline( break state.turn_done.clear() - # Acquire concurrency slot before issuing. - if self._sem is not None: - await self._sem.acquire() - # Live-history mode: build messages from accumulated history + current turn. data_override: dict[str, Any] | None = None current_turn_messages: list[dict[str, Any]] | None = None @@ -205,9 +220,7 @@ async def _conv_pipeline( query_id = phase_issuer.issue(idx, data_override=data_override) if query_id is None: - # Session stopping β€” release slot and exit. - if self._sem is not None: - self._sem.release() + # Session stopping β€” exit pipeline. break self._inflight[query_id] = conv_id @@ -218,16 +231,8 @@ async def _conv_pipeline( state.message_history.extend(current_turn_messages) def on_query_complete(self, query_id: str) -> None: - """Called by BenchmarkSession when a QueryResult arrives. - - Releases the concurrency semaphore slot. Response routing is done - via on_sample_complete (which receives the full QueryResult). - - Args: - query_id: ID of the completed query. - """ - if self._sem is not None: - self._sem.release() + """No-op. Required by LoadStrategy protocol; called by BenchmarkSession.""" + pass def on_sample_complete(self, result: QueryResult) -> None: """Route completed QueryResult to ConversationManager. diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index b602ca15..2b1c39b5 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -60,7 +60,7 @@ class PhaseType(str, Enum): WARMUP = "warmup" -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class PhaseConfig: """Configuration for a single benchmark phase.""" diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 1eecc75d..37cdfbad 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -209,23 +209,6 @@ async def test_turn_timeout_triggers_failure(): assert issuer.issued_count == 1 -@pytest.mark.unit -@pytest.mark.asyncio -async def test_on_query_complete_releases_semaphore(): - """on_query_complete releases the concurrency semaphore.""" - conv_manager = ConversationManager() - metadata = _make_dataset_metadata({"conv1": [1]}) - strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) - assert strategy._sem is not None - - # Acquire the semaphore manually - await strategy._sem.acquire() - assert strategy._sem._value == 0 # type: ignore[attr-defined] - - strategy.on_query_complete("some-query") - assert strategy._sem._value == 1 # type: ignore[attr-defined] - - @pytest.mark.unit @pytest.mark.asyncio async def test_on_sample_complete_routes_to_manager(): @@ -483,3 +466,82 @@ async def test_on_sample_complete_passes_metadata(): assert len(state.message_history) == 1 assert state.message_history[0]["tool_calls"] == tool_calls assert state.message_history[0]["content"] is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_concurrency_limits_active_conversations(): + """target_concurrency=2 starts at most 2 conversation pipelines simultaneously. + + Uses 2-turn conversations so each pipeline has an await point (turn_done.wait + between turns). With 4 conversations and 2 workers, the 3rd and 4th conversations + cannot start until a worker finishes its current conversation. + """ + conv_manager = ConversationManager() + # 4 two-turn conversations; pipeline awaits turn-1 response before issuing turn-2 + metadata = _make_dataset_metadata( + {"conv1": [1, 2], "conv2": [1, 2], "conv3": [1, 2], "conv4": [1, 2]} + ) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=2) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + execute_task = asyncio.create_task(strategy.execute(issuer)) + + # Let both workers start and block on turn_done.wait before auto_respond fires + await asyncio.sleep(0.01) + + # Only 2 workers β†’ exactly 2 turn-1 queries issued (conv3/conv4 not started yet) + assert issuer.issued_count == 2 + + await asyncio.wait_for(execute_task, timeout=5.0) + responder_task.cancel() + + assert issuer.issued_count == 8 # 4 conversations Γ— 2 turns + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_slot_reuse(): + """With target_concurrency=1, worker completes conv1 before starting conv2. + + Uses 2-turn conversations so the pipeline has an await between turns. + The single worker must process both turns of conv1 before conv2's turn 1 is issued. + """ + conv_manager = ConversationManager() + # 2 two-turn conversations; sample indices: conv1β†’[0,1], conv2β†’[2,3] + metadata = _make_dataset_metadata({"conv1": [1, 2], "conv2": [1, 2]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + await strategy.execute(issuer) + responder_task.cancel() + + # Single worker: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) + assert issuer.issued[:2] == [0, 1], "Conv1 turns should be issued before conv2" + assert issuer.issued[2:] == [2, 3], "Conv2 turns should follow conv1" From 0d66900645d6612b7e98a37265f7781b6b25d6c2 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 13:25:55 -0700 Subject: [PATCH 11/13] fix: address remaining PR #285 review comments for multi-turn implementation - openai_adapter: normalize null content to "" instead of literal "None" to avoid polluting conversation history in tool-calling responses - multi_turn_dataset: validate tool_results entries have required tool_call_id and content fields; raise InputValidationError at load time - multi_turn_dataset: remove unused "index" field from samples metadata - multi_turn_strategy: wrap mark_turn_complete/mark_turn_failed in try/except KeyError in on_sample_complete - multi_turn_strategy: clear _inflight at end of execute() with warning if entries remain (transport failure or session abort) - docs: remove prescriptive concurrency sizing guide; replace with definition of what target_concurrency controls - docs: rename "Long Conversations" to "Conversations with Many Turns" - docs: add dataset validation utility reference in Troubleshooting Co-Authored-By: Claude Sonnet 4.6 --- docs/MULTI_TURN_QUICKSTART.md | 19 +++++++---- .../dataset_manager/multi_turn_dataset.py | 30 ++++++++++------ .../load_generator/multi_turn_strategy.py | 34 ++++++++++++++----- .../openai/openai_adapter.py | 2 +- .../test_multi_turn_dataset.py | 1 - 5 files changed, 58 insertions(+), 28 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 99b35aa5..4e8e5e58 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -152,12 +152,6 @@ settings: target_concurrency: 32 # ← Limit to 32 concurrent requests ``` -**Sizing guide**: - -- Small (< 50 convs): `target_concurrency: 32` -- Medium (50-500 convs): `target_concurrency: 64` -- Large (500+ convs): `target_concurrency: 96` or higher - --- ## Common Configurations @@ -194,7 +188,7 @@ settings: workers: 16 # More workers for parallel conversations ``` -### Long Conversations +### Conversations with Many Turns ```yaml multi_turn: @@ -211,6 +205,17 @@ settings: ## Troubleshooting +### Validate Your Dataset Before Running + +Use the bundled validation script to check your JSONL file for schema errors before benchmarking: + +```bash +python scripts/validate_jsonl_schema.py path/to/your/conversations.jsonl +``` + +This catches missing required fields, invalid role sequences, non-consecutive turn numbers, and +interleaved conversations β€” all errors that would otherwise surface at benchmark startup. + ### "Conversation has invalid role sequence" **Problem**: Your dataset doesn't follow a valid role sequence. diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index cabfac79..d2f21695 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -51,14 +51,24 @@ def _expand_tool_results(row: dict) -> list[dict]: row.get("turn"), ) return [] - return [ - { - "role": "tool", - "tool_call_id": result.get("tool_call_id"), - "content": result.get("content"), - } - for result in tool_results - ] + messages = [] + for i, result in enumerate(tool_results): + tool_call_id = result.get("tool_call_id") + content = result.get("content") + if tool_call_id is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'tool_call_id'" + ) + if content is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'content'" + ) + messages.append( + {"role": "tool", "tool_call_id": tool_call_id, "content": content} + ) + return messages class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): @@ -205,6 +215,7 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_messages_by_key: dict[tuple, list[dict]] = {} system_prompts_by_conv: dict[str, str | None] = {} + assert self.dataframe is not None, "Dataframe must be initialized" for conv_id, group in self._conv_groups.items(): sorted_group = group.sort_values("turn") client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] @@ -218,7 +229,7 @@ def _build_metadata(self) -> dict[str, Any]: break system_prompts_by_conv[str(conv_id)] = system_content - for idx, row in client_rows.iterrows(): + for _, row in client_rows.iterrows(): t_n = int(row["turn"]) messages: list[dict] = [] @@ -276,7 +287,6 @@ def _build_metadata(self) -> dict[str, Any]: samples.append( { - "index": idx, "conversation_id": str_conv_id, "turn": t_n, } diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 1ce3780e..42395723 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -144,6 +144,15 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: errors = [r for r in results if isinstance(r, BaseException)] for err in errors: logger.error(f"Conversation pipeline failed: {err}") + + if self._inflight: + logger.warning( + "%d query(ies) never received a response (session stop or transport failure): %s", + len(self._inflight), + list(self._inflight.keys()), + ) + self._inflight.clear() + if errors: raise errors[0] return phase_issuer.issued_count @@ -254,14 +263,21 @@ def on_sample_complete(self, result: QueryResult) -> None: response_text = result.get_response_output_string() - if result.error is not None: - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) - else: - self._conv_manager.mark_turn_complete( + try: + if result.error is not None: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + else: + self._conv_manager.mark_turn_complete( + conv_id, + response_text, + store_in_history=self._store_in_history, + metadata=result.metadata, + ) + except KeyError: + logger.warning( + "on_sample_complete: conversation %s not found in manager (result=%s)", conv_id, - response_text, - store_in_history=self._store_in_history, - metadata=result.metadata, + result.id, ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 85f208ca..ca531ed4 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -186,5 +186,5 @@ def decode_endpoint_response( "content" not in response_dict["choices"][0]["message"] or response_dict["choices"][0]["message"]["content"] is None ): - response_dict["choices"][0]["message"]["content"] = "None" + response_dict["choices"][0]["message"]["content"] = "" return CreateChatCompletionResponse(**response_dict, ignore_extra=True) diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index a93a3a08..62301940 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -211,7 +211,6 @@ def test_multi_turn_dataset_conversation_metadata(valid_multi_turn_jsonl): # Check sample metadata structure sample_meta = metadata["samples"][0] - assert "index" in sample_meta assert "conversation_id" in sample_meta assert "turn" in sample_meta From 0621eb89db17feedefd9ca88881c93ad1edaaec9 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 14:15:30 -0700 Subject: [PATCH 12/13] fix: address remaining PR #285 review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix refusal field set to literal string "None" instead of "" in openai_adapter.py β€” made downstream refusal checks incorrectly truthy - Add test_pipeline_error_propagated to verify execute() re-raises worker exceptions instead of swallowing them via gather(return_exceptions=True) - Clarify MultiTurnStrategy docstring and MULTI_TURN_QUICKSTART.md: target_concurrency = simultaneous conversations (not requests); each active conversation has exactly 1 in-flight turn at a time - Remove unjustified "Common Configurations" section from quickstart - Correct misleading "workers = concurrent conversations" tip; clarify client.workers and target_concurrency are independent layers Co-Authored-By: Claude Sonnet 4.6 --- docs/MULTI_TURN_QUICKSTART.md | 77 ++++--------------- .../load_generator/multi_turn_strategy.py | 9 ++- .../openai/openai_adapter.py | 2 +- .../test_multi_turn_strategy.py | 19 +++++ 4 files changed, 40 insertions(+), 67 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 4e8e5e58..f3d5e082 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -47,7 +47,7 @@ datasets: settings: load_pattern: type: multi_turn # ← Use multi-turn scheduler - target_concurrency: 32 # ← Required: max concurrent requests + target_concurrency: 32 # ← Required: max simultaneous conversations client: workers: 4 @@ -120,21 +120,18 @@ mode: independent **Behavior**: -- Issues turn-1 of ALL conversations at t=0 -- Then sequences turns within each conversation independently -- Maximum parallelism and throughput +- Up to `target_concurrency` conversations are active simultaneously +- Turns within each conversation are strictly sequenced (turn N+1 waits for turn N) +- Conversations run independently of each other β€” a short conversation can finish while a long one is still on turn 2 -**Use for**: Realistic production load where short conversations finish while long ones are still running. -For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. -Note: unlike the plain `ConcurrencyScheduler`, multi-turn + `target_concurrency: 1` still enforces -per-conversation turn ordering β€” turn N+1 waits for turn N even at concurrency 1. +**Use for**: Realistic production load simulation. For single-conversation debugging, set `target_concurrency: 1`. -**Example timeline**: +**Example timeline** (target_concurrency: 3, 4 conversations total): ``` -t=0: conv1-turn1, conv2-turn1, conv3-turn1 (all at once) +t=0: conv1-turn1, conv2-turn1, conv3-turn1 ← 3 conversations start t=0.5: conv1-turn2 (after conv1-turn1 completes) -t=0.7: conv2-turn2 (after conv2-turn1 completes) +t=0.7: conv2 finishes β†’ worker picks up conv4-turn1 t=0.8: conv1-turn3 (after conv1-turn2 completes) ... ``` @@ -143,62 +140,16 @@ t=0.8: conv1-turn3 (after conv1-turn2 completes) ## Concurrency Control -`target_concurrency` is **required** for the `multi_turn` load pattern. It limits the maximum number of in-flight requests across all conversations and prevents endpoint overload when many conversations run simultaneously. +`target_concurrency` is **required** for the `multi_turn` load pattern. It controls how many +conversations are active simultaneously. Each active conversation has exactly one in-flight turn +at a time β€” a worker issues turn N, waits for the response, then issues turn N+1. A new +conversation starts only after a worker finishes all turns of its current one. ```yaml settings: load_pattern: type: multi_turn - target_concurrency: 32 # ← Limit to 32 concurrent requests -``` - ---- - -## Common Configurations - -### Recommended: With Concurrency Control - -```yaml -multi_turn: - mode: independent - -settings: - load_pattern: - type: multi_turn - target_concurrency: 32 # ← Prevents overload - client: - workers: 8 - -datasets: - - samples: 100 -``` - -### High Throughput Testing - -```yaml -multi_turn: - mode: independent - turn_timeout_s: 600 - -settings: - load_pattern: - type: multi_turn - target_concurrency: 96 - client: - workers: 16 # More workers for parallel conversations -``` - -### Conversations with Many Turns - -```yaml -multi_turn: - mode: independent - turn_timeout_s: 1800 # 30 minutes for slow responses - -settings: - load_pattern: - type: multi_turn - target_concurrency: 32 + target_concurrency: 32 # ← 32 conversations active simultaneously ``` --- @@ -329,7 +280,7 @@ jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u ### Performance -- **Workers**: Set `workers` = number of concurrent conversations +- **Workers**: `client.workers` controls HTTP worker processes, independent of `target_concurrency`. The default (`-1`) auto-tunes based on NUMA topology. - **Timeout**: Set `turn_timeout_s` = 2x your longest expected turn latency - **Memory**: ~1KB per turn, plan accordingly for large datasets diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 42395723..4297863b 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -35,9 +35,12 @@ class MultiTurnStrategy: """Async multi-turn strategy. Uses a worker-pool to limit active conversations. N worker tasks pull from a queue of conversations. Each worker processes all - turns of one conversation before moving to the next, so at most N conversations - are active simultaneously. When target_concurrency is None, all conversations - run concurrently (one worker per conversation). + turns of one conversation before moving to the next. At most N conversations + are active simultaneously, each with exactly 1 in-flight turn β€” a worker + issues turn N, waits for the response, then issues turn N+1. A new conversation + starts only after the worker finishes all turns of its current one. When + target_concurrency is None, all conversations run concurrently (one worker per + conversation). Integration with BenchmarkSession: - execute(): populates queue, spawns workers, awaits all to complete diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index ca531ed4..a458688c 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -180,7 +180,7 @@ def decode_endpoint_response( response_dict = msgspec.json.decode(response_bytes) # Set default values for optional fields if missing - response_dict["choices"][0]["message"]["refusal"] = "None" + response_dict["choices"][0]["message"]["refusal"] = "" response_dict["choices"][0]["logprobs"] = {"content": [], "refusal": []} if ( "content" not in response_dict["choices"][0]["message"] diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 37cdfbad..7a9d8be1 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -363,6 +363,25 @@ async def complete_turn(): assert len(state.message_history) == 0 +@pytest.mark.unit +@pytest.mark.asyncio +async def test_pipeline_error_propagated(): + """execute() re-raises when a conversation pipeline raises an exception.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + class ErrorIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue(self, idx: int, data_override: dict | None = None) -> str | None: + raise RuntimeError("simulated pipeline error") + + with pytest.raises(RuntimeError, match="simulated pipeline error"): + await strategy.execute(ErrorIssuer()) + + @pytest.mark.unit def test_mark_turn_complete_preserves_tool_calls(): """mark_turn_complete stores tool_calls in history when metadata contains them.""" From 9c7dcda78f81524c06c545980a3e49a0605b4f73 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 15:06:28 -0700 Subject: [PATCH 13/13] refactor: replace worker-pool with event-driven model in MultiTurnStrategy Rewrites MultiTurnStrategy to issue subsequent turns synchronously inside on_sample_complete() (zero event-loop delay), removing pre-spawned worker tasks and per-conversation asyncio.Event waiting. ConversationState no longer holds an asyncio.Event; sequencing is driven entirely by the strategy. Addresses PR #285 reviewer request to move turn issuance into the sample-complete handler. Co-Authored-By: Claude Sonnet 4.6 --- .../load_generator/conversation_manager.py | 22 +- .../load_generator/multi_turn_strategy.py | 272 ++++++++++-------- .../test_multi_turn_conversation_manager.py | 72 +---- .../test_multi_turn_strategy.py | 45 +-- 4 files changed, 181 insertions(+), 230 deletions(-) diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 30276d5e..1b0834bb 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -15,7 +15,6 @@ """Conversation state management for multi-turn benchmarking.""" -import asyncio import logging from dataclasses import dataclass, field from typing import Any @@ -27,13 +26,8 @@ class ConversationState: """Per-conversation state for multi-turn benchmarking. - The pipeline task awaits ``turn_done`` between turns; ``mark_turn_complete`` - and ``mark_turn_failed`` set it synchronously from ``on_sample_complete``. - Attributes: conversation_id: Unique identifier for this conversation. - turn_done: Event set when a response arrives. Pipeline waits, then clears - it before issuing the next turn. message_history: Accumulated message list (populated only when use_dataset_history=False; empty otherwise). completed_turns: Turns with responses (success or failure) β€” observability only. @@ -42,8 +36,6 @@ class ConversationState: """ conversation_id: str - # Python 3.12+: asyncio.Event no longer requires a running loop at construction. - turn_done: asyncio.Event = field(default_factory=asyncio.Event) message_history: list[dict[str, Any]] = field(default_factory=list) completed_turns: int = 0 failed_turns: int = 0 @@ -59,11 +51,11 @@ def is_complete(self) -> bool: class ConversationManager: """Manages per-conversation state for multi-turn benchmarking. - All methods are synchronous. The pipeline task uses ``ConversationState.turn_done`` - directly for turn-done notification β€” no locks or condition variables needed. + All methods are synchronous. Turn sequencing is driven by MultiTurnStrategy + which calls on_sample_complete() β†’ _issue_next_turn() directly. - All states are pre-created by ``MultiTurnStrategy.execute()`` before any pipeline - task starts, so ``get_or_create()`` requires no locking. + All states are pre-created by MultiTurnStrategy.execute() before any turns + are issued, so get_or_create() requires no locking. """ def __init__(self): @@ -126,7 +118,7 @@ def mark_turn_complete( store_in_history: bool = False, metadata: dict[str, Any] | None = None, ) -> None: - """Record a successful response and wake the pipeline task. + """Record a successful response. Args: conversation_id: Conversation ID. @@ -150,14 +142,13 @@ def mark_turn_complete( state.message_history.append(msg) state.completed_turns += 1 self._log_if_complete(state, conversation_id) - state.turn_done.set() def mark_turn_failed( self, conversation_id: str, store_in_history: bool = False, ) -> None: - """Record a failed response and wake the pipeline task. + """Record a failed response. Failed turns count toward completion so sequencing progresses under errors. @@ -179,4 +170,3 @@ def mark_turn_failed( state.failed_turns += 1 logger.warning(f"Turn failed for conversation {conversation_id}") self._log_if_complete(state, conversation_id) - state.turn_done.set() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 4297863b..d3f432d7 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -17,7 +17,8 @@ import asyncio import logging -from collections import defaultdict +from collections import defaultdict, deque +from collections.abc import Iterator from typing import Any from ..config.schema import MultiTurnConfig @@ -32,29 +33,28 @@ class MultiTurnStrategy: - """Async multi-turn strategy. Uses a worker-pool to limit active conversations. + """Event-driven multi-turn strategy. Completion of each turn triggers the next. - N worker tasks pull from a queue of conversations. Each worker processes all - turns of one conversation before moving to the next. At most N conversations - are active simultaneously, each with exactly 1 in-flight turn β€” a worker - issues turn N, waits for the response, then issues turn N+1. A new conversation - starts only after the worker finishes all turns of its current one. When - target_concurrency is None, all conversations run concurrently (one worker per - conversation). + execute() seeds the first N conversations (issues turn 1 for each), then + awaits _all_done. on_sample_complete() is called synchronously from the + receive coroutine for each response β€” it issues the next turn immediately + (zero event-loop iterations between response and next issuance), or starts + a new conversation when the current one finishes all turns. + + At most target_concurrency conversations are active simultaneously. When + target_concurrency is None, all conversations start at once. Integration with BenchmarkSession: - - execute(): populates queue, spawns workers, awaits all to complete + - execute(): seeds conversations, awaits completion - on_query_complete(): no-op (required by LoadStrategy protocol) - - on_sample_complete(): routes completed QueryResult to ConversationManager + - on_sample_complete(): routes completed QueryResult, issues next turn The response routing path: - 1. _conv_pipeline issues turn N via phase_issuer.issue(idx) β†’ query_id - 2. _conv_pipeline stores conv_id in _inflight[query_id] + 1. _issue_next_turn issues turn N via phase_issuer.issue(idx) β†’ query_id + 2. _issue_next_turn stores conv_id in _inflight[query_id] 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete - 5. mark_turn_complete sets state.turn_done synchronously - 6. _conv_pipeline's await asyncio.wait_for(state.turn_done.wait()) returns - 7. Pipeline clears the event and issues turn N+1 + 5. on_sample_complete calls _issue_next_turn for turn N+1 (synchronously) """ def __init__( @@ -93,6 +93,15 @@ def __init__( # Cached ConversationState refs for O(1) lookup in on_sample_complete. self._conv_states: dict[str, ConversationState] = {} + # Event-driven state β€” populated in execute(). + self._pending_convs: deque[tuple[str, list[tuple[int, int]]]] = deque() + self._active_iters: dict[str, Iterator[tuple[int, int]]] = {} + self._timeout_handles: dict[str, asyncio.TimerHandle] = {} + self._error: BaseException | None = None + self._all_done: asyncio.Event | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._phase_issuer: PhaseIssuerProtocol | None = None + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: """Drive multi-turn sample issuance. @@ -102,12 +111,17 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: Returns: Total count of samples issued. """ + self._phase_issuer = phase_issuer + self._loop = asyncio.get_running_loop() + self._all_done = asyncio.Event() + self._error = None + conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) for sample_index, sample_meta in enumerate(self._dataset_metadata["samples"]): conv_id = sample_meta["conversation_id"] conv_samples[conv_id].append((sample_index, sample_meta["turn"])) - # Pre-create all conversation states before spawning workers (no locking needed). + # Pre-create all conversation states before issuing any turns (no locking needed). sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) for conv_id, turns in conv_samples.items(): sys_content = sys_prompts.get(conv_id) if self._store_in_history else None @@ -123,30 +137,26 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: ) self._conv_states[conv_id] = state - # Build queue of (conv_id, turns) pairs for workers to pull from. - conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]] = asyncio.Queue() + # Build pending queue (sorted turns per conversation). for conv_id, turns in conv_samples.items(): - await conv_queue.put((conv_id, turns)) + self._pending_convs.append((conv_id, sorted(turns, key=lambda x: x[1]))) - n_conversations = len(conv_samples) - n_workers = ( - min(self._target_concurrency, n_conversations) + n_to_start = ( + min(self._target_concurrency, len(self._pending_convs)) if self._target_concurrency is not None and self._target_concurrency > 0 - else n_conversations + else len(self._pending_convs) ) + for _ in range(n_to_start): + self._start_conversation() - worker_tasks = [ - asyncio.create_task( - self._worker(conv_queue, phase_issuer), - name=f"mt-worker-{i}", - ) - for i in range(n_workers) - ] + if not self._active_iters and not self._inflight: + return phase_issuer.issued_count + + await self._all_done.wait() - results = await asyncio.gather(*worker_tasks, return_exceptions=True) - errors = [r for r in results if isinstance(r, BaseException)] - for err in errors: - logger.error(f"Conversation pipeline failed: {err}") + for handle in self._timeout_handles.values(): + handle.cancel() + self._timeout_handles.clear() if self._inflight: logger.warning( @@ -156,102 +166,111 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: ) self._inflight.clear() - if errors: - raise errors[0] + if self._error is not None: + raise self._error return phase_issuer.issued_count - async def _worker( - self, - conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]], - phase_issuer: PhaseIssuerProtocol, - ) -> None: - """Pull conversations from queue and process each one fully before taking the next.""" - while True: - try: - conv_id, turns = conv_queue.get_nowait() - except asyncio.QueueEmpty: - break - await self._conv_pipeline(conv_id, turns, phase_issuer) - - async def _conv_pipeline( - self, - conv_id: str, - turns: list[tuple[int, int]], - phase_issuer: PhaseIssuerProtocol, - ) -> None: - """Process all turns for a single conversation sequentially. - - For each turn after the first, waits for state.turn_done before issuing - the next. This enforces strict sequential ordering within the conversation. - """ + def _start_conversation(self) -> None: + """Pop the next conversation from the pending queue and issue its first turn.""" + conv_id, turns = self._pending_convs.popleft() + self._active_iters[conv_id] = iter(turns) + self._issue_next_turn(conv_id) + + def _issue_next_turn(self, conv_id: str) -> None: + """Issue the next turn for conv_id, or mark the conversation done.""" + it = self._active_iters.get(conv_id) + if it is None: + return + + pair = next(it, None) + if pair is None: + del self._active_iters[conv_id] + self._fill_slot() + return + + idx, turn = pair state = self._conv_states[conv_id] - sorted_turns = sorted(turns, key=lambda x: x[1]) - last_query_id: str | None = None - - for i, (idx, turn) in enumerate(sorted_turns): - if i > 0: - try: - await asyncio.wait_for( - state.turn_done.wait(), timeout=self._turn_timeout_s - ) - except TimeoutError: + + data_override: dict[str, Any] | None = None + current_turn_messages: list[dict[str, Any]] | None = None + if self._store_in_history: + current_turn_messages = self._dataset_metadata.get( + "current_turn_messages_by_key", {} + ).get((conv_id, turn)) + if current_turn_messages: + has_tool_msg = any( + m.get("role") == "tool" for m in current_turn_messages + ) + if has_tool_msg: logger.warning( - f"Turn {turn} of {conv_id} timed out waiting for previous turn" + "Live-history mode with tool messages uses dataset " + "tool_call_ids; real endpoint IDs will differ " + "(conv=%s, turn=%d)", + conv_id, + turn, ) - if last_query_id is not None: - self._inflight.pop(last_query_id, None) - remaining = len(sorted_turns) - i - for _ in range(remaining): - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) - break - state.turn_done.clear() - - # Live-history mode: build messages from accumulated history + current turn. - data_override: dict[str, Any] | None = None - current_turn_messages: list[dict[str, Any]] | None = None - if self._store_in_history: - current_turn_messages = self._dataset_metadata.get( - "current_turn_messages_by_key", {} - ).get((conv_id, turn)) - if current_turn_messages: - has_tool_msg = any( - m.get("role") == "tool" for m in current_turn_messages - ) - if has_tool_msg: - logger.warning( - "Live-history mode with tool messages uses dataset " - "tool_call_ids; real endpoint IDs will differ " - "(conv=%s, turn=%d)", - conv_id, - turn, - ) - live_messages = state.message_history.copy() + current_turn_messages - data_override = {"messages": live_messages} - - query_id = phase_issuer.issue(idx, data_override=data_override) - if query_id is None: - # Session stopping β€” exit pipeline. - break - - self._inflight[query_id] = conv_id - last_query_id = query_id - - # Append current-turn messages to history so the next turn sees them. - if self._store_in_history and current_turn_messages: - state.message_history.extend(current_turn_messages) + live_messages = state.message_history.copy() + current_turn_messages + data_override = {"messages": live_messages} + + assert self._phase_issuer is not None + query_id = self._phase_issuer.issue(idx, data_override=data_override) + if query_id is None: + # Session stopping β€” signal done. + assert self._all_done is not None + self._all_done.set() + return + + self._inflight[query_id] = conv_id + + if self._store_in_history and current_turn_messages: + state.message_history.extend(current_turn_messages) + + assert self._loop is not None + handle = self._loop.call_later( + self._turn_timeout_s, self._handle_timeout, query_id, conv_id + ) + self._timeout_handles[query_id] = handle + + def _fill_slot(self) -> None: + """Start a new conversation from the pending queue, or signal all done.""" + if self._pending_convs: + self._start_conversation() + elif not self._active_iters: + assert self._all_done is not None + self._all_done.set() + + def _handle_timeout(self, query_id: str, conv_id: str) -> None: + """Called by the event loop when a turn response does not arrive in time.""" + if self._inflight.pop(query_id, None) is None: + return + self._timeout_handles.pop(query_id, None) + + logger.warning( + "Turn timed out for conversation %s (query=%s)", conv_id, query_id + ) + + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + it = self._active_iters.pop(conv_id, None) + if it is not None: + for _ in it: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + + self._fill_slot() def on_query_complete(self, query_id: str) -> None: """No-op. Required by LoadStrategy protocol; called by BenchmarkSession.""" pass def on_sample_complete(self, result: QueryResult) -> None: - """Route completed QueryResult to ConversationManager. + """Route completed QueryResult to ConversationManager and issue next turn. - Called by execute.py on_sample_complete hook after each response. - Event.set() is synchronous β€” the pipeline task is woken immediately - without needing asyncio.ensure_future. + Called synchronously from BenchmarkSession._handle_response(). Issues the + next turn immediately (zero event-loop delay) or starts a new conversation + when this one finishes all turns. Args: result: Completed QueryResult from the endpoint. @@ -260,9 +279,9 @@ def on_sample_complete(self, result: QueryResult) -> None: if conv_id is None: return - if self._conv_manager.get_state(conv_id) is None: - logger.warning(f"on_sample_complete: unknown conversation {conv_id}") - return + handle = self._timeout_handles.pop(result.id, None) + if handle is not None: + handle.cancel() response_text = result.get_response_output_string() @@ -284,3 +303,12 @@ def on_sample_complete(self, result: QueryResult) -> None: conv_id, result.id, ) + return + + try: + self._issue_next_turn(conv_id) + except Exception as exc: + logger.error("Error issuing next turn for %s: %s", conv_id, exc) + self._error = exc + if self._all_done is not None: + self._all_done.set() diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py index 331e6709..c389fb5f 100644 --- a/tests/unit/load_generator/test_multi_turn_conversation_manager.py +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -29,7 +29,6 @@ def test_conversation_state_initialization(): state = ConversationState(conversation_id="conv_001") assert state.conversation_id == "conv_001" - assert not state.turn_done.is_set() assert state.message_history == [] assert state.completed_turns == 0 assert state.failed_turns == 0 @@ -95,7 +94,7 @@ def test_conversation_manager_multiple_conversations(): @pytest.mark.unit def test_conversation_manager_mark_turn_complete(): - """mark_turn_complete increments counter, appends history, sets event.""" + """mark_turn_complete increments counter and appends history.""" manager = ConversationManager() state = manager.get_or_create("conv_001") @@ -103,7 +102,6 @@ def test_conversation_manager_mark_turn_complete(): assert state.completed_turns == 1 assert state.failed_turns == 0 - assert state.turn_done.is_set() assert state.message_history == [] # store_in_history=False by default @@ -120,7 +118,7 @@ def test_conversation_manager_mark_turn_complete_stores_history(): @pytest.mark.unit def test_conversation_manager_mark_turn_failed(): - """mark_turn_failed increments both counters and sets event.""" + """mark_turn_failed increments both counters.""" manager = ConversationManager() state = manager.get_or_create("conv_001", expected_client_turns=2) @@ -128,7 +126,6 @@ def test_conversation_manager_mark_turn_failed(): assert state.completed_turns == 1 assert state.failed_turns == 1 - assert state.turn_done.is_set() @pytest.mark.unit @@ -187,71 +184,6 @@ def test_all_turns_fail(): assert state.failed_turns == 2 -@pytest.mark.unit -@pytest.mark.asyncio -async def test_event_set_wakes_waiter(): - """mark_turn_complete sets turn_done so a blocked await returns.""" - manager = ConversationManager() - state = manager.get_or_create("conv_001") - - woke_up: list[bool] = [] - - async def waiter(): - await state.turn_done.wait() - woke_up.append(True) - - task = asyncio.create_task(waiter()) - await asyncio.sleep(0.01) - assert not woke_up - - manager.mark_turn_complete("conv_001", "response") - await asyncio.sleep(0.01) - await task - - assert woke_up == [True] - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_failed_sets_event(): - """mark_turn_failed sets turn_done so the pipeline can unblock.""" - manager = ConversationManager() - state = manager.get_or_create("conv_001") - - woke_up: list[bool] = [] - - async def waiter(): - await state.turn_done.wait() - woke_up.append(True) - - task = asyncio.create_task(waiter()) - await asyncio.sleep(0.01) - - manager.mark_turn_failed("conv_001") - await asyncio.sleep(0.01) - await task - - assert woke_up == [True] - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_event_clear_resets_for_next_turn(): - """Clearing turn_done after wait() properly gates the next turn.""" - manager = ConversationManager() - state = manager.get_or_create("conv_001") - - # First turn: set then clear - manager.mark_turn_complete("conv_001", "r1") - await state.turn_done.wait() - state.turn_done.clear() - assert not state.turn_done.is_set() - - # Second turn: set again - manager.mark_turn_complete("conv_001", "r2") - assert state.turn_done.is_set() - - @pytest.mark.unit @pytest.mark.asyncio async def test_conversation_manager_concurrent_access(): diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 7a9d8be1..d3c9a22a 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -68,14 +68,12 @@ async def test_single_conversation_single_turn(): strategy = MultiTurnStrategy(conv_manager, metadata) issuer = FakePhaseIssuer() - # Simulate response completion (turn 1 is issued, then completes) async def complete_turns(): - # Wait a tick for the strategy to issue the first turn await asyncio.sleep(0.01) - # Mark turn 1 complete - state = conv_manager.get_state("conv1") - if state: - conv_manager.mark_turn_complete("conv1", "response 1") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response 1") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turns()) count = await strategy.execute(issuer) @@ -107,7 +105,6 @@ def tracked_issue(idx, data_override=None): async def simulate_responses(): await asyncio.sleep(0.01) for turn_q, resp in [("q0000", "r1"), ("q0001", "r2"), ("q0002", "r3")]: - # Signal turn complete via on_sample_complete result = QueryResult( id=turn_q, response_output=TextModelOutput(output=resp) ) @@ -132,7 +129,6 @@ async def test_multiple_conversations_concurrent(): async def simulate_responses(): await asyncio.sleep(0.02) - # Complete all turns for both conversations for q_prefix in range(4): q = f"q{q_prefix:04d}" result = QueryResult(id=q, response_output=TextModelOutput(output="resp")) @@ -174,12 +170,10 @@ async def simulate_responses(): import time await asyncio.sleep(0.02) - # Complete turn 1 (sample 0) after a delay complete_timestamps[0] = time.monotonic() result = QueryResult(id="q0000", response_output=TextModelOutput(output="r1")) strategy.on_sample_complete(result) await asyncio.sleep(0.05) - # Complete turn 2 (sample 1) complete_timestamps[1] = time.monotonic() result = QueryResult(id="q0001", response_output=TextModelOutput(output="r2")) strategy.on_sample_complete(result) @@ -227,7 +221,6 @@ async def test_on_sample_complete_routes_to_manager(): state = conv_manager.get_state("conv1") assert state is not None assert state.completed_turns == 1 - assert state.turn_done.is_set() assert state.is_complete() @@ -296,7 +289,10 @@ async def test_live_history_initializes_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - conv_manager.mark_turn_complete("conv1", "response") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -325,7 +321,10 @@ async def test_live_history_no_system_prompt_when_none(): async def complete_turn(): await asyncio.sleep(0.01) - conv_manager.mark_turn_complete("conv1", "response") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -352,7 +351,10 @@ async def test_dataset_history_mode_does_not_inject_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - conv_manager.mark_turn_complete("conv1", "response") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -366,7 +368,7 @@ async def complete_turn(): @pytest.mark.unit @pytest.mark.asyncio async def test_pipeline_error_propagated(): - """execute() re-raises when a conversation pipeline raises an exception.""" + """execute() re-raises when _issue_next_turn raises an exception.""" conv_manager = ConversationManager() metadata = _make_dataset_metadata({"conv1": [1]}) strategy = MultiTurnStrategy(conv_manager, metadata) @@ -492,9 +494,9 @@ async def test_on_sample_complete_passes_metadata(): async def test_concurrency_limits_active_conversations(): """target_concurrency=2 starts at most 2 conversation pipelines simultaneously. - Uses 2-turn conversations so each pipeline has an await point (turn_done.wait - between turns). With 4 conversations and 2 workers, the 3rd and 4th conversations - cannot start until a worker finishes its current conversation. + Uses 2-turn conversations so each pipeline has an await point between turns. + With 4 conversations and 2 workers, the 3rd and 4th conversations cannot start + until a worker finishes its current conversation. """ conv_manager = ConversationManager() # 4 two-turn conversations; pipeline awaits turn-1 response before issuing turn-2 @@ -519,7 +521,7 @@ async def auto_respond(): responder_task = asyncio.create_task(auto_respond()) execute_task = asyncio.create_task(strategy.execute(issuer)) - # Let both workers start and block on turn_done.wait before auto_respond fires + # Let both seed turns get issued before auto_respond fires await asyncio.sleep(0.01) # Only 2 workers β†’ exactly 2 turn-1 queries issued (conv3/conv4 not started yet) @@ -536,8 +538,7 @@ async def auto_respond(): async def test_conversation_slot_reuse(): """With target_concurrency=1, worker completes conv1 before starting conv2. - Uses 2-turn conversations so the pipeline has an await between turns. - The single worker must process both turns of conv1 before conv2's turn 1 is issued. + The single slot must process both turns of conv1 before conv2's turn 1 is issued. """ conv_manager = ConversationManager() # 2 two-turn conversations; sample indices: conv1β†’[0,1], conv2β†’[2,3] @@ -561,6 +562,6 @@ async def auto_respond(): await strategy.execute(issuer) responder_task.cancel() - # Single worker: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) + # Single slot: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) assert issuer.issued[:2] == [0, 1], "Conv1 turns should be issued before conv2" assert issuer.issued[2:] == [2, 3], "Conv2 turns should follow conv1"