-
Notifications
You must be signed in to change notification settings - Fork 805
fix(session): tolerate corrupted startup artifacts #1728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4a92ac1
f8fc6e0
9228044
122447c
9938f6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,10 +4,12 @@ | |
| import json | ||
| from collections.abc import Sequence | ||
| from pathlib import Path | ||
| from typing import Any, cast | ||
|
|
||
| import aiofiles | ||
| import aiofiles.os | ||
| from kosong.message import Message | ||
| from pydantic import ValidationError | ||
|
|
||
| from kimi_cli.soul.compaction import estimate_text_tokens | ||
| from kimi_cli.soul.message import system | ||
|
|
@@ -38,24 +40,26 @@ async def restore(self) -> bool: | |
| return False | ||
|
|
||
| messages_after_last_usage: list[Message] = [] | ||
| async with aiofiles.open(self._file_backend, encoding="utf-8") as f: | ||
| async with aiofiles.open(self._file_backend, encoding="utf-8", errors="replace") as f: | ||
| line_no = 0 | ||
| async for line in f: | ||
| line_no += 1 | ||
| if not line.strip(): | ||
| continue | ||
|
Comment on lines
42
to
48
|
||
| line_json = json.loads(line, strict=False) | ||
| if line_json["role"] == "_system_prompt": | ||
| self._system_prompt = line_json["content"] | ||
| line_json = self._parse_context_line( | ||
| line, | ||
| file_backend=self._file_backend, | ||
| line_no=line_no, | ||
| ) | ||
| if line_json is None: | ||
| continue | ||
| if line_json["role"] == "_usage": | ||
| self._token_count = line_json["token_count"] | ||
| messages_after_last_usage.clear() | ||
| continue | ||
| if line_json["role"] == "_checkpoint": | ||
| self._next_checkpoint_id = line_json["id"] + 1 | ||
| continue | ||
| message = Message.model_validate(line_json) | ||
| self._history.append(message) | ||
| messages_after_last_usage.append(message) | ||
| self._apply_context_record( | ||
| line_json, | ||
| history=self._history, | ||
| messages_after_last_usage=messages_after_last_usage, | ||
| file_backend=self._file_backend, | ||
| line_no=line_no, | ||
| ) | ||
|
|
||
| self._pending_token_estimate = estimate_text_tokens(messages_after_last_usage) | ||
| return True | ||
|
|
@@ -164,29 +168,34 @@ async def revert_to(self, checkpoint_id: int): | |
| self._system_prompt = None | ||
| messages_after_last_usage: list[Message] = [] | ||
| async with ( | ||
| aiofiles.open(rotated_file_path, encoding="utf-8") as old_file, | ||
| aiofiles.open(rotated_file_path, encoding="utf-8", errors="replace") as old_file, | ||
| aiofiles.open(self._file_backend, "w", encoding="utf-8") as new_file, | ||
| ): | ||
| line_no = 0 | ||
| async for line in old_file: | ||
| line_no += 1 | ||
| if not line.strip(): | ||
| continue | ||
|
|
||
| line_json = json.loads(line, strict=False) | ||
| if line_json["role"] == "_checkpoint" and line_json["id"] == checkpoint_id: | ||
| line_json = self._parse_context_line( | ||
| line, | ||
| file_backend=rotated_file_path, | ||
| line_no=line_no, | ||
| ) | ||
| if line_json is None: | ||
| continue | ||
| if line_json.get("role") == "_checkpoint" and line_json.get("id") == checkpoint_id: | ||
| break | ||
|
|
||
| await new_file.write(line) | ||
| if line_json["role"] == "_system_prompt": | ||
| self._system_prompt = line_json["content"] | ||
| elif line_json["role"] == "_usage": | ||
| self._token_count = line_json["token_count"] | ||
| messages_after_last_usage.clear() | ||
| elif line_json["role"] == "_checkpoint": | ||
| self._next_checkpoint_id = line_json["id"] + 1 | ||
| else: | ||
| message = Message.model_validate(line_json) | ||
| self._history.append(message) | ||
| messages_after_last_usage.append(message) | ||
| keep_line = self._apply_context_record( | ||
| line_json, | ||
| history=self._history, | ||
| messages_after_last_usage=messages_after_last_usage, | ||
| file_backend=rotated_file_path, | ||
| line_no=line_no, | ||
| ) | ||
| if keep_line: | ||
| await new_file.write(line) | ||
|
|
||
| self._pending_token_estimate = estimate_text_tokens(messages_after_last_usage) | ||
|
|
||
|
|
@@ -237,3 +246,94 @@ async def update_token_count(self, token_count: int): | |
|
|
||
| async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: | ||
| await f.write(json.dumps({"role": "_usage", "token_count": token_count}) + "\n") | ||
|
|
||
| def _parse_context_line( | ||
| self, | ||
| line: str, | ||
| *, | ||
| file_backend: Path, | ||
| line_no: int, | ||
| ) -> dict[str, Any] | None: | ||
| try: | ||
| line_json = json.loads(line, strict=False) | ||
| except json.JSONDecodeError as exc: | ||
| logger.warning( | ||
| "Skipping malformed context line {line_no} in {file}: {error}", | ||
| line_no=line_no, | ||
| file=file_backend, | ||
| error=exc, | ||
| ) | ||
| return None | ||
| if not isinstance(line_json, dict): | ||
| logger.warning( | ||
| "Skipping non-object context line {line_no} in {file}", | ||
| line_no=line_no, | ||
| file=file_backend, | ||
| ) | ||
| return None | ||
| return cast(dict[str, Any], line_json) | ||
|
|
||
| def _apply_context_record( | ||
| self, | ||
| line_json: dict[str, Any], | ||
| *, | ||
| history: list[Message], | ||
| messages_after_last_usage: list[Message], | ||
| file_backend: Path, | ||
| line_no: int, | ||
| ) -> bool: | ||
| role = line_json.get("role") | ||
| if not isinstance(role, str): | ||
| logger.warning( | ||
| "Skipping context line {line_no} in {file}: missing or invalid role", | ||
| line_no=line_no, | ||
| file=file_backend, | ||
| ) | ||
| return False | ||
| if role == "_system_prompt": | ||
| content = line_json.get("content") | ||
| if not isinstance(content, str): | ||
| logger.warning( | ||
| "Skipping invalid system prompt line {line_no} in {file}", | ||
| line_no=line_no, | ||
| file=file_backend, | ||
| ) | ||
| return False | ||
| self._system_prompt = content | ||
| return True | ||
| if role == "_usage": | ||
| token_count = line_json.get("token_count") | ||
| if not isinstance(token_count, int): | ||
| logger.warning( | ||
| "Skipping invalid usage line {line_no} in {file}", | ||
| line_no=line_no, | ||
| file=file_backend, | ||
| ) | ||
| return False | ||
| self._token_count = token_count | ||
| messages_after_last_usage.clear() | ||
| return True | ||
| if role == "_checkpoint": | ||
| checkpoint_id = line_json.get("id") | ||
| if not isinstance(checkpoint_id, int): | ||
| logger.warning( | ||
| "Skipping invalid checkpoint line {line_no} in {file}", | ||
| line_no=line_no, | ||
| file=file_backend, | ||
| ) | ||
| return False | ||
| self._next_checkpoint_id = checkpoint_id + 1 | ||
| return True | ||
| try: | ||
| message = Message.model_validate(line_json) | ||
| except ValidationError as exc: | ||
| logger.warning( | ||
| "Skipping invalid context message line {line_no} in {file}: {error}", | ||
| line_no=line_no, | ||
| file=file_backend, | ||
| error=exc, | ||
| ) | ||
| return False | ||
| history.append(message) | ||
| messages_after_last_usage.append(message) | ||
| return True | ||
Uh oh!
There was an error while loading. Please reload this page.