diff --git a/.gitignore b/.gitignore index f149cd90d..e8bddfc20 100644 --- a/.gitignore +++ b/.gitignore @@ -141,6 +141,7 @@ logs/ *.temp .tmp/ ov.conf +result/ # Jupyter Notebook .ipynb_checkpoints diff --git a/.ingest_record.json b/.ingest_record.json new file mode 100644 index 000000000..04e4753dd --- /dev/null +++ b/.ingest_record.json @@ -0,0 +1,890 @@ +{ + "viking:conv-26:session_1": { + "success": true, + "timestamp": 1774706688, + "meta": { + "date_time": "1:56 pm on 8 May, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_2": { + "success": true, + "timestamp": 1774706715, + "meta": { + "date_time": "1:14 pm on 25 May, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_3": { + "success": true, + "timestamp": 1774706739, + "meta": { + "date_time": "7:55 pm on 9 June, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_4": { + "success": true, + "timestamp": 1774706761, + "meta": { + "date_time": "10:37 am on 27 June, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_5": { + "success": true, + "timestamp": 1774706784, + "meta": { + "date_time": "1:36 pm on 3 July, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_6": { + "success": true, + "timestamp": 1774706808, + "meta": { + "date_time": "8:18 pm on 6 July, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_7": { + "success": true, + "timestamp": 1774706834, + "meta": { + "date_time": "4:33 pm on 12 July, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_8": { + "success": true, + "timestamp": 1774706855, + "meta": { + "date_time": "1:51 pm on 15 July, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_9": { + "success": true, + "timestamp": 1774706879, + "meta": { + "date_time": "2:31 pm on 17 July, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_10": { + "success": true, + "timestamp": 1774706903, + "meta": { + "date_time": "8:56 pm on 20 July, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_11": { + "success": true, + "timestamp": 1774706923, + "meta": { + "date_time": "2:24 pm on 14 August, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_12": { + "success": true, + "timestamp": 1774706941, + "meta": { + "date_time": "1:50 pm on 17 August, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_13": { + "success": true, + "timestamp": 1774706962, + "meta": { + "date_time": "3:31 pm on 23 August, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_14": { + "success": true, + "timestamp": 1774706983, + "meta": { + "date_time": "1:33 pm on 25 August, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_15": { + "success": true, + "timestamp": 1774707002, + "meta": { + "date_time": "3:19 pm on 28 August, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_16": { + "success": true, + "timestamp": 1774707021, + "meta": { + "date_time": "12:09 am on 13 September, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_17": { + "success": true, + "timestamp": 1774707033, + "meta": { + "date_time": "10:31 am on 13 October, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_18": { + "success": true, + "timestamp": 1774707054, + "meta": { + "date_time": "6:55 pm on 20 October, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-26:session_19": { + "success": true, + "timestamp": 1774707075, + "meta": { + "date_time": "9:55 am on 22 October, 2023", + "speakers": "Caroline & Melanie" + } + }, + "viking:conv-30:session_1": { + "success": true, + "timestamp": 1774707088, + "meta": { + "date_time": "4:04 pm on 20 January, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_2": { + "success": true, + "timestamp": 1774707112, + "meta": { + "date_time": "2:32 pm on 29 January, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_3": { + "success": true, + "timestamp": 1774707134, + "meta": { + "date_time": "12:48 am on 1 February, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_4": { + "success": true, + "timestamp": 1774707142, + "meta": { + "date_time": "10:43 am on 4 February, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_5": { + "success": true, + "timestamp": 1774707157, + "meta": { + "date_time": "9:32 am on 8 February, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_6": { + "success": true, + "timestamp": 1774707169, + "meta": { + "date_time": "2:35 pm on 16 March, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_7": { + "success": true, + "timestamp": 1774707184, + "meta": { + "date_time": "7:28 pm on 23 March, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_8": { + "success": true, + "timestamp": 1774707200, + "meta": { + "date_time": "1:26 pm on 3 April, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_9": { + "success": true, + "timestamp": 1774707218, + "meta": { + "date_time": "10:33 am on 9 April, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_10": { + "success": true, + "timestamp": 1774707228, + "meta": { + "date_time": "11:24 am on 25 April, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_11": { + "success": true, + "timestamp": 1774707242, + "meta": { + "date_time": "3:14 pm on 11 May, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_12": { + "success": true, + "timestamp": 1774707255, + "meta": { + "date_time": "7:18 pm on 27 May, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_13": { + "success": true, + "timestamp": 1774707263, + "meta": { + "date_time": "8:29 pm on 13 June, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_14": { + "success": true, + "timestamp": 1774707278, + "meta": { + "date_time": "9:38 pm on 16 June, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_15": { + "success": true, + "timestamp": 1774707290, + "meta": { + "date_time": "10:04 am on 19 June, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_16": { + "success": true, + "timestamp": 1774707303, + "meta": { + "date_time": "2:15 pm on 21 June, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_17": { + "success": true, + "timestamp": 1774707314, + "meta": { + "date_time": "1:25 pm on 9 July, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_18": { + "success": true, + "timestamp": 1774707326, + "meta": { + "date_time": "5:44 pm on 21 July, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-30:session_19": { + "success": true, + "timestamp": 1774700596, + "meta": { + "date_time": "6:46 pm on 23 July, 2023", + "speakers": "Jon & Gina" + } + }, + "viking:conv-41:session_1": { + "success": true, + "timestamp": 1774700612, + "meta": { + "date_time": "11:01 am on 17 December, 2022", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_2": { + "success": true, + "timestamp": 1774700629, + "meta": { + "date_time": "6:10 pm on 22 December, 2022", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_3": { + "success": true, + "timestamp": 1774700649, + "meta": { + "date_time": "8:30 pm on 1 January, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_4": { + "success": true, + "timestamp": 1774700665, + "meta": { + "date_time": "7:06 pm on 9 January, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_5": { + "success": true, + "timestamp": 1774700681, + "meta": { + "date_time": "1:17 pm on 28 January, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_6": { + "success": true, + "timestamp": 1774700697, + "meta": { + "date_time": "2:33 pm on 5 February, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_7": { + "success": true, + "timestamp": 1774700711, + "meta": { + "date_time": "8:55 pm on 25 February, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_8": { + "success": true, + "timestamp": 1774700732, + "meta": { + "date_time": "6:03 pm on 6 March, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_9": { + "success": true, + "timestamp": 1774700746, + "meta": { + "date_time": "9:36 am on 2 April, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_10": { + "success": true, + "timestamp": 1774700763, + "meta": { + "date_time": "12:24 am on 7 April, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_11": { + "success": true, + "timestamp": 1774700780, + "meta": { + "date_time": "6:13 pm on 10 April, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_12": { + "success": true, + "timestamp": 1774700797, + "meta": { + "date_time": "7:34 pm on 18 April, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_13": { + "success": true, + "timestamp": 1774700813, + "meta": { + "date_time": "3:18 pm on 4 May, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_14": { + "success": true, + "timestamp": 1774700828, + "meta": { + "date_time": "5:04 pm on 6 May, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_15": { + "success": true, + "timestamp": 1774700844, + "meta": { + "date_time": "7:38 pm on 20 May, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_16": { + "success": true, + "timestamp": 1774700859, + "meta": { + "date_time": "1:24 pm on 25 May, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_17": { + "success": true, + "timestamp": 1774700879, + "meta": { + "date_time": "11:51 am on 3 June, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_18": { + "success": true, + "timestamp": 1774700896, + "meta": { + "date_time": "2:47 pm on 12 June, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_19": { + "success": true, + "timestamp": 1774700913, + "meta": { + "date_time": "7:20 pm on 16 June, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_20": { + "success": true, + "timestamp": 1774700929, + "meta": { + "date_time": "12:21 am on 27 June, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_21": { + "success": true, + "timestamp": 1774700946, + "meta": { + "date_time": "8:43 pm on 3 July, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_22": { + "success": true, + "timestamp": 1774700965, + "meta": { + "date_time": "6:59 pm on 5 July, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_23": { + "success": true, + "timestamp": 1774700979, + "meta": { + "date_time": "6:29 pm on 7 July, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_24": { + "success": true, + "timestamp": 1774700993, + "meta": { + "date_time": "3:34 pm on 17 July, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_25": { + "success": true, + "timestamp": 1774701015, + "meta": { + "date_time": "6:21 pm on 22 July, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_26": { + "success": true, + "timestamp": 1774701033, + "meta": { + "date_time": "1:59 pm on 31 July, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_27": { + "success": true, + "timestamp": 1774701047, + "meta": { + "date_time": "6:20 pm on 3 August, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_28": { + "success": true, + "timestamp": 1774701065, + "meta": { + "date_time": "5:19 pm on 5 August, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_29": { + "success": true, + "timestamp": 1774701082, + "meta": { + "date_time": "8:06 pm on 9 August, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_30": { + "success": true, + "timestamp": 1774701098, + "meta": { + "date_time": "12:10 am on 11 August, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_31": { + "success": true, + "timestamp": 1774701118, + "meta": { + "date_time": "3:14 pm on 13 August, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-41:session_32": { + "success": true, + "timestamp": 1774701134, + "meta": { + "date_time": "11:08 am on 16 August, 2023", + "speakers": "John & Maria" + } + }, + "viking:conv-42:session_1": { + "success": true, + "timestamp": 1774701152, + "meta": { + "date_time": "7:31 pm on 21 January, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_2": { + "success": true, + "timestamp": 1774701172, + "meta": { + "date_time": "2:01 pm on 23 January, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_3": { + "success": true, + "timestamp": 1774701199, + "meta": { + "date_time": "9:27 am on 7 February, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_4": { + "success": true, + "timestamp": 1774701218, + "meta": { + "date_time": "1:07 pm on 25 February, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_5": { + "success": true, + "timestamp": 1774701238, + "meta": { + "date_time": "6:59 pm on 18 March, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_6": { + "success": true, + "timestamp": 1774701260, + "meta": { + "date_time": "1:43 pm on 24 March, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_7": { + "success": true, + "timestamp": 1774701281, + "meta": { + "date_time": "7:37 pm on 15 April, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_8": { + "success": true, + "timestamp": 1774701303, + "meta": { + "date_time": "6:44 pm on 17 April, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_9": { + "success": true, + "timestamp": 1774701325, + "meta": { + "date_time": "7:44 pm on 21 April, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_10": { + "success": true, + "timestamp": 1774701341, + "meta": { + "date_time": "11:54 am on 2 May, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_11": { + "success": true, + "timestamp": 1774701365, + "meta": { + "date_time": "3:35 pm on 12 May, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_12": { + "success": true, + "timestamp": 1774701386, + "meta": { + "date_time": "7:49 pm on 20 May, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_13": { + "success": true, + "timestamp": 1774701403, + "meta": { + "date_time": "3:00 pm on 25 May, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_14": { + "success": true, + "timestamp": 1774701427, + "meta": { + "date_time": "5:44 pm on 3 June, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_15": { + "success": true, + "timestamp": 1774701448, + "meta": { + "date_time": "2:12 pm on 5 June, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_16": { + "success": true, + "timestamp": 1774701466, + "meta": { + "date_time": "10:55 am on 24 June, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_17": { + "success": true, + "timestamp": 1774701484, + "meta": { + "date_time": "2:34 pm on 10 July, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_18": { + "success": true, + "timestamp": 1774701506, + "meta": { + "date_time": "6:12 pm on 14 August, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_19": { + "success": true, + "timestamp": 1774701523, + "meta": { + "date_time": "10:57 am on 22 August, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_20": { + "success": true, + "timestamp": 1774701545, + "meta": { + "date_time": "6:03 pm on 5 September, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_21": { + "success": true, + "timestamp": 1774701560, + "meta": { + "date_time": "1:43 pm on 14 September, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_22": { + "success": true, + "timestamp": 1774701582, + "meta": { + "date_time": "11:15 am on 6 October, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_23": { + "success": true, + "timestamp": 1774701604, + "meta": { + "date_time": "10:58 am on 9 October, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_24": { + "success": true, + "timestamp": 1774701629, + "meta": { + "date_time": "2:01 pm on 21 October, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_25": { + "success": true, + "timestamp": 1774701653, + "meta": { + "date_time": "8:16 pm on 25 October, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_26": { + "success": true, + "timestamp": 1774701672, + "meta": { + "date_time": "3:56 pm on 4 November, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_27": { + "success": true, + "timestamp": 1774701692, + "meta": { + "date_time": "8:10 pm on 7 November, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_28": { + "success": true, + "timestamp": 1774701720, + "meta": { + "date_time": "5:54 pm on 9 November, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-42:session_29": { + "success": true, + "timestamp": 1774701741, + "meta": { + "date_time": "12:06 am on 11 November, 2022", + "speakers": "Joanna & Nate" + } + }, + "viking:conv-43:session_1": { + "success": true, + "timestamp": 1774701771, + "meta": { + "date_time": "7:48 pm on 21 May, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_2": { + "success": true, + "timestamp": 1774701796, + "meta": { + "date_time": "5:08 pm on 15 June, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_3": { + "success": true, + "timestamp": 1774701820, + "meta": { + "date_time": "4:21 pm on 16 July, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_4": { + "success": true, + "timestamp": 1774701840, + "meta": { + "date_time": "4:17 pm on 2 August, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_5": { + "success": true, + "timestamp": 1774701865, + "meta": { + "date_time": "10:29 am on 9 August, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_6": { + "success": true, + "timestamp": 1774701884, + "meta": { + "date_time": "1:08 pm on 11 August, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_7": { + "success": true, + "timestamp": 1774701908, + "meta": { + "date_time": "7:54 pm on 17 August, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_8": { + "success": true, + "timestamp": 1774701934, + "meta": { + "date_time": "4:29 pm on 21 August, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_9": { + "success": true, + "timestamp": 1774701961, + "meta": { + "date_time": "6:59 pm on 26 August, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_10": { + "success": true, + "timestamp": 1774701982, + "meta": { + "date_time": "2:52 pm on 31 August, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_11": { + "success": true, + "timestamp": 1774701999, + "meta": { + "date_time": "8:17 pm on 21 September, 2023", + "speakers": "Tim & John" + } + }, + "viking:conv-43:session_12": { + "success": true, + "timestamp": 1774702021, + "meta": { + "date_time": "3:00 pm on 2 October, 2023", + "speakers": "Tim & John" + } + } +} \ No newline at end of file diff --git a/bot/eval/locomo/import_to_ov.py b/bot/eval/locomo/import_to_ov.py index efe9c3e05..c44923362 100644 --- a/bot/eval/locomo/import_to_ov.py +++ b/bot/eval/locomo/import_to_ov.py @@ -13,11 +13,12 @@ import argparse import json -import subprocess import sys import time from datetime import datetime +import openviking as ov + def parse_test_file(path: str) -> list[dict]: """Parse txt test file into sessions. @@ -47,30 +48,17 @@ def parse_test_file(path: str) -> list[dict]: return sessions -def format_locomo_message(msg: dict) -> str: - """Format a single LoCoMo message into a natural chat-style string. +def format_locomo_message(msg: dict, index: int | None = None) -> str: + """Format a single LoCoMo message into chat-style string. Output format: - Speaker: text here - image_url: caption + [index][Speaker]: text here """ speaker = msg.get("speaker", "unknown") text = msg.get("text", "") - line = f"{speaker}: {text}" - - img_urls = msg.get("img_url", []) - if isinstance(img_urls, str): - img_urls = [img_urls] - blip = msg.get("blip_caption", "") - - if img_urls: - for url in img_urls: - caption = f": {blip}" if blip else "" - line += f"\n{url}{caption}" - elif blip: - line += f"\n({blip})" - - return line + if index is not None: + return f"[{index}][{speaker}]: {text}" + return f"[{speaker}]: {text}" def load_locomo_data( @@ -93,9 +81,10 @@ def build_session_messages( item: dict, session_range: tuple[int, int] | None = None, ) -> list[dict]: - """Build bundled session messages for one LoCoMo sample. + """Build session messages for one LoCoMo sample. - Returns list of dicts with keys: message, meta. + Returns list of dicts with keys: messages, meta. + Each dict represents a session with multiple messages (user/assistant role). """ conv = item["conversation"] speakers = f"{conv['speaker_a']} & {conv['speaker_b']}" @@ -116,13 +105,20 @@ def build_session_messages( dt_key = f"{sk}_date_time" date_time = conv.get(dt_key, "") - parts = [f"[group chat conversation: {date_time}]"] - for msg in conv[sk]: - parts.append(format_locomo_message(msg)) - combined = "\n\n".join(parts) + # Extract messages with all as user role, including speaker in content + messages = [] + for idx, msg in enumerate(conv[sk]): + speaker = msg.get("speaker", "unknown") + text = msg.get("text", "") + messages.append({ + "role": "user", + "text": f"[{speaker}]: {text}", + "speaker": speaker, + "index": idx + }) sessions.append({ - "message": combined, + "messages": messages, "meta": { "sample_id": item["sample_id"], "session_key": sk, @@ -138,7 +134,8 @@ def build_session_messages( # Ingest record helpers (avoid duplicate ingestion) # --------------------------------------------------------------------------- -def load_ingest_record(record_path: str = ".ingest_record.json") -> dict: + +def load_ingest_record(record_path: str = "result/ingest_record.json") -> dict: """Load existing ingest record file, return empty dict if not exists.""" try: with open(record_path, "r", encoding="utf-8") as f: @@ -147,7 +144,7 @@ def load_ingest_record(record_path: str = ".ingest_record.json") -> dict: return {} -def save_ingest_record(record: dict, record_path: str = ".ingest_record.json") -> None: +def save_ingest_record(record: dict, record_path: str = "result/ingest_record.json") -> None: """Save ingest record to file.""" with open(record_path, "w", encoding="utf-8") as f: json.dump(record, f, indent=2, ensure_ascii=False) @@ -182,21 +179,59 @@ def mark_ingested( # OpenViking import # --------------------------------------------------------------------------- -def viking_ingest(msg: str) -> None: - """Save a message to OpenViking via `ov add-memory`.""" - result = subprocess.run( - ["ov", "add-memory", msg], - capture_output=True, - text=True, - ) - if result.returncode != 0: - raise RuntimeError(result.stderr.strip() or f"ov exited with code {result.returncode}") + +def viking_ingest(messages: list[dict], session_time: str = None) -> None: + """Save messages to OpenViking via SyncHTTPClient (add messages + commit session). + + Args: + messages: List of message dicts with role and text + session_time: Session time string (e.g., "9:36 am on 2 April, 2023") + """ + from datetime import datetime + + # 解析 session_time + created_at = None + if session_time: + try: + dt = datetime.strptime(session_time, "%I:%M %p on %d %B, %Y") + created_at = dt.isoformat() + except ValueError: + print(f"Warning: Failed to parse session_time: {session_time}", file=sys.stderr) + + client = ov.SyncHTTPClient() + client.initialize() + + # Create new session + session_result = client.create_session() + session_id = session_result.get('session_id') + + # Add messages one by one + for msg in messages: + client.add_message(session_id, role=msg["role"], content=msg["text"], created_at=created_at) + + # Commit session to trigger memory extraction + commit_result = client.commit_session(session_id) + task_id = commit_result.get("task_id") + + # Wait for commit task to complete + if task_id: + now = time.time() + while True: + task = client.get_task(task_id) + if not task or task.get("status") in ("completed", "failed"): + break + time.sleep(1) + elapsed = time.time() - now + status = task.get("status", "unknown") if task else "not found" + + client.close() # --------------------------------------------------------------------------- # Main import logic # --------------------------------------------------------------------------- + def parse_session_range(s: str) -> tuple[int, int]: """Parse '1-4' or '3' into (lo, hi) inclusive tuple.""" if "-" in s: @@ -243,7 +278,7 @@ def run_import(args: argparse.Namespace) -> None: for sess in sessions: meta = sess["meta"] - msg = sess["message"] + messages = sess["messages"] label = f"{meta['session_key']} ({meta['date_time']})" # Skip already ingested sessions unless force-ingest is enabled @@ -265,11 +300,12 @@ def run_import(args: argparse.Namespace) -> None: jsonl_output.flush() continue - preview = msg.replace("\n", " | ")[:80] - print(f" [{label}] {preview}...", file=sys.stderr) + # Preview messages + preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]]) + print(f" [{label}] {preview}", file=sys.stderr) try: - viking_ingest(msg) + viking_ingest(messages, session_time=meta.get("date_time")) print(f" -> [SUCCESS] imported to OpenViking", file=sys.stderr) success_count += 1 @@ -338,12 +374,21 @@ def run_import(args: argparse.Namespace) -> None: jsonl_output.flush() continue - combined_msg = "\n\n".join(session["messages"]) - preview = combined_msg.replace("\n", " | ")[:80] - print(f" {preview}...", file=sys.stderr) + # For plain text, all messages as user role + messages = [] + for i, text in enumerate(session["messages"]): + messages.append({ + "role": "user", + "text": text.strip(), + "speaker": "user", + "index": i + }) + + preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]]) + print(f" {preview}", file=sys.stderr) try: - viking_ingest(combined_msg) + viking_ingest(messages) print(f" -> [SUCCESS] imported to OpenViking", file=sys.stderr) success_count += 1 @@ -400,6 +445,7 @@ def run_import(args: argparse.Namespace) -> None: # CLI # --------------------------------------------------------------------------- + def main(): parser = argparse.ArgumentParser(description="Import conversations into OpenViking") parser.add_argument( diff --git a/bot/eval/locomo/run_full_eval.sh b/bot/eval/locomo/run_full_eval.sh new file mode 100755 index 000000000..72d58f739 --- /dev/null +++ b/bot/eval/locomo/run_full_eval.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +set -e + +# Step 1: 导入数据 +echo "[1/4] 导入数据..." +python bot/eval/locomo/import_to_ov.py --input ~/.test_data/locomo10.json --force-ingest + +echo "等待 3 分钟..." +sleep 180 + +# Step 2: 评估 +echo "[2/4] 评估..." +python bot/eval/locomo/run_eval.py ~/.test_data/locomo_qa_1528.csv --output ./result/locomo_result_multi_read_all.csv --threads 20 + +echo "等待 3 分钟..." +sleep 180 + +# Step 3: 裁判打分 +echo "[3/4] 裁判打分..." +python bot/eval/locomo/judge.py --token 0a2b68f6-4df3-48f5-81b9-f85fe0af9cef --input ./result/locomo_result_multi_read_all.csv --parallel 10 + +echo "等待 3 分钟..." +sleep 180 + +# Step 4: 计算结果 +echo "[4/4] 计算结果..." +python bot/eval/locomo/stat_judge_result.py --input ./result/locomo_result_multi_read_all.csv + +echo "完成!" \ No newline at end of file diff --git a/bot/scripts/kill_openviking_server.sh b/bot/scripts/kill_openviking_server.sh new file mode 100755 index 000000000..0900dd14b --- /dev/null +++ b/bot/scripts/kill_openviking_server.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Kill OpenViking Server and vikingbot processes +# Usage: ./kill_openviking_server.sh + +set -e + +echo "==========================================" +echo "Stopping OpenViking processes" +echo "==========================================" + +# Kill existing vikingbot processes +echo "" +echo "Step 1: Stopping vikingbot processes..." +if pgrep -f "vikingbot.*openapi" > /dev/null 2>&1 || pgrep -f "vikingbot.*gateway" > /dev/null 2>&1; then + pkill -f "vikingbot.*openapi" 2>/dev/null || true + pkill -f "vikingbot.*gateway" 2>/dev/null || true + sleep 2 + echo " ✓ Stopped vikingbot processes" +else + echo " ✓ No vikingbot processes found" +fi + +# Kill existing openviking-server processes +echo "" +echo "Step 2: Stopping openviking-server processes..." +if pgrep -f "openviking-server" > /dev/null 2>&1; then + pkill -f "openviking-server" 2>/dev/null || true + sleep 2 + # Force kill if still running + if pgrep -f "openviking-server" > /dev/null 2>&1; then + echo " Force killing remaining processes..." + pkill -9 -f "openviking-server" 2>/dev/null || true + sleep 1 + fi + echo " ✓ Stopped openviking-server processes" +else + echo " ✓ No openviking-server processes found" +fi + +echo "" +echo "==========================================" +echo "✓ All processes stopped" +echo "==========================================" \ No newline at end of file diff --git a/bot/scripts/test_restart_openviking_server.sh b/bot/scripts/test_restart_openviking_server.sh index be473ae07..ef8a86af3 100755 --- a/bot/scripts/test_restart_openviking_server.sh +++ b/bot/scripts/test_restart_openviking_server.sh @@ -6,7 +6,7 @@ set -e # Default values -PORT="1933" +PORT="1934" BOT_URL="http://localhost:18790" TEST_CONFIG="$HOME/.openviking_test/ov.conf" TEST_DATA_DIR="$HOME/.openviking_test/data" @@ -67,21 +67,25 @@ else echo " ✓ No existing vikingbot processes found" fi -# Step 2: Kill existing openviking-server processes +# Step 2: Kill existing openviking-server on specific port echo "" -echo "Step 2: Stopping existing openviking-server processes..." -if pgrep -f "openviking-server" > /dev/null 2>&1; then - pkill -f "openviking-server" 2>/dev/null || true +echo "Step 2: Stopping openviking-server on port $PORT..." +PID=$(lsof -ti :$PORT 2>/dev/null || true) +if [ -n "$PID" ]; then + echo " Found PID: $PID" + pkill -f "vikingbot.*openapi" 2>/dev/null || true + pkill -f "vikingbot.*gateway" 2>/dev/null || true + kill $PID 2>/dev/null || true sleep 2 # Force kill if still running - if pgrep -f "openviking-server" > /dev/null 2>&1; then - echo " Force killing remaining processes..." - pkill -9 -f "openviking-server" 2>/dev/null || true + if lsof -ti :$PORT > /dev/null 2>&1; then + echo " Force killing..." + kill -9 $PID 2>/dev/null || true sleep 1 fi - echo " ✓ Stopped existing processes" + echo " ✓ Stopped process on port $PORT" else - echo " ✓ No existing processes found" + echo " ✓ No process found on port $PORT" fi # Step 3: Wait for port to be released diff --git a/openviking/async_client.py b/openviking/async_client.py index 98b4c0f17..1e92843c1 100644 --- a/openviking/async_client.py +++ b/openviking/async_client.py @@ -166,6 +166,7 @@ async def add_message( role: str, content: str | None = None, parts: list[dict] | None = None, + created_at: str | None = None, ) -> Dict[str, Any]: """Add a message to a session. @@ -174,12 +175,13 @@ async def add_message( role: Message role ("user" or "assistant") content: Text content (simple mode) parts: Parts array (full Part support: TextPart, ContextPart, ToolPart) + created_at: Message creation time (ISO format string) If both content and parts are provided, parts takes precedence. """ await self._ensure_initialized() return await self._client.add_message( - session_id=session_id, role=role, content=content, parts=parts + session_id=session_id, role=role, content=content, parts=parts, created_at=created_at ) async def commit_session( diff --git a/openviking/client/local.py b/openviking/client/local.py index 85215c1ff..56080dfe2 100644 --- a/openviking/client/local.py +++ b/openviking/client/local.py @@ -384,6 +384,7 @@ async def add_message( role: str, content: Optional[str] = None, parts: Optional[List[Dict[str, Any]]] = None, + created_at: Optional[str] = None, ) -> Dict[str, Any]: """Add a message to a session. @@ -392,9 +393,11 @@ async def add_message( role: Message role ("user" or "assistant") content: Text content (simple mode, backward compatible) parts: Parts array (full Part support mode) + created_at: Message creation time (ISO format string) If both content and parts are provided, parts takes precedence. """ + from datetime import datetime from openviking.message.part import Part, TextPart, part_from_dict session = self._service.sessions.session(self._ctx, session_id) @@ -408,7 +411,15 @@ async def add_message( else: raise ValueError("Either content or parts must be provided") - session.add_message(role, message_parts) + # 解析 created_at + msg_created_at = None + if created_at: + try: + msg_created_at = datetime.fromisoformat(created_at) + except ValueError: + pass + + session.add_message(role, message_parts, created_at=msg_created_at) return { "session_id": session_id, "message_count": len(session.messages), diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index 31f871315..98dbb4c30 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -318,6 +318,7 @@ async def get_completion_async( else: kwargs_messages = [{"role": "user", "content": prompt}] + kwargs = { "model": self.model or "gpt-4o-mini", "messages": kwargs_messages, diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 63c8abfc5..83a1f83d6 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -7,17 +7,45 @@ import json import logging import time +from collections import OrderedDict from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from .openai_vlm import OpenAIVLM from ..base import VLMResponse, ToolCall +# Import run_async for sync-to-async calls +from openviking_cli.utils import run_async + logger = logging.getLogger(__name__) +class LRUCache: + """Simple LRU cache implementation.""" + + def __init__(self, maxsize: int = 100): + self._cache = OrderedDict() + self._maxsize = maxsize + + def get(self, key: str) -> Optional[str]: + if key in self._cache: + self._cache.move_to_end(key) + return self._cache[key] + return None + + def set(self, key: str, value: str) -> None: + if key in self._cache: + self._cache.move_to_end(key) + self._cache[key] = value + if len(self._cache) > self._maxsize: + self._cache.popitem(last=False) + + def clear(self) -> None: + self._cache.clear() + + class VolcEngineVLM(OpenAIVLM): - """VolcEngine VLM backend""" + """VolcEngine VLM backend with prompt caching support.""" def __init__(self, config: Dict[str, Any]): super().__init__(config) @@ -26,12 +54,169 @@ def __init__(self, config: Dict[str, Any]): # Ensure provider type is correct self.provider = "volcengine" + # Prompt caching: message content -> response_id + self._response_cache = LRUCache(maxsize=100) + # VolcEngine-specific defaults if not self.api_base: self.api_base = "https://ark.cn-beijing.volces.com/api/v3" if not self.model: self.model = "doubao-seed-2-0-pro-260215" + def _get_response_id_cache_key(self, messages: List[Dict[str, Any]]) -> str: + """Generate cache key for response_id using simple JSON serialization.""" + # Filter out cache_control from messages for cache key + key_messages = [] + for msg in messages: + filtered = {k: v for k, v in msg.items() if k != "cache_control"} + key_messages.append(filtered) + return json.dumps(key_messages, ensure_ascii=False, sort_keys=True) + + + def _parse_messages_with_breakpoints(self, messages: List[Dict[str, Any]]) -> Tuple[List[List[Dict[str, Any]]], List[Dict[str, Any]]]: + """Parse messages into static segment and dynamic messages. + + Only the content BEFORE the first cache_control becomes the static segment. + All messages after (including the one with cache_control) become dynamic. + """ + # 找到第一个 cache_control 的位置 + first_breakpoint_idx = -1 + for i, msg in enumerate(messages): + if msg.get("cache_control"): + first_breakpoint_idx = i + # print(f'cache_control={msg}') + break + + if first_breakpoint_idx > 0: + # 有 cache_control,取其之前的内容作为 static segment + static_segment = messages[:first_breakpoint_idx+1] + dynamic_messages = messages[first_breakpoint_idx+1:] + static_segments = [static_segment] + print(f'static_segment={len(static_segment)}') + print(f'dynamic_messages={len(dynamic_messages)}') + else: + # 没有 cache_control 或在第一个位置,全部作为 dynamic + static_segments = [] + dynamic_messages = messages + + + return static_segments, dynamic_messages + + async def _get_or_create_from_segments( + self, + segments: List[List[Dict[str, Any]]], + end_idx: int + ) -> Optional[str]: + """递归获取或创建 cache,从长到短尝试。 + + Args: + segments: static 消息分段,每段以 cache_control 结尾 + end_idx: 尝试的前缀长度(包含的 segment 数量) + + Returns: + response_id for the prefix + """ + if end_idx <= 0: + return None + + + def segments_to_messages(segs): + # 拼接前 end_idx 个 segments + msgs = [] + for seg in segs: + msgs.extend(seg) + return msgs + + prefix = segments_to_messages(segments[:end_idx]) + + if end_idx == 1: + response_id = await self._get_or_create_from_messages(prefix) + return response_id + + previous_response_id = await self._get_or_create_from_segments(segments, end_idx - 1) + return await self._get_or_create_from_messages(segments_to_messages(segments[end_idx - 1: end_idx]), previous_response_id=previous_response_id) + + + async def _get_or_create_from_messages(self, messages: List[Dict[str, Any]], previous_response_id=None) -> Optional[str]: + """从头创建新 cache。""" + + # Check cache first + cache_key = self._get_response_id_cache_key(messages) + cached_id = self._response_cache.get(cache_key) + if cached_id is not None: + return cached_id + + client = self.get_async_client() + input_data = self._convert_messages_to_input(messages) + try: + response = await client.responses.create( + model=self.model, + previous_response_id=previous_response_id, + input=input_data, + caching={"type": "enabled", "prefix": True}, + thinking={"type": "disabled"}, + ) + cached_id = response.id + self._response_cache.set(cache_key, cached_id) + return cached_id + except Exception as e: + logger.warning(f"[VolcEngineVLM] Failed to create new cache: {e}") + return None + + + async def responseapi_prefixcache_completion( + self, + static_segments: List[List[Dict[str, Any]]], + dynamic_messages: List[Dict[str, Any]], + response_format: Optional[Dict] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + ) -> Any: + """Use cached response_id for completion with dynamic messages. + + Args: + static_segments: Multiple static segments, each ending with cache_control + dynamic_messages: New messages for this request + response_format: Response format for structured output + tools: Tool definitions + tool_choice: Tool choice setting + """ + # 使用多段缓存获取 response_id + if static_segments: + response_id = await self._get_or_create_from_segments(static_segments, len(static_segments)) + else: + response_id = None + client = self.get_async_client() + input_data = self._convert_messages_to_input(dynamic_messages) + + kwargs = { + "model": self.model, + "input": input_data, + "temperature": self.temperature, + "thinking": {"type": "disabled"}, + } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens + if response_format: + kwargs["text"] = {"format": response_format} + + if response_id: + kwargs["previous_response_id"] = response_id + kwargs["caching"] = {"type": "enabled"} + elif tools: + # First call with tools: enable caching + converted_tools = self._convert_tools(tools) + kwargs["tools"] = converted_tools + kwargs["tool_choice"] = tool_choice or "auto" + kwargs["caching"] = {"type": "enabled"} + else: + # Enable caching by default + kwargs["caching"] = {"type": "enabled"} + + response = await client.responses.create(**kwargs) + return response + + def get_client(self): """Get sync client""" if self._sync_client is None: @@ -62,47 +247,130 @@ def get_async_client(self): ) return self._async_client - def _parse_tool_calls(self, message) -> List[ToolCall]: - """Parse tool calls from VolcEngine response message.""" - tool_calls = [] - if hasattr(message, "tool_calls") and message.tool_calls: - for tc in message.tool_calls: - args = tc.function.arguments - if isinstance(args, str): - try: - args = json.loads(args) - except json.JSONDecodeError: - args = {"raw": args} - tool_calls.append(ToolCall( - id=tc.id, - name=tc.function.name, - arguments=args - )) - return tool_calls + def _update_token_usage_from_response( + self, response, duration_seconds: float = 0.0, + ) -> None: + """Update token usage from VolcEngine Responses API response.""" + if hasattr(response, "usage") and response.usage: + u = response.usage + # Responses API uses input_tokens/output_tokens instead of prompt_tokens/completion_tokens + prompt_tokens = getattr(u, 'input_tokens', 0) or 0 + completion_tokens = getattr(u, 'output_tokens', 0) or 0 + self.update_token_usage( + model_name=self.model or "unknown", + provider=self.provider, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + duration_seconds=duration_seconds, + ) + return def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMResponse]: - """Build response from VolcEngine response. Returns str or VLMResponse based on has_tools.""" - choice = response.choices[0] - message = choice.message + """Build response from VolcEngine Responses API response. - if has_tools: - usage = {} - if hasattr(response, "usage") and response.usage: - usage = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - "prompt_tokens_details": getattr(response.usage, "prompt_tokens_details", None), + Responses API returns: + - response.output: list of output items + - response.id: response ID + - response.usage: token usage + """ + # Debug: print response structure + #logger.debug(f"[VolcEngineVLM] Response type: {type(response)}") + # logger.info(f"[VolcEngineVLM] Full response: {response}") + if hasattr(response, 'output'): + # logger.debug(f"[VolcEngineVLM] Output items: {len(response.output)}") + for i, item in enumerate(response.output): + # logger.debug(f"[VolcEngineVLM] Item {i}: type={getattr(item, 'type', 'unknown')}") + # Print full item for debugging + # logger.info(f"[VolcEngineVLM] Item {i} full: {item}") + pass + + # Extract content from Responses API format + content = "" + tool_calls = [] + finish_reason = "stop" + + if hasattr(response, 'output') and response.output: + for item in response.output: + item_type = getattr(item, 'type', None) + # Check if it's a function_call item (Responses API format) + if item_type == 'function_call': + # logger.debug(f"[VolcEngineVLM] Found function_call tool call") + args = item.arguments + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {"raw": args} + tool_calls.append(ToolCall( + id=item.call_id or "", + name=item.name or "", + arguments=args + )) + finish_reason = "tool_calls" + # Check if it's a message item (Chat API compatibility) + elif item_type == 'message': + message = item + if hasattr(message, 'content'): + # Content can be a list or string + if isinstance(message.content, list): + for block in message.content: + if hasattr(block, 'type') and block.type == 'output_text': + content = block.text or "" + elif hasattr(block, 'text'): + content = block.text or "" + else: + content = message.content or "" + + # Parse tool calls from message + if hasattr(message, 'tool_calls') and message.tool_calls: + # logger.debug(f"[VolcEngineVLM] Found {len(message.tool_calls)} tool calls in message") + for tc in message.tool_calls: + args = tc.arguments + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {"raw": args} + # Handle both tc.name and tc.function.name (Responses API vs Chat API) + try: + tool_name = tc.name + if not tool_name: + tool_name = tc.function.name + except AttributeError: + tool_name = tc.function.name if hasattr(tc, 'function') else "" + tool_calls.append(ToolCall( + id=tc.id or "", + name=tool_name or "", + arguments=args + )) + + finish_reason = getattr(message, 'finish_reason', 'stop') or 'stop' + + # Extract usage + usage = {} + if hasattr(response, 'usage') and response.usage: + u = response.usage + usage = { + "prompt_tokens": getattr(u, 'input_tokens', 0), + "completion_tokens": getattr(u, 'output_tokens', 0), + "total_tokens": getattr(u, 'total_tokens', 0), + } + # Handle cached tokens + input_details = getattr(u, 'input_tokens_details', None) + if input_details: + usage["prompt_tokens_details"] = { + "cached_tokens": getattr(input_details, 'cached_tokens', 0), } + if has_tools: return VLMResponse( - content=message.content, - tool_calls=self._parse_tool_calls(message), - finish_reason=choice.finish_reason or "stop", + content=content, + tool_calls=tool_calls, + finish_reason=finish_reason, usage=usage, ) else: - return message.content or "" + return content def get_completion( self, @@ -112,31 +380,147 @@ def get_completion( tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get text completion""" - client = self.get_client() - if messages: - kwargs_messages = messages - else: - kwargs_messages = [{"role": "user", "content": prompt}] - - kwargs = { - "model": self.model or "doubao-seed-2-0-pro-260215", - "messages": kwargs_messages, - "temperature": self.temperature, - "thinking": {"type": "disabled" if not thinking else "enabled"}, - } - if self.max_tokens is not None: - kwargs["max_tokens"] = self.max_tokens + """Get text completion with prompt caching support. - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = tool_choice or "auto" - - t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) + Uses VolcEngine Responses API with prefix cache. + Delegates to async implementation. + """ + return run_async(self.get_completion_async( + prompt=prompt, + thinking=thinking, + tools=tools, + tool_choice=tool_choice, + messages=messages, + )) + + def _convert_messages_to_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert OpenAI-style messages to VolcEngine Responses API input format. + + VolcEngine Responses API format (no "type" field needed): + [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + ] + + Note: Responses API doesn't support 'tool' role, so we convert tool results + to user messages with a prefix indicating it's a tool result. + """ + input_messages = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + # Handle tool_call role with content as dict {name, args, result} + if role == "tool_call" and isinstance(content, dict): + import json + content_str = json.dumps(content, ensure_ascii=False) + role = "user" # Convert tool_call to user + else: + # Handle content - check if it contains images + has_images = False + if isinstance(content, list): + text_parts = [] + image_urls = [] + for block in content: + if isinstance(block, dict): + block_type = block.get("type", "") + # Handle text blocks + if block_type == "text" or "text" in block: + text = block.get("text", "") + if text: + text_parts.append(text) + # Handle image_url blocks + elif block_type == "image_url" or "image_url" in block: + image_url = block.get("image_url", {}) + if isinstance(image_url, dict): + url = image_url.get("url", "") + if url: + image_urls.append(url) + has_images = True + # Handle other block types + else: + # Try to extract text from any dict block + text = block.get("text", "") + if text: + text_parts.append(text) + content = " ".join(text_parts) + # If there were images, include them as base64 data URLs in content + if image_urls: + # Filter out non-data URLs (keep only data: URLs) + data_urls = [u for u in image_urls if u.startswith("data:")] + if data_urls: + # Append image references to content + content = content + "\n[Images: " + ", ".join([f"data URL ({i+1})" for i in range(len(data_urls))]) + "]" + + # Ensure content is a string, use placeholder if empty + content_str = str(content) if content else "[empty]" + # Skip messages with empty content (API requirement) + if not content_str or content_str == "[empty]": + continue + + # Handle role conversion + # Responses API supports: system, user, assistant + # Convert 'tool' role to user with prefix (preserve the tool result context) + if role == "tool": + # Prefix with tool result indicator + content_str = f"[Tool Result]\n{content_str}" + role = "user" + + # Simple format: role + content (no type field) + input_messages.append({ + "role": role, + "content": content_str, + }) + + return input_messages + + def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert OpenAI-style tool format to VolcEngine Responses API format. + + OpenAI format: {"type": "function", "function": {"name": ..., "parameters": ...}} + VolcEngine format: {"type": "function", "name": ..., "description": ..., "parameters": ...} + + Note: VolcEngine Responses API requires "type": "function" and name at top level. + """ + converted = [] + for tool in tools: + if not isinstance(tool, dict): + converted.append(tool) + continue + + # Check if it's OpenAI format: {"type": "function", "function": {...}} + if tool.get("type") == "function" and "function" in tool: + func = tool["function"] + converted.append({ + "type": "function", # Keep the type field + "name": func.get("name", ""), + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + }) + elif "function" in tool: + # Has function but no type + func = tool["function"] + converted.append({ + "type": "function", + "name": func.get("name", ""), + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + }) + else: + # Already in correct format or other format + # Ensure it has type: function + if tool.get("type") != "function": + converted.append({ + "type": "function", + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + }) + else: + # Keep as is + converted.append(tool) + + return converted async def get_completion_async( self, @@ -147,45 +531,47 @@ async def get_completion_async( tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get text completion asynchronously""" - client = self.get_async_client() + """Get text completion with prompt caching support. + + Uses VolcEngine Responses API with prefix cache. + Separates messages into static (cached) and dynamic parts. + """ if messages: kwargs_messages = messages else: kwargs_messages = [{"role": "user", "content": prompt}] - kwargs = { - "model": self.model or "doubao-seed-2-0-pro-260215", - "messages": kwargs_messages, - "temperature": self.temperature, - "thinking": {"type": "disabled" if not thinking else "enabled"}, - } - if self.max_tokens is not None: - kwargs["max_tokens"] = self.max_tokens - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = tool_choice or "auto" - - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response( - response, duration_seconds=elapsed, - ) - return self._build_vlm_response(response, has_tools=bool(tools)) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: + # Parse messages into multiple static segments and dynamic messages + # Each segment ends with cache_control, dynamic is the rest + static_segments, dynamic_messages = self._parse_messages_with_breakpoints(kwargs_messages) + + # If we have static segments, try prefix cache + response_format = None # Can be extended for structured output + + try: + # Use prefix cache with multiple segments + response = await self.responseapi_prefixcache_completion( + static_segments=static_segments, + dynamic_messages=dynamic_messages, + response_format=response_format, + tools=tools, + tool_choice=tool_choice, + ) + elapsed = 0 # Timing handled in responseapi methods + self._update_token_usage_from_response(response, duration_seconds=elapsed) + return self._build_vlm_response(response, has_tools=bool(tools)) + + except Exception as e: + last_error = e + # Log token info from error response if available + error_response = getattr(e, 'response', None) + if error_response and hasattr(error_response, 'usage'): + u = error_response.usage + prompt_tokens = getattr(u, 'input_tokens', 0) or 0 + completion_tokens = getattr(u, 'output_tokens', 0) or 0 + logger.info(f"[VolcEngineVLM] Error response - Input tokens: {prompt_tokens}, Output tokens: {completion_tokens}") + logger.warning(f"[VolcEngineVLM] Request failed: {e}") raise last_error - else: - raise RuntimeError("Unknown error in async completion") def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. @@ -197,7 +583,7 @@ def _detect_image_format(self, data: bytes) -> str: - JPEG, PNG, GIF, WEBP, BMP, TIFF, ICO, DIB, ICNS, SGI, JPEG2000, HEIC, HEIF """ if len(data) < 12: - logger.warning(f"[VolcEngineVLM] Image data too small: {len(data)} bytes") + # logger.warning(f"[VolcEngineVLM] Image data too small: {len(data)} bytes") return "image/png" # PNG: 89 50 4E 47 0D 0A 1A 0A @@ -250,7 +636,7 @@ def _detect_image_format(self, data: bytes) -> str: ) # Unknown format - log and default to PNG - logger.warning(f"[VolcEngineVLM] Unknown image format, magic bytes: {data[:16].hex()}") + # logger.warning(f"[VolcEngineVLM] Unknown image format, magic bytes: {data[:16].hex()}") return "image/png" def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: @@ -258,9 +644,9 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: if isinstance(image, bytes): b64 = base64.b64encode(image).decode("utf-8") mime_type = self._detect_image_format(image) - logger.info( - f"[VolcEngineVLM] Preparing image from bytes, size={len(image)}, detected mime={mime_type}" - ) + # logger.info( + # f"[VolcEngineVLM] Preparing image from bytes, size={len(image)}, detected mime={mime_type}" + # ) return { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, @@ -309,38 +695,18 @@ def get_vision_completion( tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get vision completion""" - client = self.get_client() + """Get vision completion with prompt caching support. - if messages: - kwargs_messages = messages - else: - content = [] - if images: - for img in images: - content.append(self._prepare_image(img)) - if prompt: - content.append({"type": "text", "text": prompt}) - kwargs_messages = [{"role": "user", "content": content}] - - kwargs = { - "model": self.model or "doubao-seed-2-0-pro-260215", - "messages": kwargs_messages, - "temperature": self.temperature, - "thinking": {"type": "disabled" if not thinking else "enabled"}, - } - if self.max_tokens is not None: - kwargs["max_tokens"] = self.max_tokens - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = "auto" - - t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) + Uses VolcEngine Responses API with prefix cache. + Delegates to async implementation. + """ + return run_async(self.get_vision_completion_async( + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, + )) async def get_vision_completion_async( self, @@ -350,9 +716,10 @@ async def get_vision_completion_async( tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get vision completion asynchronously""" - client = self.get_async_client() + """Get vision completion with prompt caching support. + Uses VolcEngine Responses API with prefix cache. + """ if messages: kwargs_messages = messages else: @@ -364,21 +731,10 @@ async def get_vision_completion_async( content.append({"type": "text", "text": prompt}) kwargs_messages = [{"role": "user", "content": content}] - kwargs = { - "model": self.model or "doubao-seed-2-0-pro-260215", - "messages": kwargs_messages, - "temperature": self.temperature, - "thinking": {"type": "disabled" if not thinking else "enabled"}, - } - if self.max_tokens is not None: - kwargs["max_tokens"] = self.max_tokens - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = "auto" - - t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) + # 复用 get_completion_async 的逻辑 + return await self.get_completion_async( + prompt=prompt, + thinking=thinking, + tools=tools, + messages=kwargs_messages, + ) diff --git a/openviking/prompts/templates/memory/cases.yaml b/openviking/prompts/templates/memory/cases.yaml index 2831da6b0..50d49af97 100644 --- a/openviking/prompts/templates/memory/cases.yaml +++ b/openviking/prompts/templates/memory/cases.yaml @@ -7,7 +7,7 @@ description: | Case names should be in "Problem → Solution" format to make them easily searchable. directory: "viking://agent/{agent_space}/memories/cases" filename_template: "{case_name}.md" -enabled: true +enabled: false fields: - name: case_name type: string diff --git a/openviking/prompts/templates/memory/entities.yaml b/openviking/prompts/templates/memory/entities.yaml index 7065c5459..3b2f693a9 100644 --- a/openviking/prompts/templates/memory/entities.yaml +++ b/openviking/prompts/templates/memory/entities.yaml @@ -14,17 +14,11 @@ fields: type: string description: | # Content - - Entity card name in English, lowercase with underscores, max 3 words - - # Format Requirements - ## Entity Cards - - In triage scenarios, each card represents a department or symptom - - Entities should be as granular as possible - split combined symptoms into individual entities - + - Entity name in Chinese or English. If English, use lowercase with underscores, max 3 words. Do not include any dates. + ### Good Examples emergency_department cough_symptom - daily_20260110 ### Bad Examples progressive_memory_loss_with_personality_change // Too long, max 3 words diff --git a/openviking/prompts/templates/memory/events.yaml b/openviking/prompts/templates/memory/events.yaml index 8e1b04c9f..11bc6289f 100644 --- a/openviking/prompts/templates/memory/events.yaml +++ b/openviking/prompts/templates/memory/events.yaml @@ -1,30 +1,54 @@ memory_type: events description: | - Event memory - captures "what happened, what decision was made, and why". - Extract notable events, decisions, milestones, and turning points from the conversation. - Events should be things worth remembering for future context: decisions made, agreements reached, milestones achieved, problems solved, etc. - Each event should include: what happened, why it happened, what the outcome was, and any relevant context/timeline. - Use absolute dates for event_time, not relative time like "today" or "recently". + # Guidelines + - The summary must include information records related to the entities involved, with attention to detail, especially emphasizing commitments, agreements, or proposals that may be referenced in the future (for example: Xiaoming once asked, "Do you need me to help you get a membership?" and Xiaosen replied, "Let's talk about it later"), and nothing should be omitted. If there is a time expression (such as "yesterday"), the date must be converted to a specific year, month, and day format according to when the event occurred. + - Include all facts as comprehensively as possible, especially requests made by specific roles and mark major events. + - Convert dialogue into indirect speech, covering emotions and speaker characteristics, recording the user's emotional state, conversation content, emotional feedback (such as happy, sad, curious, etc.), and the assistant's response. + - Create a coherent narrative, retaining key dramatic elements. + - Use a third-person perspective. + - If possible, combine the user's current behavior and reactions to speculate on the user's possible thoughts or actions. + - Describe the complete content of an event within a single event as much as possible; do not split one event into multiple parts. directory: "viking://user/{user_space}/memories/events" filename_template: "{event_time}_{event_name}.md" enabled: true +# 操作模式:add_only 表示只新增记忆,不需要查看之前的记忆列表 +# upsert 表示新增或更新(默认行为) +operation_mode: "add_only" +content_template: | + {% set msg_range = extract_context.read_message_ranges(ranges|default('')) if extract_context else None %} + {% set first_time = msg_range.first_message_time() if msg_range else None %} + time: {{event_time|default(first_time if first_time else 'N/A')}} + {% if extract_context %} + {{ msg_range.pretty_print() if msg_range else '' }} + {% endif %} + + fields: - name: event_name type: string description: | - Event name in Chinese or English. If English, use lowercase with underscores, max 3 words. + Event name in Chinese or English. If English, use lowercase with underscores, max 3 words. Do not include any dates. merge_op: immutable + - name: goal + type: string + description: | + Summarize the purpose of the content in 5 words or less + + - name: summary + type: string + description: | + Based on the content of the above fields, compile a description to outline the complete Fact content, in English + - name: event_time type: string description: | - Time when the event occurred, format “2026-03-17”. If unknown, use current time. + Time when the event occurred, format “2026-03-17” / “2026-03” / “2026”. If unknown, use current time. merge_op: immutable - - name: content + - name: ranges type: string description: | - Event content in Markdown format, describing “what happened”. - Includes: decision content, reasons, outcomes, context, timeline, etc. - Example: “In memory system design discussion, found original 6 category boundaries were unclear. Especially status, lessons learned, and insights often overlapped, making them difficult to distinguish. Decided to refactor to 5 categories, removing these three to make boundaries clearer.” - merge_op: patch + Conversation message index ranges to extract, format: "start-end,start-end,..." + Example: "0-10,50-60" means extract messages 0-10 and 50-60. + merge_op: immutable diff --git a/openviking/prompts/templates/memory/patterns.yaml b/openviking/prompts/templates/memory/patterns.yaml index 0232f1ae2..6b6ba73b4 100644 --- a/openviking/prompts/templates/memory/patterns.yaml +++ b/openviking/prompts/templates/memory/patterns.yaml @@ -7,7 +7,7 @@ description: | Pattern names should be in "Process name: Step description" format. directory: "viking://agent/{agent_space}/memories/patterns" filename_template: "{pattern_name}.md" -enabled: true +enabled: false fields: - name: pattern_name type: string diff --git a/openviking/prompts/templates/memory/preferences.yaml b/openviking/prompts/templates/memory/preferences.yaml index 15b10a7fa..98ec6c7e6 100644 --- a/openviking/prompts/templates/memory/preferences.yaml +++ b/openviking/prompts/templates/memory/preferences.yaml @@ -6,9 +6,15 @@ description: | Topics can be: code style, communication style, tools, workflow, food, commute, etc. Store different topics as separate memory files, do NOT mix unrelated preferences. directory: "viking://user/{user_space}/memories/preferences" -filename_template: "{topic}.md" +filename_template: "{user}_{topic}.md" enabled: true fields: + - name: user + type: string + description: | + Username for the preference owner + merge_op: immutable + - name: topic type: string description: | diff --git a/openviking/prompts/templates/memory/skills.yaml b/openviking/prompts/templates/memory/skills.yaml index 5a286a332..8d7a9d862 100644 --- a/openviking/prompts/templates/memory/skills.yaml +++ b/openviking/prompts/templates/memory/skills.yaml @@ -1,28 +1,21 @@ memory_type: skills description: | - Skill execution memory - captures "how this skill is executed, what works well, and what doesn't". - Extract skill execution patterns, statistics, and learnings from skill usage in conversation. - For each skill, track: how many times it's been executed, success rate, what it's best for, recommended execution flow, key dependencies, common failure modes, and actionable recommendations. - Also accumulate complete guidelines with "Good Cases" and "Bad Cases" examples. - Skill memories help the agent learn from experience and execute skills more effectively over time. + Record all skills uses, directory: "viking://agent/{agent_space}/memories/skills" filename_template: "{skill_name}.md" - +enabled: true content_template: | - Skill: {skill_name} - - Skill Memory Context: - Based on {total_executions} historical executions: - - Success rate: {success_rate}% ({success_count} successful, {fail_count} failed) - - Best for: {best_for} - - Recommended flow: {recommended_flow} - - Key dependencies: {key_dependencies} - - Common failures: {common_failures} - - Recommendation: {recommendation} + Skill: {{ skill_name }} + + - Success rate: {{ ((success_count|default(0) / (total_executions|default(1) if total_executions|default(0) > 0 else 1)) * 100)|round|int }}% ({{ success_count|default(0) }}/{{ total_executions|default(0) }}) + - Best for: {{ best_for|default('N/A') }} + - Recommended flow: {{ recommended_flow|default('N/A') }} + - Key dependencies: {{ key_dependencies|default('N/A') }} + - Common failures: {{ common_failures|default('N/A') }} + - Recommendation: {{ recommendation|default('N/A') }} - {guidelines} + {{ guidelines|default('') }} -enabled: true fields: - name: skill_name type: string @@ -64,7 +57,7 @@ fields: type: string description: | Recommended execution flow for the skill, describing the best steps to execute this skill. - Examples: "1. Confirm topic and audience → 2. Collect reference materials → 3. Generate outline → 4. Create slides → 5. Refine content" + Examples: "1. Confirm topic and audience -> 2. Collect reference materials -> 3. Generate outline -> 4. Create slides -> 5. Refine content" merge_op: patch - name: key_dependencies @@ -97,4 +90,4 @@ fields: - "### Good Cases" - successful usage examples - "### Bad Cases" - failed usage examples Headings must be in English, content can be in target language. - merge_op: patch + merge_op: patch \ No newline at end of file diff --git a/openviking/prompts/templates/memory/tools.yaml b/openviking/prompts/templates/memory/tools.yaml index eeec88d2c..ec6e985cf 100644 --- a/openviking/prompts/templates/memory/tools.yaml +++ b/openviking/prompts/templates/memory/tools.yaml @@ -1,29 +1,22 @@ memory_type: tools description: | - Tool usage memory - captures "how this tool is used, what works well, and what doesn't". - Extract tool usage patterns, statistics, and learnings from [ToolCall] records and conversation context. - For each tool, track: how many times it's been called, success rate, average time/tokens, what it's best for, optimal parameters, common failure modes, and actionable recommendations. - Also accumulate complete guidelines with "Good Cases" and "Bad Cases" examples. - Tool memories help the agent learn from experience and use tools more effectively over time. + Record all tool calls directory: "viking://agent/{agent_space}/memories/tools" filename_template: "{tool_name}.md" enabled: true content_template: | - Tool: {tool_name} + Tool: {{ tool_name }} Static Description: - "{static_desc}" + "{{ static_desc|default('N/A') }}" - Tool Memory Context: - Based on {total_calls} historical calls: - - Success rate: {success_rate}% ({success_count} successful, {fail_count} failed) - - Avg time: {total_time_ms/total_calls}, Avg tokens: {total_tokens/total_calls} - - Best for: {best_for} - - Optimal params: {optimal_params} - - Common failures: {common_failures} - - Recommendation: {recommendation} + - Success rate: {{ ((success_time|default(0) / (call_count|default(1) if call_count|default(0) > 0 else 1)) * 100)|round|int }}% ({{ success_time|default(0) }}/{{ call_count|default(0) }}) + - Best for: {{ best_for|default('N/A') }} + - Optimal params: {{ optimal_params|default('N/A') }} + - Common failures: {{ common_failures|default('N/A') }} + - Recommendation: {{ recommendation|default('N/A') }} - {guidelines} + {{ guidelines|default('') }} fields: - name: tool_name @@ -40,41 +33,20 @@ fields: Static description of the tool, basic functionality description. Examples: "Searches the web for information", "Reads files from the file system" - - name: total_calls + - name: call_count type: int64 description: | Total number of tool calls, accumulated from historical statistics. Used to calculate success rate and average duration. merge_op: sum - - name: success_count + - name: success_time type: int64 description: | Number of successful tool calls, accumulated from historical statistics. Counts calls with status "completed". merge_op: sum - - name: fail_count - type: int64 - description: | - Number of failed tool calls, accumulated from historical statistics. - Counts calls with status not "completed". - merge_op: sum - - - name: total_time_ms - type: int64 - description: | - Total tool call duration in milliseconds, accumulated from historical statistics. - Used to calculate average duration. - merge_op: sum - - - name: total_tokens - type: int64 - description: | - Total tokens used by tool calls (prompt tokens + completion tokens), accumulated from historical statistics. - Used to calculate average token consumption. - merge_op: sum - - name: best_for type: string description: | @@ -113,5 +85,4 @@ fields: - "## Guidelines" - best practices - "### Good Cases" - successful usage examples - "### Bad Cases" - failed usage examples - Headings must be in English, content can be in target language. merge_op: patch diff --git a/openviking/server/identity.py b/openviking/server/identity.py index 74a95e904..e39cb6010 100644 --- a/openviking/server/identity.py +++ b/openviking/server/identity.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional +from typing import List, Optional, Any from openviking_cli.session.user_id import UserIdentifier @@ -31,8 +31,28 @@ class RequestContext: user: UserIdentifier role: Role - default_search_uris: List[str] = field(default_factory=list) @property def account_id(self) -> str: return self.user.account_id + + +@dataclass +class ToolContext: + """Tool-level context, containing request context and additional tool-specific information.""" + + request_ctx: RequestContext + default_search_uris: List[str] = field(default_factory=list) + transaction_handle: Optional[Any] = None + + @property + def user(self): + return self.request_ctx.user + + @property + def role(self): + return self.request_ctx.role + + @property + def account_id(self) -> str: + return self.request_ctx.user.account_id diff --git a/openviking/server/routers/content.py b/openviking/server/routers/content.py index 0231417e8..2b8beb44f 100644 --- a/openviking/server/routers/content.py +++ b/openviking/server/routers/content.py @@ -41,6 +41,19 @@ async def read( """Read file content (L2).""" service = get_service() result = await service.fs.read(uri, ctx=_ctx, offset=offset, limit=limit) + + # 清理MEMORY_FIELDS隐藏注释(v2记忆加工过程中的临时内部数据,不暴露给外部用户) + if isinstance(result, bytes): + text = result.decode("utf-8") + elif isinstance(result, str): + text = result + else: + text = None + + if text: + from openviking.session.memory.utils.content import deserialize_content + result = deserialize_content(text) + return Response(status="ok", result=result) diff --git a/openviking/server/routers/sessions.py b/openviking/server/routers/sessions.py index 0977685bf..ee6a898c5 100644 --- a/openviking/server/routers/sessions.py +++ b/openviking/server/routers/sessions.py @@ -3,6 +3,7 @@ """Sessions endpoints for OpenViking HTTP Server.""" import logging +from datetime import datetime from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, Path, Query @@ -63,6 +64,7 @@ class AddMessageRequest(BaseModel): role: str content: Optional[str] = None parts: Optional[List[Dict[str, Any]]] = None + created_at: Optional[str] = None @model_validator(mode="after") def validate_content_or_parts(self) -> "AddMessageRequest": @@ -245,7 +247,15 @@ async def add_message( else: parts = [TextPart(text=request.content or "")] - session.add_message(request.role, parts) + # 解析 created_at + created_at = None + if request.created_at: + try: + created_at = datetime.fromisoformat(request.created_at) + except ValueError: + logger.warning(f"Invalid created_at format: {request.created_at}") + + session.add_message(request.role, parts, created_at=created_at) return Response( status="ok", result={ diff --git a/openviking/session/compressor_v2.py b/openviking/session/compressor_v2.py index 027fb4fcb..079ed84b5 100644 --- a/openviking/session/compressor_v2.py +++ b/openviking/session/compressor_v2.py @@ -8,33 +8,22 @@ """ import os -from dataclasses import dataclass from typing import List, Optional from openviking.core.context import Context from openviking.message import Message from openviking.server.identity import RequestContext +from openviking.session.memory import ExtractLoop, MemoryUpdater from openviking.storage import VikingDBManager from openviking.storage.viking_fs import get_viking_fs +from openviking.telemetry import get_current_telemetry from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger from openviking_cli.utils.config import get_openviking_config -from openviking.session.memory import MemoryReAct, MemoryUpdater, MemoryTypeRegistry - logger = get_logger(__name__) -@dataclass -class ExtractionStats: - """Statistics for memory extraction.""" - - created: int = 0 - merged: int = 0 - deleted: int = 0 - skipped: int = 0 - - class SessionCompressorV2: """Session memory extractor with v2 templating system.""" @@ -44,18 +33,17 @@ def __init__( ): """Initialize session compressor.""" self.vikingdb = vikingdb - # Initialize registry once - used by both MemoryReAct and MemoryUpdater - self._registry = MemoryTypeRegistry() - schemas_dir = os.path.join( - os.path.dirname(__file__), "..", "prompts", "templates", "memory" - ) - self._registry.load_from_directory(schemas_dir) - # Lazy initialize MemoryReAct - we need vlm and ctx - self._react_orchestrator: Optional[MemoryReAct] = None - self._memory_updater: Optional[MemoryUpdater] = None + # registry 现在由 provider 负责加载,这里不再初始化 + # MemoryUpdater 会在 apply_operations 时从 provider 获取 registry + pass - def _get_or_create_react(self, ctx: Optional[RequestContext] = None) -> MemoryReAct: - """Create new MemoryReAct instance with current ctx. + def _get_or_create_react( + self, + ctx: Optional[RequestContext] = None, + messages: Optional[List] = None, + latest_archive_overview: str = "", + ) -> ExtractLoop: + """Create new ExtractLoop instance with current ctx. Note: Always create new instance to avoid cross-session isolation issues. The ctx contains request-scoped state that must not be shared across requests. @@ -64,20 +52,30 @@ def _get_or_create_react(self, ctx: Optional[RequestContext] = None) -> MemoryRe vlm = config.vlm.get_vlm_instance() viking_fs = get_viking_fs() - return MemoryReAct( + # Create context provider with messages (provider 负责加载 schema) + from openviking.session.memory.session_extract_context_provider import SessionExtractContextProvider + context_provider = SessionExtractContextProvider( + messages=messages, + latest_archive_overview=latest_archive_overview, + ) + + return ExtractLoop( vlm=vlm, viking_fs=viking_fs, ctx=ctx, - registry=self._registry, + context_provider=context_provider, ) - def _get_or_create_updater(self) -> MemoryUpdater: - """Get or create MemoryUpdater instance.""" - if self._memory_updater is not None: - return self._memory_updater + def _get_or_create_updater(self, registry, transaction_handle=None) -> MemoryUpdater: + """Create new MemoryUpdater instance for each request. - self._memory_updater = MemoryUpdater(registry=self._registry, vikingdb=self.vikingdb) - return self._memory_updater + Always create new instance to avoid cross-request state pollution. + """ + return MemoryUpdater( + registry=registry, + vikingdb=self.vikingdb, + transaction_handle=transaction_handle + ) async def extract_long_term_memories( self, @@ -100,38 +98,102 @@ async def extract_long_term_memories( logger.warning("No RequestContext provided, skipping memory extraction") return [] - # Provide the latest completed archive overview as non-actionable history context. - conversation_sections: List[str] = [] - if latest_archive_overview: - conversation_sections.append(f"## Previous Archive Overview\n{latest_archive_overview}") + logger.info("Starting v2 memory extraction from conversation") - conversation_sections.append( - "\n".join([f"[{msg.role}]: {msg.content}" for msg in messages]) - ) - conversation_str = "\n\n".join(section for section in conversation_sections if section) + # Initialize telemetry to 0 (matching v1 pattern) + telemetry = get_current_telemetry() + telemetry.set("memory.extract.candidates.total", 0) + telemetry.set("memory.extract.candidates.standard", 0) + telemetry.set("memory.extract.candidates.tool_skill", 0) + telemetry.set("memory.extract.created", 0) + telemetry.set("memory.extract.merged", 0) + telemetry.set("memory.extract.deleted", 0) + telemetry.set("memory.extract.skipped", 0) + + from openviking.storage.transaction import get_lock_manager, init_lock_manager + from openviking.storage.viking_fs import get_viking_fs + + # 初始化锁管理器(仅在有 AGFS 时使用锁机制) + viking_fs = get_viking_fs() + lock_manager = None + transaction_handle = None + if viking_fs and hasattr(viking_fs, 'agfs') and viking_fs.agfs: + init_lock_manager(viking_fs.agfs) + lock_manager = get_lock_manager() + transaction_handle = lock_manager.create_handle() + else: + logger.warning("VikingFS or AGFS not available, running without lock mechanism") - logger.info("Starting v2 memory extraction from conversation") try: - # Initialize orchestrator - orchestrator = self._get_or_create_react(ctx=ctx) - updater = self._get_or_create_updater() + # 获取所有记忆 schema 目录并加锁(仅在有锁管理器时) + orchestrator = self._get_or_create_react( + ctx=ctx, + messages=messages, + latest_archive_overview=latest_archive_overview, + ) + if lock_manager: + # 基于 provider 的 schemas 生成目录列表 + schemas = orchestrator.context_provider.get_memory_schemas(ctx) + memory_schema_dirs = [] + for schema in schemas: + if not schema.directory: + continue + user_space = ctx.user.user_space_name() if ctx and ctx.user else "default" + agent_space = ctx.user.agent_space_name() if ctx and ctx.user else "default" + dir_path = schema.directory.replace("{user_space}", user_space).replace("{agent_space}", agent_space) + dir_path = viking_fs._uri_to_path(dir_path, ctx) + if dir_path not in memory_schema_dirs: + memory_schema_dirs.append(dir_path) + logger.debug(f"Memory schema directories to lock: {memory_schema_dirs}") + + # 循环等待获取锁(机制确保不会死锁) + # 由于使用有序加锁法,可以安全地无限等待 + while True: + lock_acquired = await lock_manager.acquire_subtree_batch( + transaction_handle, + memory_schema_dirs, + timeout=None, + ) + if lock_acquired: + break + logger.warning("Failed to acquire memory locks, retrying...") + + orchestrator._transaction_handle = transaction_handle # 传递给 ExtractLoop # Run ReAct orchestrator - operations, tools_used = await orchestrator.run(conversation=conversation_str) + operations, tools_used = await orchestrator.run() if operations is None: logger.info("No memory operations generated") return [] + # Convert to legacy format for logging and apply_operations + if hasattr(operations, 'to_legacy_operations'): + legacy = operations.to_legacy_operations() + write_uris = legacy.get('write_uris', []) + edit_uris = legacy.get('edit_uris', []) + else: + # Fallback for old format + write_uris = operations.write_uris + edit_uris = operations.edit_uris + + # 从 orchestrator 获取 registry(从 provider 获取) + registry = orchestrator.context_provider._get_registry() + updater = self._get_or_create_updater(registry, transaction_handle) + logger.info( - f"Generated memory operations: write={len(operations.write_uris)}, " - f"edit={len(operations.edit_uris)}, edit_overview={len(operations.edit_overview_uris)}, " + f"Generated memory operations: write={len(write_uris)}, " + f"edit={len(edit_uris)}, edit_overview={len(operations.edit_overview_uris)}, " f"delete={len(operations.delete_uris)}" ) + # Create extract context from messages + from openviking.session.memory.memory_updater import ExtractContext + extract_context = ExtractContext(messages) + # Apply operations - result = await updater.apply_operations(operations, ctx, registry=orchestrator.registry) + result = await updater.apply_operations(operations, ctx, registry=registry, extract_context=extract_context) logger.info( f"Applied memory operations: written={len(result.written_uris)}, " @@ -139,15 +201,52 @@ async def extract_long_term_memories( f"errors={len(result.errors)}" ) - # Return list with dummy values to preserve count for stats in session.py - # v2 directly writes to storage, so we return None objects to maintain len() accuracy - total_changes = ( - len(result.written_uris) + len(result.edited_uris) + len(result.deleted_uris) - ) - return [None] * total_changes + # Report telemetry stats (matching v1 pattern) + telemetry = get_current_telemetry() + telemetry.set("memory.extract.candidates.total", len(result.written_uris) + len(result.edited_uris)) + telemetry.set("memory.extract.created", len(result.written_uris)) + telemetry.set("memory.extract.merged", len(result.edited_uris)) + telemetry.set("memory.extract.deleted", len(result.deleted_uris)) + telemetry.set("memory.extract.skipped", len(result.errors)) + + # Build Context objects for stats in session.py + contexts: List[Context] = [] + + # Written memories + for uri in result.written_uris: + contexts.append(Context( + uri=uri, + category="memory_write", + context_type="memory", + )) + + # Edited memories + for uri in result.edited_uris: + contexts.append(Context( + uri=uri, + category="memory_edit", + context_type="memory", + )) + + # Deleted memories + for uri in result.deleted_uris: + contexts.append(Context( + uri=uri, + category="memory_delete", + context_type="memory", + )) + + return contexts except Exception as e: logger.error(f"Failed to extract memories with v2: {e}", exc_info=True) if strict_extract_errors: raise return [] + finally: + # 确保释放所有锁(仅在有锁管理器时) + if lock_manager and transaction_handle: + try: + await lock_manager.release(transaction_handle) + except Exception as e: + logger.warning(f"Failed to release transaction lock: {e}") diff --git a/openviking/session/memory/__init__.py b/openviking/session/memory/__init__.py index 11866f3c3..b4d20d944 100644 --- a/openviking/session/memory/__init__.py +++ b/openviking/session/memory/__init__.py @@ -24,8 +24,8 @@ StructuredMemoryOperations, ) from openviking.session.memory.merge_op import MergeOp, FieldType, MemoryPatchHandler -from openviking.session.memory.memory_react import ( - MemoryReAct, +from openviking.session.memory.extract_loop import ( + ExtractLoop, ) from openviking.session.memory.memory_type_registry import MemoryTypeRegistry from openviking.session.memory.memory_updater import MemoryUpdater, MemoryUpdateResult @@ -40,8 +40,6 @@ MemoryTool, add_tool_call_items_to_messages, add_tool_call_pair_to_messages, - create_tool_call_message, - create_tool_result_message, get_tool, get_tool_schemas, list_tools, @@ -68,8 +66,8 @@ # Updater "MemoryUpdater", "MemoryUpdateResult", - # ReAct - "MemoryReAct", + # ExtractLoop + "ExtractLoop", # Tools (Tool implementations) "MemoryTool", "MemoryReadTool", @@ -79,8 +77,6 @@ "get_tool", "list_tools", "get_tool_schemas", - "create_tool_call_message", - "create_tool_result_message", "add_tool_call_pair_to_messages", "add_tool_call_items_to_messages", # Language utilities and helpers diff --git a/openviking/session/memory/core.py b/openviking/session/memory/core.py new file mode 100644 index 000000000..96eb86413 --- /dev/null +++ b/openviking/session/memory/core.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Extract Context Provider - 抽象接口 + +定义 ExtractLoop 使用的 Provider 接口,支持两种场景: +1. SessionExtractContextProvider - 从会话消息提取记忆 +2. ConsolidationExtractContextProvider - 定时整理已有记忆 +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from openviking.server.identity import RequestContext +from openviking.storage.viking_fs import VikingFS + + +class ExtractContextProvider(ABC): + """Extract Context Provider 接口""" + + @abstractmethod + def instruction(self) -> str: + """ + 指令 - Provider 相关,包含 goal、conversation 等 + + Returns: + 完整的指令描述 + """ + pass + + @abstractmethod + async def prefetch( + self, + ctx: RequestContext, + viking_fs: VikingFS, + transaction_handle, + vlm, + ) -> List[Dict]: + """ + 执行 prefetch + + Args: + ctx: RequestContext + viking_fs: VikingFS + transaction_handle: 事务句柄 + vlm: VLM 实例 + + Returns: + 预取的 tool call messages 列表 + """ + pass + + @abstractmethod + def get_tools(self) -> List[str]: + """ + 获取可用的工具列表 + + Returns: + 工具名称列表 + """ + pass + + @abstractmethod + def get_memory_schemas(self, ctx: RequestContext) -> List[Any]: + """ + 获取需要参与的 memory schemas + + Args: + ctx: RequestContext + + Returns: + 需要参与的 MemoryTypeSchema 列表 + """ + pass \ No newline at end of file diff --git a/openviking/session/memory/dataclass.py b/openviking/session/memory/dataclass.py index 4ce2c03a8..b66f5d3aa 100644 --- a/openviking/session/memory/dataclass.py +++ b/openviking/session/memory/dataclass.py @@ -4,10 +4,11 @@ Core domain data classes for memory system. """ +import json from datetime import datetime -from typing import Any, List, Optional, Protocol, TypeVar +from typing import Any, Dict, List, Optional, Protocol, TypeVar, Union, get_type_hints, get_origin, get_args -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from openviking.session.memory.merge_op.base import ( FieldType, @@ -45,6 +46,7 @@ class MemoryTypeSchema(BaseModel): content_template: Optional[str] = Field(None, description="Content template (for template mode)") directory: str = Field("", description="Directory path") enabled: bool = Field(True, description="Whether this memory type is enabled") + operation_mode: str = Field("upsert", description="Operation mode: 'upsert' (default), 'add_only', or 'update_only'") class MemoryData(BaseModel): @@ -72,6 +74,116 @@ def set_field(self, field_name: str, value: Any) -> None: +# ============================================================================ +# Fault Tolerant Base Model (参考 vikingdb BaseModelCompat) +# ============================================================================ + + +class FaultTolerantBaseModel(BaseModel): + """ + 支持验证前自动容错的 BaseModel,类似 vikingdb 的 BaseModelCompat。 + + 在 model_validator(mode='before') 中对所有字段做类型容错处理, + 使得模型可以接受 LLM 输出的不标准格式数据。 + """ + + @model_validator(mode='before') + @classmethod + def values_fault_tolerance(cls, data: Dict[str, Any]) -> Dict[str, Any]: + """在验证前对所有字段做容错处理""" + if isinstance(data, dict): + field_types = get_type_hints(cls) + for field_name, value in data.items(): + if field_name in field_types: + data[field_name] = cls.value_fault_tolerance(field_types[field_name], value) + return data + return {} + + @classmethod + def get_origin_type(cls, annotation) -> type: + """从 Optional 或 Union 类型中提取基础类型""" + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + if len(args) == 2 and args[1] == type(None): + return cls.get_origin_type(args[0]) + elif origin is list: + return list + return annotation + + @classmethod + def get_arg_type(cls, annotation) -> type: + """从 List annotation 中提取元素类型""" + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + if len(args) == 2 and args[1] == type(None): + return cls.get_arg_type(args[0]) + elif origin is list: + args = get_args(annotation) + if args: + return args[0] + return None + + @classmethod + def any_to_str(cls, value) -> str: + """将任意值转换为字符串""" + if value is None: + return "" + if isinstance(value, list): + return ",".join(map(str, value)) + elif isinstance(value, dict): + return json.dumps(value, ensure_ascii=False) + elif isinstance(value, (int, bool, float)): + return f'{value}' + return str(value) + + @classmethod + def value_fault_tolerance(cls, field_type, value): + """ + 字段级别的容错处理: + - 'None' -> None (非 str 类型) + - list/dict/number -> str (目标是 str) + - str -> int/float (目标是数字) + - str/dict -> list (目标是 list) + - list 元素类型容错 + """ + origin_type = cls.get_origin_type(field_type) + + # json_repair 会把 None 转换成 'None' + if value == 'None' and origin_type is not str: + return None + + if origin_type is str: + return cls.any_to_str(value) + elif origin_type is int: + if isinstance(value, str): + if value is None or value == 'None': + return 0 + try: + return int(value) + except (ValueError, TypeError): + pass + elif origin_type is float: + if isinstance(value, str): + if value is None or value == 'None': + return 0.0 + try: + return float(value) + except (ValueError, TypeError): + pass + elif origin_type is list: + if isinstance(value, str): + return [value] + elif isinstance(value, dict): + return [value] + elif isinstance(value, list): + arg_type = cls.get_arg_type(field_type) + if arg_type is str: + return [cls.any_to_str(v) for v in value] + return value + + # ============================================================================ # Memory Operations # ============================================================================ @@ -89,13 +201,12 @@ class MemoryOperationsProtocol(Protocol): def is_empty(self) -> bool: ... -class StructuredMemoryOperations(BaseModel): +class StructuredMemoryOperations(FaultTolerantBaseModel): """ - DEPRECATED: Placeholder only. The actual model is dynamically generated. + Fallback memory operations model with fault tolerance. Use SchemaModelGenerator.create_structured_operations_model() to get - the actual type-safe implementation with proper union types for write_uris - and edit_uris. + the actual type-safe implementation with per-memory_type fields. """ reasoning: str = Field( @@ -128,6 +239,15 @@ def is_empty(self) -> bool: and len(self.delete_uris) == 0 ) + def to_legacy_operations(self) -> Dict[str, Any]: + """Convert to legacy format (identity for fallback).""" + return { + "write_uris": self.write_uris, + "edit_uris": self.edit_uris, + "edit_overview_uris": self.edit_overview_uris, + "delete_uris": self.delete_uris, + } + model_config = {'extra': 'ignore'} diff --git a/openviking/session/memory/extract_loop.py b/openviking/session/memory/extract_loop.py new file mode 100644 index 000000000..1965747a3 --- /dev/null +++ b/openviking/session/memory/extract_loop.py @@ -0,0 +1,475 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Simplified ReAct orchestrator for memory updates - single LLM call with tool use. + +Reference: bot/vikingbot/agent/loop.py AgentLoop structure +""" + +import asyncio +import json +from typing import Any, Dict, List, Optional, Set, Tuple + +from openviking.message import Message +from openviking.models.vlm.base import VLMBase +from openviking.server.identity import RequestContext +from openviking.session.memory.dataclass import MemoryOperations +from openviking.session.memory.memory_type_registry import MemoryTypeRegistry +from openviking.session.memory.schema_model_generator import ( + SchemaModelGenerator, + SchemaPromptGenerator, +) +from openviking.session.memory.tools import ( + add_tool_call_pair_to_messages, + get_tool, + get_tool_schemas, + MEMORY_TOOLS_REGISTRY, +) +from openviking.session.memory.utils import ( + parse_json_with_stability, + parse_memory_file_with_fields, + pretty_print_messages, + truncate_content, + validate_operations_uris, +) +from openviking.storage.viking_fs import VikingFS, get_viking_fs +from openviking_cli.utils import get_logger +from openviking_cli.utils.config import get_openviking_config + +logger = get_logger(__name__) + + + +class ExtractLoop: + """ + Simplified ReAct orchestrator for memory updates. + + Workflow: + 0. Pre-fetch: System performs ls + read .overview.md + search (via strategy) + 1. LLM call with tools: Model decides to either use tools OR output final operations + 2. If tools used: Execute and continue loop + 3. If operations output: Return and finish + """ + + def __init__( + self, + vlm: VLMBase, + viking_fs: Optional[VikingFS] = None, + model: Optional[str] = None, + max_iterations: int = 3, + ctx: Optional[RequestContext] = None, + context_provider: Optional[Any] = None, # ExtractContextProvider + ): + """ + Initialize the ExtractLoop. + + Args: + vlm: VLM instance (from openviking.models.vlm.base) + viking_fs: VikingFS instance for storage operations + model: Model name to use + max_iterations: Maximum number of ReAct iterations (default: 5) + ctx: Request context + context_provider: ExtractContextProvider - 必须提供(由 provider 加载 schema) + """ + self.vlm = vlm + self.viking_fs = viking_fs or get_viking_fs() + self.model = model or self.vlm.model + self.max_iterations = max_iterations + self.ctx = ctx + self.context_provider = context_provider + + # Schema 生成器(在 run() 中初始化) + self.schema_model_generator = None + self.schema_prompt_generator = None + self._json_schema = None + + # 预计算:避免每次迭代重复计算 + self._tool_schemas: Optional[List[Dict[str, Any]]] = None + self._expected_fields: Optional[List[str]] = None + self._operations_model: Optional[Any] = None + + # Track files read during ReAct for refetch detection + self._read_files: Set[str] = set() + # Transaction handle for file locking + self._transaction_handle = None + + + + async def run(self) -> Tuple[Optional[MemoryOperations], List[Dict[str, Any]]]: + """ + Run the simplified ReAct loop for memory updates. + + Returns: + Tuple of (final MemoryOperations, tools_used list) + """ + iteration = 0 + max_iterations = self.max_iterations + final_operations = None + tools_used: List[Dict[str, Any]] = [] + + # 从 provider 获取 schemas(内部自动加载 registry) + schemas = self.context_provider.get_memory_schemas(self.ctx) + + # 初始化 schema 生成器(使用 schemas 而非 registry) + self.schema_model_generator = SchemaModelGenerator(schemas) + self.schema_prompt_generator = SchemaPromptGenerator(schemas) + self.schema_model_generator.generate_all_models() + self._json_schema = self.schema_model_generator.get_llm_json_schema() + + # 预计算工具 schemas + allowed_tools = self.context_provider.get_tools() + self._tool_schemas = [tool.to_schema() for tool in MEMORY_TOOLS_REGISTRY.values() if tool.name in allowed_tools] + + # 预计算 expected_fields + self._expected_fields = ['reasoning', 'edit_overview_uris', 'delete_uris'] + for schema in schemas: + self._expected_fields.append(schema.memory_type) + + # 预计算 operations_model + self._operations_model = self.schema_model_generator.create_structured_operations_model() + + + + # Reset read files tracking for this run + self._read_files.clear() + + # Build initial messages from provider + import json + schema_str = json.dumps(self._json_schema, ensure_ascii=False) + + messages = [] + # instruction() 返回字符串,需要包装成 message 格式 + messages.append({ + "role": "system", + "content": self.context_provider.instruction(), + }) + messages.append({ + "role":"system", + "content":f""" +## Output Format +See the complete JSON Schema below: +```json +{schema_str} +``` + """ + }) + + await self._mark_cache_breakpoint(messages) + # Pre-fetch context via provider + tool_call_messages = await self.context_provider.prefetch( + ctx=self.ctx, + viking_fs=self.viking_fs, + transaction_handle=self._transaction_handle, + vlm=self.vlm, + ) + messages.extend(tool_call_messages) + + while iteration < max_iterations: + iteration += 1 + logger.info(f"ReAct iteration {iteration}/{max_iterations}") + + # Check if this is the last iteration - force final result + is_last_iteration = iteration >= max_iterations + + # If last iteration, add a message telling the model to return result directly + if is_last_iteration: + messages.append({ + "role": "user", + "content": "You have reached the maximum number of tool call iterations. Do not call any more tools - return your final result directly now." + }) + + # Call LLM with tools - model decides: tool calls OR final operations + pretty_print_messages(messages) + tool_calls, operations = await self._call_llm(messages, force_final=is_last_iteration) + + + if tool_calls: + await self._execute_tool_calls(messages, tool_calls, tools_used) + await self._mark_cache_breakpoint(messages) + continue + + # If model returned final operations, check if refetch is needed + if operations is not None: + # Check if any write_uris target existing files that weren't read + refetch_uris = await self._check_unread_existing_files(operations) + if refetch_uris: + logger.info(f"Found unread existing files: {refetch_uris}, refetching...") + # Add refetch results to messages and continue loop + await self._add_refetch_results_to_messages(messages, refetch_uris) + # Allow one extra iteration for refetch + if iteration >= max_iterations: + max_iterations += 1 + logger.info(f"Extended max_iterations to {max_iterations} for refetch") + + await self._mark_cache_breakpoint(messages) + continue + + final_operations = operations + break + # If no tool calls either, continue to next iteration (don't break!) + logger.warning(f"LLM returned neither tool calls nor operations (iteration {iteration}/{max_iterations})") + # If it's the last iteration, use empty operations + if is_last_iteration: + final_operations = MemoryOperations() + break + # Otherwise continue and try again + continue + + + if final_operations is None: + if iteration >= max_iterations: + raise RuntimeError(f"Reached {max_iterations} iterations without completion") + else: + raise RuntimeError("ReAct loop completed but no operations generated") + + logger.info(f'final_operations={final_operations.model_dump_json(indent=4)}') + + return final_operations, tools_used + + async def _execute_tool_calls(self, messages, tool_calls, tools_used): + # Execute all tool calls in parallel + async def execute_single_tool_call(idx: int, tool_call): + """Execute a single tool call.""" + result = await self._execute_tool(tool_call) + return idx, tool_call, result + + action_tasks = [ + execute_single_tool_call(idx, tool_call) + for idx, tool_call in enumerate(tool_calls) + ] + results = await self._execute_in_parallel(action_tasks) + + # Process results and add to messages + for _idx, tool_call, result in results: + # Skip if arguments is None + if tool_call.arguments is None: + logger.warning(f"Tool call {tool_call.name} has no arguments, skipping") + continue + + tools_used.append({ + "tool_name": tool_call.name, + "params": tool_call.arguments, + "result": result, + }) + + # Track read tool calls for refetch detection + if tool_call.name == "read" and tool_call.arguments.get("uri"): + self._read_files.add(tool_call.arguments["uri"]) + + add_tool_call_pair_to_messages( + messages, + call_id=tool_call.id, + tool_name=tool_call.name, + params=tool_call.arguments, + result=result, + ) + + def _validate_operations(self, operations: MemoryOperations) -> None: + """ + Validate that all operations have allowed URIs. + + Args: + operations: The MemoryOperations to validate + + Raises: + ValueError: If any operation has a disallowed URI + """ + # Get registry from provider (internal method) + registry = self.context_provider._get_registry() + schemas = self.context_provider.get_memory_schemas(self.ctx) + + is_valid, errors = validate_operations_uris( + operations, + schemas, + registry, + user_space="default", + agent_space="default", + ) + if not is_valid: + error_msg = "Invalid memory operations:\n" + "\n".join(f" - {err}" for err in errors) + logger.error(error_msg) + raise ValueError(error_msg) + + async def _call_llm( + self, + messages: List[Dict[str, Any]], + force_final: bool = False, + ) -> Tuple[Optional[List], Optional[MemoryOperations]]: + """ + Call LLM with tools. Returns either tool calls OR final operations. + + Args: + messages: Message list + force_final: If True, force model to return final result (not tool calls) + + Returns: + Tuple of (tool_calls, operations) - one will be None, the other set + """ + # Call LLM with tools - use tools from strategy + tool_choice = "none" if force_final else None + + response = await self.vlm.get_completion_async( + messages=messages, + tools=self._tool_schemas, + tool_choice=tool_choice, + max_retries=self.vlm.max_retries, + ) + # print(f'response={response}') + # Log cache hit info + if hasattr(response, 'usage') and response.usage: + usage = response.usage + prompt_tokens = usage.get('prompt_tokens', 0) + cached_tokens = usage.get('prompt_tokens_details', {}).get('cached_tokens', 0) if isinstance(usage.get('prompt_tokens_details'), dict) else 0 + if prompt_tokens > 0: + cache_hit_rate = (cached_tokens / prompt_tokens) * 100 + logger.info(f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}, cache_hit_rate={cache_hit_rate:.1f}%") + else: + logger.info(f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}") + + # Case 1: LLM returned tool calls + if response.has_tool_calls: + # Format tool calls nicely for debug logging + for tc in response.tool_calls: + logger.info(f"[assistant tool_call] (id={tc.id}, name={tc.name})") + logger.info(f" {json.dumps(tc.arguments, indent=2, ensure_ascii=False)}") + return (response.tool_calls, None) + + # Case 2: Try to parse MemoryOperations from content with stability + content = response.content or "" + if content: + try: + # print(f'LLM response content: {content}') + logger.debug(f"[assistant]\n{content}") + + # Use cached operations_model and expected_fields + operations, error = parse_json_with_stability( + content=content, + model_class=self._operations_model, + expected_fields=self._expected_fields, + ) + + if error is not None: + print(f'content={content}') + logger.warning(f"Failed to parse memory operations: {error}") + return (None, None) + + # Validate that all URIs are allowed + self._validate_operations(operations) + return (None, operations) + except Exception as e: + print(f'Error parsing operations: {e}') + logger.warning(f"Unexpected error parsing memory operations: {e}") + + # Case 3: No tool calls and no parsable operations + print('No tool calls or operations parsed') + return (None, None) + + async def _execute_tool( + self, + tool_call, + ) -> Any: + """Execute a single read action (read/search/ls/tree).""" + if not self.viking_fs: + return {"error": "VikingFS not available"} + + tool = get_tool(tool_call.name) + if not tool: + return {"error": f"Unknown tool: {tool_call.name}"} + + # 创建 ToolContext + from openviking.server.identity import ToolContext + tool_ctx = ToolContext( + request_ctx=self.ctx, + transaction_handle=self._transaction_handle + ) + + try: + result = await tool.execute(self.viking_fs, tool_ctx, **tool_call.arguments) + return result + except Exception as e: + logger.error(f"Failed to execute {tool_call.name}: {e}") + return {"error": str(e)} + + async def _execute_in_parallel( + self, + tasks: List[Any], + ) -> List[Any]: + """Execute tasks in parallel, similar to AgentLoop.""" + return await asyncio.gather(*tasks) + + async def _check_unread_existing_files( + self, + operations: MemoryOperations, + ) -> List[str]: + """Check if write operations target existing files that weren't read during ReAct.""" + memory_type_fields = getattr(operations, '_memory_type_fields', None) + if not memory_type_fields: + return [] + + from openviking.session.memory.utils.uri import resolve_flat_model_uri + + registry = self.context_provider._get_registry() + refetch_uris = [] + + for field_name in memory_type_fields: + value = getattr(operations, field_name, None) + if value is None: + continue + items = value if isinstance(value, list) else [value] + for item in items: + # Convert to dict + item_dict = dict(item) if hasattr(item, 'model_dump') else dict(item) + try: + uri = resolve_flat_model_uri(item_dict, registry, "default", "default", memory_type=field_name) + except Exception as e: + logger.warning(f"Failed to resolve URI for {item}: {e}") + continue + + if uri in self._read_files: + continue + try: + await self.viking_fs.read_file(uri, ctx=self.ctx) + refetch_uris.append(uri) + except Exception: + pass + return refetch_uris + + async def _add_refetch_results_to_messages( + self, + messages: List[Dict[str, Any]], + refetch_uris: List[str], + ) -> None: + """Add existing file content as read tool results to messages.""" + # Calculate call_id based on existing tool messages + call_id_seq = len([m for m in messages if m.get("role") == "tool"]) + 1000 + + for uri in refetch_uris: + try: + content = await self.viking_fs.read_file(uri, ctx=self.ctx) + parsed = parse_memory_file_with_fields(content) + + # Add as read tool call + result + add_tool_call_pair_to_messages( + messages=messages, + call_id=call_id_seq, + tool_name="read", + params={"uri": uri}, + result=parsed, + ) + call_id_seq += 1 + + # Mark as read + self._read_files.add(uri) + except Exception as e: + logger.warning(f"Failed to refetch {uri}: {e}") + + # Add reminder message for the model + messages.append({ + "role": "user", + "content": "Note: The files above were automatically read because they exist and you didn't read them before deciding to write. Please consider the existing content when making write decisions. You can now output updated operations." + }) + + async def _mark_cache_breakpoint(self, messages): + # 支持 dict 消息和 object 消息 + last_msg = messages[-1] + last_msg["cache_control"] = {"type": "ephemeral"} diff --git a/openviking/session/memory/memory_react.py b/openviking/session/memory/memory_react.py deleted file mode 100644 index 9f85838cc..000000000 --- a/openviking/session/memory/memory_react.py +++ /dev/null @@ -1,665 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -""" -Simplified ReAct orchestrator for memory updates - single LLM call with tool use. - -Reference: bot/vikingbot/agent/loop.py AgentLoop structure -""" - -import asyncio -import json -from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple - -from pydantic import BaseModel, Field - -from openviking.models.vlm.base import VLMBase, VLMResponse -from openviking.server.identity import RequestContext -from openviking.session.memory.utils import ( - collect_allowed_directories, - detect_language_from_conversation, - extract_json_from_markdown, - parse_json_with_stability, - parse_memory_file_with_fields, - pretty_print_messages, - validate_operations_uris, -) -from openviking.session.memory.dataclass import MemoryOperations -from openviking.session.memory.memory_type_registry import MemoryTypeRegistry -from openviking.session.memory.schema_model_generator import ( - SchemaModelGenerator, - SchemaPromptGenerator, -) -from openviking.session.memory.tools import ( - get_tool, - get_tool_schemas, - add_tool_call_pair_to_messages, -) -from openviking.storage.viking_fs import VikingFS, get_viking_fs -from openviking_cli.utils import get_logger -from openviking_cli.utils.config import get_openviking_config - -logger = get_logger(__name__) - - - -class MemoryReAct: - """ - Simplified ReAct orchestrator for memory updates. - - Workflow: - 0. Pre-fetch: System performs ls + read .overview.md + search - 1. LLM call with tools: Model decides to either use tools OR output final operations - 2. If tools used: Execute and continue loop - 3. If operations output: Return and finish - """ - - def __init__( - self, - vlm: VLMBase, - viking_fs: Optional[VikingFS] = None, - model: Optional[str] = None, - max_iterations: int = 5, - ctx: Optional[RequestContext] = None, - registry: Optional[MemoryTypeRegistry] = None, - ): - """ - Initialize the MemoryReAct. - - Args: - vlm: VLM instance (from openviking.models.vlm.base) - viking_fs: VikingFS instance for storage operations - model: Model name to use - max_iterations: Maximum number of ReAct iterations (default: 5) - ctx: Request context - registry: Optional MemoryTypeRegistry - if not provided, will be created - """ - self.vlm = vlm - self.viking_fs = viking_fs or get_viking_fs() - self.model = model or self.vlm.model - self.max_iterations = max_iterations - self.ctx = ctx - - # Initialize schema registry and generators - if registry is not None: - self.registry = registry - else: - import os - schemas_dir = os.path.join(os.path.dirname(__file__), "..", "..", "prompts", "templates", "memory") - self.registry = MemoryTypeRegistry() - self.registry.load_from_directory(schemas_dir) - self.schema_model_generator = SchemaModelGenerator(self.registry) - self.schema_prompt_generator = SchemaPromptGenerator(self.registry) - - # Pre-generate models and JSON schema - self.schema_model_generator.generate_all_models() - self._json_schema = self.schema_model_generator.get_llm_json_schema() - - # Track files read during ReAct for refetch detection - self._read_files: Set[str] = set() - self._output_language: str = "en" - - async def _pre_fetch_context(self, conversation: str) -> Dict[str, Any]: - """ - Pre-fetch context based on activated schemas. - - Optimized logic: - - For multi-file schemas (filename_template has variables): ls the directory - - For single-file schemas (filename_template no variables): directly read the file - - No longer ls the root memories directory - - Args: - conversation: Conversation history for search query - - Returns: - Pre-fetched context with directories, summaries, and search_results - """ - from openviking.session.memory.tools import get_tool - messages = [] - - # Step 1: Separate schemas into multi-file (ls) and single-file (direct read) - ls_dirs = set() # directories to ls (for multi-file schemas) - read_files = set() # files to read directly (for single-file schemas) - - for schema in self.registry.list_all(include_disabled=False): - if not schema.directory: - continue - - # Replace variables in directory path with actual user/agent space - user_space = self.ctx.user.user_space_name() if self.ctx and self.ctx.user else "default" - agent_space = self.ctx.user.agent_space_name() if self.ctx and self.ctx.user else "default" - dir_path = schema.directory.replace("{user_space}", user_space).replace("{agent_space}", agent_space) - - # Check if filename_template has variables (contains {xxx}) - has_variables = False - if schema.filename_template: - has_variables = "{" in schema.filename_template and "}" in schema.filename_template - - if has_variables or not schema.filename_template: - # Multi-file schema or no filename template: ls the directory - ls_dirs.add(dir_path) - else: - # Single-file schema: directly read the specific file - file_uri = f"{dir_path}/{schema.filename_template}" - read_files.add(file_uri) - - call_id_seq = 0 - # Step 2: Execute ls for multi-file schema directories in parallel - ls_tool = get_tool("ls") - read_tool = get_tool("read") - if ls_tool and self.viking_fs and ls_dirs: - for dir_uri in ls_dirs: - try: - result_str = await ls_tool.execute(self.viking_fs, self.ctx, uri=dir_uri) - add_tool_call_pair_to_messages( - messages=messages, - call_id=call_id_seq, - tool_name='ls', - params={ - "uri": dir_uri - }, - result=result_str - ) - call_id_seq += 1 - - result_str = await read_tool.execute(self.viking_fs, self.ctx, uri=f'{dir_uri}/.overview.md') - - add_tool_call_pair_to_messages( - messages=messages, - call_id=call_id_seq, - tool_name='read', - params={ - "uri": f'{dir_uri}/.overview.md' - }, - result=result_str - ) - call_id_seq += 1 - - except Exception as e: - logger.warning(f"Failed to ls {dir_uri}: {e}") - - # Step 3: Search for relevant memories based on user messages in conversation - search_tool = get_tool("search") - if search_tool and self.viking_fs and self.ctx: - try: - # Extract only user messages from conversation - user_messages = [] - for line in conversation.split("\n"): - if line.startswith("[user]:"): - user_messages.append(line[len("[user]:"):].strip()) - user_query = " ".join(user_messages) - - if user_query: - search_result = await search_tool.execute( - viking_fs=self.viking_fs, - ctx=self.ctx, - query=user_query, - ) - if search_result and not search_result.get("error"): - add_tool_call_pair_to_messages( - messages=messages, - call_id=call_id_seq, - tool_name='search', - params={"query": user_query}, - result=str(search_result) - ) - call_id_seq += 1 - except Exception as e: - logger.warning(f"Pre-fetch search failed: {e}") - - return messages - - - async def run( - self, - conversation: str, - ) -> Tuple[Optional[MemoryOperations], List[Dict[str, Any]]]: - """ - Run the simplified ReAct loop for memory updates. - - Args: - conversation: Conversation history - - Returns: - Tuple of (final MemoryOperations, tools_used list) - """ - iteration = 0 - final_operations = None - tools_used: List[Dict[str, Any]] = [] - - # Detect output language from conversation - config = get_openviking_config() - fallback_language = (config.language_fallback or "en").strip() or "en" - self._output_language = detect_language_from_conversation( - conversation, fallback_language=fallback_language - ) - logger.info(f"Detected output language for memory ReAct: {self._output_language}") - - # Pre-fetch context internally - tool_call_messages = await self._pre_fetch_context(conversation) - - # Reset read files tracking for this run - self._read_files.clear() - - messages = self._build_initial_messages(conversation, tool_call_messages, self._output_language) - - while iteration < self.max_iterations: - iteration += 1 - logger.debug(f"ReAct iteration {iteration}/{self.max_iterations}") - - # Check if this is the last iteration - force final result - is_last_iteration = iteration >= self.max_iterations - - # If last iteration, add a message telling the model to return result directly - if is_last_iteration: - messages.append({ - "role": "user", - "content": "You have reached the maximum number of tool call iterations. Do not call any more tools - return your final result directly now." - }) - - # Call LLM with tools - model decides: tool calls OR final operations - tool_calls, operations = await self._call_llm(messages, force_final=is_last_iteration) - - # If model returned final operations, check if refetch is needed - if operations is not None: - # Check if any write_uris target existing files that weren't read - refetch_uris = await self._check_unread_existing_files(operations) - if refetch_uris: - logger.info(f"Found unread existing files: {refetch_uris}, refetching...") - # Add refetch results to messages and continue loop - await self._add_refetch_results_to_messages(messages, refetch_uris) - # Clear operations to force another iteration - operations = None - # Continue to next iteration - continue - - final_operations = operations - break - - # If no tool calls either, continue to next iteration (don't break!) - if not tool_calls: - logger.warning(f"LLM returned neither tool calls nor operations (iteration {iteration}/{self.max_iterations})") - # If it's the last iteration, use empty operations - if is_last_iteration: - final_operations = MemoryOperations() - break - # Otherwise continue and try again - continue - - # Execute all tool calls in parallel - async def execute_single_tool_call(idx: int, tool_call): - """Execute a single tool call.""" - result = await self._execute_tool(tool_call) - return idx, tool_call, result - - action_tasks = [ - execute_single_tool_call(idx, tool_call) - for idx, tool_call in enumerate(tool_calls) - ] - results = await self._execute_in_parallel(action_tasks) - - # Process results and add to messages - for _idx, tool_call, result in results: - tools_used.append({ - "tool_name": tool_call.name, - "params": tool_call.arguments, - "result": result, - }) - - # Track read tool calls for refetch detection - if tool_call.name == "read" and tool_call.arguments.get("uri"): - self._read_files.add(tool_call.arguments["uri"]) - - add_tool_call_pair_to_messages( - messages, - call_id=tool_call.id, - tool_name=tool_call.name, - params=tool_call.arguments, - result=result, - ) - # Print updated messages with tool results - pretty_print_messages(messages) - if final_operations is None: - if iteration >= self.max_iterations: - raise RuntimeError(f"Reached {self.max_iterations} iterations without completion") - else: - raise RuntimeError("ReAct loop completed but no operations generated") - - logger.info(f'final_operations={final_operations.model_dump_json(indent=4)}') - - return final_operations, tools_used - - def _build_initial_messages( - self, - conversation: str, - tool_call_messages: List, - output_language: str, - ) -> List[Dict[str, Any]]: - """Build initial messages from conversation and pre-fetched context.""" - system_prompt = self._get_system_prompt(output_language) - messages = [ - { - "role": "system", - "content": system_prompt, - } - ] - - # Add pre-fetched context as tool calls - messages.extend(tool_call_messages) - messages.append({ - "role": "user", - "content": f"""## Conversation History -{conversation} - -After exploring, analyze the conversation and output ALL memory write/edit/delete operations in a single response. Do not output operations one at a time - gather all changes first, then return them together.""", - }) - # Print messages in a readable format - pretty_print_messages(messages) - - return messages - - - def _get_allowed_directories_list(self) -> str: - """Get a formatted list of allowed directories for the system prompt.""" - user_space = self.ctx.user.user_space_name() if self.ctx and self.ctx.user else "default" - agent_space = self.ctx.user.agent_space_name() if self.ctx and self.ctx.user else "default" - allowed_dirs = collect_allowed_directories( - self.registry.list_all(include_disabled=False), - user_space=user_space, - agent_space=agent_space, - ) - if not allowed_dirs: - return "No directories configured (this is an error)." - return "\n".join(f"- {dir_path}" for dir_path in sorted(allowed_dirs)) - - def _get_system_prompt(self, output_language: str) -> str: - """Get the simplified system prompt.""" - import json - schema_str = json.dumps(self._json_schema, ensure_ascii=False) - allowed_dirs_list = self._get_allowed_directories_list() - - return f"""You are a memory extraction agent. Your task is to analyze conversations and update memories. - -## Workflow -1. Analyze the conversation and pre-fetched context -2. If you need more information, use the available tools (read/search) -3. When you have enough information, output ONLY a JSON object (no extra text before or after) - -## CRITICAL: Available Tools -- ONLY read and search tools are available -- DO NOT use write tool - just output the JSON result, the system will handle writing -- ls tool is NOT available - -## Critical: Read Before Edit -IMPORTANT: Before you edit or update ANY existing memory file, you MUST first use the read tool to read its complete content. - -- The pre-fetched .overview.md files are only partial information - they are NOT the complete memory content -- You MUST use the read tool to get the actual content of any file you want to edit -- Without reading the actual file first, your edit operations will fail because the search string won't match - -## Target Output Language -All memory content (abstract, overview, content fields) MUST be written in {output_language}. - -## URI Handling (Automatic) -IMPORTANT: You do NOT need to construct URIs manually. The system will automatically generate URIs based on: -- For write_uris: Using memory_type and fields -- For edit_uris: Using memory_type and fields to identify the target -- For edit_overview_uris: Using memory_type to identify the directory, then updates the .overview.md file in that directory -- For delete_uris: Using memory_type and fields to identify the target - -Just provide the correct memory_type and fields, and the system will handle the rest. - -## Edit Overview Files (IMPORTANT - Don't Forget!) -You MUST use edit_overview_uris to update the .overview.md file whenever you write new memories. - -This is a REQUIRED step after writing memories: -1. After adding new entries via write_uris, ALWAYS also update the corresponding .overview.md -2. The .overview.md provides a high-level summary for that memory type directory -3. Without updating overview, new memories won't be visible in high-level summaries - -Example workflow: -- write_uris: Add new skill "Python async programming" → writes to skills/python_async.md -- edit_overview_uris: {{"memory_type": "skills", "overview": "Python async programming, Go concurrency, System design..."}} - -How to use edit_overview_uris: -- Provide memory_type to identify which directory's overview to update -- Provide overview field with the new content (string or patch format) -- Example: {{"memory_type": "profile", "overview": "User profile overview..."}} - -## Overview Format Requirements (IMPORTANT) -When generating overview content for edit_overview_uris, you MUST follow this structure: - -1. **Title (H1)**: Directory name (e.g., "# skills") -2. **Brief Description (plain text paragraph, 50-150 words)**: - - Immediately following the title, without any H2 heading - - Explain what this directory is about - - Include core keywords for easy searching -3. **Quick Navigation (H2)**: Decision Tree style - - Use "What do you want to learn?" or "What do you want to do?" - - Use markdown links with relative paths: [description](./filename.md) -4. **Detailed Description (H2)**: One H3 subsection for each file - -Example: -# skills - -Python async programming, Go concurrency, and System design skills for backend developers. - -## Quick Navigation -- Want to learn async programming → [Python Async](./python_async.md) -- Want to learn concurrency → [Go Concurrency](./go_concurrency.md) - -## Detailed Description -### Python Async -... - -Total length: 400-800 words - -## Final Output Format -Outputs will be a complete JSON object with the following fields (Don't have '```json' appear and do not use '//' to omit content) - -JSON schema: -```json -{schema_str} -``` - -## Important Notes -- DO NOT use write tool - the system will write memories based on your JSON output -- Only read and search tools are available for you to use -- Output ONLY the JSON object - no extra text before or after -- Put your thinking and reasoning in the `reasonning` field of the JSON -""" - - def _validate_operations(self, operations: MemoryOperations) -> None: - """ - Validate that all operations have allowed URIs. - - Args: - operations: The MemoryOperations to validate - - Raises: - ValueError: If any operation has a disallowed URI - """ - is_valid, errors = validate_operations_uris( - operations, - self.registry.list_all(include_disabled=False), - self.registry, - user_space="default", - agent_space="default", - ) - if not is_valid: - error_msg = "Invalid memory operations:\n" + "\n".join(f" - {err}" for err in errors) - logger.error(error_msg) - raise ValueError(error_msg) - - async def _call_llm( - self, - messages: List[Dict[str, Any]], - force_final: bool = False, - ) -> Tuple[Optional[List], Optional[MemoryOperations]]: - """ - Call LLM with tools. Returns either tool calls OR final operations. - - Args: - messages: Message list - force_final: If True, force model to return final result (not tool calls) - - Returns: - Tuple of (tool_calls, operations) - one will be None, the other set - """ - # Call LLM with tools - tool_choice = "none" if force_final else None - response = await self.vlm.get_completion_async( - messages=messages, - tools=get_tool_schemas(), - tool_choice=tool_choice, - max_retries=self.vlm.max_retries, - ) - - # Log cache hit info - if hasattr(response, 'usage') and response.usage: - usage = response.usage - prompt_tokens = usage.get('prompt_tokens', 0) - cached_tokens = usage.get('prompt_tokens_details', {}).get('cached_tokens', 0) if isinstance(usage.get('prompt_tokens_details'), dict) else 0 - if prompt_tokens > 0: - cache_hit_rate = (cached_tokens / prompt_tokens) * 100 - logger.info(f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}, cache_hit_rate={cache_hit_rate:.1f}%") - else: - logger.info(f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}") - - # Case 1: LLM returned tool calls - if response.has_tool_calls: - # Format tool calls nicely for debug logging - for tc in response.tool_calls: - logger.info(f"[assistant tool_call] (id={tc.id}, name={tc.name})") - logger.info(f" {json.dumps(tc.arguments, indent=2, ensure_ascii=False)}") - return (response.tool_calls, None) - - # Case 2: Try to parse MemoryOperations from content with stability - content = response.content or "" - if content: - try: - logger.debug(f"[assistant]\n{content}") - # Get the dynamically generated operations model for better type safety - operations_model = self.schema_model_generator.create_structured_operations_model() - - # Use five-layer stable JSON parsing - operations, error = parse_json_with_stability( - content=content, - model_class=operations_model, - expected_fields=['reasoning', 'write_uris', 'edit_uris', 'edit_overview_uris', 'delete_uris'], - ) - - if error is not None: - logger.warning(f"Failed to parse memory operations (stable parse): {error}") - # Fallback: try with base MemoryOperations - content_no_md = extract_json_from_markdown(content) - operations, error_fallback = parse_json_with_stability( - content=content_no_md, - model_class=MemoryOperations, - expected_fields=['reasoning', 'write_uris', 'edit_uris', 'edit_overview_uris', 'delete_uris'], - ) - if error_fallback is not None: - logger.warning(f"Fallback parse also failed: {error_fallback}") - return (None, None) - - # Validate that all URIs are allowed - self._validate_operations(operations) - return (None, operations) - except Exception as e: - logger.warning(f"Unexpected error parsing memory operations: {e}") - - # Case 3: No tool calls and no parsable operations - return (None, None) - - async def _execute_tool( - self, - tool_call, - ) -> Any: - """Execute a single read action (read/search/ls/tree).""" - if not self.viking_fs: - return {"error": "VikingFS not available"} - - tool = get_tool(tool_call.name) - if not tool: - return {"error": f"Unknown tool: {tool_call.name}"} - - try: - result = await tool.execute(self.viking_fs, self.ctx, **tool_call.arguments) - return result - except Exception as e: - logger.error(f"Failed to execute {tool_call.name}: {e}") - return {"error": str(e)} - - async def _execute_in_parallel( - self, - tasks: List[Any], - ) -> List[Any]: - """Execute tasks in parallel, similar to AgentLoop.""" - return await asyncio.gather(*tasks) - - async def _check_unread_existing_files( - self, - operations: MemoryOperations, - ) -> List[str]: - """Check if write_uris target existing files that weren't read during ReAct.""" - if not operations.write_uris: - return [] - - from openviking.session.memory.utils.uri import resolve_flat_model_uri - - refetch_uris = [] - for op in operations.write_uris: - # Resolve the flat model to URI - try: - uri = resolve_flat_model_uri(op, self.registry, "default", "default") - except Exception as e: - logger.warning(f"Failed to resolve URI for {op}: {e}") - continue - - # Skip if already read - if uri in self._read_files: - continue - # Check if file exists - try: - await self.viking_fs.read_file(uri, ctx=self.ctx) - # File exists and wasn't read - need refetch - refetch_uris.append(uri) - except Exception: - # File doesn't exist, no need to refetch - pass - return refetch_uris - - async def _add_refetch_results_to_messages( - self, - messages: List[Dict[str, Any]], - refetch_uris: List[str], - ) -> None: - """Add existing file content as read tool results to messages.""" - # Calculate call_id based on existing tool messages - call_id_seq = len([m for m in messages if m.get("role") == "tool"]) + 1000 - - for uri in refetch_uris: - try: - content = await self.viking_fs.read_file(uri, ctx=self.ctx) - parsed = parse_memory_file_with_fields(content) - - # Add as read tool call + result - add_tool_call_pair_to_messages( - messages=messages, - call_id=call_id_seq, - tool_name="read", - params={"uri": uri}, - result=parsed, - ) - call_id_seq += 1 - - # Mark as read - self._read_files.add(uri) - except Exception as e: - logger.warning(f"Failed to refetch {uri}: {e}") - - # Add reminder message for the model - messages.append({ - "role": "user", - "content": "Note: The files above were automatically read because they exist and you didn't read them before deciding to write. Please consider the existing content when making write decisions. You can now output updated operations." - }) diff --git a/openviking/session/memory/memory_type_registry.py b/openviking/session/memory/memory_type_registry.py index 77f17c869..ef1b7afd3 100644 --- a/openviking/session/memory/memory_type_registry.py +++ b/openviking/session/memory/memory_type_registry.py @@ -63,6 +63,23 @@ def list_names(self, include_disabled: bool = False) -> List[str]: return list(self._types.keys()) return [mt.memory_type for mt in self._types.values() if mt.enabled] + def list_search_uris(self, user_space: str, agent_space: str) -> List[str]: + """List all directory URIs for search scope. + + Args: + user_space: User space name + agent_space: Agent space name + + Returns: + List of directory URIs from enabled schemas + """ + uris = [] + for schema in self.list_all(include_disabled=False): + if schema.directory: + dir_path = schema.directory.replace("{user_space}", user_space).replace("{agent_space}", agent_space) + uris.append(dir_path) + return uris + def load_from_yaml(self, yaml_path: str) -> None: """ Load memory type from a YAML file. @@ -131,6 +148,7 @@ def _parse_memory_type(self, data: dict) -> MemoryTypeSchema: content_template=data.get("content_template"), directory=data.get("directory", ""), enabled=data.get("enabled", data.get("enable", True)), + operation_mode=data.get("operation_mode", "upsert"), ) diff --git a/openviking/session/memory/memory_updater.py b/openviking/session/memory/memory_updater.py index 005b42f51..16dec5805 100644 --- a/openviking/session/memory/memory_updater.py +++ b/openviking/session/memory/memory_updater.py @@ -7,20 +7,22 @@ to the storage system. """ -from datetime import datetime from typing import Any, Dict, List, Optional, Tuple +from openviking.message import Message from openviking.server.identity import RequestContext +from openviking.session.memory.dataclass import MemoryField +from openviking.session.memory.memory_type_registry import MemoryTypeRegistry +from openviking.session.memory.merge_op import MergeOpFactory, PatchOp +from openviking.session.memory.merge_op.base import FieldType, SearchReplaceBlock, StrPatch from openviking.session.memory.utils import ( deserialize_full, - serialize_with_metadata, - resolve_all_operations, flat_model_to_dict, + parse_memory_file_with_fields, + resolve_all_operations, + serialize_with_metadata, ) -from openviking.session.memory.dataclass import MemoryField -from openviking.session.memory.merge_op import MergeOpFactory, PatchOp -from openviking.session.memory.merge_op.base import FieldType, SearchReplaceBlock, StrPatch -from openviking.session.memory.memory_type_registry import MemoryTypeRegistry +from openviking.session.memory.utils.uri import ResolvedOperation from openviking.storage.viking_fs import get_viking_fs from openviking_cli.exceptions import NotFoundError from openviking_cli.utils import get_logger @@ -28,6 +30,85 @@ logger = get_logger(__name__) +class ExtractContext: + """Extract context for template rendering.""" + def __init__(self, messages: List[Message]): + self.messages = messages + + def read_message_ranges(self, ranges_str: str) -> "MessageRange": + """Parse ranges string like "0-10,50-60" or "7,9,11,13" and return combined MessageRange. + + If there's a gap between ranges (e.g., 0-10 and 50-60), add "..." as separator. + Supports: + - "0-10,50-60" - ranges + - "7,9,11,13" - single indices + - "0-10,15,20-25" - mixed + """ + if not ranges_str: + return MessageRange([]) + + # 解析所有范围/索引 + ranges = [] + for part in ranges_str.split(','): + part = part.strip() + if not part: + continue + if '-' in part: + start, end = part.split('-') + ranges.append((int(start), int(end))) + else: + # 单个索引转为相同起止范围 + idx = int(part) + ranges.append((idx, idx)) + + if not ranges: + return MessageRange([]) + + # 按 start 排序 + ranges.sort(key=lambda x: x[0]) + + # elements 可以是 Message 或 str ("...") + elements: List[Message | str] = [] + for i, (start, end) in enumerate(ranges): + if start < 0 or end >= len(self.messages): + continue + range_msgs = self.messages[start:end + 1] + + if i > 0: + prev_end = ranges[i - 1][1] + # 如果有间隔,加 ... + if start > prev_end + 1: + elements.append("...") + elements.extend(range_msgs) + + return MessageRange(elements) + + +class MessageRange: + """Represents a range of messages for formatting.""" + def __init__(self, elements: List[Message | str]): + self.elements = elements + + def pretty_print(self) -> str: + """Pretty print the message range.""" + result = [] + for elem in self.elements: + if isinstance(elem, str): + result.append(elem) + else: + result.append(f"[{elem.role}]: {elem.content}") + return "\n".join(result) + + def first_message_time(self) -> str | None: + """获取第一条消息的时间(YAML 日期格式),如果没有消息则返回 None""" + for elem in self.elements: + if isinstance(elem, str): + continue + if hasattr(elem, 'created_at') and elem.created_at: + return elem.created_at.strftime("%Y-%m-%d") + return None + + class MemoryUpdateResult: """Result of memory update operation.""" @@ -75,10 +156,11 @@ class MemoryUpdater: No function calls are used for write/edit/delete - these are executed directly. """ - def __init__(self, registry: Optional[MemoryTypeRegistry] = None, vikingdb=None): + def __init__(self, registry: Optional[MemoryTypeRegistry] = None, vikingdb=None, transaction_handle=None): self._viking_fs = None self._registry = registry self._vikingdb = vikingdb + self._transaction_handle = transaction_handle def set_registry(self, registry: MemoryTypeRegistry) -> None: """Set the memory type registry for URI resolution.""" @@ -95,6 +177,7 @@ async def apply_operations( operations: Any, ctx: RequestContext, registry: Optional[MemoryTypeRegistry] = None, + extract_context: Any = None, ) -> MemoryUpdateResult: """ Apply MemoryOperations directly using the flat model format. @@ -139,22 +222,30 @@ async def apply_operations( return result # Apply write operations - for op, uri in resolved_ops.write_operations: + for resolved_op in resolved_ops.write_operations: try: - await self._apply_write(op, uri, ctx) - result.add_written(uri) + await self._apply_write( + resolved_op.model, + resolved_op.uri, + ctx, + extract_context=extract_context, + memory_type=resolved_op.memory_type, + ) + result.add_written(resolved_op.uri) except Exception as e: - logger.error(f"Failed to write memory: {e}") - result.add_error(uri, e) + logger.info(f"Failed to write memory: {e}, op={resolved_op.model}, op type={type(resolved_op.model)}") + if hasattr(resolved_op.model, 'model_dump'): + logger.info(f"Op dump: {resolved_op.model.model_dump()}") + result.add_error(resolved_op.uri, e) # Apply edit operations - for op, uri in resolved_ops.edit_operations: + for resolved_op in resolved_ops.edit_operations: try: - await self._apply_edit(op, uri, ctx) - result.add_edited(uri) + await self._apply_edit(resolved_op.model, resolved_op.uri, ctx) + result.add_edited(resolved_op.uri) except Exception as e: - logger.error(f"Failed to edit memory {uri}: {e}") - result.add_error(uri, e) + logger.error(f"Failed to edit memory {resolved_op.uri}: {e}") + result.add_error(resolved_op.uri, e) # Apply edit_overview operations for op, uri in resolved_ops.edit_overview_operations: @@ -180,23 +271,19 @@ async def apply_operations( logger.info(f"Memory operations applied: {result.summary()}") return result - async def _apply_write(self, flat_model: Any, uri: str, ctx: RequestContext) -> None: + async def _apply_write(self, flat_model: Any, uri: str, ctx: RequestContext, extract_context: Any = None, memory_type: str = None) -> None: """Apply write operation from a flat model.""" viking_fs = self._get_viking_fs() # Convert model to dict model_dict = flat_model_to_dict(flat_model) - # Set timestamps if not provided - now = datetime.utcnow() - created_at = model_dict.get("created_at", now) - updated_at = model_dict.get("updated_at", now) - # Extract content - priority: model_dict["content"] content = model_dict.pop("content", None) or "" - # Get memory type schema to know which fields are business fields vs metadata - memory_type_str = model_dict.get("memory_type") + # Get memory type schema - use passed memory_type first, then fallback to model_dict + memory_type_str = memory_type or model_dict.get("memory_type") + field_schema_map: Dict[str, MemoryField] = {} business_fields: Dict[str, Any] = {} @@ -209,6 +296,16 @@ async def _apply_write(self, flat_model: Any, uri: str, ctx: RequestContext) -> if field_name in model_dict: business_fields[field_name] = model_dict[field_name] + # 模板渲染逻辑 + if schema.content_template: + try: + rendered_content = self._render_content_template(schema.content_template, business_fields, extract_context=extract_context) + if rendered_content: + content = rendered_content + except Exception as e: + logger.warning(f"Failed to render content template for memory type {memory_type_str}: {e}") + # 渲染失败时保留原始 content,确保写入操作继续进行 + # Collect metadata - only include business fields (from schema, except content) metadata = business_fields.copy() @@ -220,23 +317,71 @@ async def _apply_write(self, flat_model: Any, uri: str, ctx: RequestContext) -> await viking_fs.write_file(uri, full_content, ctx=ctx) logger.debug(f"Written memory: {uri}") + def _render_content_template(self, template: str, fields: Dict[str, Any], extract_context: Any = None) -> str: + """ + Render content template using Jinja2 template engine. + + Args: + template: The content template string with placeholders + fields: Dictionary of field values to use for substitution + extract_context: Extract context for message extraction + + Returns: + Rendered template string + + Raises: + Exception: If template rendering fails + """ + try: + # 导入 Jinja2(延迟导入以避免循环依赖) + import jinja2 + from jinja2 import Environment + + # 创建 Jinja2 环境,允许未定义的变量(打印警告但不报错) + env = Environment(autoescape=False, undefined=jinja2.DebugUndefined) + + # 创建模板变量 + template_vars = fields.copy() + # 始终传入 extract_context,即使是 None,避免模板中访问时 undefined + template_vars["extract_context"] = extract_context + + # 渲染模板 + jinja_template = env.from_string(template) + return jinja_template.render(**template_vars).strip() + except Exception as e: + logger.error(f"Template rendering failed: {e}") + raise + + def _is_patch_format(self, content: Any) -> bool: + """Check if content is a patch format (StrPatch), not a complete replacement.""" + from openviking.session.memory.merge_op.patch import StrPatch + return isinstance(content, StrPatch) + async def _apply_edit(self, flat_model: Any, uri: str, ctx: RequestContext) -> None: """Apply edit operation from a flat model.""" viking_fs = self._get_viking_fs() + # Convert flat model to dict first (needed for checking content type) + model_dict = flat_model_to_dict(flat_model) + # Read current memory try: current_full_content = await viking_fs.read_file(uri, ctx=ctx) or "" except NotFoundError: + # If memory doesn't exist, check if any field is a StrPatch + # If no StrPatch fields, treat as write operation + has_str_patch = any(self._is_patch_format(v) for v in model_dict.values()) + if not has_str_patch: + logger.debug(f"Memory not found for edit, treating as write: {uri}") + await self._apply_write(flat_model, uri, ctx) + return + # Has StrPatch field but file doesn't exist - cannot apply logger.warning(f"Memory not found for edit: {uri}") return # Deserialize content and metadata current_plain_content, current_metadata = deserialize_full(current_full_content) - # Convert flat model to dict - model_dict = flat_model_to_dict(flat_model) - # Get memory type schema memory_type_str = model_dict.get("memory_type") or current_metadata.get("memory_type") field_schema_map: Dict[str, MemoryField] = {} @@ -332,11 +477,14 @@ async def _apply_edit_overview(self, overview_model: Any, uri: str, ctx: Request new_overview = current_overview if overview_value is None: # No overview provided, nothing to do - logger.debug(f"No overview value provided, skipping edit") + logger.debug("No overview value provided, skipping edit") return elif isinstance(overview_value, str): - # Direct string - replace - new_overview = overview_value + # 空字符串保持原值 + if overview_value == "": + new_overview = current_overview + else: + new_overview = overview_value elif isinstance(overview_value, dict): # Dict format - convert to StrPatch if needed if 'blocks' in overview_value: @@ -367,7 +515,12 @@ async def _apply_edit_overview(self, overview_model: Any, uri: str, ctx: Request def _extract_abstract_from_overview(self, overview_content: str) -> str: """Extract abstract from overview.md - same logic as SemanticProcessor.""" - lines = overview_content.split("\n") + # Use parse_memory_file_with_fields to strip MEMORY_FIELDS comment + parsed = parse_memory_file_with_fields(overview_content) + content = parsed.get("content", "") + + # Then extract abstract from the cleaned content + lines = content.split("\n") # Skip header lines (starting with #) content_lines = [] @@ -477,8 +630,9 @@ async def _vectorize_memories( # Read the memory file to get content content = await viking_fs.read_file(uri, ctx=ctx) or "" - # Extract abstract (first 200 chars or first paragraph) - abstract = content[:200].split("\n\n")[0] if content else "" + # Use parse_memory_file_with_fields to strip MEMORY_FIELDS comment + parsed = parse_memory_file_with_fields(content) + abstract = parsed.get("content", "") # Get parent URI from openviking_cli.utils.uri import VikingURI diff --git a/openviking/session/memory/merge_op/base.py b/openviking/session/memory/merge_op/base.py index 35c29c90b..89c51ecd6 100644 --- a/openviking/session/memory/merge_op/base.py +++ b/openviking/session/memory/merge_op/base.py @@ -53,7 +53,10 @@ def get_python_type_for_field(field_type: FieldType, default: Type[Any] = str) - class SearchReplaceBlock(BaseModel): """Single SEARCH/REPLACE block for string patches.""" - search: str = Field(..., description="Content to search for") + search: str = Field( + ..., + description="Content to search for. ONLY include the EXACT lines you need to change - NEVER include the entire section. Example (WRONG): '## Melanie\\n- line1\\n- line2\\n[50 more lines]'. Example (CORRECT): '- Art can be in the most unlikely places, and love and acceptance really can be found everywhere'" + ) replace: str = Field(..., description="Content to replace with") start_line: Optional[int] = Field(None, description="Starting line number hint") @@ -66,7 +69,7 @@ class StrPatch(BaseModel): blocks: List[SearchReplaceBlock] = Field( default_factory=list, - description="List of SEARCH/REPLACE blocks to apply" + description="List of SEARCH/REPLACE blocks to apply. PREFER direct string replacement over SEARCH/REPLACE when possible. When using SEARCH/REPLACE, only include the specific line(s) to change, never the entire section." ) diff --git a/openviking/session/memory/merge_op/patch.py b/openviking/session/memory/merge_op/patch.py index 35827f60d..c31b64055 100644 --- a/openviking/session/memory/merge_op/patch.py +++ b/openviking/session/memory/merge_op/patch.py @@ -44,8 +44,7 @@ def apply(self, current_value: Any, patch_value: Any) -> Any: For string fields (content): - StrPatch: use apply_str_patch() - - str with "<<<<<<< SEARCH": use MemoryPatchHandler - - other str: full replacement + - other: full replacement For non-string fields: - Just replace with patch_value @@ -54,15 +53,12 @@ def apply(self, current_value: Any, patch_value: Any) -> Any: if self._field_type != FieldType.STRING: return patch_value - # For string fields, handle various patch formats - from openviking.session.memory.merge_op.patch_handler import ( - MemoryPatchHandler, - apply_str_patch, - ) + # For string fields + from openviking.session.memory.merge_op.patch_handler import apply_str_patch current_str = current_value or "" - # Case 1: StrPatch object + # Case 1: StrPatch object - apply patch if isinstance(patch_value, StrPatch): return apply_str_patch(current_str, patch_value) @@ -82,15 +78,8 @@ def apply(self, current_value: Any, patch_value: Any) -> Any: # If conversion fails, treat as simple replacement return str(patch_value) if patch_value is not None else "" - # Case 3: string with SEARCH/REPLACE markers - if isinstance(patch_value, str): - if "<<<<<<< SEARCH" in patch_value: - if self._patch_handler is None: - self._patch_handler = MemoryPatchHandler() - return self._patch_handler.apply_content_patch(current_str, patch_value) - else: - # Simple full replacement - return patch_value - - # Fallback: just return patch_value as-is + # Case 3: Simple full replacement + # 空字符串和 None 都保持原值 + if patch_value is None or patch_value == "": + return current_value return patch_value diff --git a/openviking/session/memory/merge_op/sum.py b/openviking/session/memory/merge_op/sum.py index f6a38caf6..f92f0e7da 100644 --- a/openviking/session/memory/merge_op/sum.py +++ b/openviking/session/memory/merge_op/sum.py @@ -26,6 +26,9 @@ def get_output_schema_description(self, field_description: str) -> str: return f"add for '{field_description}'" def apply(self, current_value: Any, patch_value: Any) -> Any: + # None 或空值保留原值 + if patch_value is None or patch_value == "": + return current_value if current_value is None: return patch_value try: @@ -33,4 +36,4 @@ def apply(self, current_value: Any, patch_value: Any) -> Any: return float(current_value) + float(patch_value) return int(current_value) + int(patch_value) except (ValueError, TypeError): - return patch_value + return current_value diff --git a/openviking/session/memory/schema_model_generator.py b/openviking/session/memory/schema_model_generator.py index cc7c742ef..2e4d10e76 100644 --- a/openviking/session/memory/schema_model_generator.py +++ b/openviking/session/memory/schema_model_generator.py @@ -15,7 +15,7 @@ from pydantic.config import ConfigDict from typing_extensions import Annotated, Literal -from openviking.session.memory.dataclass import MemoryTypeSchema +from openviking.session.memory.dataclass import FaultTolerantBaseModel, MemoryTypeSchema from openviking.session.memory.merge_op import MergeOp, MergeOpFactory from openviking.session.memory.merge_op.base import FieldType, StrPatch, get_python_type_for_field from openviking.session.memory.memory_type_registry import MemoryTypeRegistry @@ -44,8 +44,8 @@ class SchemaModelGenerator: # Generic overview edit model shared by all memory types _generic_overview_edit_model: Optional[Type[BaseModel]] = None - def __init__(self, registry: MemoryTypeRegistry): - self.registry = registry + def __init__(self, schemas: List[MemoryTypeSchema]): + self.schemas = schemas self._model_cache: Dict[str, Type[BaseModel]] = {} self._flat_data_models: Dict[str, Type[BaseModel]] = {} self._overview_edit_models: Dict[str, Type[BaseModel]] = {} @@ -60,10 +60,8 @@ def create_flat_data_model(self, memory_type: MemoryTypeSchema) -> Type[BaseMode """ Create a fully flat Pydantic model for a specific memory type. - The model includes: - - memory_type (literal discriminator) - - All business fields (with Union[base_type, patch_type] for mutable fields) - - Standard metadata fields (uri, name, abstract, overview, content, tags, created_at, updated_at) + Note: memory_type field is NOT included since each type has its own + output field in the structured operations model. Args: memory_type: The memory type schema @@ -78,15 +76,9 @@ def create_flat_data_model(self, memory_type: MemoryTypeSchema) -> Type[BaseMode model_name = f"{to_pascal_case(memory_type.memory_type)}Data" - # Build field definitions + # Build field definitions - no memory_type field needed field_definitions: Dict[str, Tuple[Type[Any], Any]] = {} - # Add memory_type as literal discriminator - field_definitions["memory_type"] = ( - Literal[memory_type.memory_type], # type: ignore - Field(..., description=f"Memory type: {memory_type.memory_type}"), - ) - # Add business fields from schema for field in memory_type.fields: base_type = self._map_field_type(field.field_type) @@ -128,7 +120,7 @@ def generate_all_models(self, include_disabled: bool = True) -> Dict[str, Type[B Dictionary mapping memory_type to generated model class """ models: Dict[str, Type[BaseModel]] = {} - for memory_type in self.registry.list_all(include_disabled=include_disabled): + for memory_type in self.schemas: models[memory_type.memory_type] = self.create_flat_data_model(memory_type) return models @@ -187,12 +179,11 @@ def create_discriminated_union_model(self) -> Type[BaseModel]: self.generate_all_models(include_disabled=True) # Build the annotated union with discriminator - only use enabled types - memory_types = self.registry.list_all(include_disabled=False) - if not memory_types: - raise ValueError("No memory types registered in registry") + if not self.schemas: + raise ValueError("No memory types in schemas") # Create union of flat data models - enabled_memory_types = self.registry.list_all(include_disabled=False) + enabled_memory_types = self.schemas flat_model_union_types = tuple( self._flat_data_models[mt.memory_type] for mt in enabled_memory_types @@ -217,12 +208,24 @@ class MemoryDataWrapper(BaseModel): self._union_model = MemoryDataWrapper return self._union_model + def _is_single_value_schema(self, schema: MemoryTypeSchema) -> bool: + """ + Determine if a schema should output as single value (not list). + + Single value if filename_template does NOT contain {xxx} variable. + For example: + - "profile.md" -> single value + - "{skill_name}.md" -> list + """ + return "{" not in schema.filename_template + def create_structured_operations_model(self) -> Type[BaseModel]: """ Create a structured MemoryOperations model with type-safe write operations. - This uses fully flat models for write_uris and edit_uris, - and simple string URIs for delete_uris. + Each memory_type gets its own field (mixed add + edit), with: + - Single value if filename_template has no variable (e.g., profile) + - List if filename_template has variable (e.g., {skill_name}) Returns: Pydantic model for structured operations @@ -234,54 +237,104 @@ def create_structured_operations_model(self) -> Type[BaseModel]: self.generate_all_models(include_disabled=True) # Get enabled memory types - enabled_memory_types = self.registry.list_all(include_disabled=False) + enabled_memory_types = self.schemas + memory_type_fields = [mt.memory_type for mt in enabled_memory_types] + + # Build field definitions for each memory_type + field_definitions: Dict[str, Tuple[Type[Any], Any]] = {} + + field_definitions["reasoning"] = ( + str, + Field('', description="reasoning"), + ) - # Create union type for flat data models (used for both write and edit) - flat_models: List[Type[BaseModel]] = [] for mt in enabled_memory_types: flat_model = self.create_flat_data_model(mt) - flat_models.append(flat_model) + is_single = self._is_single_value_schema(mt) - FlatDataUnion = Union[tuple(flat_models)] # type: ignore + if is_single: + # Single value: Optional[FlatModel] = None + field_definitions[mt.memory_type] = ( + Optional[flat_model], # type: ignore + Field(None, description=f"{mt.memory_type} memory (add or edit)"), + ) + else: + # List: List[FlatModel] = [] + field_definitions[mt.memory_type] = ( + List[flat_model], # type: ignore + Field(default_factory=list, description=f"{mt.memory_type} memories (add or edit)"), + ) # Use single generic model for overview edit (same for all memory types) generic_overview_edit = self.create_overview_edit_model(enabled_memory_types[0] if enabled_memory_types else None) - # Create structured operations - class StructuredMemoryOperations(BaseModel): - """Final memory operations output from LLM with type safety.""" + field_definitions["edit_overview_uris"] = ( + List[generic_overview_edit], # type: ignore + Field(default_factory=list, description="Edit operations for .overview.md files using memory_type"), + ) - reasoning: str = Field( - '', - description="reasoning", - ) - write_uris: List[FlatDataUnion] = Field( # type: ignore - default_factory=list, - description="Write operations with flat data format", - ) - edit_uris: List[FlatDataUnion] = Field( # type: ignore - default_factory=list, - description="Edit operations with flat data format", - ) - edit_overview_uris: List[generic_overview_edit] = Field( # type: ignore - default_factory=list, - description="Edit operations for .overview.md files using memory_type", - ) - delete_uris: List[str] = Field( - default_factory=list, - description="Delete operations as URI strings", - ) + field_definitions["delete_uris"] = ( + List[str], + Field(default_factory=list, description="Delete operations as URI strings"), + ) - def is_empty(self) -> bool: - """Check if there are any operations.""" - return ( - len(self.write_uris) == 0 - and len(self.edit_uris) == 0 - and len(self.edit_overview_uris) == 0 - and len(self.delete_uris) == 0 - ) + # Create model using create_model + StructuredMemoryOperations = create_model( + 'StructuredMemoryOperations', + __config__=ConfigDict(extra='ignore'), + __base__=FaultTolerantBaseModel, + **field_definitions, + ) + + # Add custom methods + def is_empty(self) -> bool: + """Check if there are any operations.""" + for field_name in memory_type_fields: + value = getattr(self, field_name, None) + if value is not None: + if isinstance(value, list): + if len(value) > 0: + return False + else: + # Single value (not None) + return False + return ( + len(self.edit_overview_uris) == 0 + and len(self.delete_uris) == 0 + ) - model_config = ConfigDict(extra='ignore') + def to_legacy_operations(self) -> Dict[str, Any]: + """Convert new per-type structure to legacy write_uris/edit_uris format.""" + write_uris = [] + edit_uris = [] + + for field_name in memory_type_fields: + value = getattr(self, field_name, None) + if value is None: + continue + if isinstance(value, list): + for item in value: + if hasattr(item, 'uri') and item.uri: + edit_uris.append(item) + else: + write_uris.append(item) + else: + if hasattr(value, 'uri') and value.uri: + edit_uris.append(value) + else: + write_uris.append(value) + + return { + "write_uris": write_uris, + "edit_uris": edit_uris, + "edit_overview_uris": self.edit_overview_uris, + "delete_uris": self.delete_uris, + } + + # Attach methods + StructuredMemoryOperations.is_empty = is_empty + StructuredMemoryOperations.to_legacy_operations = to_legacy_operations + StructuredMemoryOperations._memory_type_fields = memory_type_fields # type: ignore self._operations_model = StructuredMemoryOperations return self._operations_model @@ -315,8 +368,8 @@ class SchemaPromptGenerator: based on the YAML schema definitions. """ - def __init__(self, registry: MemoryTypeRegistry): - self.registry = registry + def __init__(self, schemas: List[MemoryTypeSchema]): + self.schemas = schemas def generate_type_descriptions(self) -> str: """ @@ -327,7 +380,7 @@ def generate_type_descriptions(self) -> str: """ lines = ["## Available Memory Types"] - for mt in self.registry.list_all(): + for mt in self.schemas: lines.append(f"\n### {mt.memory_type}") lines.append(f"{mt.description}") @@ -366,7 +419,7 @@ def generate_field_descriptions(self, memory_type: str) -> Optional[str]: Returns: Formatted string with field descriptions, or None if not found """ - mt = self.registry.get(memory_type) + mt = next((s for s in self.schemas if s.memory_type == memory_type), None) if not mt: return None @@ -399,6 +452,6 @@ def get_full_prompt_context(self) -> Dict[str, Any]: for f in mt.fields ], } - for mt in self.registry.list_all() + for mt in self.schemas ], } diff --git a/openviking/session/memory/session_extract_context_provider.py b/openviking/session/memory/session_extract_context_provider.py new file mode 100644 index 000000000..d246690ec --- /dev/null +++ b/openviking/session/memory/session_extract_context_provider.py @@ -0,0 +1,354 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Session Extract Context Provider - 会话提取 Provider 实现 + +从会话消息中提取记忆的实现。 +""" + +import json +import os +from typing import Any, Dict, List + +from openviking.server.identity import RequestContext +from openviking.session.memory.core import ExtractContextProvider +from openviking.session.memory.memory_type_registry import MemoryTypeRegistry +from openviking.session.memory.tools import ( + add_tool_call_pair_to_messages, + get_tool, +) +from openviking.storage.viking_fs import VikingFS +from openviking_cli.utils import get_logger +from openviking_cli.utils.config import get_openviking_config + +logger = get_logger(__name__) + + +class SessionExtractContextProvider(ExtractContextProvider): + """会话提取 Provider - 从会话消息中提取记忆""" + + def __init__(self, messages: Any, latest_archive_overview: str = ""): + self.messages = messages + self.latest_archive_overview = latest_archive_overview + self._output_language = self._detect_language() + self._registry = None # 延迟加载 + self._schema_directories = None + + def _detect_language(self) -> str: + """检测输出语言""" + from openviking.session.memory.utils import detect_language_from_conversation + conversation = self._assemble_conversation(self.messages) + config = get_openviking_config() + fallback_language = (config.language_fallback or "en").strip() or "en" + return detect_language_from_conversation(conversation, fallback_language=fallback_language) + + def instruction(self) -> str: + output_language = self._output_language + goal = f"""You are a memory extraction agent. Your task is to analyze conversations and update memories. + +## Workflow +1. Analyze the conversation and pre-fetched context +2. If you need more information, use the available tools (read/search) +3. When you have enough information, output ONLY a JSON object (no extra text before or after) + +## Critical +- ONLY read and search tools are available - DO NOT use write tool +- Before editing ANY existing memory file, you MUST first read its complete content +- ONLY read URIs that are explicitly listed in ls tool results or returned by previous tool calls + +## Target Output Language +All memory content MUST be written in {output_language}. + +## URI Handling +The system automatically generates URIs based on memory_type and fields. Just provide correct memory_type and fields. + +## Edit Overview Files +After writing new memories, you MUST also update the corresponding .overview.md file. +- Provide memory_type to identify which directory's overview to update + +## Overview Format +Two options: +1. **PREFERRED: Direct string** - Just provide the complete new overview content: + {{"memory_type": "events", "overview": "# Events Overview\n- [event1](event1.md) - Description"}} +2. **SEARCH/REPLACE** - Only use if you must modify a small portion: + {{"memory_type": "events", "overview": {{"blocks": [{{"search": "exact line to change", "replace": "new line"}}]}}}} + +See GenericOverviewEdit in the JSON Schema below.""" + + return goal + + def _build_conversation_message(self) -> Dict[str, Any]: + """构建包含 Conversation History 的 user message""" + from datetime import datetime + if self.messages: + first_msg_time = getattr(self.messages[0], "created_at", None) + last_msg_time = getattr(self.messages[-1], "created_at", None) + else: + first_msg_time = None + last_msg_time = None + + if first_msg_time: + session_time = first_msg_time + else: + session_time = datetime.now() + + session_time_str = session_time.strftime("%Y-%m-%d %H:%M") + day_of_week = session_time.strftime("%A") + + # 检查是否需要显示范围 + if last_msg_time and last_msg_time != first_msg_time: + time_display = f"{session_time_str} - {last_msg_time.strftime('%Y-%m-%d %H:%M')}" + else: + time_display = session_time_str + + conversation = self._assemble_conversation(self.messages) + + return { + "role": "user", + "content": f"""## Conversation History +**Session Time:** {time_display} ({day_of_week}) +Relative times (e.g., 'last week', 'next month') are based on Session Time, not today. + +{conversation} + +After exploring, analyze the conversation and output ALL memory write/edit/delete operations in a single response. Do not output operations one at a time - gather all changes first, then return them together.""" + } + + def _assemble_conversation(self, messages: Any) -> str: + """Assemble conversation string from messages. + + Args: + messages: List of Message objects + latest_archive_overview: Optional overview from previous archive for context + + Returns: + Formatted conversation string + """ + import json + from openviking.message import Message + from openviking.message.part import ToolPart + + conversation_sections: List[str] = [] + + def format_message_with_parts(msg: Message) -> str: + """Format message with text and tool parts.""" + parts = getattr(msg, "parts", []) + has_tool_parts = any(isinstance(p, ToolPart) for p in parts) + + if not has_tool_parts: + return msg.content + + tool_lines = [] + text_lines = [] + for part in parts: + if hasattr(part, "text") and part.text: + text_lines.append(part.text) + elif isinstance(part, ToolPart): + tool_info = { + "type": "tool_call", + "tool_name": part.tool_name, + "tool_input": part.tool_input, + "tool_status": part.tool_status, + } + if part.skill_uri: + tool_info["skill_name"] = part.skill_uri.rstrip("/").split("/")[-1] + tool_lines.append(f"[ToolCall] {json.dumps(tool_info, ensure_ascii=False)}") + + all_lines = tool_lines + text_lines + return "\n".join(all_lines) if all_lines else msg.content + + conversation_sections.append( + "\n".join([f"[{idx}][{msg.role}]: {format_message_with_parts(msg)}" for idx, msg in enumerate(messages)]) + ) + + return "\n\n".join(section for section in conversation_sections if section) + + async def prefetch( + self, + ctx: RequestContext, + viking_fs: VikingFS, + transaction_handle, + vlm, + ) -> List[Dict]: + """ + 执行 prefetch - 从会话消息中提取相关记忆上下文 + + Args: + ctx: RequestContext + viking_fs: VikingFS + transaction_handle: 事务句柄 + vlm: VLM 实例 + + Returns: + 预取的消息列表,第一个元素是 Conversation History user message,后续是 tool call messages + """ + messages = self.messages + + if not isinstance(messages, list): + logger.warning(f"Expected List[Message], got {type(messages)}") + return [] + + # 先构建 Conversation History user message + pre_fetch_messages = [] + pre_fetch_messages.append(self._build_conversation_message()) + + # 触发 registry 加载 + schemas = self._get_registry().list_all(include_disabled=False) + + from openviking.server.identity import ToolContext + + # Step 1: Separate schemas into multi-file (ls) and single-file (direct read) + ls_dirs = set() # directories to ls (for multi-file schemas) + read_files = set() # files to read directly (for single-file schemas) + overview_files = set() # .overview.md files to read + + for schema in schemas: + if not schema.directory: + continue + + # Replace variables in directory path with actual user/agent space + user_space = ctx.user.user_space_name() if ctx and ctx.user else "default" + agent_space = ctx.user.agent_space_name() if ctx and ctx.user else "default" + dir_path = schema.directory.replace("{user_space}", user_space).replace("{agent_space}", agent_space) + + # Always add .overview.md to read list + overview_files.add(f"{dir_path}/.overview.md") + + # 根据 operation_mode 决定是否需要 ls 和读取其他文件 + if schema.operation_mode == "add_only": + # 只新增,不需要查看之前的记忆列表,只需要读取 .overview.md + continue + + # Check if filename_template has variables (contains {xxx}) + has_variables = False + if schema.filename_template: + has_variables = "{" in schema.filename_template and "}" in schema.filename_template + + if has_variables or not schema.filename_template: + # Multi-file schema or no filename template: ls the directory + ls_dirs.add(dir_path) + else: + # Single-file schema: directly read the specific file + file_uri = f"{dir_path}/{schema.filename_template}" + read_files.add(file_uri) + + call_id_seq = 0 + # Step 2: Execute search for each ls directory (instead of ls) + read_tool = get_tool("read") + search_tool = get_tool("search") + + # 首先读取所有 .overview.md 文件(截断以避免窗口过大) + # 为 overview 读取创建一个基本的 tool_ctx + tool_ctx = ToolContext( + request_ctx=ctx, + transaction_handle=transaction_handle, + default_search_uris=[] + ) + for overview_uri in overview_files: + try: + result_str = await read_tool.execute(viking_fs, tool_ctx, uri=overview_uri) + add_tool_call_pair_to_messages( + messages=pre_fetch_messages, + call_id=call_id_seq, + tool_name='read', + params={ + "uri": overview_uri + }, + result=result_str + ) + call_id_seq += 1 + except Exception as e: + logger.warning(f"Failed to read .overview.md: {e}") + + # 在每个之前 ls 的目录内执行 search(替换原来的 ls 操作) + if search_tool and viking_fs and ls_dirs: + for dir_uri in ls_dirs: + # 创建只在该目录搜索的 tool_ctx + tool_ctx_dir = ToolContext( + request_ctx=ctx, + transaction_handle=transaction_handle, + default_search_uris=[dir_uri] + ) + try: + search_result = await search_tool.execute( + viking_fs=viking_fs, + ctx=tool_ctx_dir, + query="[Keywords]", + ) + # 处理搜索结果 + if isinstance(search_result, list): + result_value = [m.get("uri", "") for m in search_result] + elif isinstance(search_result, dict): + if "error" in search_result: + result_value = f"Error: {search_result.get('error')}" + else: + result_value = [m.get("uri", "") for m in search_result.get("memories", [])] + else: + result_value = [] + + add_tool_call_pair_to_messages( + messages=pre_fetch_messages, + call_id=call_id_seq, + tool_name='search', + params={ + "query": "[Keywords]", + "search_uri": dir_uri + }, + result=result_value + ) + call_id_seq += 1 + except Exception as e: + logger.warning(f"Failed to search in {dir_uri}: {e}") + + # 读取单文件 schema 的文件(只对非 add_only 模式) + for file_uri in read_files: + try: + result_str = await read_tool.execute(viking_fs, tool_ctx, uri=file_uri) + add_tool_call_pair_to_messages( + messages=pre_fetch_messages, + call_id=call_id_seq, + tool_name='read', + params={ + "uri": file_uri + }, + result=result_str + ) + call_id_seq += 1 + except Exception as e: + logger.warning(f"Failed to read {file_uri}: {e}") + + + return pre_fetch_messages + + def get_tools(self) -> List[str]: + """获取可用的工具列表 - 会话场景只使用 read""" + return ["read"] + + def get_memory_schemas(self, ctx: RequestContext) -> List[Any]: + """获取需要参与的 memory schemas(内部自动加载)""" + return self._get_registry().list_all(include_disabled=False) + + def get_schema_directories(self) -> List[str]: + """返回需要加载的 schema 目录""" + if self._schema_directories is None: + builtin_dir = os.path.join( + os.path.dirname(__file__), "..", "..", "prompts", "templates", "memory" + ) + config = get_openviking_config() + custom_dir = config.memory.custom_templates_dir + self._schema_directories = [builtin_dir] + if custom_dir: + custom_dir_expanded = os.path.expanduser(custom_dir) + if os.path.exists(custom_dir_expanded): + self._schema_directories.append(custom_dir_expanded) + return self._schema_directories + + def _get_registry(self) -> MemoryTypeRegistry: + """内部获取 registry(自动加载)""" + if self._registry is None: + self._registry = MemoryTypeRegistry() + for dir_path in self.get_schema_directories(): + if os.path.exists(dir_path): + self._registry.load_from_directory(dir_path) + return self._registry + diff --git a/openviking/session/memory/tools.py b/openviking/session/memory/tools.py index 8878aa3da..bfb31bc56 100644 --- a/openviking/session/memory/tools.py +++ b/openviking/session/memory/tools.py @@ -18,57 +18,42 @@ logger = get_logger(__name__) -def create_tool_call_message( - call_id: Union[str, int], - tool_name: str, - params: Dict[str, Any], -) -> Dict[str, Any]: - """ - Create an assistant role message with tool_calls. - - Args: - call_id: Unique identifier for the tool call - tool_name: Name of the tool being called - params: Parameters for the tool call - - Returns: - Assistant message with tool_calls field - """ - return { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": str(call_id), - "type": "function", - "function": { - "name": tool_name, - "arguments": json.dumps(params), - }, - } - ], - } -def create_tool_result_message( - call_id: Union[str, int], - result: Any, -) -> Dict[str, Any]: - """ - Create a tool role message with the tool execution result. - - Args: - call_id: Unique identifier matching the tool call - result: Result from the tool execution - Returns: - Tool message with result content - """ - return { - "role": "tool", - "tool_call_id": str(call_id), - "content": json.dumps(result, ensure_ascii=False), - } +def optimize_search_result(result: Any, limit: int = 10) -> Any: + """优化搜索结果以减少 Token 消耗,并过滤掉抽象文件。""" + if isinstance(result, dict) and "error" in result: + return {"error": extract_error_summary(result["error"])} + if isinstance(result, dict) and "memories" in result: + filtered = [ + item for item in result["memories"] + if not (item.get("uri", "").endswith(".abstract.md") or item.get("uri", "").endswith(".overview.md")) + ] + return [{"uri": item["uri"], "score": item["score"]} for item in filtered[:limit]] + return [] + + +def optimize_tool_result(tool_name: str, result: Any) -> Any: + """优化工具结果以减少 Token 消耗。""" + if isinstance(result, dict) and "error" in result: + return {"error": extract_error_summary(result["error"])} + if tool_name == "search" and isinstance(result, dict) and "memories" in result: + return optimize_search_result(result) + # 对 read 工具返回的 dict,如果包含 content 字段,则截断 content + if tool_name == "read" and isinstance(result, dict) and "content" in result: + result = result.copy() + result["content"] = truncate_content(result["content"]) + return result + +def extract_error_summary(error: str) -> str: + if "File not found" in error: + return "File not found" + if "Permission denied" in error: + return "Permission denied" + if "Timeout" in error: + return "Timeout" + return error[:50] def add_tool_call_pair_to_messages( @@ -78,18 +63,15 @@ def add_tool_call_pair_to_messages( params: Dict[str, Any], result: Any, ) -> None: - """ - Add a pair of tool call + tool result messages to the messages list. - - Args: - messages: List to append messages to - call_id: Unique identifier for the tool call - tool_name: Name of the tool being called - params: Parameters for the tool call - result: Result from the tool execution - """ - messages.append(create_tool_call_message(call_id, tool_name, params)) - messages.append(create_tool_result_message(call_id, result)) + """Add a tool call pair with optimized format to save tokens.""" + messages.append({ + "role": "user", + "content": { + "tool_call_name": tool_name, + "args": params, + "result": result + } + }) def add_tool_call_items_to_messages( @@ -137,7 +119,7 @@ def parameters(self) -> Dict[str, Any]: async def execute( self, viking_fs: VikingFS, - ctx: Optional[RequestContext], + ctx: Optional["ToolContext"], **kwargs: Any, ) -> Any: """ @@ -145,7 +127,7 @@ async def execute( Args: viking_fs: VikingFS instance - ctx: Request context + ctx: Tool context **kwargs: Tool-specific parameters Returns: @@ -192,14 +174,14 @@ def parameters(self) -> Dict[str, Any]: async def execute( self, viking_fs: VikingFS, - ctx: Optional[RequestContext], + ctx: Optional["ToolContext"], **kwargs: Any, ) -> Any: try: uri = kwargs.get("uri", "") content = await viking_fs.read_file( uri, - ctx=ctx, + ctx=ctx.request_ctx, ) # Parse MEMORY_FIELDS from comment and return dict directly parsed = parse_memory_file_with_fields(content) @@ -218,7 +200,7 @@ def name(self) -> str: @property def description(self) -> str: - return "Semantic search with session context, target_uri is target directory URI" + return "Semantic search with session context" @property def parameters(self) -> Dict[str, Any]: @@ -229,28 +211,11 @@ def parameters(self) -> Dict[str, Any]: "type": "string", "description": "Search query text", }, - "target_uri": { - "type": "string", - "description": "Target directory URI, default empty means search all", - "default": "", - }, - "session_info": { - "type": "object", - "description": "Session information with latest_archive_overview and current_messages, optional", - }, "limit": { "type": "integer", "description": "Maximum results to return, default 10", "default": 10, }, - "score_threshold": { - "type": "number", - "description": "Score threshold, optional", - }, - "filter": { - "type": "object", - "description": "Filter conditions, optional", - }, }, "required": ["query"], } @@ -258,34 +223,24 @@ def parameters(self) -> Dict[str, Any]: async def execute( self, viking_fs: VikingFS, - ctx: Optional[RequestContext], + ctx: Optional["ToolContext"], **kwargs: Any, ) -> Any: try: query = kwargs.get("query", "") - target_uri = kwargs.get("target_uri", "") - # If target_uri is empty, use default from ctx - if ( - not target_uri - and ctx - and hasattr(ctx, "default_search_uris") - and ctx.default_search_uris - ): + # Get target_uri from ctx.default_search_uris + target_uri = "" + if ctx and hasattr(ctx, "default_search_uris") and ctx.default_search_uris: target_uri = ctx.default_search_uris - session_info = kwargs.get("session_info") limit = kwargs.get("limit", 10) - score_threshold = kwargs.get("score_threshold") - filter = kwargs.get("filter") + # 多搜索 10 个,过滤抽象文件后再截断 search_result = await viking_fs.search( query, target_uri=target_uri, - session_info=session_info, - limit=limit, - score_threshold=score_threshold, - filter=filter, + limit=limit + 10, ctx=ctx, ) - return search_result.to_dict() + return optimize_search_result(search_result.to_dict(), limit=limit) except Exception as e: logger.error(f"Failed to execute search: {e}") return {"error": str(e)} @@ -328,7 +283,7 @@ def parameters(self) -> Dict[str, Any]: async def execute( self, viking_fs: VikingFS, - ctx: Optional[RequestContext], + ctx: Optional["ToolContext"], **kwargs: Any, ) -> Any: try: @@ -339,7 +294,7 @@ async def execute( abs_limit=256, show_all_hidden=False, node_limit=1000, - ctx=ctx, + ctx=ctx.request_ctx, ) # Format: filename size (e.g., "file.md 1.2K") result_lines = [] @@ -381,7 +336,7 @@ def list_tools() -> Dict[str, MemoryTool]: # Tools exposed to LLM (not all registered tools are exposed) -LLM_TOOLS = ["read", "search"] +LLM_TOOLS = ["read"] def get_tool_schemas() -> List[Dict[str, Any]]: diff --git a/openviking/session/memory/utils/__init__.py b/openviking/session/memory/utils/__init__.py index a516fd0bd..2ab288176 100644 --- a/openviking/session/memory/utils/__init__.py +++ b/openviking/session/memory/utils/__init__.py @@ -9,6 +9,7 @@ deserialize_full, deserialize_metadata, serialize_with_metadata, + truncate_content, ) from openviking.session.memory.utils.language import ( detect_language_from_conversation, @@ -18,6 +19,7 @@ pretty_print_messages, ) from openviking.session.memory.utils.uri import ( + ResolvedOperation, ResolvedOperations, collect_allowed_directories, collect_allowed_path_patterns, @@ -35,7 +37,6 @@ _get_arg_type, _get_origin_type, extract_json_content, - extract_json_from_markdown, parse_json_with_stability, parse_value_with_tolerance, remove_json_trailing_content, @@ -52,6 +53,7 @@ "deserialize_content", "deserialize_metadata", "deserialize_full", + "truncate_content", # Language "detect_language_from_conversation", # Messages @@ -66,6 +68,7 @@ "is_uri_allowed_for_schema", "extract_uri_fields_from_flat_model", "resolve_flat_model_uri", + "ResolvedOperation", "ResolvedOperations", "resolve_all_operations", "validate_operations_uris", @@ -73,7 +76,6 @@ "extract_json_content", "remove_json_trailing_content", "parse_json_with_stability", - "extract_json_from_markdown", "value_fault_tolerance", "parse_value_with_tolerance", "_get_origin_type", diff --git a/openviking/session/memory/utils/content.py b/openviking/session/memory/utils/content.py index 3a8a12998..d66732aef 100644 --- a/openviking/session/memory/utils/content.py +++ b/openviking/session/memory/utils/content.py @@ -157,3 +157,30 @@ def deserialize_full(full_content: str) -> Tuple[str, Optional[Dict[str, Any]]]: content = deserialize_content(full_content) metadata = deserialize_metadata(full_content) return content, metadata + + +# 默认截断配置 +DEFAULT_TRUNCATE_MAX_CHARS = 1000 + + +def truncate_content(content: str, max_chars: int = DEFAULT_TRUNCATE_MAX_CHARS) -> str: + """ + Truncate content to max_chars while keeping complete lines. + + Args: + content: Content to truncate + max_chars: Maximum number of characters to keep (default: 1000) + + Returns: + Truncated content with truncation note appended + """ + if len(content) <= max_chars: + return content + + # 从 max_chars 位置向前找最近的换行符,保持完整行 + truncated = content[:max_chars] + last_newline = truncated.rfind('\n') + if last_newline > 0: + truncated = truncated[:last_newline] + + return truncated + f"\n... [truncated {len(content) - len(truncated)} chars]" diff --git a/openviking/session/memory/utils/json_parser.py b/openviking/session/memory/utils/json_parser.py index 08bbb4012..6b1474c8d 100644 --- a/openviking/session/memory/utils/json_parser.py +++ b/openviking/session/memory/utils/json_parser.py @@ -4,14 +4,13 @@ JSON stable parsing - Five-Layer Fault Tolerance Architecture. Layer 1: JSON Cleanup - extract_json_content() -Layer 2: JSON Repair - json_repair.loads() +Layer 2: JSON Repair - json_repair.loads() (handles markdown too) Layer 3: Structure Tolerance - list→object conversion + field filtering Layer 4: Value Tolerance - value_fault_tolerance() Layer 5: Validation Tolerance - TypeAdapter(strict=False) + list item filtering """ import json -import re from types import UnionType from typing import Any, Dict, List, Optional, Tuple, Type, get_type_hints, get_origin, get_args, Union @@ -28,7 +27,6 @@ "extract_json_content", "remove_json_trailing_content", "parse_json_with_stability", - "extract_json_from_markdown", "value_fault_tolerance", "parse_value_with_tolerance", "_get_origin_type", @@ -270,10 +268,41 @@ def parse_value_with_tolerance(value, annotation): if value == 'None': return None - parsed_value = value - - # Apply value fault tolerance first - parsed_value = value_fault_tolerance(annotation, parsed_value) + # Apply value fault tolerance (inline for efficiency) + origin_type = _get_origin_type(annotation) + if origin_type is str: + parsed_value = _any_to_str(value) + elif origin_type is int: + if isinstance(value, str): + if value == 'None': + parsed_value = 0 + else: + try: + parsed_value = int(value) + except (ValueError, TypeError): + parsed_value = value + else: + parsed_value = value + elif origin_type is float: + if isinstance(value, str): + if value == 'None': + parsed_value = 0.0 + else: + try: + parsed_value = float(value) + except (ValueError, TypeError): + parsed_value = value + else: + parsed_value = value + elif origin_type is list: + if isinstance(value, str): + parsed_value = [value] + elif isinstance(value, dict): + parsed_value = [value] + else: + parsed_value = value + else: + parsed_value = value # Try validation with TypeAdapter try: @@ -378,7 +407,7 @@ def parse_json_with_stability( return model_class.model_validate(parsed_data), None except Exception as e: logger.warning(f"Direct model validation failed, trying parse_value_with_tolerance: {e}") - + logger.warning(f"content={content}") # Fallback: Apply value fault tolerance to each field individually try: field_types = get_type_hints(model_class) @@ -401,30 +430,3 @@ def parse_json_with_stability( return None, f"Model validation failed even after tolerance: {e} (fallback: {e2})" -def extract_json_from_markdown(content: str) -> str: - """ - Extract JSON from markdown code blocks. - - Handles: - - ```json { ... } ``` - - ``` { ... } ``` - - Plain JSON without markdown - - Args: - content: Content possibly containing markdown code blocks - - Returns: - Extracted JSON string - """ - if not content: - return content - - content = content.strip() - - # Try to find ```json ... ``` - match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content, re.DOTALL) - if match: - return match.group(1).strip() - - # If no code block, return as-is - return content diff --git a/openviking/session/memory/utils/messages.py b/openviking/session/memory/utils/messages.py index 8ac895d57..9c3de6972 100644 --- a/openviking/session/memory/utils/messages.py +++ b/openviking/session/memory/utils/messages.py @@ -10,6 +10,7 @@ import json_repair +from openviking.session.memory.utils import truncate_content from openviking_cli.utils import get_logger logger = get_logger(__name__) @@ -25,61 +26,54 @@ def pretty_print_messages(messages: List[Dict[str, Any]]) -> None: Args: messages: List of message dictionaries with 'role', 'content', and optional 'tool_calls' """ - def _format_tool_call(tc: Dict[str, Any]) -> Dict[str, Any]: - """Format a single tool call, pretty-printing arguments if it's JSON.""" - tc_copy = dict(tc) - if "function" in tc_copy and "arguments" in tc_copy["function"]: - args_str = tc_copy["function"]["arguments"] - if isinstance(args_str, str): - try: - # Try to parse and pretty-print the arguments - args_json = json.loads(args_str) - tc_copy["function"] = dict(tc_copy["function"]) - tc_copy["function"]["arguments"] = args_json - except (json.JSONDecodeError, TypeError): - # If it's not valid JSON, leave it as is - pass - return tc_copy - - print("=== Messages ===") + output = ["=== Messages ==="] for msg in messages: role = msg.get("role", "unknown") content = msg.get("content", "") - if role == "tool": - # Tool result - show correspondence with tool_call_id + if role == "tool_call": + # Optimized tool call format - print as JSON to match stored format + output.append(f"\n[{role}]") + output.append(json.dumps(msg, ensure_ascii=False, indent=2)) + elif role == "tool": + # Legacy tool result format tool_call_id = msg.get("tool_call_id", "") - print(f"\n[{role}] (id={tool_call_id})") + output.append(f"\n[{role}] (id={tool_call_id})") if content: - # Try to pretty-print tool result if it's JSON try: result_json = json.loads(content) - print(json.dumps(result_json, indent=2, ensure_ascii=False)) + output.append(json.dumps(result_json, indent=2, ensure_ascii=False)) except (json.JSONDecodeError, TypeError): - # If it's not valid JSON, print as is - print(content) + output.append(content) else: if content: - print(f"\n[{role}]") - print(content) + output.append(f"\n[{role}]") + # Handle content as dict (e.g., tool_call format) + if isinstance(content, dict): + output.append(json.dumps(content, ensure_ascii=False, indent=2)) + else: + output.append(content) if "tool_calls" in msg and msg["tool_calls"]: + # Legacy tool call format tool_calls = msg["tool_calls"] if len(tool_calls) == 1: - # Single tool call - show its id tc = tool_calls[0] tc_id = tc.get("id", "") tc_name = tc.get("function", {}).get("name", "") - print(f"\n[{role} tool_call] (id={tc_id}, name={tc_name})") - formatted_tc = _format_tool_call(tc) - print(json.dumps(formatted_tc, indent=2, ensure_ascii=False)) + output.append(f"\n[{role} tool_call] (id={tc_id}, name={tc_name})") + args_str = tc.get("function", {}).get("arguments", {}) + try: + args_json = json.loads(args_str) + output.append(json.dumps(args_json, indent=2, ensure_ascii=False)) + except: + output.append(args_str) else: - # Multiple tool calls - print(f"\n[{role} tool_calls]") - formatted_tcs = [_format_tool_call(tc) for tc in tool_calls] - print(json.dumps(formatted_tcs, indent=2, ensure_ascii=False)) + output.append(f"\n[{role} tool_calls]") + output.append(json.dumps(tool_calls, indent=2, ensure_ascii=False)) - print("\n=== End Messages ===") + output.append("\n=== End Messages ===") + logger.info("\n".join(output)) def parse_memory_file_with_fields(content: str) -> Dict[str, Any]: @@ -121,6 +115,8 @@ def parse_memory_file_with_fields(content: str) -> Dict[str, Any]: # Remove the comment from content content_without_comment = re.sub(pattern, "", content).strip() + + content_without_comment = truncate_content(content_without_comment) result["content"] = content_without_comment return result diff --git a/openviking/session/memory/utils/uri.py b/openviking/session/memory/utils/uri.py index f47d9d545..0ce9f64fb 100644 --- a/openviking/session/memory/utils/uri.py +++ b/openviking/session/memory/utils/uri.py @@ -5,6 +5,7 @@ """ import re +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type from openviking.session.memory.dataclass import MemoryTypeSchema @@ -14,6 +15,14 @@ logger = get_logger(__name__) +@dataclass +class ResolvedOperation: + """A resolved memory operation with URI and memory_type.""" + model: Any # The flat model data + uri: str # The resolved URI + memory_type: str # The memory type (e.g., 'tools', 'skills', 'events') + + def generate_uri( memory_type: MemoryTypeSchema, fields: Dict[str, Any], @@ -274,15 +283,17 @@ def resolve_flat_model_uri( registry: MemoryTypeRegistry, user_space: str = "default", agent_space: str = "default", + memory_type: Optional[str] = None, ) -> str: """ Resolve URI for a flat model (used for both write and edit operations). Args: - flat_model: Flat model instance with memory_type and business fields + flat_model: Flat model instance with business fields registry: MemoryTypeRegistry to get schema user_space: User space for substitution agent_space: Agent space for substitution + memory_type: Optional memory_type - if provided, use it instead of reading from model Returns: Resolved URI @@ -290,8 +301,10 @@ def resolve_flat_model_uri( Raises: ValueError: If memory_type not found or URI generation fails """ - # Get memory_type from the model - if hasattr(flat_model, 'memory_type'): + # Get memory_type from parameter or from model + if memory_type: + memory_type_str = memory_type + elif hasattr(flat_model, 'memory_type'): memory_type_str = flat_model.memory_type elif isinstance(flat_model, dict) and 'memory_type' in flat_model: memory_type_str = flat_model['memory_type'] @@ -361,8 +374,8 @@ class ResolvedOperations: """Operations with resolved URIs.""" def __init__(self): - self.write_operations: List[Tuple[Any, str]] = [] # (flat_model, resolved_uri) - self.edit_operations: List[Tuple[Any, str]] = [] # (flat_model, resolved_uri) + self.write_operations: List[ResolvedOperation] = [] + self.edit_operations: List[ResolvedOperation] = [] self.edit_overview_operations: List[Tuple[Any, str]] = [] # (overview_edit_model, overview_uri) self.delete_operations: List[Tuple[str, str]] = [] # (uri_str, uri_str) - just the uri self.errors: List[str] = [] @@ -378,10 +391,12 @@ def resolve_all_operations( agent_space: str = "default", ) -> ResolvedOperations: """ - Resolve URIs for all operations using the new flat model format. + Resolve URIs for all operations. + + Supports both legacy format (write_uris/edit_uris) and new per-memory_type format. Args: - operations: StructuredMemoryOperations with write_uris, edit_uris, delete_uris + operations: StructuredMemoryOperations registry: MemoryTypeRegistry to get schemas user_space: User space for substitution agent_space: Agent space for substitution @@ -391,21 +406,51 @@ def resolve_all_operations( """ resolved = ResolvedOperations() - # Resolve write operations (flat models) - if hasattr(operations, 'write_uris'): - for op in operations.write_uris: + # Check if using new per-memory_type format + memory_type_fields = getattr(operations, '_memory_type_fields', None) + if memory_type_fields: + # New format: iterate each memory_type field + for field_name in memory_type_fields: + value = getattr(operations, field_name, None) + if value is None: + continue + items = value if isinstance(value, list) else [value] + for item in items: + # Determine if edit (has uri) or write + is_edit = False + if hasattr(item, 'uri') and item.uri: + is_edit = True + elif isinstance(item, dict) and item.get('uri'): + is_edit = True + # Convert to dict for URI resolution + item_dict = dict(item) if hasattr(item, 'model_dump') else dict(item) + try: + uri = resolve_flat_model_uri(item_dict, registry, user_space, agent_space, memory_type=field_name) + if is_edit: + resolved.edit_operations.append(ResolvedOperation(model=item_dict, uri=uri, memory_type=field_name)) + else: + resolved.write_operations.append(ResolvedOperation(model=item_dict, uri=uri, memory_type=field_name)) + except Exception as e: + resolved.errors.append(f"Failed to resolve {field_name} operation: {e}") + else: + # Legacy format + write_uris = operations.write_uris if hasattr(operations, 'write_uris') else [] + edit_uris = operations.edit_uris if hasattr(operations, 'edit_uris') else [] + + for op in write_uris: try: uri = resolve_flat_model_uri(op, registry, user_space, agent_space) - resolved.write_operations.append((op, uri)) + # Legacy format: try to get memory_type from model, otherwise empty + memory_type = op.get('memory_type', '') if isinstance(op, dict) else '' + resolved.write_operations.append(ResolvedOperation(model=op, uri=uri, memory_type=memory_type)) except Exception as e: resolved.errors.append(f"Failed to resolve write operation: {e}") - # Resolve edit operations (flat models) - if hasattr(operations, 'edit_uris'): - for op in operations.edit_uris: + for op in edit_uris: try: uri = resolve_flat_model_uri(op, registry, user_space, agent_space) - resolved.edit_operations.append((op, uri)) + memory_type = op.get('memory_type', '') if isinstance(op, dict) else '' + resolved.edit_operations.append(ResolvedOperation(model=op, uri=uri, memory_type=memory_type)) except Exception as e: resolved.errors.append(f"Failed to resolve edit operation: {e}") @@ -462,13 +507,13 @@ def validate_operations_uris( errors.extend(resolved.errors) else: # Validate resolved URIs - for _op, uri in resolved.write_operations: - if not is_uri_allowed(uri, allowed_dirs, allowed_patterns): - errors.append(f"Write operation URI not allowed: {uri}") + for resolved_op in resolved.write_operations: + if not is_uri_allowed(resolved_op.uri, allowed_dirs, allowed_patterns): + errors.append(f"Write operation URI not allowed: {resolved_op.uri}") - for _op, uri in resolved.edit_operations: - if not is_uri_allowed(uri, allowed_dirs, allowed_patterns): - errors.append(f"Edit operation URI not allowed: {uri}") + for resolved_op in resolved.edit_operations: + if not is_uri_allowed(resolved_op.uri, allowed_dirs, allowed_patterns): + errors.append(f"Edit operation URI not allowed: {resolved_op.uri}") for _op, uri in resolved.edit_overview_operations: if not is_uri_allowed(uri, allowed_dirs, allowed_patterns): diff --git a/openviking/session/session.py b/openviking/session/session.py index 3592594b3..a71bf535b 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -301,13 +301,14 @@ def add_message( self, role: str, parts: List[Part], + created_at: datetime = None, ) -> Message: """Add a message.""" msg = Message( id=f"msg_{uuid4().hex}", role=role, parts=parts, - created_at=datetime.now(), + created_at=created_at or datetime.now(), ) self._messages.append(msg) diff --git a/openviking/storage/transaction/lock_manager.py b/openviking/storage/transaction/lock_manager.py index 2fec7e42a..0b46d3252 100644 --- a/openviking/storage/transaction/lock_manager.py +++ b/openviking/storage/transaction/lock_manager.py @@ -5,7 +5,7 @@ import asyncio import json import time -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from openviking.pyagfs import AGFSClient from openviking.storage.transaction.lock_handle import LockHandle @@ -80,6 +80,60 @@ async def acquire_subtree( path, handle, timeout=timeout if timeout is not None else self._lock_timeout ) + async def acquire_subtree_batch( + self, + handle: LockHandle, + paths: List[str], + timeout: Optional[float] = None, + ) -> bool: + """ + 一次性对多个路径进行子树加锁,使用有序加锁法防止死锁 + + 核心思想: + 1. 对路径按照固定的顺序进行排序,确保所有进程获取锁的顺序一致 + 2. 防止循环等待条件,从而避免死锁 + + 排序规则: + 1. 路径长度升序 + 2. 长度相同的路径按照字典序升序 + + Args: + handle: 锁句柄 + paths: 需要加锁的路径列表 + timeout: 超时时间,None表示无限等待 + + Returns: + 是否成功获取所有锁 + """ + if not paths: + return True + + # 对路径进行排序,确保加锁顺序一致 + sorted_paths = sorted(paths, key=lambda x: (len(x), x)) + acquired = [] + + try: + for path in sorted_paths: + success = await self._path_lock.acquire_subtree( + path, + handle, + timeout=timeout, + ) + if not success: + # 释放已获得的锁 + for p in acquired: + await self._path_lock.release_subtree(p, handle) + return False + acquired.append(path) + + return True + + except Exception as e: + logger.error(f"Failed to acquire subtree batch lock: {e}") + for p in acquired: + await self._path_lock.release_subtree(p, handle) + return False + async def acquire_mv( self, handle: LockHandle, @@ -140,7 +194,6 @@ async def _redo_session_memory(self, info: Dict[str, Any]) -> None: """ from openviking.message import Message from openviking.server.identity import RequestContext, Role - from openviking.session.compressor import SessionCompressor from openviking.storage.viking_fs import get_viking_fs from openviking_cli.session.user_id import UserIdentifier @@ -180,7 +233,9 @@ async def _redo_session_memory(self, info: Dict[str, Any]) -> None: if messages: session_id = session_uri.rstrip("/").rsplit("/", 1)[-1] try: - compressor = SessionCompressor(vikingdb=None) + from openviking.session import create_session_compressor + + compressor = create_session_compressor(vikingdb=None) memories = await compressor.extract_long_term_memories( messages=messages, user=user, diff --git a/openviking/storage/transaction/path_lock.py b/openviking/storage/transaction/path_lock.py index 2aaaecf10..0f62a68cf 100644 --- a/openviking/storage/transaction/path_lock.py +++ b/openviking/storage/transaction/path_lock.py @@ -58,6 +58,26 @@ def _get_lock_path(self, path: str) -> str: path = path.rstrip("/") return f"{path}/{LOCK_FILE_NAME}" + def _ensure_directory_exists(self, path: str): + """确保目录存在,不存在则创建""" + try: + # 检查路径是否存在 + self._agfs.stat(path) + except Exception: + # 路径不存在,尝试创建目录 + try: + parent = self._get_parent_path(path) + if parent: + # 递归创建父目录 + self._ensure_directory_exists(parent) + # 创建当前目录 + self._agfs.mkdir(path) + logger.debug(f"Directory created: {path}") + except Exception as e: + logger.warning(f"Failed to create directory {path}: {e}") + return False + return True + def _get_parent_path(self, path: str) -> Optional[str]: path = path.rstrip("/") if "/" not in path: @@ -154,15 +174,19 @@ async def _scan_descendants_for_locks(self, path: str, exclude_owner_id: str) -> logger.warning(f"Failed to scan descendants of {path}: {e}") return None - async def acquire_point(self, path: str, owner: LockOwner, timeout: float = 0.0) -> bool: + async def acquire_point(self, path: str, owner: LockOwner, timeout: Optional[float] = 0.0) -> bool: owner_id = owner.id lock_path = self._get_lock_path(path) - deadline = asyncio.get_running_loop().time() + timeout + if timeout is None: + # 无限等待 + deadline = float('inf') + else: + # 有限超时 + deadline = asyncio.get_running_loop().time() + timeout - try: - self._agfs.stat(path) - except Exception: - logger.warning(f"[POINT] Directory does not exist: {path}") + # 确保目录存在 + if not self._ensure_directory_exists(path): + logger.warning(f"[POINT] Failed to ensure directory exists: {path}") return False while True: @@ -231,15 +255,19 @@ async def acquire_point(self, path: str, owner: LockOwner, timeout: float = 0.0) logger.debug(f"[POINT] Lock acquired: {lock_path}") return True - async def acquire_subtree(self, path: str, owner: LockOwner, timeout: float = 0.0) -> bool: + async def acquire_subtree(self, path: str, owner: LockOwner, timeout: Optional[float] = 0.0) -> bool: owner_id = owner.id lock_path = self._get_lock_path(path) - deadline = asyncio.get_running_loop().time() + timeout + if timeout is None: + # 无限等待 + deadline = float('inf') + else: + # 有限超时 + deadline = asyncio.get_running_loop().time() + timeout - try: - self._agfs.stat(path) - except Exception: - logger.warning(f"[SUBTREE] Directory does not exist: {path}") + # 确保目录存在 + if not self._ensure_directory_exists(path): + logger.warning(f"[SUBTREE] Failed to ensure directory exists: {path}") return False while True: @@ -330,7 +358,7 @@ async def acquire_mv( src_path: str, dst_parent_path: str, owner: LockOwner, - timeout: float = 0.0, + timeout: Optional[float] = 0.0, src_is_dir: bool = True, ) -> bool: """Acquire locks for a move operation. diff --git a/openviking_cli/client/base.py b/openviking_cli/client/base.py index 30fe8febe..baee3bf8d 100644 --- a/openviking_cli/client/base.py +++ b/openviking_cli/client/base.py @@ -237,6 +237,7 @@ async def add_message( role: str, content: str | None = None, parts: list[dict] | None = None, + created_at: str | None = None, ) -> Dict[str, Any]: """Add a message to a session. @@ -245,6 +246,7 @@ async def add_message( role: Message role ("user" or "assistant") content: Text content (simple mode) parts: Parts array (full Part support: TextPart, ContextPart, ToolPart) + created_at: Message creation time (ISO format string) If both content and parts are provided, parts takes precedence. """ diff --git a/openviking_cli/client/http.py b/openviking_cli/client/http.py index 0b1689ad8..35d6bf26e 100644 --- a/openviking_cli/client/http.py +++ b/openviking_cli/client/http.py @@ -759,6 +759,7 @@ async def add_message( role: str, content: str | None = None, parts: list[dict] | None = None, + created_at: str | None = None, ) -> Dict[str, Any]: """Add a message to a session. @@ -767,6 +768,7 @@ async def add_message( role: Message role ("user" or "assistant") content: Text content (simple mode, backward compatible) parts: Parts array (full Part support mode) + created_at: Message creation time (ISO format string) If both content and parts are provided, parts takes precedence. """ @@ -778,6 +780,9 @@ async def add_message( else: raise ValueError("Either content or parts must be provided") + if created_at is not None: + payload["created_at"] = created_at + response = await self._http.post( f"/api/v1/sessions/{session_id}/messages", json=payload, diff --git a/openviking_cli/client/sync_http.py b/openviking_cli/client/sync_http.py index fba30a372..c9e37e371 100644 --- a/openviking_cli/client/sync_http.py +++ b/openviking_cli/client/sync_http.py @@ -108,6 +108,7 @@ def add_message( role: str, content: str | None = None, parts: list[dict] | None = None, + created_at: str | None = None, ) -> Dict[str, Any]: """Add a message to a session. @@ -116,10 +117,11 @@ def add_message( role: Message role ("user" or "assistant") content: Text content (simple mode) parts: Parts array (full Part support: TextPart, ContextPart, ToolPart) + created_at: Message creation time (ISO format string) If both content and parts are provided, parts takes precedence. """ - return run_async(self._async_client.add_message(session_id, role, content, parts)) + return run_async(self._async_client.add_message(session_id, role, content, parts, created_at)) def get_task(self, task_id: str) -> Optional[Dict[str, Any]]: """Query background task status.""" diff --git a/openviking_cli/utils/config/config_loader.py b/openviking_cli/utils/config/config_loader.py index d95aaf8f7..7cdd950ff 100644 --- a/openviking_cli/utils/config/config_loader.py +++ b/openviking_cli/utils/config/config_loader.py @@ -88,7 +88,7 @@ def load_json_config(path: Path) -> Dict[str, Any]: raw = os.path.expandvars(raw) try: - print(f"Loading config file: {path}") + # print(f"Loading config file: {path}") return json.loads(raw) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in config file {path}: {e}") from e diff --git a/openviking_cli/utils/config/memory_config.py b/openviking_cli/utils/config/memory_config.py index e90f87f34..9f42ef425 100644 --- a/openviking_cli/utils/config/memory_config.py +++ b/openviking_cli/utils/config/memory_config.py @@ -20,6 +20,11 @@ class MemoryConfig(BaseModel): ), ) + custom_templates_dir: str = Field( + default="", + description="Custom memory templates directory. If set, templates from this directory will be loaded in addition to built-in templates", + ) + model_config = {"extra": "forbid"} @field_validator("agent_scope_mode") diff --git a/tests/integration/test_compressor_v2_e2e.py b/tests/integration/test_compressor_v2_e2e.py index d085ff0c2..f01995e89 100644 --- a/tests/integration/test_compressor_v2_e2e.py +++ b/tests/integration/test_compressor_v2_e2e.py @@ -9,6 +9,7 @@ """ from dataclasses import asdict +from datetime import datetime import pytest import pytest_asyncio @@ -113,11 +114,15 @@ async def test_memory_v2_extraction_e2e( print(f"\nCreated session: {session_id}") # 2. Add conversation messages + # 设置一个测试用的会话时间(2023年4月2日) + session_time = datetime(2023, 4, 2, 9, 36) + session_time_str = session_time.isoformat() + conversation = create_test_conversation_messages() for role, content in conversation: parts = [TextPart(content)] parts_dicts = [asdict(p) for p in parts] - await client.add_message(session_id, role, parts=parts_dicts) + await client.add_message(session_id, role, parts=parts_dicts, created_at=session_time_str) print(f"[{role}]: {content[:60]}...") # 3. Commit session (this should trigger memory extraction) diff --git a/tests/integration/test_compressor_v2_event_span_multiple_turns.py b/tests/integration/test_compressor_v2_event_span_multiple_turns.py new file mode 100644 index 000000000..fb8e5f786 --- /dev/null +++ b/tests/integration/test_compressor_v2_event_span_multiple_turns.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" +OpenViking 记忆演示脚本 — 事件跨多个 turn 的测试 +""" + +import argparse +import time +from datetime import datetime + +from rich import box +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +import openviking as ov + +# ── 常量 ─────────────────────────────────────────────────────────────────── + +DISPLAY_NAME = "小明" +DEFAULT_URL = "http://localhost:1934" +PANEL_WIDTH = 78 +DEFAULT_API_KEY = "1cf407c39990e5dc874ccc697942da4892208a86a44c4781396dfdc57aa5c98d" +DEFAULT_AGENT_ID = "test" +DEFAULT_SESSION_ID = "event-span-multiple-turns" + + +console = Console() + +# ── 对话数据 (事件跨多个 turn) ───────────────────────────────────────────── +# 用户消息描述一个持续多轮的事件(如项目讨论、问题解决过程) +# 这里模拟一个产品需求讨论的事件,持续 4 个 user + assistant 轮次 + +CONVERSATION = [ + { + "user": "我们公司要做一个新功能,是关于用户反馈系统的。我需要帮产品经理整理一下需求,你能帮我吗?", + "assistant": "当然可以!你可以告诉我产品经理的具体需求,我会帮你记录和整理。", + }, + { + "user": "产品经理说这个反馈系统需要支持文字和图片上传,用户可以匿名提交,还需要有分类功能,比如分为 bug 反馈、功能建议、使用体验等。", + "assistant": "好的,我已经记录了:支持文字和图片上传、匿名提交、分类功能(bug、建议、体验)。", + }, + { + "user": "还有,产品经理要求反馈系统要能实时通知,当用户提交反馈后,相关人员要能立即收到消息。另外,还需要有反馈处理进度的跟踪功能。", + "assistant": "我补充了:实时通知功能、反馈处理进度跟踪。", + }, + { + "user": "最后,产品经理说要在下周之前完成需求文档的编写,然后开始开发。我现在需要把这些需求整理成一份清晰的文档。", + "assistant": "明白了,你需要在下周前完成需求文档,然后开始开发。我会帮你记住这些关键点。", + }, + { + "user": "今天天气真好!我想下午去公园散步,顺便看看有没有好看的花。", + "assistant": "天气好的时候去公园散步是个不错的选择。春天的公园应该有很多花盛开。", + }, + { + "user": "对了,我上周买的那本书还没看完。书名是《人类简史》,写得很有意思。我计划这个周末读完它。", + "assistant": "《人类简史》确实是一本很有趣的书。周末读完应该是可行的。", + }, + { + "user": "我们项目的需求文档已经完成了,我昨天加班到很晚才写完。今天早上已经发给产品经理了,他说写得不错。", + "assistant": "恭喜你完成了需求文档!产品经理认可你的工作,说明你写得很好。", + }, + { + "user": "产品经理说反馈系统的开发工作已经安排好了,下周一开始正式开发。我需要负责前端页面的设计和实现。", + "assistant": "开发工作安排好了,下周一开始。你负责前端页面的设计和实现。", + }, + { + "user": "今天中午我和同事一起去吃了新开的那家日料店,味道很不错。刺身很新鲜,寿司也很好吃。", + "assistant": "新开的日料店味道不错,刺身新鲜,寿司好吃。", + }, + { + "user": "反馈系统的前端页面已经设计好了,我昨天和设计师一起讨论了很久。现在需要开始写代码实现了。", + "assistant": "前端页面设计完成,现在开始代码实现。", + }, +] + +# ── 验证查询 ────────────────────────────────────────────────────────────── + +VERIFY_QUERIES = [ + { + "query": "反馈系统的功能需求", + "expected_keywords": ["文字", "图片", "匿名", "分类", "通知", "进度", "需求文档"], + }, + { + "query": "反馈系统的开发计划", + "expected_keywords": ["下周", "前端", "设计", "实现"], + }, + { + "query": "小明的其他活动", + "expected_keywords": ["公园", "散步", "读书", "日料"], + }, +] + +# ── 辅助函数 ────────────────────────────────────────────────────────────── + + +def run_ingest(client: ov.SyncHTTPClient, session_id: str, wait_seconds: float): + """写入对话并提交""" + console.print() + console.rule(f"[bold]Phase 1: 写入对话 — {DISPLAY_NAME} ({len(CONVERSATION)} 轮)[/bold]") + + session = client.create_session() + session_id = session.get('session_id') + console.print(f" Session: [bold cyan]{session_id}[/bold cyan]") + console.print() + + # 设置一个测试用的会话时间(2023年4月2日) + session_time = datetime(2023, 4, 2, 9, 36) + session_time_str = session_time.isoformat() + + total = len(CONVERSATION) + for i, turn in enumerate(CONVERSATION, 1): + console.print(f" [dim][{i}/{total}][/dim] 添加 user + assistant 消息...") + client.add_message(session_id, role="user", parts=[{"type": "text", "text": turn["user"]}], created_at=session_time_str) + client.add_message(session_id, role="assistant", parts=[{"type": "text", "text": turn["assistant"]}], created_at=session_time_str) + + console.print() + console.print(f" 共添加 [bold]{total * 2}[/bold] 条消息") + + console.print() + console.print(" [yellow]提交 Session(触发记忆抽取)...[/yellow]") + commit_result = client.commit_session(session_id) + task_id = commit_result.get("task_id") + console.print(f" Commit 结果: {commit_result}") + + if task_id: + now = time.time() + console.print(f" [yellow]等待记忆提取完成 (task_id={task_id})...[/yellow]") + while True: + task = client.get_task(task_id) + if not task or task.get("status") in ("completed", "failed"): + break + time.sleep(1) + elapsed = time.time() - now + status = task.get("status", "unknown") if task else "not found" + console.print(f" [green]任务 {status},耗时 {elapsed:.2f}s[/green]") + console.print(f" Task 详情: {task}") + + console.print(f" [yellow]等待向量化完成...[/yellow]") + client.wait_processed() + + if wait_seconds > 0: + console.print(f" [dim]额外等待 {wait_seconds:.0f}s...[/dim]") + time.sleep(wait_seconds) + + session_info = client.get_session(session_id) + console.print(f" Session 详情: {session_info}") + + return session_id + + +def run_verify(client: ov.SyncHTTPClient): + """验证记忆召回""" + console.print() + console.rule(f"[bold]Phase 2: 验证记忆召回 — {DISPLAY_NAME} ({len(VERIFY_QUERIES)} 条查询)[/bold]") + + results_table = Table( + title=f"记忆召回验证 — {DISPLAY_NAME}", + box=box.ROUNDED, + show_header=True, + header_style="bold", + ) + results_table.add_column("#", style="bold", width=4) + results_table.add_column("查询", style="cyan", max_width=30) + results_table.add_column("召回数", justify="center", width=8) + results_table.add_column("命中关键词", style="green") + + total = len(VERIFY_QUERIES) + for i, item in enumerate(VERIFY_QUERIES, 1): + query = item["query"] + expected = item["expected_keywords"] + + console.print(f"\n [dim][{i}/{total}][/dim] 搜索: [cyan]{query}[/cyan]") + console.print(f" [dim]期望关键词: {', '.join(expected)}[/dim]") + + try: + results = client.find(query, limit=5) + + recall_texts = [] + count = 0 + if hasattr(results, "memories") and results.memories: + for m in results.memories: + text = getattr(m, "content", "") or getattr(m, "text", "") or str(m) + print(f" [DEBUG] memory text: {repr(text)}") + recall_texts.append(text) + uri = getattr(m, "uri", "") + score = getattr(m, "score", 0) + console.print(f" [green]Memory:[/green] {uri} (score: {score:.4f})") + console.print(f" [dim]{text[:120]}...[/dim]" if len(text) > 120 else f" [dim]{text}[/dim]") + count += len(results.memories) + + if hasattr(results, "resources") and results.resources: + for r in results.resources: + text = getattr(r, "content", "") or getattr(r, "text", "") or str(r) + print(f" [DEBUG] resource text: {repr(text)}") + recall_texts.append(text) + console.print( + f" [blue]Resource:[/blue] {r.uri} (score: {r.score:.4f})" + ) + count += len(results.resources) + + if hasattr(results, "skills") and results.skills: + count += len(results.skills) + + all_text = " ".join(recall_texts) + hits = [kw for kw in expected if kw in all_text] + hit_str = ", ".join(hits) if hits else "[dim]无[/dim]" + + results_table.add_row(str(i), query, str(count), hit_str) + + except Exception as e: + console.print(f" [red]ERROR: {e}[/red]") + results_table.add_row(str(i), query, "[red]ERR[/red]", str(e)[:40]) + + console.print() + console.print(results_table) + + +def main(): + """入口函数""" + parser = argparse.ArgumentParser(description=f"OpenViking 记忆演示 — {DISPLAY_NAME}") + parser.add_argument("--url", default=DEFAULT_URL, help=f"Server URL (默认: {DEFAULT_URL})") + parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key") + parser.add_argument("--agent-id", default=DEFAULT_AGENT_ID, help="Agent ID") + parser.add_argument( + "--phase", + choices=["all", "ingest", "verify"], + default="all", + help="all=全部, ingest=仅写入, verify=仅验证 (默认: all)", + ) + parser.add_argument( + "--session-id", default=DEFAULT_SESSION_ID, help=f"Session ID (默认: {DEFAULT_SESSION_ID})" + ) + parser.add_argument( + "--wait", type=float, default=2, help="写入后等待秒数 (默认: 2)" + ) + + args = parser.parse_args() + + client = ov.SyncHTTPClient( + url=args.url, api_key=args.api_key, agent_id=args.agent_id, + timeout=180 + ) + + try: + client.initialize() + console.print(f" [green]已连接[/green] {args.url}") + + if args.phase in ("all", "ingest"): + run_ingest(client, args.session_id, args.wait) + + if args.phase in ("all", "verify"): + run_verify(client) + + console.print( + Panel( + "[bold green]演示完成[/bold green]", + style="green", + width=PANEL_WIDTH, + ) + ) + + except Exception as e: + console.print( + Panel(f"[bold red]Error:[/bold red] {e}", style="red", width=PANEL_WIDTH) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration/test_compressor_v2_tool_skill_memory.py b/tests/integration/test_compressor_v2_tool_skill_memory.py new file mode 100644 index 000000000..838512fc1 --- /dev/null +++ b/tests/integration/test_compressor_v2_tool_skill_memory.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +""" +OpenViking 记忆演示脚本 — 工具调用和Skill调用记忆测试 + +测试 assistant 调用工具和使用 skill 的记忆是否被正确提取和召回 +""" + +import argparse +import time + +from rich import box +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +import openviking as ov + +# ── 常量 ─────────────────────────────────────────────────────────────────── + +DISPLAY_NAME = "测试用户" +DEFAULT_URL = "http://localhost:1934" +PANEL_WIDTH = 78 +DEFAULT_API_KEY = "1cf407c39990e5dc874ccc697942da4892208a86a44c4781396dfdc57aa5c98d" +DEFAULT_AGENT_ID = "test" +DEFAULT_SESSION_ID = "tool-skill-memory-test" + + +console = Console() + +# ── 对话数据 (工具调用 + Skill调用) ───────────────────────────────────────── +# 模拟 assistant 调用工具(read/write_file/bash/Glob)读取 SKILL.md 等文件 +# 注意:tool_calls 需要传入真正的工具调用信息 + +CONVERSATION = [ + # ===== Skill 调用:assistant 调用 read 工具读取 SKILL.md ===== + { + "user": "帮我创建一个PPT演示文稿,主题是季度工作报告。", + "assistant": "好的,我先读取一下 ppt skill 的 SKILL.md 了解如何创建PPT。", + "tool_calls": [ + {"tool_name": "Read", "tool_uri": "tools:Read", "input": {"file_path": "/skills/ppt/SKILL.md"}} + ], + }, + { + "user": "PPT需要包含三个部分:业绩回顾、业务分析和下季度计划。", + "assistant": "好的,我根据 SKILL.md 的指引来创建这三个部分的PPT。", + "tool_calls": [ + {"tool_name": "Read", "tool_uri": "tools:Read", "input": {"file_path": "/skills/ppt/SKILL.md"}} + ], + }, + { + "user": "把PPT的模板换成蓝色主题。", + "assistant": "好的,我来修改PPT模板为蓝色主题。", + "tool_calls": [ + {"tool_name": "write_file", "tool_uri": "tools:write_file", "input": {"path": "template.pptx", "content": "蓝色主题模板"}} + ], + }, + # ===== 工具调用:write_file ===== + { + "user": "帮我写一个Python函数,计算斐波那契数列。", + "assistant": "我来写一个计算斐波那契数列的函数并保存到文件。", + "tool_calls": [ + {"tool_name": "write_file", "tool_uri": "tools:write_file", "input": {"path": "fibonacci.py", "content": "def fib(n):\n if n <= 1:\n return n\n return fib(n-1) + fib(n-2)\nprint(fib(10))"}} + ], + }, + # ===== 工具调用:bash ===== + { + "user": "执行一下这个Python文件,看看结果对不对。", + "assistant": "我来执行这个文件。", + "tool_calls": [ + {"tool_name": "Bash", "tool_uri": "tools:Bash", "input": {"command": "python fibonacci.py"}} + ], + }, + # ===== Skill 调用:PDF ===== + { + "user": "帮我把这份PDF文件提取文字内容。", + "assistant": "好的,我先读取一下 pdf skill 的 SKILL.md。", + "tool_calls": [ + {"tool_name": "Read", "tool_uri": "tools:Read", "input": {"file_path": "/skills/pdf/SKILL.md"}} + ], + }, + { + "user": "PDF有多少页?", + "assistant": "这份PDF有15页。", + }, + # ===== 工具调用:Glob ===== + { + "user": "搜索一下项目里有哪些Python文件。", + "assistant": "我来搜索项目里的Python文件。", + "tool_calls": [ + {"tool_name": "Glob", "tool_uri": "tools:Glob", "input": {"pattern": "**/*.py"}} + ], + }, + # ===== 工具调用:Read ===== + { + "user": "查看一下这个文件的内容。", + "assistant": "好的,我读取一下这个文件。", + "tool_calls": [ + {"tool_name": "Read", "tool_uri": "tools:Read", "input": {"file_path": "main.py"}} + ], + }, + # ===== Skill 调用:Email ===== + { + "user": "帮我写一封邮件给客户,主题是项目进度汇报。", + "assistant": "好的,我先读取一下 email skill 的 SKILL.md 了解邮件格式。", + "tool_calls": [ + {"tool_name": "Read", "tool_uri": "tools:Read", "input": {"file_path": "/skills/email/SKILL.md"}} + ], + }, + { + "user": "邮件内容要包含本周完成的工作和下周计划。", + "assistant": "好的,我来编写邮件内容。", + }, +] + +# ── 验证查询 ────────────────────────────────────────────────────────────── + +VERIFY_QUERIES = [ + { + "query": "创建了什么PPT", + "expected_keywords": ["PPT", "季度工作", "业绩回顾", "业务分析", "下季度计划", "蓝色主题"], + }, + { + "query": "执行了什么代码", + "expected_keywords": ["Python", "斐波那契", "fibonacci", "函数"], + }, + { + "query": "处理了什么PDF", + "expected_keywords": ["PDF", "文字", "15页"], + }, + { + "query": "搜索了什么文件", + "expected_keywords": ["Python", "文件", "搜索"], + }, + { + "query": "写了什么邮件", + "expected_keywords": ["邮件", "客户", "项目进度", "工作", "计划"], + }, + { + "query": "使用了哪些skill", + "expected_keywords": ["ppt", "pdf", "email"], + }, + { + "query": "使用了哪些工具", + "expected_keywords": ["Python", "文件", "搜索", "读取"], + }, +] + +# ── 辅助函数 ────────────────────────────────────────────────────────────── + + +def run_ingest(client: ov.SyncHTTPClient, session_id: str, wait_seconds: float): + """写入对话并提交""" + console.print() + console.rule(f"[bold]Phase 1: 写入对话 — {DISPLAY_NAME} ({len(CONVERSATION)} 轮)[/bold]") + + session = client.create_session() + session_id = session.get('session_id') + console.print(f" Session: [bold cyan]{session_id}[/bold cyan]") + console.print() + + total = len(CONVERSATION) + for i, turn in enumerate(CONVERSATION, 1): + tool_calls = turn.get("tool_calls", []) + if tool_calls: + tool_info = f" [blue](tools: {[tc['tool_name'] for tc in tool_calls]})[/blue]" + else: + tool_info = "" + console.print(f" [dim][{i}/{total}][/dim] 添加 user + assistant 消息{tool_info}...") + + # 添加 user 消息 + client.add_message(session_id, role="user", parts=[{"type": "text", "text": turn["user"]}]) + + # 添加 assistant 消息,包含 tool_calls + assistant_parts = [{"type": "text", "text": turn["assistant"]}] + for tc in tool_calls: + tool_part = { + "type": "tool", + "tool_name": tc["tool_name"], + "tool_uri": tc.get("tool_uri", f"tools:{tc['tool_name']}"), + "tool_input": tc.get("input", {}), + "tool_status": "completed", + } + assistant_parts.append(tool_part) + print(f" [DEBUG] Adding tool part: {tool_part}") + result = client.add_message(session_id, role="assistant", parts=assistant_parts) + print(f" [DEBUG] add_message result: {result}") + + console.print() + console.print(f" 共添加 [bold]{total * 2}[/bold] 条消息") + + console.print() + console.print(" [yellow]提交 Session(触发记忆抽取)...[/yellow]") + commit_result = client.commit_session(session_id) + task_id = commit_result.get("task_id") + console.print(f" Commit 结果: {commit_result}") + + if task_id: + now = time.time() + console.print(f" [yellow]等待记忆提取完成 (task_id={task_id})...[/yellow]") + while True: + task = client.get_task(task_id) + if not task or task.get("status") in ("completed", "failed"): + break + time.sleep(1) + elapsed = time.time() - now + status = task.get("status", "unknown") if task else "not found" + console.print(f" [green]任务 {status},耗时 {elapsed:.2f}s[/green]") + console.print(f" Task 详情: {task}") + + console.print(f" [yellow]等待向量化完成...[/yellow]") + client.wait_processed() + + if wait_seconds > 0: + console.print(f" [dim]额外等待 {wait_seconds:.0f}s...[/dim]") + time.sleep(wait_seconds) + + session_info = client.get_session(session_id) + console.print(f" Session 详情: {session_info}") + + return session_id + + +def run_verify(client: ov.SyncHTTPClient): + """验证记忆召回""" + console.print() + console.rule(f"[bold]Phase 2: 验证记忆召回 — {DISPLAY_NAME} ({len(VERIFY_QUERIES)} 条查询)[/bold]") + + results_table = Table( + title=f"记忆召回验证 — {DISPLAY_NAME}", + box=box.ROUNDED, + show_header=True, + header_style="bold", + ) + results_table.add_column("#", style="bold", width=4) + results_table.add_column("查询", style="cyan", max_width=30) + results_table.add_column("召回数", justify="center", width=8) + results_table.add_column("命中关键词", style="green") + + total = len(VERIFY_QUERIES) + for i, item in enumerate(VERIFY_QUERIES, 1): + query = item["query"] + expected = item["expected_keywords"] + + console.print(f"\n [dim][{i}/{total}][/dim] 搜索: [cyan]{query}[/cyan]") + console.print(f" [dim]期望关键词: {', '.join(expected)}[/dim]") + + try: + results = client.find(query, limit=5) + + recall_texts = [] + count = 0 + if hasattr(results, "memories") and results.memories: + for m in results.memories: + text = getattr(m, "content", "") or getattr(m, "text", "") or str(m) + print(f" [DEBUG] memory text: {repr(text)}") + recall_texts.append(text) + uri = getattr(m, "uri", "") + score = getattr(m, "score", 0) + console.print(f" [green]Memory:[/green] {uri} (score: {score:.4f})") + console.print(f" [dim]{text[:120]}...[/dim]" if len(text) > 120 else f" [dim]{text}[/dim]") + count += len(results.memories) + + if hasattr(results, "resources") and results.resources: + for r in results.resources: + text = getattr(r, "content", "") or getattr(r, "text", "") or str(r) + print(f" [DEBUG] resource text: {repr(text)}") + recall_texts.append(text) + console.print( + f" [blue]Resource:[/blue] {r.uri} (score: {r.score:.4f})" + ) + count += len(results.resources) + + if hasattr(results, "skills") and results.skills: + count += len(results.skills) + + all_text = " ".join(recall_texts) + hits = [kw for kw in expected if kw in all_text] + misses = [kw for kw in expected if kw not in all_text] + + # 格式化关键词,命中的绿色,未命中的红色 + formatted_keywords = [] + for kw in expected: + if kw in hits: + formatted_keywords.append(f"[green]{kw}[/green]") + else: + formatted_keywords.append(f"[red]{kw}[/red]") + + keyword_str = ", ".join(formatted_keywords) + + results_table.add_row(str(i), query, str(count), keyword_str) + + except Exception as e: + console.print(f" [red]ERROR: {e}[/red]") + results_table.add_row(str(i), query, "[red]ERR[/red]", str(e)[:40]) + + console.print() + console.print(results_table) + + +def main(): + """入口函数""" + parser = argparse.ArgumentParser(description=f"OpenViking 记忆演示 — 工具调用和Skill调用") + parser.add_argument("--url", default=DEFAULT_URL, help=f"Server URL (默认: {DEFAULT_URL})") + parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key") + parser.add_argument("--agent-id", default=DEFAULT_AGENT_ID, help="Agent ID") + parser.add_argument( + "--phase", + choices=["all", "ingest", "verify"], + default="all", + help="all=全部, ingest=仅写入, verify=仅验证 (默认: all)", + ) + parser.add_argument( + "--session-id", default=DEFAULT_SESSION_ID, help=f"Session ID (默认: {DEFAULT_SESSION_ID})" + ) + parser.add_argument( + "--wait", type=float, default=2, help="写入后等待秒数 (默认: 2)" + ) + + args = parser.parse_args() + + client = ov.SyncHTTPClient( + url=args.url, api_key=args.api_key, agent_id=args.agent_id, + timeout=180 + ) + + try: + client.initialize() + console.print(f" [green]已连接[/green] {args.url}") + + if args.phase in ("all", "ingest"): + run_ingest(client, args.session_id, args.wait) + + if args.phase in ("all", "verify"): + run_verify(client) + + console.print( + Panel( + "[bold green]演示完成[/bold green]", + style="green", + width=PANEL_WIDTH, + ) + ) + + except Exception as e: + console.print( + Panel(f"[bold red]Error:[/bold red] {e}", style="red", width=PANEL_WIDTH) + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/integration/test_compressor_v2_xiaomei.py b/tests/integration/test_compressor_v2_xiaomei.py index 13508195b..c37eaddd0 100644 --- a/tests/integration/test_compressor_v2_xiaomei.py +++ b/tests/integration/test_compressor_v2_xiaomei.py @@ -5,6 +5,7 @@ import argparse import time +from datetime import datetime from rich import box from rich.console import Console @@ -16,7 +17,7 @@ # ── 常量 ─────────────────────────────────────────────────────────────────── DISPLAY_NAME = "小美" -DEFAULT_URL = "http://localhost:1933" +DEFAULT_URL = "http://localhost:1934" PANEL_WIDTH = 78 DEFAULT_API_KEY = "1cf407c39990e5dc874ccc697942da4892208a86a44c4781396dfdc57aa5c98d" DEFAULT_AGENT_ID = "test" @@ -112,12 +113,16 @@ def run_ingest(client: ov.SyncHTTPClient, session_id: str, wait_seconds: float): console.print(f" Session: [bold cyan]{session_id}[/bold cyan]") console.print() + # 设置一个测试用的会话时间(2023年4月2日) + session_time = datetime(2023, 4, 2, 9, 36) + session_time_str = session_time.isoformat() + # 逐轮添加消息 total = len(CONVERSATION) for i, turn in enumerate(CONVERSATION, 1): console.print(f" [dim][{i}/{total}][/dim] 添加 user + assistant 消息...") - client.add_message(session_id, role="user", parts=[{"type": "text", "text": turn["user"]}]) - client.add_message(session_id, role="assistant", parts=[{"type": "text", "text": turn["assistant"]}]) + client.add_message(session_id, role="user", parts=[{"type": "text", "text": turn["user"]}], created_at=session_time_str) + client.add_message(session_id, role="assistant", parts=[{"type": "text", "text": turn["assistant"]}], created_at=session_time_str) console.print() console.print(f" 共添加 [bold]{total * 2}[/bold] 条消息") @@ -196,6 +201,7 @@ def run_verify(client: ov.SyncHTTPClient): if hasattr(results, "memories") and results.memories: for m in results.memories: text = getattr(m, "content", "") or getattr(m, "text", "") or str(m) + print(f" [DEBUG] memory text: {repr(text)}") recall_texts.append(text) uri = getattr(m, "uri", "") score = getattr(m, "score", 0) @@ -206,6 +212,7 @@ def run_verify(client: ov.SyncHTTPClient): if hasattr(results, "resources") and results.resources: for r in results.resources: text = getattr(r, "content", "") or getattr(r, "text", "") or str(r) + print(f" [DEBUG] resource text: {repr(text)}") recall_texts.append(text) console.print( f" [blue]Resource:[/blue] {r.uri} (score: {r.score:.4f})" diff --git a/tests/session/memory/test_compressor_v2.py b/tests/session/memory/test_compressor_v2.py index eeacc1102..d3957da6d 100644 --- a/tests/session/memory/test_compressor_v2.py +++ b/tests/session/memory/test_compressor_v2.py @@ -8,8 +8,9 @@ import logging from types import SimpleNamespace -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from unittest.mock import patch + import pytest from openviking.message import Message, TextPart @@ -35,6 +36,11 @@ def __init__(self): self._store: Dict[str, Dict[str, Any]] = {} self._snapshot: Dict[str, str] = {} + def _uri_to_path(self, uri: str, ctx=None) -> str: + """Mock _uri_to_path method for testing.""" + # For testing purposes, we'll just return the URI as-is + return uri + def _get_parent_uri(self, uri: str) -> str: """Get parent directory URI.""" # Handle URIs like "viking://agent/default/memories/cards/file.md" @@ -154,6 +160,10 @@ async def find(self, query: str, **kwargs) -> Dict[str, Any]: "skills": [], } + async def search(self, query: str, **kwargs) -> Any: + """Mock search.""" + return {"memories": [], "resources": [], "skills": []} + async def tree(self, uri: str, **kwargs) -> Dict[str, Any]: """Mock tree.""" return {"uri": uri, "tree": []} @@ -231,7 +241,7 @@ def create_test_conversation() -> List[Message]: role="user", parts=[ TextPart( - "We've decided to use the MemoryReAct pattern, combined with LLMs to analyze conversations and generate memory operations. " + "We've decided to use the ExtractLoop pattern, combined with LLMs to analyze conversations and generate memory operations. " "There are two main memory types: cards for knowledge cards (Zettelkasten note-taking method), and events for recording important events and decisions." ) ], @@ -278,8 +288,16 @@ async def test_extract_long_term_memories_includes_latest_archive_overview(self) class DummyOrchestrator: registry = object() - async def run(self, conversation: str): - captured["conversation"] = conversation + @property + def context_provider(self): + # 返回一个 mock provider + class DummyProvider: + def get_memory_schemas(self, ctx): + return [] + return DummyProvider() + + async def run(self): + # 捕获最终的消息列表 return ( SimpleNamespace( write_uris=[], @@ -300,7 +318,7 @@ async def apply_operations(self, operations, ctx, registry=None): ) compressor._get_or_create_react = lambda ctx=None: DummyOrchestrator() - compressor._get_or_create_updater = lambda: DummyUpdater() + compressor._get_or_create_updater = lambda transaction_handle=None: DummyUpdater() result = await compressor.extract_long_term_memories( messages=messages, @@ -311,9 +329,7 @@ async def apply_operations(self, operations, ctx, registry=None): ) assert result == [] - assert "## Previous Archive Overview" in captured["conversation"] - assert "LATEST OVERVIEW" in captured["conversation"] - assert "[user]: Current task" in captured["conversation"] + # Note: latest_archive_overview 功能已移除,测试需要更新 @pytest.mark.integration @@ -367,7 +383,7 @@ async def test_extract_long_term_memories(self): # Patch get_viking_fs() to return our mock # Need to patch it in all the places it's used - with patch("openviking.session.memory.memory_react.get_viking_fs", return_value=viking_fs): + with patch("openviking.session.memory.extract_loop.get_viking_fs", return_value=viking_fs): with patch( "openviking.session.memory.memory_updater.get_viking_fs", return_value=viking_fs ): diff --git a/tests/session/memory/test_json_stability.py b/tests/session/memory/test_json_stability.py index 37ed2aca0..cc7f2773d 100644 --- a/tests/session/memory/test_json_stability.py +++ b/tests/session/memory/test_json_stability.py @@ -13,7 +13,6 @@ from openviking.session.memory.utils import ( remove_json_trailing_content, extract_json_content, - extract_json_from_markdown, parse_memory_file_with_fields, value_fault_tolerance, parse_value_with_tolerance, @@ -93,32 +92,6 @@ def test_alias_works(self): assert result1 == result2 -class TestExtractJsonFromMarkdown: - """Tests for markdown extraction.""" - - def test_extracts_json_from_json_code_block(self): - """Test JSON wrapped in ```json ... ```.""" - content = '''```json -{"reasonning": "test", "write_uris": []} -```''' - result = extract_json_from_markdown(content) - assert result == '{"reasonning": "test", "write_uris": []}' - - def test_extracts_json_from_generic_code_block(self): - """Test JSON wrapped in ``` ... ``` without json specifier.""" - content = '''``` -{"reasonning": "test"} -```''' - result = extract_json_from_markdown(content) - assert result == '{"reasonning": "test"}' - - def test_returns_plain_json_when_no_markdown(self): - """Test plain JSON without markdown is returned as-is.""" - content = '{"reasonning": "test"}' - result = extract_json_from_markdown(content) - assert result == content - - class TestValueFaultTolerance: """Tests for Layer 4: Value-level fault tolerance.""" diff --git a/tests/session/memory/test_memory_extractor_flow.py b/tests/session/memory/test_memory_extractor_flow.py index 7f8611f97..639241d5c 100644 --- a/tests/session/memory/test_memory_extractor_flow.py +++ b/tests/session/memory/test_memory_extractor_flow.py @@ -31,7 +31,7 @@ from openviking.server.identity import RequestContext, Role from openviking.session.memory import ( MemoryOperations, - MemoryReAct, + ExtractLoop, MemoryUpdater, MemoryUpdateResult, ) @@ -361,7 +361,7 @@ def create_test_conversation() -> List[Message]: id="msg3", role="user", parts=[TextPart( - "We've decided to use the MemoryReAct pattern, combined with LLMs to analyze conversations and generate memory operations. " + "We've decided to use the ExtractLoop pattern, combined with LLMs to analyze conversations and generate memory operations. " "There are two main memory types: cards for knowledge cards (Zettelkasten note-taking method), and events for recording important events and decisions." )], ) @@ -399,7 +399,7 @@ def create_existing_memories_content() -> Dict[str, str]: OpenViking is an Agent-native context database. ## Technical Approach -- Uses MemoryReAct pattern +- Uses ExtractLoop pattern - Combines LLM to analyze conversations and generate memory operations @@ -408,10 +408,10 @@ def create_existing_memories_content() -> Dict[str, str]: "name": "openviking_project" } -->""", - "viking://agent/default/memories/cards/memory_react.md": """# MemoryReAct Pattern + "viking://agent/default/memories/cards/extract_loop.md": """# ExtractLoop Pattern ## Overview -MemoryReAct is an orchestrator pattern for memory extraction. +ExtractLoop is an orchestrator pattern for memory extraction. ## Features - Analyze conversation content @@ -420,7 +420,7 @@ def create_existing_memories_content() -> Dict[str, str]: """, "viking://user/default/memories/events/2026-03-20_Started_memory_extraction_feature_development.md": """# Event: Started memory extraction feature development @@ -432,7 +432,7 @@ def create_existing_memories_content() -> Dict[str, str]: 2026-03-20 ## Content -Today we started working on the memory extraction feature for the OpenViking project. Decided to use the MemoryReAct pattern. +Today we started working on the memory extraction feature for the OpenViking project. Decided to use the ExtractLoop pattern.