From a36c808e31934dea23857ea9648dd99b8ac5bb8c Mon Sep 17 00:00:00 2001 From: DuTao Date: Tue, 31 Mar 2026 14:27:18 +0800 Subject: [PATCH 1/2] import eval --- bot/eval/locomo/import_to_ov.py | 511 +++++++++++++++++++------------- 1 file changed, 308 insertions(+), 203 deletions(-) diff --git a/bot/eval/locomo/import_to_ov.py b/bot/eval/locomo/import_to_ov.py index c4492336..0f9db7e5 100644 --- a/bot/eval/locomo/import_to_ov.py +++ b/bot/eval/locomo/import_to_ov.py @@ -12,12 +12,21 @@ """ import argparse +import asyncio +import csv import json +import subprocess import sys import time from datetime import datetime +from pathlib import Path +from typing import List, Dict, Any import openviking as ov +from openviking.message.part import TextPart + +# 全局信号量用于控制并发 +semaphore: asyncio.Semaphore = None def parse_test_file(path: str) -> list[dict]: @@ -48,17 +57,30 @@ def parse_test_file(path: str) -> list[dict]: return sessions -def format_locomo_message(msg: dict, index: int | None = None) -> str: - """Format a single LoCoMo message into chat-style string. +def format_locomo_message(msg: dict) -> str: + """Format a single LoCoMo message into a natural chat-style string. Output format: - [index][Speaker]: text here + Speaker: text here + image_url: caption """ speaker = msg.get("speaker", "unknown") text = msg.get("text", "") - if index is not None: - return f"[{index}][{speaker}]: {text}" - return f"[{speaker}]: {text}" + line = f"{speaker}: {text}" + + img_urls = msg.get("img_url", []) + if isinstance(img_urls, str): + img_urls = [img_urls] + blip = msg.get("blip_caption", "") + + if img_urls: + for url in img_urls: + caption = f": {blip}" if blip else "" + line += f"\n{url}{caption}" + elif blip: + line += f"\n({blip})" + + return line def load_locomo_data( @@ -81,10 +103,9 @@ def build_session_messages( item: dict, session_range: tuple[int, int] | None = None, ) -> list[dict]: - """Build session messages for one LoCoMo sample. + """Build bundled session messages for one LoCoMo sample. - Returns list of dicts with keys: messages, meta. - Each dict represents a session with multiple messages (user/assistant role). + Returns list of dicts with keys: message, meta. """ conv = item["conversation"] speakers = f"{conv['speaker_a']} & {conv['speaker_b']}" @@ -105,20 +126,13 @@ def build_session_messages( dt_key = f"{sk}_date_time" date_time = conv.get(dt_key, "") - # Extract messages with all as user role, including speaker in content - messages = [] - for idx, msg in enumerate(conv[sk]): - speaker = msg.get("speaker", "unknown") - text = msg.get("text", "") - messages.append({ - "role": "user", - "text": f"[{speaker}]: {text}", - "speaker": speaker, - "index": idx - }) + parts = [f"[group chat conversation: {date_time}]"] + for msg in conv[sk]: + parts.append(format_locomo_message(msg)) + combined = "\n\n".join(parts) sessions.append({ - "messages": messages, + "message": combined, "meta": { "sample_id": item["sample_id"], "session_key": sk, @@ -134,8 +148,55 @@ def build_session_messages( # Ingest record helpers (avoid duplicate ingestion) # --------------------------------------------------------------------------- +def load_success_csv(csv_path: str = "import_success.csv") -> set: + """加载成功导入的CSV记录,返回已成功的键集合""" + success_keys = set() + if Path(csv_path).exists(): + with open(csv_path, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + key = f"viking:{row['sample_id']}:{row['session']}" + success_keys.add(key) + return success_keys + + +def write_success_record(record: Dict[str, Any], csv_path: str = "import_success.csv") -> None: + """写入成功记录到CSV文件""" + file_exists = Path(csv_path).exists() + fieldnames = ["timestamp", "sample_id", "session", "date_time", "speakers", + "embedding_tokens", "vlm_tokens", "llm_input_tokens", + "llm_output_tokens", "total_tokens"] + + with open(csv_path, "a", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + + writer.writerow({ + "timestamp": record["timestamp"], + "sample_id": record["sample_id"], + "session": record["session"], + "date_time": record.get("meta", {}).get("date_time", ""), + "speakers": record.get("meta", {}).get("speakers", ""), + "embedding_tokens": record["token_usage"].get("embedding", 0), + "vlm_tokens": record["token_usage"].get("vlm", 0), + "llm_input_tokens": record["token_usage"].get("llm_input", 0), + "llm_output_tokens": record["token_usage"].get("llm_output", 0), + "total_tokens": record["token_usage"].get("total", 0) + }) + + +def write_error_record(record: Dict[str, Any], error_path: str = "import_errors.log") -> None: + """写入错误记录到日志文件""" + with open(error_path, "a", encoding="utf-8") as f: + timestamp = record["timestamp"] + sample_id = record["sample_id"] + session = record["session"] + error = record["error"] + f.write(f"[{timestamp}] ERROR [{sample_id}/{session}]: {error}\n") + -def load_ingest_record(record_path: str = "result/ingest_record.json") -> dict: +def load_ingest_record(record_path: str = "./result/.ingest_record.json") -> dict: """Load existing ingest record file, return empty dict if not exists.""" try: with open(record_path, "r", encoding="utf-8") as f: @@ -144,7 +205,7 @@ def load_ingest_record(record_path: str = "result/ingest_record.json") -> dict: return {} -def save_ingest_record(record: dict, record_path: str = "result/ingest_record.json") -> None: +def save_ingest_record(record: dict, record_path: str = "./result/.ingest_record.json") -> None: """Save ingest record to file.""" with open(record_path, "w", encoding="utf-8") as f: json.dump(record, f, indent=2, ensure_ascii=False) @@ -154,9 +215,12 @@ def is_already_ingested( sample_id: str | int, session_key: str, record: dict, + success_keys: set = None, ) -> bool: """Check if a specific session has already been successfully ingested.""" key = f"viking:{sample_id}:{session_key}" + if success_keys is not None and key in success_keys: + return True return key in record and record[key].get("success", False) @@ -178,60 +242,84 @@ def mark_ingested( # --------------------------------------------------------------------------- # OpenViking import # --------------------------------------------------------------------------- +def _parse_token_usage(token_data: dict) -> dict: + """解析Token使用数据(仅支持新版token_usage格式)""" + usage = token_data["token_usage"] + return { + "embedding": usage["embedding"]["total_tokens"], + "vlm": usage["llm"]["total_tokens"], + "llm_input": usage["llm"]["prompt_tokens"], + "llm_output": usage["llm"]["completion_tokens"], + "total": usage["total"]["total_tokens"] + } -def viking_ingest(messages: list[dict], session_time: str = None) -> None: - """Save messages to OpenViking via SyncHTTPClient (add messages + commit session). - - Args: - messages: List of message dicts with role and text - session_time: Session time string (e.g., "9:36 am on 2 April, 2023") +async def viking_ingest(msg: str, openviking_url: str) -> dict: + """Save a message to OpenViking via OpenViking SDK client. + Returns token usage dict with embedding and vlm token counts. """ - from datetime import datetime + # 使用信号量控制并发 + async with semaphore: + # Create client + client = ov.AsyncHTTPClient(url=openviking_url) + await client.initialize() - # 解析 session_time - created_at = None - if session_time: try: - dt = datetime.strptime(session_time, "%I:%M %p on %d %B, %Y") - created_at = dt.isoformat() - except ValueError: - print(f"Warning: Failed to parse session_time: {session_time}", file=sys.stderr) + # Create session + create_res = await client.create_session() + session_id = create_res["session_id"] + session = client.session(session_id) + + # Add message + await session.add_message( + role="user", + parts=[TextPart(text=msg)] + ) + + # Commit + result = await session.commit(telemetry=True) - client = ov.SyncHTTPClient() - client.initialize() + if not (result.get("status") == "accepted" and result.get("task_id")): + raise RuntimeError(f"Commit failed: {result}") - # Create new session - session_result = client.create_session() - session_id = session_result.get('session_id') + # 轮询等待异步任务完成 + task_id = result["task_id"] + max_wait = 1200 # 最多等待20分钟 + waited = 0 - # Add messages one by one - for msg in messages: - client.add_message(session_id, role=msg["role"], content=msg["text"], created_at=created_at) + while waited < max_wait: + task = await client.get_task(task_id) + if task["status"] == "completed": + token_usage = _parse_token_usage(task["result"]) + break + elif task["status"] == "failed": + raise RuntimeError(f"Commit failed: {task.get('error', 'Unknown error')}") - # Commit session to trigger memory extraction - commit_result = client.commit_session(session_id) - task_id = commit_result.get("task_id") + await asyncio.sleep(1) + waited += 1 + else: + raise RuntimeError(f"Commit timed out after {max_wait} seconds") - # Wait for commit task to complete - if task_id: - now = time.time() - while True: - task = client.get_task(task_id) - if not task or task.get("status") in ("completed", "failed"): - break - time.sleep(1) - elapsed = time.time() - now - status = task.get("status", "unknown") if task else "not found" + return token_usage - client.close() + finally: + await client.close() + +async def task_status(task_id: str): + client = ov.AsyncHTTPClient(url="http://localhost:1933") + await client.initialize() + task = await client.get_task(task_id) + print(f"Task status: {task}") + +def sync_viking_ingest(msg: str) -> dict: + """Synchronous wrapper for viking_ingest to maintain existing API.""" + return asyncio.run(viking_ingest(msg)) # --------------------------------------------------------------------------- # Main import logic # --------------------------------------------------------------------------- - def parse_session_range(s: str) -> tuple[int, int]: """Parse '1-4' or '3' into (lo, hi) inclusive tuple.""" if "-" in s: @@ -241,7 +329,69 @@ def parse_session_range(s: str) -> tuple[int, int]: return n, n -def run_import(args: argparse.Namespace) -> None: +async def process_single_session( + msg: str, + sample_id: str | int, + session_key: str, + meta: Dict[str, Any], + run_time: str, + ingest_record: Dict, + args: argparse.Namespace, +) -> Dict[str, Any]: + """处理单个会话的导入任务""" + try: + token_usage = await viking_ingest(msg, args.openviking_url) + print(f" -> [SUCCESS] [{sample_id}/{session_key}] imported to OpenViking", file=sys.stderr) + + # Extract token counts + embedding_tokens = token_usage.get("embedding", 0) + vlm_tokens = token_usage.get("vlm", 0) + print(f" -> [USAGE] [{sample_id}/{session_key}] Embedding tokens: {embedding_tokens}, VLM tokens: {vlm_tokens}", file=sys.stderr) + + # Write success record + result = { + "timestamp": run_time, + "sample_id": sample_id, + "session": session_key, + "status": "success", + "meta": meta, + "token_usage": token_usage, + "embedding_tokens": embedding_tokens, + "vlm_tokens": vlm_tokens + } + + # 写入成功CSV + write_success_record(result, args.success_csv) + + # Mark as successfully ingested + mark_ingested(sample_id, session_key, ingest_record, meta) + save_ingest_record(ingest_record) # Save immediately after success + + return result + + except Exception as e: + print(f" -> [ERROR] [{sample_id}/{session_key}] {e}", file=sys.stderr) + + # Write error record + result = { + "timestamp": run_time, + "sample_id": sample_id, + "session": session_key, + "status": "error", + "error": str(e) + } + + # 写入错误日志 + write_error_record(result, args.error_log) + + return result + + +async def run_import(args: argparse.Namespace) -> None: + global semaphore + # 初始化信号量控制并发 + semaphore = asyncio.Semaphore(args.parallel) + session_range = parse_session_range(args.sessions) if args.sessions else None # Handle ingest record operations @@ -252,18 +402,21 @@ def run_import(args: argparse.Namespace) -> None: else: ingest_record = load_ingest_record() - # Open output files for incremental writing - txt_output = open(args.output, "a", encoding="utf-8") - jsonl_output = open(f"{args.output}.jsonl", "a", encoding="utf-8") + # 加载成功CSV记录用于去重 + success_keys = set() + if not args.force_ingest: + success_keys = load_success_csv(args.success_csv) + print(f"[INFO] Loaded {len(success_keys)} existing success records from {args.success_csv}", file=sys.stderr) # Write run header run_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - txt_output.write(f"\n=== Import run at {run_time} ===\n") - txt_output.flush() skipped_count = 0 success_count = 0 error_count = 0 + total_embedding_tokens = 0 + total_vlm_tokens = 0 + tasks: List[asyncio.Task] = [] if args.input.endswith(".json"): # LoCoMo JSON format @@ -278,73 +431,32 @@ def run_import(args: argparse.Namespace) -> None: for sess in sessions: meta = sess["meta"] - messages = sess["messages"] - label = f"{meta['session_key']} ({meta['date_time']})" + msg = sess["message"] + session_key = meta["session_key"] + label = f"{session_key} ({meta['date_time']})" # Skip already ingested sessions unless force-ingest is enabled - if not args.force_ingest and is_already_ingested(sample_id, meta['session_key'], ingest_record): + if not args.force_ingest and is_already_ingested(sample_id, session_key, ingest_record, success_keys): print(f" [{label}] [SKIP] already imported (use --force-ingest to reprocess)", file=sys.stderr) skipped_count += 1 - - # Write skip record - result = { - "timestamp": run_time, - "sample_id": sample_id, - "session": meta["session_key"], - "status": "skipped", - "reason": "already imported" - } - txt_output.write(f"[{sample_id}/{meta['session_key']}] SKIPPED: already imported\n") - jsonl_output.write(json.dumps(result, ensure_ascii=False) + "\n") - txt_output.flush() - jsonl_output.flush() continue - # Preview messages - preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]]) - print(f" [{label}] {preview}", file=sys.stderr) - - try: - viking_ingest(messages, session_time=meta.get("date_time")) - print(f" -> [SUCCESS] imported to OpenViking", file=sys.stderr) - success_count += 1 - - # Write success record - result = { - "timestamp": run_time, - "sample_id": sample_id, - "session": meta["session_key"], - "status": "success", - "meta": meta - } - txt_output.write(f"[{sample_id}/{meta['session_key']}] SUCCESS\n") - jsonl_output.write(json.dumps(result, ensure_ascii=False) + "\n") - txt_output.flush() - jsonl_output.flush() - - # Mark as successfully ingested - mark_ingested(sample_id, meta['session_key'], ingest_record, { - "date_time": meta['date_time'], - "speakers": meta['speakers'] - }) - save_ingest_record(ingest_record) # Save immediately after success - - except Exception as e: - print(f" -> [ERROR] {e}", file=sys.stderr) - error_count += 1 - - # Write error record - result = { - "timestamp": run_time, - "sample_id": sample_id, - "session": meta["session_key"], - "status": "error", - "error": str(e) - } - txt_output.write(f"[{sample_id}/{meta['session_key']}] ERROR: {str(e)}\n") - jsonl_output.write(json.dumps(result, ensure_ascii=False) + "\n") - txt_output.flush() - jsonl_output.flush() + preview = msg.replace("\n", " | ")[:80] + print(f" [{label}] {preview}...", file=sys.stderr) + + # 创建异步任务 + task = asyncio.create_task( + process_single_session( + msg=msg, + sample_id=sample_id, + session_key=session_key, + meta=meta, + run_time=run_time, + ingest_record=ingest_record, + args=args, + ) + ) + tasks.append(task) else: # Plain text format @@ -356,80 +468,46 @@ def run_import(args: argparse.Namespace) -> None: print(f"\n=== Text Session {idx} ===", file=sys.stderr) # Skip already ingested sessions unless force-ingest is enabled - if not args.force_ingest and is_already_ingested("txt", session_key, ingest_record): + if not args.force_ingest and is_already_ingested("txt", session_key, ingest_record, success_keys): print(f" [SKIP] already imported (use --force-ingest to reprocess)", file=sys.stderr) skipped_count += 1 - - # Write skip record - result = { - "timestamp": run_time, - "sample_id": "txt", - "session": session_key, - "status": "skipped", - "reason": "already imported" - } - txt_output.write(f"[txt/{session_key}] SKIPPED: already imported\n") - jsonl_output.write(json.dumps(result, ensure_ascii=False) + "\n") - txt_output.flush() - jsonl_output.flush() continue - # For plain text, all messages as user role - messages = [] - for i, text in enumerate(session["messages"]): - messages.append({ - "role": "user", - "text": text.strip(), - "speaker": "user", - "index": i - }) - - preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]]) - print(f" {preview}", file=sys.stderr) - - try: - viking_ingest(messages) - print(f" -> [SUCCESS] imported to OpenViking", file=sys.stderr) - success_count += 1 - - # Write success record - result = { - "timestamp": run_time, - "sample_id": "txt", - "session": session_key, - "status": "success", - "session_index": idx - } - txt_output.write(f"[txt/{session_key}] SUCCESS\n") - jsonl_output.write(json.dumps(result, ensure_ascii=False) + "\n") - txt_output.flush() - jsonl_output.flush() - - mark_ingested("txt", session_key, ingest_record, { - "session_index": idx - }) - save_ingest_record(ingest_record) # Save immediately after success - - except Exception as e: - print(f" -> [ERROR] {e}", file=sys.stderr) - error_count += 1 - - # Write error record - result = { - "timestamp": run_time, - "sample_id": "txt", - "session": session_key, - "status": "error", - "error": str(e) - } - txt_output.write(f"[txt/{session_key}] ERROR: {str(e)}\n") - jsonl_output.write(json.dumps(result, ensure_ascii=False) + "\n") - txt_output.flush() - jsonl_output.flush() - - # Close output files - txt_output.close() - jsonl_output.close() + combined_msg = "\n\n".join(session["messages"]) + preview = combined_msg.replace("\n", " | ")[:80] + print(f" {preview}...", file=sys.stderr) + + # 创建异步任务 + task = asyncio.create_task( + process_single_session( + msg=combined_msg, + sample_id="txt", + session_key=session_key, + meta={"session_index": idx}, + run_time=run_time, + ingest_record=ingest_record, + args=args, + ) + ) + tasks.append(task) + + # 等待所有任务完成 + print(f"\n[INFO] Starting import with {args.parallel} concurrent workers, {len(tasks)} tasks to process", file=sys.stderr) + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 统计结果 + for result in results: + if isinstance(result, Exception): + error_count += 1 + print(f"[UNEXPECTED ERROR] Task failed with exception: {result}", file=sys.stderr) + continue + + if result["status"] == "success": + success_count += 1 + total_embedding_tokens += result["embedding_tokens"] + total_vlm_tokens += result["vlm_tokens"] + elif result["status"] == "error": + error_count += 1 # Final summary total_processed = success_count + error_count + skipped_count @@ -438,14 +516,21 @@ def run_import(args: argparse.Namespace) -> None: print(f"Successfully imported: {success_count}", file=sys.stderr) print(f"Failed: {error_count}", file=sys.stderr) print(f"Skipped (already imported): {skipped_count}", file=sys.stderr) - print(f"Results saved to: {args.output} (text) and {args.output}.jsonl (JSON Lines)", file=sys.stderr) + print(f"\n=== Token usage summary ===", file=sys.stderr) + print(f"Total Embedding tokens: {total_embedding_tokens}", file=sys.stderr) + print(f"Total VLM tokens: {total_vlm_tokens}", file=sys.stderr) + if success_count > 0: + print(f"Average Embedding per session: {total_embedding_tokens // success_count}", file=sys.stderr) + print(f"Average VLM per session: {total_vlm_tokens // success_count}", file=sys.stderr) + print(f"\nResults saved to:", file=sys.stderr) + print(f" - Success records: {args.success_csv}", file=sys.stderr) + print(f" - Error logs: {args.error_log}", file=sys.stderr) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- - def main(): parser = argparse.ArgumentParser(description="Import conversations into OpenViking") parser.add_argument( @@ -454,9 +539,25 @@ def main(): help="Path to input file (.txt or LoCoMo .json)" ) parser.add_argument( - "--output", - default="./result/import_results", - help="Path to output file (default: import_results)", + "--success-csv", + default="./result/import_success.csv", + help="Path to success records CSV file (default: import_success.csv)", + ) + parser.add_argument( + "--error-log", + default="./result/import_errors.log", + help="Path to error log file (default: import_errors.log)", + ) + parser.add_argument( + "--openviking-url", + default="http://localhost:1933", + help="OpenViking service URL (default: http://localhost:1933)", + ) + parser.add_argument( + "--parallel", + type=int, + default=5, + help="Number of concurrent import workers (default: 5)", ) parser.add_argument( "--sample", @@ -483,7 +584,11 @@ def main(): ) args = parser.parse_args() - run_import(args) + # 确保输出目录存在 + Path(args.success_csv).parent.mkdir(parents=True, exist_ok=True) + Path(args.error_log).parent.mkdir(parents=True, exist_ok=True) + + asyncio.run(run_import(args)) if __name__ == "__main__": From a70390c4686964da12b3f3afcb49577f29adc7da Mon Sep 17 00:00:00 2001 From: DuTao Date: Tue, 31 Mar 2026 14:39:08 +0800 Subject: [PATCH 2/2] import eval --- bot/eval/locomo/import_to_ov.py | 77 +++++++++++++++++---------------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/bot/eval/locomo/import_to_ov.py b/bot/eval/locomo/import_to_ov.py index 0f9db7e5..7a531d04 100644 --- a/bot/eval/locomo/import_to_ov.py +++ b/bot/eval/locomo/import_to_ov.py @@ -15,21 +15,22 @@ import asyncio import csv import json -import subprocess import sys import time from datetime import datetime from pathlib import Path -from typing import List, Dict, Any +from typing import List, Dict, Any, Tuple, Optional import openviking as ov from openviking.message.part import TextPart -# 全局信号量用于控制并发 -semaphore: asyncio.Semaphore = None +def _get_session_number(session_key: str) -> int: + """Extract session number from session key.""" + return int(session_key.split("_")[1]) -def parse_test_file(path: str) -> list[dict]: + +def parse_test_file(path: str) -> List[Dict[str, Any]]: """Parse txt test file into sessions. Each session is a dict with: @@ -57,7 +58,7 @@ def parse_test_file(path: str) -> list[dict]: return sessions -def format_locomo_message(msg: dict) -> str: +def format_locomo_message(msg: Dict[str, Any]) -> str: """Format a single LoCoMo message into a natural chat-style string. Output format: @@ -85,24 +86,23 @@ def format_locomo_message(msg: dict) -> str: def load_locomo_data( path: str, - sample_index: int | None = None, -) -> list[dict]: + sample_index: Optional[int] = None, +) -> List[Dict[str, Any]]: """Load LoCoMo JSON and optionally filter to one sample.""" with open(path, "r", encoding="utf-8") as f: data = json.load(f) if sample_index is not None: if sample_index < 0 or sample_index >= len(data): - print(f"Error: sample index {sample_index} out of range (0-{len(data)-1})", file=sys.stderr) - sys.exit(1) + raise ValueError(f"Sample index {sample_index} out of range (0-{len(data)-1})") return [data[sample_index]] return data def build_session_messages( - item: dict, - session_range: tuple[int, int] | None = None, -) -> list[dict]: + item: Dict[str, Any], + session_range: Optional[Tuple[int, int]] = None, +) -> List[Dict[str, Any]]: """Build bundled session messages for one LoCoMo sample. Returns list of dicts with keys: message, meta. @@ -112,12 +112,12 @@ def build_session_messages( session_keys = sorted( [k for k in conv if k.startswith("session_") and not k.endswith("_date_time")], - key=lambda k: int(k.split("_")[1]), + key=_get_session_number, ) sessions = [] for sk in session_keys: - sess_num = int(sk.split("_")[1]) + sess_num = _get_session_number(sk) if session_range: lo, hi = session_range if sess_num < lo or sess_num > hi: @@ -196,7 +196,7 @@ def write_error_record(record: Dict[str, Any], error_path: str = "import_errors. f.write(f"[{timestamp}] ERROR [{sample_id}/{session}]: {error}\n") -def load_ingest_record(record_path: str = "./result/.ingest_record.json") -> dict: +def load_ingest_record(record_path: str = "./result/.ingest_record.json") -> Dict[str, Any]: """Load existing ingest record file, return empty dict if not exists.""" try: with open(record_path, "r", encoding="utf-8") as f: @@ -205,7 +205,7 @@ def load_ingest_record(record_path: str = "./result/.ingest_record.json") -> dic return {} -def save_ingest_record(record: dict, record_path: str = "./result/.ingest_record.json") -> None: +def save_ingest_record(record: Dict[str, Any], record_path: str = "./result/.ingest_record.json") -> None: """Save ingest record to file.""" with open(record_path, "w", encoding="utf-8") as f: json.dump(record, f, indent=2, ensure_ascii=False) @@ -214,8 +214,8 @@ def save_ingest_record(record: dict, record_path: str = "./result/.ingest_record def is_already_ingested( sample_id: str | int, session_key: str, - record: dict, - success_keys: set = None, + record: Dict[str, Any], + success_keys: Optional[set] = None, ) -> bool: """Check if a specific session has already been successfully ingested.""" key = f"viking:{sample_id}:{session_key}" @@ -227,8 +227,8 @@ def is_already_ingested( def mark_ingested( sample_id: str | int, session_key: str, - record: dict, - meta: dict | None = None, + record: Dict[str, Any], + meta: Optional[Dict[str, Any]] = None, ) -> None: """Mark a session as successfully ingested.""" key = f"viking:{sample_id}:{session_key}" @@ -242,7 +242,7 @@ def mark_ingested( # --------------------------------------------------------------------------- # OpenViking import # --------------------------------------------------------------------------- -def _parse_token_usage(token_data: dict) -> dict: +def _parse_token_usage(token_data: Dict[str, Any]) -> Dict[str, int]: """解析Token使用数据(仅支持新版token_usage格式)""" usage = token_data["token_usage"] return { @@ -254,7 +254,7 @@ def _parse_token_usage(token_data: dict) -> dict: } -async def viking_ingest(msg: str, openviking_url: str) -> dict: +async def viking_ingest(msg: str, openviking_url: str, semaphore: asyncio.Semaphore) -> Dict[str, int]: """Save a message to OpenViking via OpenViking SDK client. Returns token usage dict with embedding and vlm token counts. """ @@ -295,7 +295,8 @@ async def viking_ingest(msg: str, openviking_url: str) -> dict: elif task["status"] == "failed": raise RuntimeError(f"Commit failed: {task.get('error', 'Unknown error')}") - await asyncio.sleep(1) + # 指数退避策略,避免频繁请求 + await asyncio.sleep(min(1 << (waited // 10), 60)) waited += 1 else: raise RuntimeError(f"Commit timed out after {max_wait} seconds") @@ -305,22 +306,18 @@ async def viking_ingest(msg: str, openviking_url: str) -> dict: finally: await client.close() -async def task_status(task_id: str): - client = ov.AsyncHTTPClient(url="http://localhost:1933") - await client.initialize() - task = await client.get_task(task_id) - print(f"Task status: {task}") -def sync_viking_ingest(msg: str) -> dict: +def sync_viking_ingest(msg: str, openviking_url: str) -> Dict[str, int]: """Synchronous wrapper for viking_ingest to maintain existing API.""" - return asyncio.run(viking_ingest(msg)) + semaphore = asyncio.Semaphore(1) # 同步调用时使用信号量为1 + return asyncio.run(viking_ingest(msg, openviking_url, semaphore)) # --------------------------------------------------------------------------- # Main import logic # --------------------------------------------------------------------------- -def parse_session_range(s: str) -> tuple[int, int]: +def parse_session_range(s: str) -> Tuple[int, int]: """Parse '1-4' or '3' into (lo, hi) inclusive tuple.""" if "-" in s: lo, hi = s.split("-", 1) @@ -335,12 +332,13 @@ async def process_single_session( session_key: str, meta: Dict[str, Any], run_time: str, - ingest_record: Dict, + ingest_record: Dict[str, Any], args: argparse.Namespace, + semaphore: asyncio.Semaphore ) -> Dict[str, Any]: """处理单个会话的导入任务""" try: - token_usage = await viking_ingest(msg, args.openviking_url) + token_usage = await viking_ingest(msg, args.openviking_url, semaphore) print(f" -> [SUCCESS] [{sample_id}/{session_key}] imported to OpenViking", file=sys.stderr) # Extract token counts @@ -388,7 +386,6 @@ async def process_single_session( async def run_import(args: argparse.Namespace) -> None: - global semaphore # 初始化信号量控制并发 semaphore = asyncio.Semaphore(args.parallel) @@ -454,6 +451,7 @@ async def run_import(args: argparse.Namespace) -> None: run_time=run_time, ingest_record=ingest_record, args=args, + semaphore=semaphore ) ) tasks.append(task) @@ -487,6 +485,7 @@ async def run_import(args: argparse.Namespace) -> None: run_time=run_time, ingest_record=ingest_record, args=args, + semaphore=semaphore ) ) tasks.append(task) @@ -588,8 +587,12 @@ def main(): Path(args.success_csv).parent.mkdir(parents=True, exist_ok=True) Path(args.error_log).parent.mkdir(parents=True, exist_ok=True) - asyncio.run(run_import(args)) + try: + asyncio.run(run_import(args)) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) if __name__ == "__main__": - main() + main() \ No newline at end of file