diff --git a/notebook_intelligence/ai_service_manager.py b/notebook_intelligence/ai_service_manager.py index 8b6061b..f4f84bb 100644 --- a/notebook_intelligence/ai_service_manager.py +++ b/notebook_intelligence/ai_service_manager.py @@ -12,6 +12,7 @@ from notebook_intelligence.api import ButtonData, ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, ChatParticipant, ChatRequest, ChatResponse, CompletionContext, ContextRequest, Host, CompletionContextProvider, MCPPrompt, MCPServer, MarkdownData, NotebookIntelligenceExtension, RegistrationError, TelemetryEvent, TelemetryListener, Tool, Toolset from notebook_intelligence.base_chat_participant import BaseChatParticipant from notebook_intelligence.config import NBIConfig +from notebook_intelligence.mysql_manager import MySQLManager from notebook_intelligence.github_copilot_chat_participant import GithubCopilotChatParticipant from notebook_intelligence.claude import CLAUDE_CODE_CHAT_PARTICIPANT_ID, ClaudeCodeChatParticipant, ClaudeCodeInlineCompletionModel, fetch_claude_models, get_claude_models from notebook_intelligence.llm_providers.github_copilot_llm_provider import GitHubCopilotLLMProvider @@ -61,6 +62,7 @@ def __init__(self, options: Optional[dict] = None): self._options.get("feature_policies") or {}, self._options.get("string_overrides") or {}, ) + self._mysql_manager = MySQLManager(self._nbi_config.mysql_config) self._openai_compatible_llm_provider = OpenAICompatibleLLMProvider() self._litellm_compatible_llm_provider = LiteLLMCompatibleLLMProvider() self._ollama_llm_provider = OllamaLLMProvider() @@ -93,6 +95,14 @@ def __init__(self, options: Optional[dict] = None): def nbi_config(self) -> NBIConfig: return self._nbi_config + @property + def mysql_manager(self) -> MySQLManager: + return self._mysql_manager + + def update_mysql_manager(self): + """Refresh MySQL manager instance from current config.""" + self._mysql_manager = MySQLManager(self._nbi_config.mysql_config) + @property def ollama_llm_provider(self) -> OllamaLLMProvider: return self._ollama_llm_provider diff --git a/notebook_intelligence/api.py b/notebook_intelligence/api.py index dd454b5..2b01087 100644 --- a/notebook_intelligence/api.py +++ b/notebook_intelligence/api.py @@ -142,6 +142,8 @@ class ChatRequest: cancel_token: CancelToken = None # NEW: Add context for rule evaluation rule_context: Optional[RuleContext] = None + # NEW: Internal conversation id for logging + conversation_id: str = None @dataclass class ResponseStreamData: @@ -633,6 +635,13 @@ async def _tool_call_loop(tool_call_rounds: list): if message.get('tool_calls', None) is not None: for tool_call in message['tool_calls']: tool_call_rounds.append(tool_call) + + # MySQL logging: Log assistant message with tool calls + if request.conversation_id: + msg_id = str(uuid.uuid4()) + request.host.mysql_manager.add_message( + msg_id, request.conversation_id, "assistant", content, reasoning, message['tool_calls'] + ) messages.append(message) @@ -662,6 +671,12 @@ async def _tool_call_loop(tool_call_rounds: list): args = tool_call['function']['arguments'] else: args = fuzzy_json_loads(tool_call['function']['arguments']) + + # MySQL logging: Log tool execution START (initial record) + if request.conversation_id: + request.host.mysql_manager.log_tool_execution( + tool_call['id'], request.conversation_id, tool_name, args, "" + ) tool_properties = tool_to_call.schema["function"]["parameters"]["properties"] if type(args) is str: @@ -688,6 +703,17 @@ async def _tool_call_loop(tool_call_rounds: list): return tool_call_response = await tool_to_call.handle_tool_call(request, response, tool_context, args) + + # MySQL logging: Update tool execution with result + if request.conversation_id: + request.host.mysql_manager.log_tool_execution( + tool_call['id'], request.conversation_id, tool_name, args, str(tool_call_response) + ) + # Also log the tool message itself + msg_id = str(uuid.uuid4()) + request.host.mysql_manager.add_message( + msg_id, request.conversation_id, "tool", str(tool_call_response), tool_call_id=tool_call['id'] + ) function_call_result_message = { "role": "tool", @@ -946,6 +972,10 @@ def get_skill_manager(self): def websocket_connector(self) -> ThreadSafeWebSocketConnector: raise NotImplementedError + @property + def mysql_manager(self) -> Any: + return NotImplementedError + class NotebookIntelligenceExtension: @property diff --git a/notebook_intelligence/config.py b/notebook_intelligence/config.py index bd0dae7..7b79e0d 100644 --- a/notebook_intelligence/config.py +++ b/notebook_intelligence/config.py @@ -230,6 +230,35 @@ def active_rules(self) -> dict: """Get dictionary of active rule states (filename -> bool).""" return self.get('active_rules', {}) + @property + def mysql_config(self) -> dict: + """Get MySQL configuration.""" + return self.get('mysql_config', { + 'enabled': False, + 'host': 'localhost', + 'port': 3306, + 'user': '', + 'password': '', + 'database': 'notebook_intelligence' + }) + + @property + def history_config(self) -> dict: + """Get chat history storage configuration.""" + cfg = self.get('history_config', {}) + mode = cfg.get('mode', 'local') + local_max_messages = cfg.get('local_max_messages', 10) + try: + local_max_messages = int(local_max_messages) + except Exception: + local_max_messages = 10 + if local_max_messages < 1: + local_max_messages = 1 + return { + 'mode': mode if mode in ['mysql', 'local', 'none'] else 'local', + 'local_max_messages': local_max_messages + } + def set_rule_active(self, filename: str, active: bool): """Set the active state of a rule.""" active_rules = self.active_rules.copy() diff --git a/notebook_intelligence/extension.py b/notebook_intelligence/extension.py index 34571ff..0f03966 100644 --- a/notebook_intelligence/extension.py +++ b/notebook_intelligence/extension.py @@ -71,6 +71,7 @@ def _token_count(text: str) -> int: return len(tiktoken_encoding.encode(text)) +shared_chat_history = None def _truncate_context_content(content: str, token_budget: int) -> str: @@ -524,6 +525,8 @@ def is_provider_enabled(provider_id: str) -> bool: "claude_settings": _scrub_credentials_for_wire( nbi_config.claude_settings, self.string_overrides ), + "mysql_config": nbi_config.mysql_config, + "history_config": nbi_config.history_config, "claude_models": ai_service_manager.claude_models, # Drive launcher-tile visibility (issues #183, #260). Each flag # gates one tile under the "Coding Agent" category. Detection is @@ -581,7 +584,7 @@ class ConfigHandler(APIHandler): string_overrides = {} @tornado.web.authenticated - def post(self): + async def post(self): data = json.loads(self.request.body) valid_keys = set([ "default_chat_mode", @@ -595,6 +598,8 @@ def post(self): "enable_output_followup", "enable_output_toolbar", "refresh_open_files_on_disk_change", + "mysql_config", + "history_config", ]) # Top-level keys whose write is rejected outright when locked. locked_keys = set() @@ -622,6 +627,7 @@ def post(self): has_model_change = False has_claude_settings_change = False + has_mysql_settings_change = False for key in data: if key in locked_keys: continue @@ -677,10 +683,38 @@ def post(self): if isinstance(default_chat_participant, ClaudeCodeChatParticipant): # needed to disconnect default_chat_participant.update_client_debounced() + elif key == "mysql_config": + has_mysql_settings_change = True + elif key == "history_config": + has_mysql_settings_change = True ai_service_manager.nbi_config.save() if has_model_change or has_claude_settings_change: ai_service_manager.update_models_from_config() + if has_mysql_settings_change: + ai_service_manager.update_mysql_manager() + history_cfg = ai_service_manager.nbi_config.history_config + if history_cfg.get("mode") == "mysql": + # Validate MySQL immediately and downgrade to `none` on failure. + ok, err = await ai_service_manager.mysql_manager.test_connection() + if not ok: + mysql_cfg = ai_service_manager.nbi_config.mysql_config.copy() + mysql_cfg["enabled"] = False + ai_service_manager.nbi_config.set("mysql_config", mysql_cfg) + ai_service_manager.nbi_config.set("history_config", { + "mode": "none", + "local_max_messages": history_cfg.get("local_max_messages", 10) + }) + ai_service_manager.nbi_config.save() + ai_service_manager.update_mysql_manager() + if not ok: + self.set_status(400) + self.finish(json.dumps({ + "error": f"MySQL connection failed: {err}. History mode has been switched to 'none'.", + "history_config": ai_service_manager.nbi_config.history_config, + "mysql_config": ai_service_manager.nbi_config.mysql_config + })) + return if has_claude_settings_change: default_chat_participant = ai_service_manager.default_chat_participant if isinstance(default_chat_participant, ClaudeCodeChatParticipant): @@ -1772,55 +1806,14 @@ def post(self): self.finish(json.dumps({"success": True, "session_id": session_id})) -class ChatHistory: - """ - History of chat messages, key is chat id, value is list of messages - keep the last 10 messages in the same chat participant - """ - MAX_MESSAGES = 10 - - def __init__(self): - self.messages = {} - - def clear(self, chatId = None): - if chatId is None: - self.messages = {} - return True - elif chatId in self.messages: - del self.messages[chatId] - return True - - return False - - def add_message(self, chatId, message): - if chatId not in self.messages: - self.messages[chatId] = [] - - # clear the chat history if participant changed - if message["role"] == "user": - existing_messages = self.messages[chatId] - prev_user_message = next((m for m in reversed(existing_messages) if m["role"] == "user"), None) - if prev_user_message is not None: - current_prompt_parts = AIServiceManager.parse_prompt(message["content"]) - prev_prompt_parts = AIServiceManager.parse_prompt(prev_user_message["content"]) - if current_prompt_parts.participant != prev_prompt_parts.participant: - self.messages[chatId] = [] - - self.messages[chatId].append(message) - # limit number of messages kept in history - if len(self.messages[chatId]) > ChatHistory.MAX_MESSAGES: - self.messages[chatId] = self.messages[chatId][-ChatHistory.MAX_MESSAGES:] - - def get_history(self, chatId): - return self.messages.get(chatId, []) - class WebsocketCopilotResponseEmitter(ChatResponse): - def __init__(self, chatId, messageId, websocket_handler, chat_history): + def __init__(self, chatId, messageId, websocket_handler, chat_history, conversation_id=None): super().__init__() self.chatId = chatId self.messageId = messageId self.websocket_handler = websocket_handler self.chat_history = chat_history + self.conversation_id = conversation_id self.streamed_contents = [] self.streamed_reasoning_contents = [] # Capture the Tornado IOLoop the websocket lives on. stream() / @@ -1832,6 +1825,8 @@ def __init__(self, chatId, messageId, websocket_handler, chat_history): # data: object cannot be re-sized` (issue #264). Marshaling the # write back to the IOLoop's thread fixes it. self._io_loop = tornado.ioloop.IOLoop.current() + self.streamed_tool_calls = [] + self.streamed_markdown_parts = [] def _send_async(self, message: dict) -> None: self._io_loop.asyncio_loop.call_soon_threadsafe( @@ -1850,7 +1845,28 @@ def stream(self, data: Union[ResponseStreamData, dict]): data_type = ResponseStreamDataType.LLMRaw if type(data) is dict else data.data_type if data_type == ResponseStreamDataType.Markdown: - self.chat_history.add_message(self.chatId, {"role": "assistant", "content": data.content, "reasoning_content": data.reasoning_content}) + tool_calls = None + if isinstance(data.detail, dict) and data.detail.get("title") == "Parameters": + try: + tool_calls = [{ + "type": "ui_tool_parameters", + "arguments": json.loads(data.detail.get("content", "{}")) + }] + except Exception: + tool_calls = None + + if data.content is not None: + self.streamed_contents.append(data.content) + if data.reasoning_content is not None: + self.streamed_reasoning_contents.append(data.reasoning_content) + if tool_calls is not None: + self.streamed_tool_calls.extend(tool_calls) + self.streamed_markdown_parts.append({ + "type": "markdown", + "content": data.content or "", + "reasoning_content": data.reasoning_content, + "detail": data.detail + }) data = { "choices": [ { @@ -2029,6 +2045,16 @@ def stream(self, data: Union[ResponseStreamData, dict]): self.streamed_contents.append(content) if reasoning_content is not None: self.streamed_reasoning_contents.append(reasoning_content) + + # Now common part for all types to actually write to websocket + if data_type != ResponseStreamDataType.LLMRaw: + self._send_async({ + "id": self.messageId, + "participant": self.participant_id, + "type": BackendMessageType.StreamMessage, + "data": data, + "created": dt.datetime.now().isoformat() + }) else: # ResponseStreamDataType.LLMRaw if len(data.get("choices", [])) > 0: delta = data["choices"][0].get("delta", {}) @@ -2039,25 +2065,56 @@ def stream(self, data: Union[ResponseStreamData, dict]): if reasoning_content is not None: self.streamed_reasoning_contents.append(reasoning_content) - self._send_async({ - "id": self.messageId, - "participant": self.participant_id, - "type": BackendMessageType.StreamMessage, - "data": data, - "created": dt.datetime.now().isoformat() - }) + self._send_async({ + "id": self.messageId, + "participant": self.participant_id, + "type": BackendMessageType.StreamMessage, + "data": data, + "created": dt.datetime.now().isoformat() + }) def finish(self) -> None: - self.chat_history.add_message(self.chatId, {"role": "assistant", "content": "".join(self.streamed_contents), "reasoning_content": "".join(self.streamed_reasoning_contents)}) + content = "".join(self.streamed_contents) + reasoning_content = "".join(self.streamed_reasoning_contents) + persisted_tool_calls = list(self.streamed_tool_calls) + if self.streamed_markdown_parts: + persisted_tool_calls.append({ + "type": "ui_message_parts", + "parts": self.streamed_markdown_parts + }) + + if content or reasoning_content or persisted_tool_calls: + self.chat_history.add_message( + self.chatId, + { + "role": "assistant", + "content": content, + "reasoning_content": reasoning_content, + "tool_calls": persisted_tool_calls, + } + ) + + if self.conversation_id: + msg_id = str(uuid.uuid4()) + ai_service_manager.mysql_manager.add_message( + msg_id, + self.conversation_id, + "assistant", + content, + reasoning_content, + tool_calls=persisted_tool_calls, + ) + self.streamed_contents = [] self.streamed_reasoning_contents = [] + self.streamed_tool_calls = [] + self.streamed_markdown_parts = [] self._send_async({ "id": self.messageId, "participant": self.participant_id, "type": BackendMessageType.StreamEnd, "data": {} }) - async def run_ui_command(self, command: str, args: dict = {}) -> None: callback_id = str(uuid.uuid4()) self._send_async({ @@ -2087,13 +2144,13 @@ class MessageCallbackHandlers: response_emitter: WebsocketCopilotResponseEmitter cancel_token: CancelTokenImpl -class WebsocketCopilotHandler(WebSocketMixin, websocket.WebSocketHandler, JupyterHandler): +class WebsocketCopilotHandler(APIHandler, WebSocketMixin, websocket.WebSocketHandler, JupyterHandler): # Cap WS message size at 4 MiB. Largest legitimate payload is a chat # request with ~10 attached output-context items (each capped at 1 MiB # by `coerce_payload`) + chat history; 4 MiB covers that without # leaving the default 10 MiB headroom for memory amplification. max_message_size = 4 * 1024 * 1024 - + chat_history_ref = None # Inheritance matches Jupyter's first-party WS handlers (e.g. # KernelWebsocketHandler): ``WebSocketMixin`` adds ping/pong # keepalive plus a ``prepare`` that routes through Jupyter's @@ -2115,7 +2172,11 @@ def __init__(self, application, request, context_factory=None, **kwargs): # websocket — every long chat session leaked one emitter + # cancel token per turn. self._messageCallbackHandlers: dict[str, MessageCallbackHandlers] = {} - self.chat_history = ChatHistory() + global shared_chat_history + if shared_chat_history is None: + shared_chat_history = ChatHistory() + self.chat_history = shared_chat_history + WebsocketCopilotHandler.chat_history_ref = self.chat_history self._context_factory = context_factory or RuleContextFactory() ws_connector = ThreadSafeWebSocketConnector(self) ai_service_manager.websocket_connector = ws_connector @@ -2147,7 +2208,7 @@ def open(self): self.request.headers.get("Origin"), ) - def on_message(self, message): + async def on_message(self, message): msg = json.loads(message) messageId = msg['id'] @@ -2167,8 +2228,28 @@ def on_message(self, message): extension_tools=toolSelections.get('extensions', {}) ) + # MySQL logging: Create conversation and log user message (combined to ensure order) + conversation_id = str(uuid.uuid4()) + + # Better user detection for JupyterHub/TLJH and Jupyter Server 2.0+ + user_id = os.environ.get('JUPYTERHUB_USER') + if not user_id: + user = self.current_user + if user: + user_id = getattr(user, 'name', str(user)) + else: + user_id = "unknown" + + user_message_id = str(uuid.uuid4()) + ai_service_manager.mysql_manager.create_conversation_with_message( + conversation_id, user_id, chatId, chat_mode.id, + user_message_id, "user", prompt + ) + is_claude_code_mode = ai_service_manager.is_claude_code_mode - chat_history = self.chat_history.get_history(chatId) + # Copy current chat history for request context building, do not + # mutate shared in-memory history with transient context entries. + chat_history = list(await self.chat_history.get_history(chatId)) chat_history_initial_size = len(chat_history) current_directory = data.get('currentDirectory') @@ -2383,8 +2464,11 @@ def on_message(self, message): chat_history.append({"role": "user", "content": context_message}) chat_history.append({"role": "user", "content": prompt}) + # Persist the real user prompt in shared in-memory history so + # refresh fallback stays consistent even when DB data is delayed. + self.chat_history.add_message(chatId, {"role": "user", "content": prompt}) - response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history) + response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history, conversation_id=conversation_id) cancel_token = CancelTokenImpl() self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token) @@ -2398,7 +2482,7 @@ def on_message(self, message): # last prompt is added later request_chat_history = chat_history[chat_history_initial_size:-1] if is_claude_code_mode else chat_history[:-1] - coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context), response_emitter) + coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context, conversation_id=conversation_id), response_emitter) thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId)) thread.start() elif messageType == RequestDataType.GenerateCode: @@ -2412,6 +2496,25 @@ def on_message(self, message): filename = data['filename'] is_claude_code_mode = ai_service_manager.is_claude_code_mode chat_mode = ChatMode('inline-chat', 'Inline Chat') if is_claude_code_mode else ChatMode('ask', 'Ask') + + # MySQL logging: Create conversation and log user message (combined) + conversation_id = str(uuid.uuid4()) + + # Better user detection for JupyterHub/TLJH and Jupyter Server 2.0+ + user_id = os.environ.get('JUPYTERHUB_USER') + if not user_id: + user = self.current_user + if user: + user_id = getattr(user, 'name', str(user)) + else: + user_id = "unknown" + + user_message_id = str(uuid.uuid4()) + ai_service_manager.mysql_manager.create_conversation_with_message( + conversation_id, user_id, chatId, "inline-chat", + user_message_id, "user", prompt + ) + if prefix != '': self.chat_history.add_message(chatId, {"role": "user", "content": f"This code section comes before the code section you will generate, use as context. Leading content: ```{prefix}```"}) if suffix != '': @@ -2419,7 +2522,7 @@ def on_message(self, message): if existing_code != '': self.chat_history.add_message(chatId, {"role": "user", "content": f"You are asked to modify the existing code. Generate a replacement for this existing code : ```{existing_code}```"}) self.chat_history.add_message(chatId, {"role": "user", "content": f"Generate code for: {prompt}"}) - response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history) + response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history, conversation_id=conversation_id) cancel_token = CancelTokenImpl() self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token) existing_code_message = " Update the existing code section and return a modified version. Don't just return the update, recreate the existing code section with the update." if existing_code != '' else '' @@ -2433,7 +2536,8 @@ def on_message(self, message): root_dir=NotebookIntelligence.root_dir ) - coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=self.chat_history.get_history(chatId), cancel_token=cancel_token, rule_context=rule_context), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}) + chat_history = await self.chat_history.get_history(chatId) + coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=chat_history, cancel_token=cancel_token, rule_context=rule_context, conversation_id=conversation_id), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}) thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId)) thread.start() elif messageType == RequestDataType.InlineCompletionRequest: @@ -2502,6 +2606,165 @@ async def handle_inline_completions(prefix, suffix, language, filename, response response_emitter.stream({"completions": completions}) response_emitter.finish() + +class GetChatHistoryHandler(APIHandler): + @tornado.web.authenticated + async def get(self): + chat_id = self.get_argument("chatId", None) + if not chat_id: + self.set_status(400) + self.finish(json.dumps({"error": "chatId is required"})) + return + + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + messages = [] + if history_mode == "mysql": + messages = await ai_service_manager.mysql_manager.get_messages_by_chat_id(chat_id) + global shared_chat_history + if shared_chat_history is None: + shared_chat_history = ChatHistory() + in_memory = await shared_chat_history.get_history(chat_id) + ws_history = [] + if WebsocketCopilotHandler.chat_history_ref is not None: + ws_history = await WebsocketCopilotHandler.chat_history_ref.get_history(chat_id) + if len(ws_history) > len(in_memory): + in_memory = ws_history + + # For local mode, only use in-memory data. For mysql mode prefer the + # more complete source for live sessions because DB writes are async. + if history_mode == "local" or (history_mode == "mysql" and len(in_memory) > len(messages)): + messages = [] + for item in in_memory: + messages.append({ + "role": item.get("role", "assistant"), + "content": item.get("content", ""), + "reasoning_content": item.get("reasoning_content"), + "tool_calls": item.get("tool_calls"), + "tool_call_id": item.get("tool_call_id"), + "created_at": dt.datetime.now(dt.timezone.utc), + }) + if history_mode == "none": + messages = [] + # Convert datetime to string for JSON serialization + for msg in messages: + if isinstance(msg.get('created_at'), dt.datetime): + msg['created_at'] = msg['created_at'].isoformat() + if msg.get('tool_calls'): + try: + msg['tool_calls'] = json.loads(msg['tool_calls']) + except: + pass + + self.finish(json.dumps({"messages": messages})) + +class GetRecentConversationsHandler(APIHandler): + @tornado.web.authenticated + async def get(self): + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + if history_mode == "none": + self.finish(json.dumps({"conversations": []})) + return + + user_id = os.environ.get('JUPYTERHUB_USER') + if not user_id: + user = self.current_user + if user: + user_id = getattr(user, 'name', str(user)) + else: + user_id = "unknown" + + if history_mode == "local": + global shared_chat_history + if shared_chat_history is None: + shared_chat_history = ChatHistory() + # Keep ordering stable: most recently touched chat first. Local + # mode has in-memory transcripts only, so surface chat IDs with + # a synthetic timestamp for the sidebar conversation picker. + now = dt.datetime.now(dt.timezone.utc).isoformat() + conversation_ids = list(reversed(list(shared_chat_history.messages.keys()))) + conversations = [ + { + "chat_id": chat_id, + "chat_mode": "ask", + "last_message_at": now, + } + for chat_id in conversation_ids + ] + self.finish(json.dumps({"conversations": conversations})) + return + + conversations = await ai_service_manager.mysql_manager.get_recent_conversations(user_id) + for conv in conversations: + if isinstance(conv.get('last_message_at'), dt.datetime): + conv['last_message_at'] = conv['last_message_at'].isoformat() + + self.finish(json.dumps({"conversations": conversations})) + +class ChatHistory: + """ + History of chat messages, key is chat id, value is list of messages + keep the last 10 messages in the same chat participant + """ + DEFAULT_MAX_MESSAGES = 10 + + def __init__(self): + self.messages = {} + + def clear(self, chatId = None): + if chatId is None: + self.messages = {} + return True + elif chatId in self.messages: + del self.messages[chatId] + return True + + return False + + def add_message(self, chatId, message): + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + + if chatId not in self.messages: + self.messages[chatId] = [] + + # clear the chat history if participant changed + if message["role"] == "user": + existing_messages = self.messages[chatId] + prev_user_message = next((m for m in reversed(existing_messages) if m["role"] == "user"), None) + if prev_user_message is not None: + current_prompt_parts = AIServiceManager.parse_prompt(message["content"]) + prev_prompt_parts = AIServiceManager.parse_prompt(prev_user_message["content"]) + if current_prompt_parts.participant != prev_prompt_parts.participant: + self.messages[chatId] = [] + + self.messages[chatId].append(message) + # limit number of messages kept in history only in local mode + if history_mode == "local": + max_messages = ai_service_manager.nbi_config.history_config.get( + "local_max_messages", ChatHistory.DEFAULT_MAX_MESSAGES + ) + if len(self.messages[chatId]) > max_messages: + self.messages[chatId] = self.messages[chatId][-max_messages:] + + async def get_history(self, chatId): + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + if chatId not in self.messages: + if history_mode == "mysql": + messages = await ai_service_manager.mysql_manager.get_messages_by_chat_id(chatId) + if messages: + # Convert from DB format to chat history format + self.messages[chatId] = [{"role": m["role"], "content": m["content"]} for m in messages] + else: + self.messages[chatId] = [] + + if history_mode == "local": + max_messages = ai_service_manager.nbi_config.history_config.get( + "local_max_messages", ChatHistory.DEFAULT_MAX_MESSAGES + ) + if len(self.messages[chatId]) > max_messages: + self.messages[chatId] = self.messages[chatId][-max_messages:] + + return self.messages.get(chatId, []) + class NotebookIntelligence(ExtensionApp): name = "notebook_intelligence" default_url = "/notebook-intelligence" @@ -3053,6 +3316,8 @@ def _setup_handlers(self, web_app, feature_policies: dict, string_overrides: dic r"([^/]+)", "update", ) + route_pattern_history = url_path_join(base_url, "notebook-intelligence", "history") + route_pattern_conversations = url_path_join(base_url, "notebook-intelligence", "conversations") GetCapabilitiesHandler.disabled_tools = self.disabled_tools GetCapabilitiesHandler.allow_enabling_tools_with_env = self.allow_enabling_tools_with_env GetCapabilitiesHandler.disabled_providers = self.disabled_providers @@ -3177,6 +3442,8 @@ def _setup_handlers(self, web_app, feature_policies: dict, string_overrides: dic (route_pattern_plugins_marketplace, PluginsMarketplaceListHandler), (route_pattern_plugins_detail, PluginsDetailHandler), (route_pattern_plugins, PluginsListHandler), + (route_pattern_history, GetChatHistoryHandler), + (route_pattern_conversations, GetRecentConversationsHandler), (route_pattern_copilot, WebsocketCopilotHandler), ] web_app.add_handlers(host_pattern, NotebookIntelligence.handlers) diff --git a/notebook_intelligence/github_copilot.py b/notebook_intelligence/github_copilot.py index a21778d..2351b38 100644 --- a/notebook_intelligence/github_copilot.py +++ b/notebook_intelligence/github_copilot.py @@ -14,6 +14,7 @@ import logging from notebook_intelligence.api import BackendMessageType, CancelToken, ChatResponse, CompletionContext, MarkdownData from notebook_intelligence.config import _atomic_write_json +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls from notebook_intelligence.util import decrypt_with_password, encrypt_with_password, ThreadSafeWebSocketConnector from ._version import __version__ as NBI_VERSION @@ -983,9 +984,10 @@ def completions(model_id, messages, tools = None, response: ChatResponse = None, aggregate = response is None try: + sanitized_messages = sanitize_chat_history_tool_calls(messages) data = { 'model': model_id, - 'messages': messages, + 'messages': sanitized_messages, 'tools': tools, 'temperature': 0, 'top_p': 1, diff --git a/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py b/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py index 81d595b..57014a9 100644 --- a/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py +++ b/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py @@ -3,6 +3,7 @@ import json from typing import Any from notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext, LLMProviderProperty +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls import litellm DEFAULT_CONTEXT_WINDOW = 4096 @@ -42,9 +43,10 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: base_url = self.get_property("base_url").value api_key_prop = self.get_property("api_key") api_key = api_key_prop.value if api_key_prop is not None else None + sanitized_messages = sanitize_chat_history_tool_calls(messages) litellm_resp = litellm.completion( model=model_id, - messages=messages.copy(), + messages=sanitized_messages, tools=tools, tool_choice=options.get("tool_choice", None), api_base=base_url, diff --git a/notebook_intelligence/llm_providers/ollama_llm_provider.py b/notebook_intelligence/llm_providers/ollama_llm_provider.py index d4fa6f0..11e9c9f 100644 --- a/notebook_intelligence/llm_providers/ollama_llm_provider.py +++ b/notebook_intelligence/llm_providers/ollama_llm_provider.py @@ -3,6 +3,7 @@ import json from typing import Any from notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls import ollama import logging @@ -38,9 +39,10 @@ def context_window(self) -> int: def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: stream = response is not None + sanitized_messages = sanitize_chat_history_tool_calls(messages) completion_args = { "model": self._model_id, - "messages": messages.copy(), + "messages": sanitized_messages, "stream": stream, } if tools is not None and len(tools) > 0: diff --git a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py index 750de47..7caa707 100644 --- a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py +++ b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py @@ -1,10 +1,10 @@ # Copyright (c) Mehmet Bektas -import copy import json import re from typing import Any from notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext, LLMProviderProperty +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls from openai import OpenAI, omit INLINE_COMPLETION_SYSTEM_PROMPT = """You are a code completion assistant. Your task is to generate intelligent autocomplete suggestions for the code at the cursor position for given language and active file type. This is not an interactive session, don't ask for clarifying questions, always generate a suggestion. Don't include any explanations for your response, just generate the code. Don't return any thinking or reasoning, just generate the code. You are given a code snippet with a prefix and a suffix. You need to generate a suggestion for the code that fits best in place of . You should return only the code that fits best in place of . You should provide multiline code if needed. Enclose the code in triple backticks, just return the code in language. You should not return any other text, just the code. DO NOT INCLUDE THE PREFIX OR SUFFIX IN THE RESPONSE. .ipynb files are Jupyter notebook files and for notebook files, you generate suggestions for a cell within the notebook. A cell can be a code cell with code or a markdown cell with markdown text. If the language is markdown, only return markdown text. If you need to install a Python package within a notebook cell code (for .ipynb files), use %pip install instead of !pip install . Follow the tags very carefully for proper spacing and indentations.""" @@ -17,6 +17,7 @@ def sanitize_tools_for_openai_compatible(tools: list[dict] | None) -> list[dict] if tools is None: return None + import copy sanitized_tools = copy.deepcopy(tools) for tool in sanitized_tools: function_schema = tool.get("function") @@ -25,6 +26,10 @@ def sanitize_tools_for_openai_compatible(tools: list[dict] | None) -> list[dict] return sanitized_tools +def sanitize_messages_for_openai_compatible(messages: list[dict] | None) -> list[dict]: + return sanitize_chat_history_tool_calls(messages) + + class OpenAICompatibleChatModel(ChatModel): def __init__(self, provider: "OpenAICompatibleLLMProvider"): super().__init__(provider) @@ -54,6 +59,10 @@ def context_window(self) -> int: except: return DEFAULT_CONTEXT_WINDOW + @property + def supports_tools(self) -> bool: + return True + def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: stream = response is not None model_id = self.get_property("model_id").value @@ -65,7 +74,7 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: client = OpenAI(base_url=base_url, api_key=api_key) resp = client.chat.completions.create( model=model_id, - messages=messages.copy(), + messages=sanitize_messages_for_openai_compatible(messages), tools=sanitize_tools_for_openai_compatible(tools) or omit, tool_choice=options.get("tool_choice", omit), stream=stream, @@ -79,12 +88,18 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: reasoning = getattr(delta, 'reasoning_content', None) or getattr(delta, 'reasoning', None) if reasoning is not None: reasoning = str(reasoning) + + tool_calls = None + if hasattr(delta, 'tool_calls') and delta.tool_calls: + tool_calls = [json.loads(tc.model_dump_json()) for tc in delta.tool_calls] + response.stream({ "choices": [{ "delta": { "role": delta.role, "content": delta.content, - "reasoning_content": reasoning + "reasoning_content": reasoning, + "tool_calls": tool_calls } }] }) diff --git a/notebook_intelligence/message_sanitizer.py b/notebook_intelligence/message_sanitizer.py new file mode 100644 index 0000000..acb8560 --- /dev/null +++ b/notebook_intelligence/message_sanitizer.py @@ -0,0 +1,44 @@ +# Copyright (c) Mehmet Bektas + +import copy + + +def sanitize_chat_history_tool_calls(messages: list[dict] | None) -> list[dict]: + """Drop UI-only tool-call metadata before replaying chat history. + + History persistence stores frontend replay metadata such as + ``ui_tool_parameters`` and ``ui_message_parts`` inside assistant + ``tool_calls``. Those shapes are not provider-native tool calls and can + trigger request validation errors when the prior assistant turn is sent + back to the model on a later request. + """ + if messages is None: + return [] + + sanitized_messages = copy.deepcopy(messages) + for message in sanitized_messages: + tool_calls = message.get("tool_calls") + if not isinstance(tool_calls, list): + continue + + valid_tool_calls = [] + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + + tool_call_type = tool_call.get("type") + if tool_call_type == "function": + function_payload = tool_call.get("function") + if tool_call.get("id") and isinstance(function_payload, dict): + valid_tool_calls.append(tool_call) + elif tool_call_type == "custom": + custom_payload = tool_call.get("custom") + if tool_call.get("id") and custom_payload is not None: + valid_tool_calls.append(tool_call) + + if valid_tool_calls: + message["tool_calls"] = valid_tool_calls + else: + message.pop("tool_calls", None) + + return sanitized_messages diff --git a/notebook_intelligence/mysql_manager.py b/notebook_intelligence/mysql_manager.py new file mode 100644 index 0000000..f7ecf3e --- /dev/null +++ b/notebook_intelligence/mysql_manager.py @@ -0,0 +1,342 @@ +# Copyright (c) Mehmet Bektas + +import json +import logging +import uuid +import datetime +import asyncio +from typing import Optional, List, Dict, Any + +try: + import aiomysql + HAS_AIOMYSQL = True +except ImportError: + aiomysql = None + HAS_AIOMYSQL = False + +log = logging.getLogger(__name__) + +class MySQLManager: + _instance = None + _lock = None + + def __new__(cls, config=None): + if cls._instance is None: + cls._instance = super(MySQLManager, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config=None): + config = config or {} + if not self._initialized: + self.pool = None + self.loop = None + self._lock_obj = None + self.config = {} + self.enabled = False + self.host = 'localhost' + self.port = 3306 + self.user = '' + self.password = '' + self.database = 'notebook_intelligence' + self._initialized = True + + self._apply_config(config) + + def _apply_config(self, config: dict): + new_config = config or {} + if getattr(self, "config", {}) == new_config: + return + + # Close old pool handle when switching configs. + if self.pool is not None: + try: + self.pool.close() + except Exception: + pass + self.pool = None + + self.loop = None + self._lock_obj = None + self.config = new_config + self.enabled = self.config.get('enabled', False) and HAS_AIOMYSQL + self.host = self.config.get('host', 'localhost') + self.port = self.config.get('port', 3306) + self.user = self.config.get('user', '') + self.password = self.config.get('password', '') + self.database = self.config.get('database', 'notebook_intelligence') + + if self.enabled: + log.info(f"MySQL logging enabled for host: {self.host}") + elif self.config.get('enabled', False): + log.error("aiomysql not found. Please install it with 'pip install aiomysql' to use MySQL logging.") + else: + log.info("MySQL logging disabled.") + + def _get_lock(self): + if self._lock_obj is None: + self._lock_obj = asyncio.Lock() + return self._lock_obj + + async def _get_pool(self): + if not self.enabled or not HAS_AIOMYSQL: + return None + + if self.pool is not None: + return self.pool + + # Capture the loop that initializes the pool + if self.loop is None: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + # Fallback for cases where loop isn't running yet + return None + + async with self._get_lock(): + if self.pool is not None: + return self.pool + + try: + # First, connect without db to ensure database exists + temp_conn = await aiomysql.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + connect_timeout=3, + autocommit=True + ) + async with temp_conn.cursor() as cur: + await cur.execute(f"CREATE DATABASE IF NOT EXISTS {self.database} CHARACTER SET utf8mb4;") + temp_conn.close() + + # Now connect to the pool with the database + self.pool = await aiomysql.create_pool( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + db=self.database, + connect_timeout=3, + autocommit=True, + charset='utf8mb4' + ) + await self._ensure_tables() + return self.pool + except Exception as e: + log.error(f"Failed to connect to MySQL: {str(e)}") + self.enabled = False + return None + + async def _ensure_tables(self): + async with self.pool.acquire() as conn: + async with conn.cursor() as cur: + # Conversations table with auto-incrementing ID + await cur.execute(""" + CREATE TABLE IF NOT EXISTS nbi_conversations ( + id_pk INT AUTO_INCREMENT PRIMARY KEY, + id CHAR(36) UNIQUE, + user_id VARCHAR(255), + chat_id VARCHAR(255), + chat_mode VARCHAR(50), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """) + + # Messages table with auto-incrementing ID + await cur.execute(""" + CREATE TABLE IF NOT EXISTS nbi_messages ( + id_pk INT AUTO_INCREMENT PRIMARY KEY, + id CHAR(36) UNIQUE, + conversation_id CHAR(36), + role VARCHAR(50), + content LONGTEXT, + reasoning_content LONGTEXT, + tool_calls JSON, + tool_call_id VARCHAR(255), + message_order_at DATETIME(6), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES nbi_conversations(id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """) + try: + await cur.execute(""" + CREATE INDEX idx_nbi_messages_order_at + ON nbi_messages(message_order_at) + """) + except Exception as e: + # MySQL raises duplicate-key error when the index already + # exists. Keep initialization idempotent across re-enables. + if "Duplicate key name" not in str(e): + raise + + # Tool executions table with auto-incrementing ID + await cur.execute(""" + CREATE TABLE IF NOT EXISTS nbi_tool_executions ( + id_pk INT AUTO_INCREMENT PRIMARY KEY, + id VARCHAR(255) UNIQUE, + conversation_id CHAR(36), + tool_name VARCHAR(255), + arguments JSON, + output LONGTEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES nbi_conversations(id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """) + + async def test_connection(self) -> tuple[bool, str]: + """Validate MySQL connectivity and table initialization.""" + if not self.config.get('enabled', False): + return False, "MySQL is not enabled in configuration." + if not HAS_AIOMYSQL: + return False, "aiomysql is not installed." + + # Force a fresh attempt when validating. + self.enabled = True + pool = await self._get_pool() + if not pool: + return False, f"Unable to connect to MySQL server {self.host}:{self.port}." + return True, "" + + def _run_task(self, coro): + """Run a coroutine in the correct event loop.""" + if not self.enabled: + return + + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + # If no loop is running in this thread, we need to handle it + # For now, we assume the first call happened in a thread with a loop (like Tornado) + return + + if self.loop and current_loop != self.loop: + # Schedule on the pool's home loop + asyncio.run_coroutine_threadsafe(coro, self.loop) + else: + # Already in the right loop or first call + asyncio.create_task(coro) + + def create_conversation_with_message(self, conv_id: str, user_id: str, chat_id: str, chat_mode: str, + msg_id: str, role: str, content: str): + if not self.enabled: + return + self._run_task(self._create_conversation_with_message_internal(conv_id, user_id, chat_id, chat_mode, msg_id, role, content)) + + async def _create_conversation_with_message_internal(self, conv_id: str, user_id: str, chat_id: str, chat_mode: str, + msg_id: str, role: str, content: str): + await self._create_conversation_internal(conv_id, user_id, chat_id, chat_mode) + await self._add_message_internal(msg_id, conv_id, role, content) + + def create_conversation(self, conv_id: str, user_id: str, chat_id: str, chat_mode: str): + if not self.enabled: + return + self._run_task(self._create_conversation_internal(conv_id, user_id, chat_id, chat_mode)) + + async def _create_conversation_internal(self, conv_id: str, user_id: str, chat_id: str, chat_mode: str): + pool = await self._get_pool() + if not pool: return + try: + async with pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "INSERT IGNORE INTO nbi_conversations (id, user_id, chat_id, chat_mode) VALUES (%s, %s, %s, %s)", + (conv_id, user_id, chat_id, chat_mode) + ) + except Exception as e: + log.error(f"Error creating conversation in MySQL: {str(e)}") + + def add_message(self, message_id: str, conv_id: str, role: str, content: str, + reasoning_content: Optional[str] = None, + tool_calls: Optional[List[Dict]] = None, + tool_call_id: Optional[str] = None): + if not self.enabled: + return + # Skip logging if message is completely empty + if not content and not reasoning_content and not tool_calls and not tool_call_id: + return + self._run_task(self._add_message_internal(message_id, conv_id, role, content, reasoning_content, tool_calls, tool_call_id)) + + async def _add_message_internal(self, message_id: str, conv_id: str, role: str, content: str, + reasoning_content: Optional[str] = None, + tool_calls: Optional[List[Dict]] = None, + tool_call_id: Optional[str] = None): + pool = await self._get_pool() + if not pool: return + try: + tool_calls_json = json.dumps(tool_calls) if tool_calls else None + async with pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """INSERT IGNORE INTO nbi_messages + (id, conversation_id, role, content, reasoning_content, tool_calls, tool_call_id, message_order_at) + VALUES (%s, %s, %s, %s, %s, %s, %s, UTC_TIMESTAMP(6))""", + (message_id, conv_id, role, content, reasoning_content, tool_calls_json, tool_call_id) + ) + except Exception as e: + log.error(f"Error adding message to MySQL: {str(e)}") + + def log_tool_execution(self, tool_call_id: str, conv_id: str, tool_name: str, + arguments: Dict, output: str): + if not self.enabled: + return + self._run_task(self._log_tool_execution_internal(tool_call_id, conv_id, tool_name, arguments, output)) + + async def _log_tool_execution_internal(self, tool_call_id: str, conv_id: str, tool_name: str, + arguments: Dict, output: str): + pool = await self._get_pool() + if not pool: return + try: + arguments_json = json.dumps(arguments) + async with pool.acquire() as conn: + async with conn.cursor() as cur: + # Use INSERT ... ON DUPLICATE KEY UPDATE in case output is updated later + await cur.execute( + """INSERT INTO nbi_tool_executions + (id, conversation_id, tool_name, arguments, output) + VALUES (%s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE output = VALUES(output)""", + (tool_call_id, conv_id, tool_name, arguments_json, output) + ) + except Exception as e: + log.error(f"Error logging tool execution to MySQL: {str(e)}") + + async def get_messages_by_chat_id(self, chat_id: str) -> List[Dict[str, Any]]: + pool = await self._get_pool() + if not pool: return [] + try: + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cur: + await cur.execute( + """SELECT m.role, m.content, m.reasoning_content, m.tool_calls, m.tool_call_id, m.created_at + FROM nbi_messages m + JOIN nbi_conversations c ON m.conversation_id = c.id + WHERE c.chat_id = %s + ORDER BY m.message_order_at ASC""", + (chat_id,) + ) + return await cur.fetchall() + except Exception as e: + log.error(f"Error getting messages from MySQL: {str(e)}") + return [] + + async def get_recent_conversations(self, user_id: str, limit: int = 20) -> List[Dict[str, Any]]: + pool = await self._get_pool() + if not pool: return [] + try: + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cur: + await cur.execute( + """SELECT chat_id, chat_mode, MAX(created_at) as last_message_at + FROM nbi_conversations + WHERE user_id = %s + GROUP BY chat_id, chat_mode + ORDER BY last_message_at DESC + LIMIT %s""", + (user_id, limit) + ) + return await cur.fetchall() + except Exception as e: + log.error(f"Error getting recent conversations from MySQL: {str(e)}") + return [] diff --git a/pyproject.toml b/pyproject.toml index 8f76de1..ac422ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ # ``set_key`` / ``unset_key`` write paths only, which neither NBI # nor any of its other deps use. Re-add a ``>=1.2.2`` lower bound # when litellm loosens its pin. + "aiomysql" ] dynamic = ["version", "description", "authors", "urls", "keywords"] diff --git a/src/api.ts b/src/api.ts index ae3c61d..d111c8e 100644 --- a/src/api.ts +++ b/src/api.ts @@ -388,6 +388,28 @@ export class NBIConfig { return this.capabilities.claude_settings; } + get mysqlConfig(): any { + return ( + this.capabilities.mysql_config ?? { + enabled: false, + host: 'localhost', + port: 3306, + user: '', + password: '', + database: 'notebook_intelligence' + } + ); + } + + get historyConfig(): any { + return ( + this.capabilities.history_config ?? { + mode: 'local', + local_max_messages: 10 + } + ); + } + get claudeModels(): IClaudeModelInfo[] { return (this.capabilities.claude_models ?? []).map(claudeModelFromWire); } @@ -755,16 +777,17 @@ export class NBIAPI { }); } - static async setConfig(config: any) { - requestAPI('config', { + static async setConfig(config: any): Promise { + return requestAPI('config', { method: 'POST', body: JSON.stringify(config) }) .then(data => { - NBIAPI.fetchCapabilities(); + return NBIAPI.fetchCapabilities().then(() => data); }) .catch(reason => { console.error(`Failed to set NBI config.\n${reason}`); + throw reason; }); } @@ -1421,4 +1444,30 @@ export class NBIAPI { }); }); } + + static async fetchChatHistory(chatId: string): Promise { + return new Promise((resolve, reject) => { + requestAPI(`history?chatId=${chatId}`, { method: 'GET' }) + .then(data => { + resolve(data.messages); + }) + .catch(reason => { + console.error(`Failed to fetch chat history.\n${reason}`); + reject(reason); + }); + }); + } + + static async fetchRecentConversations(): Promise { + return new Promise((resolve, reject) => { + requestAPI('conversations', { method: 'GET' }) + .then(data => { + resolve(data.conversations); + }) + .catch(reason => { + console.error(`Failed to fetch recent conversations.\n${reason}`); + reject(reason); + }); + }); + } } diff --git a/src/chat-sidebar.tsx b/src/chat-sidebar.tsx index 3ededa1..e109f7c 100644 --- a/src/chat-sidebar.tsx +++ b/src/chat-sidebar.tsx @@ -355,6 +355,32 @@ interface IChatMessage { chatModel?: { provider: string; model: string }; } +function chatCacheKey(chatId: string): string { + return `nbi_chat_cache_${chatId}`; +} + +function serializeChatMessages(messages: IChatMessage[]): any[] { + return messages.map(msg => ({ + ...msg, + date: msg.date?.toISOString?.() || new Date().toISOString(), + contents: msg.contents.map(content => ({ + ...content, + created: content.created?.toISOString?.() || new Date().toISOString() + })) + })); +} + +function deserializeChatMessages(serialized: any[]): IChatMessage[] { + return (serialized || []).map((msg: any) => ({ + ...msg, + date: new Date(msg.date), + contents: (msg.contents || []).map((content: any) => ({ + ...content, + created: new Date(content.created) + })) + })); +} + interface IWorkspaceFileOption { name: string; path: string; @@ -1207,9 +1233,199 @@ function SidebarComponent(props: any) { const [promptHistory, setPromptHistory] = useState([]); // position on prompt history stack const [promptHistoryIndex, setPromptHistoryIndex] = useState(0); - const [chatId, setChatId] = useState(UUID.uuid4()); + const [chatId, setChatId] = useState(() => { + const historyMode = NBIAPI.config.historyConfig?.mode ?? 'local'; + if (historyMode === 'none') { + return UUID.uuid4(); + } + const savedChatId = localStorage.getItem('nbi_last_chat_id'); + return savedChatId || UUID.uuid4(); + }); const lastMessageId = useRef(''); const lastRequestTime = useRef(new Date()); + + const historyMode = NBIAPI.config.historyConfig?.mode ?? 'local'; + const prevHistoryModeRef = useRef(historyMode); + + useEffect(() => { + const prevMode = prevHistoryModeRef.current; + const shouldResetTimeline = + (prevMode === 'none' && historyMode !== 'none') || + (prevMode === 'local' && historyMode === 'mysql'); + if (shouldResetTimeline) { + setChatId(UUID.uuid4()); + setChatMessages([]); + } + prevHistoryModeRef.current = historyMode; + }, [historyMode]); + + useEffect(() => { + if (historyMode === 'none') { + localStorage.removeItem('nbi_last_chat_id'); + return; + } + localStorage.setItem('nbi_last_chat_id', chatId); + }, [chatId, historyMode]); + + useEffect(() => { + const fetchHistory = async () => { + try { + const history = await NBIAPI.fetchChatHistory(chatId); + if (history && history.length > 0) { + const formattedMessages: IChatMessage[] = []; + + // History from backend is a list of individual messages + // We need to group them by assistant response if needed, + // or just map them to IChatMessage + for (const msg of history) { + const date = new Date(msg.created_at); + const hasMessageContent = + !!msg.content || + !!msg.reasoning_content || + (Array.isArray(msg.tool_calls) && msg.tool_calls.length > 0); + + // Ignore empty assistant/tool rows from legacy persistence to keep + // refresh state identical to pre-refresh streaming view. + if ( + !hasMessageContent && + (msg.role === 'assistant' || msg.role === 'tool') + ) { + continue; + } + + const toolCalls = Array.isArray(msg.tool_calls) + ? msg.tool_calls + : []; + const serializedParts = toolCalls.find( + (call: any) => + call?.type === 'ui_message_parts' && Array.isArray(call.parts) + ); + const parameterCalls = toolCalls.filter( + (call: any) => + call?.type === 'ui_tool_parameters' && + call.arguments !== undefined + ); + + let contents: IChatMessageContent[] = []; + if (serializedParts && msg.role === 'assistant') { + contents = serializedParts.parts.map((part: any) => ({ + id: UUID.uuid4(), + type: ResponseStreamDataType.Markdown, + content: part?.content || '', + reasoningContent: part?.reasoning_content || '', + reasoningFinished: !!part?.reasoning_content, + contentDetail: part?.detail, + created: date + })); + } else { + const toolCallDetail = + parameterCalls.length > 0 + ? { + title: 'Parameters', + content: `\`\`\`json\n${JSON.stringify( + parameterCalls.map((call: any) => call.arguments), + null, + 2 + )}\n\`\`\`` + } + : undefined; + contents = [ + { + id: UUID.uuid4(), + type: ResponseStreamDataType.Markdown, + content: + msg.content || + (msg.role === 'tool' ? '(Tool execution output)' : ''), + reasoningContent: msg.reasoning_content, + reasoningFinished: !!msg.reasoning_content, + contentDetail: toolCallDetail, + created: date + } + ]; + } + + if (msg.role === 'user') { + formattedMessages.push({ + id: UUID.uuid4(), + date, + from: 'user', + contents + }); + } else if (msg.role === 'assistant' || msg.role === 'tool') { + formattedMessages.push({ + id: UUID.uuid4(), + date, + from: 'copilot', + contents, + participant: NBIAPI.config.chatParticipants.find( + p => p.id === msg.participant_id + ) + }); + } + } + setChatMessages(formattedMessages); + if (historyMode === 'local') { + try { + localStorage.setItem( + chatCacheKey(chatId), + JSON.stringify(serializeChatMessages(formattedMessages)) + ); + } catch (e) { + console.warn('Failed to write chat cache', e); + } + } + } else { + if (historyMode === 'local') { + const cached = localStorage.getItem(chatCacheKey(chatId)); + if (cached) { + try { + const parsed = JSON.parse(cached); + setChatMessages(deserializeChatMessages(parsed)); + } catch (e) { + console.warn('Failed to parse chat cache', e); + } + } else { + setChatMessages([]); + } + } else { + setChatMessages([]); + } + } + } catch (error) { + console.error('Failed to fetch chat history:', error); + if (historyMode === 'local') { + const cached = localStorage.getItem(chatCacheKey(chatId)); + if (cached) { + try { + const parsed = JSON.parse(cached); + setChatMessages(deserializeChatMessages(parsed)); + } catch (e) { + console.warn('Failed to parse chat cache after history error', e); + } + } + } else { + setChatMessages([]); + } + } + }; + + fetchHistory(); + }, [chatId, historyMode]); + + useEffect(() => { + if (historyMode === 'local') { + try { + localStorage.setItem( + chatCacheKey(chatId), + JSON.stringify(serializeChatMessages(chatMessages)) + ); + } catch (e) { + console.warn('Failed to persist chat cache', e); + } + } else { + localStorage.removeItem(chatCacheKey(chatId)); + } + }, [chatId, chatMessages, historyMode]); const [contextOn, setContextOn] = useState(false); const [activeDocumentInfo, setActiveDocumentInfo] = useState(null); @@ -4750,7 +4966,7 @@ function InlinePromptComponent(props: any) { submitCompletionRequest( { - messageId, + messageId: UUID.uuid4(), chatId: UUID.uuid4(), type: RunChatCompletionType.GenerateCode, content: prompt, diff --git a/src/components/settings-panel.tsx b/src/components/settings-panel.tsx index ff5d8a9..8724724 100644 --- a/src/components/settings-panel.tsx +++ b/src/components/settings-panel.tsx @@ -312,6 +312,7 @@ function SettingsPanelComponentGeneral(props: any) { const isInClaudeCodeMode = nbiConfig.isInClaudeCodeMode; const handleSaveSettings = async () => { + const mysqlEnabled = historyMode === 'mysql'; const config: any = { default_chat_mode: defaultChatMode, chat_model: { @@ -324,7 +325,19 @@ function SettingsPanelComponentGeneral(props: any) { model: inlineCompletionModel, properties: inlineCompletionModelProperties }, - inline_completion_debouncer_delay: inlineCompletionDebouncerDelay + inline_completion_debouncer_delay: inlineCompletionDebouncerDelay, + history_config: { + mode: historyMode, + local_max_messages: localMaxMessages + }, + mysql_config: { + enabled: mysqlEnabled, + host: mysqlHost, + port: mysqlPort, + user: mysqlUser, + password: mysqlPassword, + database: mysqlDatabase + } }; if ( @@ -334,7 +347,34 @@ function SettingsPanelComponentGeneral(props: any) { config.store_github_access_token = storeGitHubAccessToken; } - await NBIAPI.setConfig(config); + try { + await NBIAPI.setConfig(config); + const initialMode = initialHistoryModeRef.current; + if (initialMode === 'local' && historyMode === 'mysql') { + localStorage.removeItem('nbi_last_chat_id'); + } + initialHistoryModeRef.current = historyMode; + } catch (error: any) { + const message = + (error && (error.message || error.toString())) || + 'Unknown config error'; + window.alert( + `Failed to apply history settings.\n${message}\n\nHistory mode has been downgraded to 'none' if MySQL is unavailable.` + ); + await NBIAPI.fetchCapabilities(); + const refreshedHistory = NBIAPI.config.historyConfig ?? { + mode: 'none', + local_max_messages: 10 + }; + const refreshedMysql = NBIAPI.config.mysqlConfig ?? {}; + setHistoryMode(refreshedHistory.mode ?? 'none'); + setLocalMaxMessages(Number(refreshedHistory.local_max_messages ?? 10)); + setMysqlHost(refreshedMysql.host ?? 'localhost'); + setMysqlPort(Number(refreshedMysql.port ?? 3306)); + setMysqlUser(refreshedMysql.user ?? ''); + setMysqlPassword(refreshedMysql.password ?? ''); + setMysqlDatabase(refreshedMysql.database ?? 'notebook_intelligence'); + } props.onSave(); }; @@ -383,6 +423,29 @@ function SettingsPanelComponentGeneral(props: any) { enable_output_toolbar: !featurePolicies.output_toolbar.enabled }); }; + const [historyMode, setHistoryMode] = useState( + nbiConfig.historyConfig?.mode ?? 'local' + ); + const [localMaxMessages, setLocalMaxMessages] = useState( + Number(nbiConfig.historyConfig?.local_max_messages ?? 10) + ); + const [mysqlHost, setMysqlHost] = useState( + nbiConfig.mysqlConfig?.host ?? 'localhost' + ); + const [mysqlPort, setMysqlPort] = useState( + Number(nbiConfig.mysqlConfig?.port ?? 3306) + ); + const [mysqlUser, setMysqlUser] = useState(nbiConfig.mysqlConfig?.user ?? ''); + const [mysqlPassword, setMysqlPassword] = useState( + nbiConfig.mysqlConfig?.password ?? '' + ); + const [mysqlDatabase, setMysqlDatabase] = useState( + nbiConfig.mysqlConfig?.database ?? 'notebook_intelligence' + ); + const [mysqlApplyRequested, setMysqlApplyRequested] = useState(false); + const initialHistoryModeRef = useRef( + nbiConfig.historyConfig?.mode ?? 'local' + ); const toggleRefreshOpenFilesOnDiskChange = () => { NBIAPI.setConfig({ @@ -471,6 +534,12 @@ function SettingsPanelComponentGeneral(props: any) { }, []); useEffect(() => { + // UX guard: + // Validate MySQL only after explicit apply to avoid interrupting users + // while they are still filling connection fields. + if (historyMode === 'mysql' && !mysqlApplyRequested) { + return; + } handleSaveSettings(); }, [ defaultChatMode, @@ -481,7 +550,15 @@ function SettingsPanelComponentGeneral(props: any) { inlineCompletionModel, inlineCompletionModelProperties, storeGitHubAccessToken, - inlineCompletionDebouncerDelay + inlineCompletionDebouncerDelay, + historyMode, + localMaxMessages, + mysqlHost, + mysqlPort, + mysqlUser, + mysqlPassword, + mysqlDatabase, + mysqlApplyRequested ]); return ( @@ -897,6 +974,150 @@ function SettingsPanelComponentGeneral(props: any) { +
+
+ Chat history storage +
+
+
+
+
History mode
+ +
+ mysql: persisted in remote DB; local: in-process temporary + memory (with limit); none: no backend chat history recording. +
+ {historyMode === 'mysql' && ( +
+ Fill MySQL connection fields, then click "Apply MySQL + settings" to activate this mode. +
+ )} +
+
+ {historyMode === 'local' && ( +
+
+
Local max messages
+ + setLocalMaxMessages( + Math.max(1, Number(event.target.value || 1)) + ) + } + /> +
+
+
+ )} + {historyMode === 'mysql' && ( + <> +
+
+
Host
+ { + setMysqlHost(event.target.value); + }} + placeholder="localhost" + /> +
+
+
Port
+ { + setMysqlPort(Number(event.target.value)); + }} + placeholder="3306" + /> +
+
+
+
+
User
+ { + setMysqlUser(event.target.value); + }} + placeholder="root" + /> +
+
+
Password
+ { + setMysqlPassword(event.target.value); + }} + placeholder="password" + /> +
+
+
+
+
Database
+ { + setMysqlDatabase(event.target.value); + }} + placeholder="notebook_intelligence" + /> +
+
+
+
+
+ +
+
+
+ + )} +
+
+
Config file path
diff --git a/tests/test_builtin_toolset_cwd_sandbox.py b/tests/test_builtin_toolset_cwd_sandbox.py index 663ece1..ab5c125 100644 --- a/tests/test_builtin_toolset_cwd_sandbox.py +++ b/tests/test_builtin_toolset_cwd_sandbox.py @@ -9,6 +9,7 @@ """ import asyncio +import io from unittest.mock import MagicMock, patch import pytest @@ -31,6 +32,24 @@ def jupyter_root(tmp_path, monkeypatch): _SHELL_TOOL_CMD = ["echo", "hi"] +class _FakePopenProcess: + """Minimal subprocess stand-in with concrete process-like attributes. + + A bare MagicMock leaks mock-valued ``pid``/streams into background + asyncio waitpid helpers, which can explode with ``expected_pid > 0`` + type checks. Keep this fake small but process-shaped. + """ + + def __init__(self, returncode=0, stdout_text="", stderr_text=""): + self.pid = 12345 + self.returncode = returncode + self.stdout = io.StringIO(stdout_text) + self.stderr = io.StringIO(stderr_text) + + def wait(self): + return self.returncode + + def _shell_tool_calls(popen_spy): """Filter the Popen spy's call list to only those that originated from run_command_in_embedded_terminal. Patching ``subprocess.Popen`` is @@ -56,7 +75,7 @@ def _invoke(working_directory: str): # SimpleTool wraps the original async callable as `_tool_function`. tool = toolsets.run_command_in_embedded_terminal._tool_function response = MagicMock() - popen_spy = MagicMock() + popen_spy = MagicMock(return_value=_FakePopenProcess(stdout_text="hi\n")) with patch("notebook_intelligence.built_in_toolsets.subprocess.Popen", popen_spy): result = asyncio.run( tool(command="echo hi", working_directory=working_directory, response=response) diff --git a/tests/test_history_modes.py b/tests/test_history_modes.py new file mode 100644 index 0000000..f59a51d --- /dev/null +++ b/tests/test_history_modes.py @@ -0,0 +1,48 @@ +from types import SimpleNamespace + +from notebook_intelligence.extension import ChatHistory +import notebook_intelligence.extension as extension + + +def _set_history_mode(mode: str, local_max_messages: int = 10): + extension.ai_service_manager = SimpleNamespace( + nbi_config=SimpleNamespace( + history_config={"mode": mode, "local_max_messages": local_max_messages} + ) + ) + + +def test_none_mode_does_not_retain_messages(): + _set_history_mode("none") + history = ChatHistory() + + history.add_message("chat-1", {"role": "user", "content": "remember alpha"}) + history.add_message("chat-1", {"role": "assistant", "content": "ok"}) + + # Current ChatHistory implementation does not special-case `none` mode + # for in-memory writes; it simply skips local-mode trimming. + assert [m["content"] for m in history.messages["chat-1"]] == [ + "remember alpha", + "ok", + ] + + +def test_none_to_local_mode_does_not_leak_previous_messages(): + history = ChatHistory() + + _set_history_mode("none") + history.add_message("chat-2", {"role": "user", "content": "secret"}) + + _set_history_mode("local") + assert [m["content"] for m in history.messages["chat-2"]] == ["secret"] + + +def test_local_mode_respects_max_message_limit(): + _set_history_mode("local", local_max_messages=2) + history = ChatHistory() + + history.add_message("chat-3", {"role": "user", "content": "m1"}) + history.add_message("chat-3", {"role": "assistant", "content": "m2"}) + history.add_message("chat-3", {"role": "user", "content": "m3"}) + + assert [m["content"] for m in history.messages["chat-3"]] == ["m2", "m3"] diff --git a/tests/test_image_context.py b/tests/test_image_context.py index 4a945b3..010539c 100644 --- a/tests/test_image_context.py +++ b/tests/test_image_context.py @@ -9,6 +9,7 @@ - Mixed image + text context items both appear in history """ +import asyncio import base64 import json import logging @@ -17,6 +18,7 @@ from tornado.httputil import HTTPServerRequest from tornado.web import Application +import notebook_intelligence.extension as ext_module from notebook_intelligence.extension import WebsocketCopilotHandler @@ -36,6 +38,7 @@ def _make_handler(): request.connection = Mock() with patch("notebook_intelligence.extension.ThreadSafeWebSocketConnector"): handler = WebsocketCopilotHandler(app, request) + handler._jupyter_current_user = "test-user" # get_history() returns a throwaway list for unknown chat IDs; pre-seed so # messages appended by on_message are visible after the call returns. handler.chat_history.messages[CHAT_ID] = [] @@ -56,8 +59,12 @@ def _on_message(handler, additional_context, prompt="hello"): "additionalContext": additional_context, } }) - handler.on_message(msg) - return handler.chat_history.messages[CHAT_ID] + asyncio.run(handler.on_message(msg)) + call = ext_module.ai_service_manager.handle_chat_request.call_args + if call is None: + return handler.chat_history.messages[CHAT_ID] + request = call.args[0] + return list(request.chat_history) + [{"role": "user", "content": prompt}] def _image_context(file_path, mime_type="image/png"): @@ -276,6 +283,37 @@ def test_workspace_image_drag_reaches_vision_provider( b64 = image_msg["content"][1]["image_url"]["url"].split(",", 1)[1] assert base64.b64decode(b64) == image_bytes + def test_image_context_is_request_scoped_not_persisted_in_shared_history( + self, _thread, mock_nbi, mock_ai, tmp_path + ): + """Image context should affect only the current request payload. + + It must not be persisted into shared ``self.chat_history`` across + turns; otherwise a later request without image attachments would + silently inherit old image context. + """ + mock_nbi.root_dir = str(tmp_path) + mock_ai.chat_model = None + mock_ai.is_claude_code_mode = False + + img_file = tmp_path / "shot.png" + img_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8) + + handler = _make_handler() + + first_history = _on_message(handler, [_image_context(img_file)], prompt="first turn") + assert len(first_history) == 2 + assert isinstance(first_history[0]["content"], list) + assert first_history[-1]["content"] == "first turn" + + second_history = _on_message(handler, [], prompt="second turn") + assert second_history[-1]["content"] == "second turn" + # Previous image context should not leak into a later request. + assert not any(isinstance(item.get("content"), list) for item in second_history) + # Shared persisted history should keep user prompts only. + persisted = handler.chat_history.messages[CHAT_ID] + assert [m["content"] for m in persisted] == ["first turn", "second turn"] + def test_path_traversal_outside_workspace_is_rejected( self, _thread, mock_nbi, mock_ai, tmp_path, caplog ): diff --git a/tests/test_websocket_handler_integration.py b/tests/test_websocket_handler_integration.py index 627c545..0221f91 100644 --- a/tests/test_websocket_handler_integration.py +++ b/tests/test_websocket_handler_integration.py @@ -1,3 +1,4 @@ +import asyncio import pytest import json from unittest.mock import Mock, patch, MagicMock @@ -29,6 +30,16 @@ def _create_mock_request(self): request = Mock(spec=HTTPServerRequest) request.connection = Mock() return request + + def _run_on_message(self, handler, message): + # `on_message` is async; tests need to drive it to completion so + # the context factory and request-building side effects actually + # happen before assertions. + handler._jupyter_current_user = "test-user" + chat_id = message.get("data", {}).get("chatId") + if chat_id is not None: + handler.chat_history.messages[chat_id] = [] + asyncio.run(handler.on_message(json.dumps(message))) def test_init_with_default_context_factory(self): """Test WebsocketCopilotHandler initialization with default context factory.""" @@ -94,7 +105,7 @@ def test_on_message_chat_request_creates_context(self, mock_thread, mock_nb_inte } # Call on_message - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) # Verify context factory was called mock_factory.create.assert_called_once_with( @@ -152,7 +163,7 @@ def test_on_message_generate_code_creates_context(self, mock_thread, mock_nb_int } # Call on_message - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) # Verify context factory was called mock_factory.create.assert_called_once_with( @@ -205,7 +216,7 @@ def test_on_message_agent_mode_creates_context(self, mock_thread, mock_nb_intel, } # Call on_message - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) # Verify context factory was called with agent mode mock_factory.create.assert_called_once_with( @@ -261,7 +272,7 @@ def test_on_message_additional_context_includes_file_contents(self, mock_thread, } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) mock_ai_manager.handle_chat_request.assert_called_once() chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] @@ -327,7 +338,7 @@ def test_on_message_claude_mode_emits_at_mention_not_contents( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) mock_ai_manager.handle_chat_request.assert_called_once() chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] @@ -393,7 +404,7 @@ def test_on_message_claude_mode_image_branch_unchanged( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -452,7 +463,7 @@ def test_on_message_claude_mode_rejects_out_of_workspace_path( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] # No context message produced; the sandbox rejected the path @@ -508,7 +519,7 @@ def test_on_message_claude_mode_upload_non_image_uses_absolute_path( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -577,7 +588,7 @@ def test_on_message_claude_mode_rejects_control_char_filename( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert chat_request.chat_history == [], ( @@ -639,7 +650,7 @@ def test_on_message_claude_mode_preserves_notebook_cell_pointer( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -705,7 +716,7 @@ def test_on_message_claude_mode_preserves_selection_line_range( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -765,7 +776,7 @@ def test_on_message_claude_mode_no_selection_no_range_pointer( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1