diff --git a/bot/eval/locomo/import_to_ov.py b/bot/eval/locomo/import_to_ov.py index 7a531d04a..d85618bd8 100644 --- a/bot/eval/locomo/import_to_ov.py +++ b/bot/eval/locomo/import_to_ov.py @@ -17,12 +17,12 @@ import json import sys import time -from datetime import datetime +import traceback +from datetime import datetime, timedelta from pathlib import Path from typing import List, Dict, Any, Tuple, Optional import openviking as ov -from openviking.message.part import TextPart def _get_session_number(session_key: str) -> int: @@ -58,32 +58,6 @@ def parse_test_file(path: str) -> List[Dict[str, Any]]: return sessions -def format_locomo_message(msg: Dict[str, Any]) -> str: - """Format a single LoCoMo message into a natural chat-style string. - - Output format: - Speaker: text here - image_url: caption - """ - speaker = msg.get("speaker", "unknown") - text = msg.get("text", "") - line = f"{speaker}: {text}" - - img_urls = msg.get("img_url", []) - if isinstance(img_urls, str): - img_urls = [img_urls] - blip = msg.get("blip_caption", "") - - if img_urls: - for url in img_urls: - caption = f": {blip}" if blip else "" - line += f"\n{url}{caption}" - elif blip: - line += f"\n({blip})" - - return line - - def load_locomo_data( path: str, sample_index: Optional[int] = None, @@ -103,9 +77,10 @@ def build_session_messages( item: Dict[str, Any], session_range: Optional[Tuple[int, int]] = None, ) -> List[Dict[str, Any]]: - """Build bundled session messages for one LoCoMo sample. + """Build session messages for one LoCoMo sample. - Returns list of dicts with keys: message, meta. + Returns list of dicts with keys: messages, meta. + Each dict represents a session with multiple messages (user/assistant role). """ conv = item["conversation"] speakers = f"{conv['speaker_a']} & {conv['speaker_b']}" @@ -126,13 +101,20 @@ def build_session_messages( dt_key = f"{sk}_date_time" date_time = conv.get(dt_key, "") - parts = [f"[group chat conversation: {date_time}]"] - for msg in conv[sk]: - parts.append(format_locomo_message(msg)) - combined = "\n\n".join(parts) + # Extract messages with all as user role, including speaker in content + messages = [] + for idx, msg in enumerate(conv[sk]): + speaker = msg.get("speaker", "unknown") + text = msg.get("text", "") + messages.append({ + "role": "user", + "text": f"[{speaker}]: {text}", + "speaker": speaker, + "index": idx + }) sessions.append({ - "message": combined, + "messages": messages, "meta": { "sample_id": item["sample_id"], "session_key": sk, @@ -148,7 +130,7 @@ def build_session_messages( # Ingest record helpers (avoid duplicate ingestion) # --------------------------------------------------------------------------- -def load_success_csv(csv_path: str = "import_success.csv") -> set: +def load_success_csv(csv_path: str = "./result/import_success.csv") -> set: """加载成功导入的CSV记录,返回已成功的键集合""" success_keys = set() if Path(csv_path).exists(): @@ -160,7 +142,7 @@ def load_success_csv(csv_path: str = "import_success.csv") -> set: return success_keys -def write_success_record(record: Dict[str, Any], csv_path: str = "import_success.csv") -> None: +def write_success_record(record: Dict[str, Any], csv_path: str = "./result/import_success.csv") -> None: """写入成功记录到CSV文件""" file_exists = Path(csv_path).exists() fieldnames = ["timestamp", "sample_id", "session", "date_time", "speakers", @@ -186,7 +168,7 @@ def write_success_record(record: Dict[str, Any], csv_path: str = "import_success }) -def write_error_record(record: Dict[str, Any], error_path: str = "import_errors.log") -> None: +def write_error_record(record: Dict[str, Any], error_path: str = "./result/import_errors.log") -> None: """写入错误记录到日志文件""" with open(error_path, "a", encoding="utf-8") as f: timestamp = record["timestamp"] @@ -242,22 +224,42 @@ def mark_ingested( # --------------------------------------------------------------------------- # OpenViking import # --------------------------------------------------------------------------- -def _parse_token_usage(token_data: Dict[str, Any]) -> Dict[str, int]: - """解析Token使用数据(仅支持新版token_usage格式)""" - usage = token_data["token_usage"] +def _parse_token_usage(commit_result: Dict[str, Any]) -> Dict[str, int]: + """解析Token使用数据(从commit返回的telemetry中提取)""" + telemetry = commit_result.get("telemetry", {}).get("summary", {}) + tokens = telemetry.get("tokens", {}) return { - "embedding": usage["embedding"]["total_tokens"], - "vlm": usage["llm"]["total_tokens"], - "llm_input": usage["llm"]["prompt_tokens"], - "llm_output": usage["llm"]["completion_tokens"], - "total": usage["total"]["total_tokens"] + "embedding": tokens.get("embedding", {}).get("total", 0), + "vlm": tokens.get("llm", {}).get("total", 0), + "llm_input": tokens.get("llm", {}).get("input", 0), + "llm_output": tokens.get("llm", {}).get("output", 0), + "total": tokens.get("total", 0) } -async def viking_ingest(msg: str, openviking_url: str, semaphore: asyncio.Semaphore) -> Dict[str, int]: - """Save a message to OpenViking via OpenViking SDK client. +async def viking_ingest( + messages: List[Dict[str, Any]], + openviking_url: str, + semaphore: asyncio.Semaphore, + session_time: Optional[str] = None +) -> Dict[str, int]: + """Save messages to OpenViking via OpenViking SDK client. Returns token usage dict with embedding and vlm token counts. + + Args: + messages: List of message dicts with role and text + openviking_url: OpenViking service URL + semaphore: Async semaphore for concurrency control + session_time: Session time string (e.g., "9:36 am on 2 April, 2023") """ + # 解析 session_time - 为每条消息计算递增的时间戳 + base_datetime = None + if session_time: + try: + base_datetime = datetime.strptime(session_time, "%I:%M %p on %d %B, %Y") + except ValueError: + print(f"Warning: Failed to parse session_time: {session_time}", file=sys.stderr) + # 使用信号量控制并发 async with semaphore: # Create client @@ -268,38 +270,30 @@ async def viking_ingest(msg: str, openviking_url: str, semaphore: asyncio.Semaph # Create session create_res = await client.create_session() session_id = create_res["session_id"] - session = client.session(session_id) - # Add message - await session.add_message( - role="user", - parts=[TextPart(text=msg)] - ) + # Add messages one by one with created_at + for idx, msg in enumerate(messages): + msg_created_at = None + if base_datetime: + # 每条消息递增1秒,确保时间顺序 + msg_dt = base_datetime + timedelta(seconds=idx) + msg_created_at = msg_dt.isoformat() + + await client.add_message( + session_id=session_id, + role=msg["role"], + parts=[{"type": "text", "text": msg["text"]}], + created_at=msg_created_at + ) # Commit - result = await session.commit(telemetry=True) + result = await client.commit_session(session_id, telemetry=True) - if not (result.get("status") == "accepted" and result.get("task_id")): + if result.get("status") != "committed": raise RuntimeError(f"Commit failed: {result}") - # 轮询等待异步任务完成 - task_id = result["task_id"] - max_wait = 1200 # 最多等待20分钟 - waited = 0 - - while waited < max_wait: - task = await client.get_task(task_id) - if task["status"] == "completed": - token_usage = _parse_token_usage(task["result"]) - break - elif task["status"] == "failed": - raise RuntimeError(f"Commit failed: {task.get('error', 'Unknown error')}") - - # 指数退避策略,避免频繁请求 - await asyncio.sleep(min(1 << (waited // 10), 60)) - waited += 1 - else: - raise RuntimeError(f"Commit timed out after {max_wait} seconds") + # 直接从commit结果中提取token使用情况 + token_usage = _parse_token_usage(result) return token_usage @@ -307,10 +301,10 @@ async def viking_ingest(msg: str, openviking_url: str, semaphore: asyncio.Semaph await client.close() -def sync_viking_ingest(msg: str, openviking_url: str) -> Dict[str, int]: +def sync_viking_ingest(messages: List[Dict[str, Any]], openviking_url: str, session_time: Optional[str] = None) -> Dict[str, int]: """Synchronous wrapper for viking_ingest to maintain existing API.""" semaphore = asyncio.Semaphore(1) # 同步调用时使用信号量为1 - return asyncio.run(viking_ingest(msg, openviking_url, semaphore)) + return asyncio.run(viking_ingest(messages, openviking_url, semaphore, session_time)) # --------------------------------------------------------------------------- @@ -327,7 +321,7 @@ def parse_session_range(s: str) -> Tuple[int, int]: async def process_single_session( - msg: str, + messages: List[Dict[str, Any]], sample_id: str | int, session_key: str, meta: Dict[str, Any], @@ -338,7 +332,7 @@ async def process_single_session( ) -> Dict[str, Any]: """处理单个会话的导入任务""" try: - token_usage = await viking_ingest(msg, args.openviking_url, semaphore) + token_usage = await viking_ingest(messages, args.openviking_url, semaphore, meta.get("date_time")) print(f" -> [SUCCESS] [{sample_id}/{session_key}] imported to OpenViking", file=sys.stderr) # Extract token counts @@ -369,6 +363,7 @@ async def process_single_session( except Exception as e: print(f" -> [ERROR] [{sample_id}/{session_key}] {e}", file=sys.stderr) + traceback.print_exc(file=sys.stderr) # Write error record result = { @@ -428,7 +423,7 @@ async def run_import(args: argparse.Namespace) -> None: for sess in sessions: meta = sess["meta"] - msg = sess["message"] + messages = sess["messages"] session_key = meta["session_key"] label = f"{session_key} ({meta['date_time']})" @@ -438,13 +433,14 @@ async def run_import(args: argparse.Namespace) -> None: skipped_count += 1 continue - preview = msg.replace("\n", " | ")[:80] - print(f" [{label}] {preview}...", file=sys.stderr) + # Preview messages + preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]]) + print(f" [{label}] {preview}", file=sys.stderr) # 创建异步任务 task = asyncio.create_task( process_single_session( - msg=msg, + messages=messages, sample_id=sample_id, session_key=session_key, meta=meta, @@ -471,14 +467,23 @@ async def run_import(args: argparse.Namespace) -> None: skipped_count += 1 continue - combined_msg = "\n\n".join(session["messages"]) - preview = combined_msg.replace("\n", " | ")[:80] - print(f" {preview}...", file=sys.stderr) + # For plain text, all messages as user role + messages = [] + for i, text in enumerate(session["messages"]): + messages.append({ + "role": "user", + "text": text.strip(), + "speaker": "user", + "index": i + }) + + preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]]) + print(f" {preview}", file=sys.stderr) # 创建异步任务 task = asyncio.create_task( process_single_session( - msg=combined_msg, + messages=messages, sample_id="txt", session_key=session_key, meta={"session_index": idx}, @@ -499,6 +504,8 @@ async def run_import(args: argparse.Namespace) -> None: if isinstance(result, Exception): error_count += 1 print(f"[UNEXPECTED ERROR] Task failed with exception: {result}", file=sys.stderr) + if hasattr(result, '__traceback__'): + traceback.print_exception(type(result), result, result.__traceback__, file=sys.stderr) continue if result["status"] == "success":