Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions notebook_intelligence/ai_service_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from notebook_intelligence.api import ButtonData, ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, ChatParticipant, ChatRequest, ChatResponse, CompletionContext, ContextRequest, Host, CompletionContextProvider, MCPPrompt, MCPServer, MarkdownData, NotebookIntelligenceExtension, RegistrationError, TelemetryEvent, TelemetryListener, Tool, Toolset
from notebook_intelligence.base_chat_participant import BaseChatParticipant
from notebook_intelligence.config import NBIConfig
from notebook_intelligence.mysql_manager import MySQLManager
from notebook_intelligence.github_copilot_chat_participant import GithubCopilotChatParticipant
from notebook_intelligence.claude import CLAUDE_CODE_CHAT_PARTICIPANT_ID, ClaudeCodeChatParticipant, ClaudeCodeInlineCompletionModel, fetch_claude_models, get_claude_models
from notebook_intelligence.llm_providers.github_copilot_llm_provider import GitHubCopilotLLMProvider
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, options: Optional[dict] = None):
self._options.get("feature_policies") or {},
self._options.get("string_overrides") or {},
)
self._mysql_manager = MySQLManager(self._nbi_config.mysql_config)
self._openai_compatible_llm_provider = OpenAICompatibleLLMProvider()
self._litellm_compatible_llm_provider = LiteLLMCompatibleLLMProvider()
self._ollama_llm_provider = OllamaLLMProvider()
Expand Down Expand Up @@ -93,6 +95,14 @@ def __init__(self, options: Optional[dict] = None):
def nbi_config(self) -> NBIConfig:
return self._nbi_config

@property
def mysql_manager(self) -> MySQLManager:
return self._mysql_manager

def update_mysql_manager(self):
"""Refresh MySQL manager instance from current config."""
self._mysql_manager = MySQLManager(self._nbi_config.mysql_config)

@property
def ollama_llm_provider(self) -> OllamaLLMProvider:
return self._ollama_llm_provider
Expand Down
30 changes: 30 additions & 0 deletions notebook_intelligence/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class ChatRequest:
cancel_token: CancelToken = None
# NEW: Add context for rule evaluation
rule_context: Optional[RuleContext] = None
# NEW: Internal conversation id for logging
conversation_id: str = None

@dataclass
class ResponseStreamData:
Expand Down Expand Up @@ -633,6 +635,13 @@ async def _tool_call_loop(tool_call_rounds: list):
if message.get('tool_calls', None) is not None:
for tool_call in message['tool_calls']:
tool_call_rounds.append(tool_call)

# MySQL logging: Log assistant message with tool calls
if request.conversation_id:
msg_id = str(uuid.uuid4())
request.host.mysql_manager.add_message(
msg_id, request.conversation_id, "assistant", content, reasoning, message['tool_calls']
)

messages.append(message)

Expand Down Expand Up @@ -662,6 +671,12 @@ async def _tool_call_loop(tool_call_rounds: list):
args = tool_call['function']['arguments']
else:
args = fuzzy_json_loads(tool_call['function']['arguments'])

# MySQL logging: Log tool execution START (initial record)
if request.conversation_id:
request.host.mysql_manager.log_tool_execution(
tool_call['id'], request.conversation_id, tool_name, args, ""
)

tool_properties = tool_to_call.schema["function"]["parameters"]["properties"]
if type(args) is str:
Expand All @@ -688,6 +703,17 @@ async def _tool_call_loop(tool_call_rounds: list):
return

tool_call_response = await tool_to_call.handle_tool_call(request, response, tool_context, args)

# MySQL logging: Update tool execution with result
if request.conversation_id:
request.host.mysql_manager.log_tool_execution(
tool_call['id'], request.conversation_id, tool_name, args, str(tool_call_response)
)
# Also log the tool message itself
msg_id = str(uuid.uuid4())
request.host.mysql_manager.add_message(
msg_id, request.conversation_id, "tool", str(tool_call_response), tool_call_id=tool_call['id']
)

function_call_result_message = {
"role": "tool",
Expand Down Expand Up @@ -946,6 +972,10 @@ def get_skill_manager(self):
def websocket_connector(self) -> ThreadSafeWebSocketConnector:
raise NotImplementedError

@property
def mysql_manager(self) -> Any:
return NotImplementedError


class NotebookIntelligenceExtension:
@property
Expand Down
29 changes: 29 additions & 0 deletions notebook_intelligence/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,35 @@ def active_rules(self) -> dict:
"""Get dictionary of active rule states (filename -> bool)."""
return self.get('active_rules', {})

@property
def mysql_config(self) -> dict:
"""Get MySQL configuration."""
return self.get('mysql_config', {
'enabled': False,
'host': 'localhost',
'port': 3306,
'user': '',
'password': '',
'database': 'notebook_intelligence'
})

@property
def history_config(self) -> dict:
"""Get chat history storage configuration."""
cfg = self.get('history_config', {})
mode = cfg.get('mode', 'local')
local_max_messages = cfg.get('local_max_messages', 10)
try:
local_max_messages = int(local_max_messages)
except Exception:
local_max_messages = 10
if local_max_messages < 1:
local_max_messages = 1
return {
'mode': mode if mode in ['mysql', 'local', 'none'] else 'local',
'local_max_messages': local_max_messages
}

def set_rule_active(self, filename: str, active: bool):
"""Set the active state of a rule."""
active_rules = self.active_rules.copy()
Expand Down
Loading
Loading