Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ai_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# Pattern-matching backend
# ---------------------------------------------------------------------------


class SimpleBot:
"""Offline chatbot backed by a CSV dialog dataset."""

Expand Down
31 changes: 23 additions & 8 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@
# Per-session bots (keyed by session_id string)
_sessions: dict[str, ChatBot] = {}

# Dialog file used when creating new sessions (updated by /train)
_default_dialog_file: Optional[str] = None


def clear_all_sessions() -> None:
"""Remove all active sessions. Intended for use in tests."""
"""Remove all active sessions and reset the default dialog file. Intended for use in tests."""
global _default_dialog_file
_sessions.clear()
_default_dialog_file = None


def _get_or_create_session(session_id: Optional[str]) -> tuple[str, ChatBot]:
Expand All @@ -59,7 +64,14 @@ def _get_or_create_session(session_id: Optional[str]) -> tuple[str, ChatBot]:
session_id = str(uuid.uuid4())

if session_id not in _sessions:
_sessions[session_id] = ChatBot()
# When a dialog file has been explicitly configured via /train, create
# the bot in offline mode so the file is always used (no Kaggle fallback).
if _default_dialog_file is not None:
_sessions[session_id] = ChatBot(
dialog_file=_default_dialog_file, force_offline=True
)
else:
_sessions[session_id] = ChatBot()

return session_id, _sessions[session_id]

Expand Down Expand Up @@ -127,21 +139,24 @@ def reset_session(session_id: str):
def train(request: TrainRequest):
"""Reload pattern-matching data from a CSV file on the server.

The file must exist on the server filesystem. This endpoint only
affects future sessions created after the reload (existing sessions
keep their current bot instance).
The file must exist on the server filesystem. All existing sessions are
cleared so that subsequent requests create new sessions using the updated
training data.
"""
if not os.path.exists(request.dialog_file):
raise HTTPException(
status_code=404,
detail=f"File not found: {request.dialog_file}",
detail="Training file not found on server. Please verify the file path.",
)

# Retrain a fresh bot and store it as the template for new sessions
# Retrain a fresh bot to validate and count the patterns
fresh_bot = ChatBot(dialog_file=request.dialog_file, force_offline=True)
patterns = fresh_bot.pattern_count

# Clear all existing sessions so next requests pick up new data
# Persist the dialog file so new sessions created after this point use it,
# then clear existing sessions so they pick up the new data on next request.
global _default_dialog_file
_default_dialog_file = request.dialog_file
_sessions.clear()

return TrainResponse(status="retrained", patterns_loaded=patterns)
31 changes: 30 additions & 1 deletion test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import tempfile
import pytest
from fastapi.testclient import TestClient

# ---------------------------------------------------------------------------
# Fixtures
Expand Down Expand Up @@ -216,3 +215,33 @@ def test_train_missing_file(self, api_client):
client, _ = api_client
resp = client.post("/train", json={"dialog_file": "/nonexistent/file.csv"})
assert resp.status_code == 404

def test_train_affects_new_chat_session(self, api_client):
client, _ = api_client

new_dialog_content = (
"dialog_id,line_id,text\n"
"1,1,hello\n"
"1,2,Hello from NEW dialog!\n"
)
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False)
try:
tmp.write(new_dialog_content)
tmp.flush()
tmp_path = tmp.name
finally:
tmp.close()

try:
# Retrain with the new dialog file
resp = client.post("/train", json={"dialog_file": tmp_path})
assert resp.status_code == 200
assert resp.json()["patterns_loaded"] > 0

# A brand-new session (no session_id) should use the new patterns
chat_resp = client.post("/chat", json={"message": "hello"})
assert chat_resp.status_code == 200
assert "Hello from NEW dialog" in chat_resp.json()["reply"]
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
9 changes: 1 addition & 8 deletions web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
OPENAI_API_KEY Set this to enable the LLM backend (optional).
"""

import os

import streamlit as st

# Load .env if available
Expand Down Expand Up @@ -39,15 +37,10 @@
# ---------------------------------------------------------------------------
# Session-level chatbot instance
# ---------------------------------------------------------------------------
@st.cache_resource
def _get_bot_factory() -> ChatBot:
"""Return a template bot (loads data once); actual per-session bots copy from this."""
return ChatBot()


# Per-session bot stored in session_state so each browser tab/user gets its own history
if "bot" not in st.session_state:
st.session_state.bot = _get_bot_factory()
st.session_state.bot = ChatBot()

bot: ChatBot = st.session_state.bot

Expand Down