diff --git a/ai_chatbot.py b/ai_chatbot.py index baf17bd..6f40a21 100644 --- a/ai_chatbot.py +++ b/ai_chatbot.py @@ -26,6 +26,7 @@ # Pattern-matching backend # --------------------------------------------------------------------------- + class SimpleBot: """Offline chatbot backed by a CSV dialog dataset.""" diff --git a/api.py b/api.py index 42b2689..7865b99 100644 --- a/api.py +++ b/api.py @@ -46,10 +46,15 @@ # Per-session bots (keyed by session_id string) _sessions: dict[str, ChatBot] = {} +# Dialog file used when creating new sessions (updated by /train) +_default_dialog_file: Optional[str] = None + def clear_all_sessions() -> None: - """Remove all active sessions. Intended for use in tests.""" + """Remove all active sessions and reset the default dialog file. Intended for use in tests.""" + global _default_dialog_file _sessions.clear() + _default_dialog_file = None def _get_or_create_session(session_id: Optional[str]) -> tuple[str, ChatBot]: @@ -59,7 +64,14 @@ def _get_or_create_session(session_id: Optional[str]) -> tuple[str, ChatBot]: session_id = str(uuid.uuid4()) if session_id not in _sessions: - _sessions[session_id] = ChatBot() + # When a dialog file has been explicitly configured via /train, create + # the bot in offline mode so the file is always used (no Kaggle fallback). + if _default_dialog_file is not None: + _sessions[session_id] = ChatBot( + dialog_file=_default_dialog_file, force_offline=True + ) + else: + _sessions[session_id] = ChatBot() return session_id, _sessions[session_id] @@ -127,21 +139,24 @@ def reset_session(session_id: str): def train(request: TrainRequest): """Reload pattern-matching data from a CSV file on the server. - The file must exist on the server filesystem. This endpoint only - affects future sessions created after the reload (existing sessions - keep their current bot instance). + The file must exist on the server filesystem. All existing sessions are + cleared so that subsequent requests create new sessions using the updated + training data. """ if not os.path.exists(request.dialog_file): raise HTTPException( status_code=404, - detail=f"File not found: {request.dialog_file}", + detail="Training file not found on server. Please verify the file path.", ) - # Retrain a fresh bot and store it as the template for new sessions + # Retrain a fresh bot to validate and count the patterns fresh_bot = ChatBot(dialog_file=request.dialog_file, force_offline=True) patterns = fresh_bot.pattern_count - # Clear all existing sessions so next requests pick up new data + # Persist the dialog file so new sessions created after this point use it, + # then clear existing sessions so they pick up the new data on next request. + global _default_dialog_file + _default_dialog_file = request.dialog_file _sessions.clear() return TrainResponse(status="retrained", patterns_loaded=patterns) diff --git a/test_chatbot.py b/test_chatbot.py index fb814a2..4c4c07f 100644 --- a/test_chatbot.py +++ b/test_chatbot.py @@ -8,7 +8,6 @@ import os import tempfile import pytest -from fastapi.testclient import TestClient # --------------------------------------------------------------------------- # Fixtures @@ -216,3 +215,33 @@ def test_train_missing_file(self, api_client): client, _ = api_client resp = client.post("/train", json={"dialog_file": "/nonexistent/file.csv"}) assert resp.status_code == 404 + + def test_train_affects_new_chat_session(self, api_client): + client, _ = api_client + + new_dialog_content = ( + "dialog_id,line_id,text\n" + "1,1,hello\n" + "1,2,Hello from NEW dialog!\n" + ) + tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) + try: + tmp.write(new_dialog_content) + tmp.flush() + tmp_path = tmp.name + finally: + tmp.close() + + try: + # Retrain with the new dialog file + resp = client.post("/train", json={"dialog_file": tmp_path}) + assert resp.status_code == 200 + assert resp.json()["patterns_loaded"] > 0 + + # A brand-new session (no session_id) should use the new patterns + chat_resp = client.post("/chat", json={"message": "hello"}) + assert chat_resp.status_code == 200 + assert "Hello from NEW dialog" in chat_resp.json()["reply"] + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/web_demo.py b/web_demo.py index d141b9f..2909f04 100644 --- a/web_demo.py +++ b/web_demo.py @@ -8,8 +8,6 @@ OPENAI_API_KEY Set this to enable the LLM backend (optional). """ -import os - import streamlit as st # Load .env if available @@ -39,15 +37,10 @@ # --------------------------------------------------------------------------- # Session-level chatbot instance # --------------------------------------------------------------------------- -@st.cache_resource -def _get_bot_factory() -> ChatBot: - """Return a template bot (loads data once); actual per-session bots copy from this.""" - return ChatBot() - # Per-session bot stored in session_state so each browser tab/user gets its own history if "bot" not in st.session_state: - st.session_state.bot = _get_bot_factory() + st.session_state.bot = ChatBot() bot: ChatBot = st.session_state.bot