diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..283b591d --- /dev/null +++ b/.dockerignore @@ -0,0 +1,21 @@ +.git +.gitignore +.venv +__pycache__ +*.pyc +*.pyo +*.pyd +.pytest_cache +.mypy_cache +.vscode +.idea + +backend/.env +backend/tests +frontend +docs +example_uploads + +node_modules +npm-debug.log* +yarn-error.log* diff --git a/backend/.env.example b/backend/.env.example index 30864721..ad75bce2 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -20,6 +20,15 @@ LANGSMITH_TRACING=false # 如需覆盖运行时输出的根目录,请设置 APP_RUNTIME_DIR(绝对路径) # APP_RUNTIME_DIR= +# ============================================================================ +# PostgreSQL / pgvector(推荐) +# ============================================================================ +# 开发环境可直接使用 pgsql.txt 中的 Docker 启动方式,然后把连接串指向容器。 +# 默认建议创建独立数据库,如 ctb;若直接使用 postgres 默认库,也可改为 /postgres。 +# APP_DATABASE_URL=postgresql+psycopg://postgres:root@localhost:15432/ctb +# APP_DATABASE_ECHO=false +# APP_POSTGRES_VECTOR_DIMENSIONS=1536 + # ============================================================================ # 文字擦除模型权重(可选) # ============================================================================ @@ -45,6 +54,14 @@ LANGSMITH_TRACING=false # APP_DEFAULT_PADDLEOCR_USE_DOC_UNWARPING=false # APP_DEFAULT_PADDLEOCR_USE_CHART_RECOGNITION=false +# ============================================================================ +# RAG 语义检索配置(可选) +# +# Embedding 模型复用 APP_DEFAULT_OPENAI_* 的 key 和 base_url。 +# 如果未配置 OpenAI provider,RAG 将降级为不生成向量,语义搜索自动回退到 hash 模糊匹配。 +# ============================================================================ +# APP_RAG_EMBEDDING_MODEL=text-embedding-3-small + # ============================================================================ # 验证码邮件 — SMTP 配置(注册验证码 + 找回密码共用) # diff --git a/backend/agents/error_correction/agent.py b/backend/agents/error_correction/agent.py index 7527cd67..8396f030 100644 --- a/backend/agents/error_correction/agent.py +++ b/backend/agents/error_correction/agent.py @@ -8,6 +8,21 @@ import re import threading import time + +# ============================================================ +# Monkeypatch: 修复 langgraph 依赖冲突 +# ============================================================ +try: + import langgraph.runtime + if not hasattr(langgraph.runtime, 'ExecutionInfo'): + class ExecutionInfo: pass + langgraph.runtime.ExecutionInfo = ExecutionInfo + if not hasattr(langgraph.runtime, 'ServerInfo'): + class ServerInfo: pass + langgraph.runtime.ServerInfo = ServerInfo +except ImportError: + pass + from langchain.agents import create_agent from langchain.agents.structured_output import ToolStrategy from langchain_core.messages import SystemMessage, HumanMessage diff --git a/backend/benchmark/recall.py b/backend/benchmark/recall.py new file mode 100644 index 00000000..1520eca4 --- /dev/null +++ b/backend/benchmark/recall.py @@ -0,0 +1,359 @@ +import argparse +import json +import logging +import random +from dataclasses import dataclass + +from sqlalchemy.orm import joinedload + +from core.config import settings +from db import SessionLocal, init_db +from db.models import Question, QuestionTagMapping, UploadBatch + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class QueryItem: + question_id: int + user_id: int | None + project_id: int | None + subject: str + tags: frozenset[str] + query_text: str + + +def _safe_json_loads(value, fallback): + if not value: + return fallback + try: + return json.loads(value) + except (TypeError, ValueError): + return fallback + + +def _blocks_text(content_json: str) -> str: + blocks = _safe_json_loads(content_json, []) + parts: list[str] = [] + for block in blocks or []: + if isinstance(block, dict): + content = (block.get("content") or block.get("text") or "").strip() + if content: + parts.append(content) + elif isinstance(block, str): + content = block.strip() + if content: + parts.append(content) + return "\n".join(parts).strip() + + +def _options_text(options_json: str) -> str: + options = _safe_json_loads(options_json, []) + parts: list[str] = [] + if isinstance(options, dict): + options = list(options.values()) + for opt in options or []: + if isinstance(opt, dict): + label = str(opt.get("label") or "").strip() + content = str(opt.get("content") or "").strip() + if not (label or content): + continue + parts.append(f"{label}. {content}" if label else content) + elif isinstance(opt, str): + content = opt.strip() + if content: + parts.append(content) + return "\n".join(parts).strip() + + +def _build_query_text(q: Question) -> str: + subject = q.batch.subject if q.batch else "" + tags: list[str] = [] + for mapping in q.tags or []: + if mapping.tag and mapping.tag.tag_name: + tags.append(mapping.tag.tag_name) + content_text = _blocks_text(q.content_json) + options_text = _options_text(q.options_json) + parts: list[str] = [] + if subject: + parts.append(f"科目:{subject}") + if q.question_type: + parts.append(f"题型:{q.question_type}") + if tags: + parts.append(f"知识点:{'、'.join(tags)}") + if content_text: + parts.append(f"题干:{content_text}") + if options_text: + parts.append(f"选项:{options_text}") + text = "\n".join(parts).strip() + return text[:4000] + + +def _load_queries(*, limit: int, seed: int | None) -> list[QueryItem]: + with SessionLocal() as db: + questions = ( + db.query(Question) + .options(joinedload(Question.batch), joinedload(Question.tags).joinedload(QuestionTagMapping.tag)) + .order_by(Question.id.desc()) + .all() + ) + + items: list[QueryItem] = [] + if questions: + for q in questions: + subject = q.batch.subject if q.batch else "" + tags = [] + for mapping in q.tags or []: + if mapping.tag and mapping.tag.tag_name: + tags.append(mapping.tag.tag_name) + items.append( + QueryItem( + question_id=q.id, + user_id=q.user_id, + project_id=q.project_id, + subject=subject or "", + tags=frozenset(tags), + query_text=_build_query_text(q), + ) + ) + else: + runs_dir = settings.runs_dir + candidates = sorted(runs_dir.glob("**/questions.json"), key=lambda p: p.stat().st_mtime, reverse=True) + latest = candidates[0] if candidates else None + if latest and latest.exists(): + payload = _safe_json_loads(latest.read_text(encoding="utf-8"), []) + if isinstance(payload, list): + for row in payload: + if not isinstance(row, dict): + continue + qid_raw = row.get("question_id") + try: + qid = int(qid_raw) + except (TypeError, ValueError): + qid = len(items) + 1 + tags = row.get("knowledge_tags") or [] + if not isinstance(tags, list): + tags = [] + content_blocks = row.get("content_blocks") or [] + options = row.get("options") or [] + content_json = json.dumps(content_blocks, ensure_ascii=False) + options_json = json.dumps(options, ensure_ascii=False) + q_stub = Question( + id=qid, + user_id=None, + project_id=None, + question_type=row.get("question_type") or "", + content_json=content_json, + options_json=options_json, + ) + q_stub.batch = UploadBatch(subject="") + q_stub.tags = [] + items.append( + QueryItem( + question_id=qid, + user_id=None, + project_id=None, + subject="", + tags=frozenset(str(t).strip() for t in tags if str(t).strip()), + query_text=_build_query_text(q_stub), + ) + ) + + if seed is not None: + rnd = random.Random(seed) + rnd.shuffle(items) + if limit > 0: + items = items[:limit] + return items + + +def _group_relevance(items: list[QueryItem]) -> dict[tuple[int | None, int | None], dict[str, set[int]]]: + by_scope: dict[tuple[int | None, int | None], dict[str, set[int]]] = {} + for item in items: + key = (item.user_id, item.project_id) + tag_map = by_scope.setdefault(key, {}) + for tag in item.tags: + tag_map.setdefault(tag, set()).add(item.question_id) + return by_scope + + +def _relevant_ids( + item: QueryItem, + *, + tag_index: dict[tuple[int | None, int | None], dict[str, set[int]]], +) -> set[int]: + key = (item.user_id, item.project_id) + tag_map = tag_index.get(key, {}) + relevant: set[int] = set() + for tag in item.tags: + relevant.update(tag_map.get(tag, set())) + relevant.discard(item.question_id) + return relevant + + +def _hash_retrieve_ids(*, query_text: str, user_id: int | None, project_id: int | None, top_k: int) -> list[int]: + from db import crud + + with SessionLocal() as db: + matches = crud.find_questions_by_natural_language( + db, + query_text=query_text, + limit=top_k, + user_id=user_id, + project_id=project_id, + ) + ids: list[int] = [] + for match in matches: + q = match.get("question") + if q and getattr(q, "id", None) is not None: + ids.append(q.id) + return ids + + +def _hash_retrieve_ids_in_memory( + *, + query_index: int, + items: list[QueryItem], + vectors: list[list[float]], + top_k: int, +) -> list[int]: + query_vec = vectors[query_index] + scored: list[tuple[int, float]] = [] + for idx, item in enumerate(items): + if idx == query_index: + continue + score = sum(a * b for a, b in zip(query_vec, vectors[idx])) + scored.append((item.question_id, score)) + scored.sort(key=lambda x: x[1], reverse=True) + return [qid for qid, _ in scored[:top_k]] + + +def _rag_retrieve_ids(*, query_text: str, user_id: int | None, project_id: int | None, top_k: int) -> list[int] | None: + try: + from core.rag import retrieve_context + except Exception as e: + logger.warning("RAG 模块不可用: %s", e) + return None + + with SessionLocal() as db: + results = retrieve_context( + db, + query=query_text, + user_id=user_id, + project_id=project_id, + top_k=top_k, + ) + return [int(r["source_id"]) for r in results if r.get("source_id") is not None] + + +def _evaluate( + items: list[QueryItem], + *, + top_ks: list[int], + method: str, +) -> dict: + tag_index = _group_relevance(items) + max_k = max(top_ks) + + use_in_memory_hash = False + hash_vectors: list[list[float]] = [] + if method == "hash": + with SessionLocal() as db: + has_db_questions = bool(db.query(Question.id).limit(1).first()) + if not has_db_questions: + from db.crud.questions import _hash_embedding + use_in_memory_hash = True + hash_vectors = [_hash_embedding(item.query_text) for item in items] + + valid = 0 + skipped_no_relevant = 0 + hit_counts = {k: 0 for k in top_ks} + recall_sums = {k: 0.0 for k in top_ks} + + for idx, item in enumerate(items): + relevant = _relevant_ids(item, tag_index=tag_index) + if not relevant: + skipped_no_relevant += 1 + continue + + if method == "hash": + if use_in_memory_hash: + retrieved = _hash_retrieve_ids_in_memory( + query_index=idx, + items=items, + vectors=hash_vectors, + top_k=max_k, + ) + else: + retrieved = _hash_retrieve_ids( + query_text=item.query_text, + user_id=item.user_id, + project_id=item.project_id, + top_k=max_k, + ) + elif method == "rag": + retrieved = _rag_retrieve_ids( + query_text=item.query_text, + user_id=item.user_id, + project_id=item.project_id, + top_k=max_k, + ) + if retrieved is None: + return { + "method": method, + "error": "RAG 不可用(导入失败或未配置 embedding)", + } + else: + raise ValueError(f"unknown method: {method}") + + valid += 1 + for k in top_ks: + topk = set(retrieved[:k]) + hit = bool(topk & relevant) + if hit: + hit_counts[k] += 1 + recall_sums[k] += (len(topk & relevant) / len(relevant)) + + report = { + "method": method, + "total_queries": len(items), + "evaluated_queries": valid, + "skipped_no_relevant": skipped_no_relevant, + "top_ks": top_ks, + "hit_at_k": {}, + "recall_at_k": {}, + } + if valid == 0: + return {**report, "error": "没有可评测的样本(可能是没有知识点标签或只有单题)"} + + for k in top_ks: + report["hit_at_k"][str(k)] = round(hit_counts[k] / valid, 4) + report["recall_at_k"][str(k)] = round(recall_sums[k] / valid, 4) + return report + + +def main(): + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") + + parser = argparse.ArgumentParser(description="Recall@K(召回率)评测") + parser.add_argument("--method", choices=["hash", "rag"], default="hash") + parser.add_argument("--k", default="1,3,5,10", help="K 值列表,如 1,3,5,10") + parser.add_argument("--limit", type=int, default=200, help="最多评测多少条 query(0 表示全部)") + parser.add_argument("--seed", type=int, default=42, help="随机采样种子") + args = parser.parse_args() + + init_db() + + top_ks = sorted({int(x) for x in str(args.k).split(",") if str(x).strip().isdigit() and int(x) > 0}) + if not top_ks: + raise SystemExit("k 参数非法") + + items = _load_queries(limit=args.limit, seed=args.seed) + report = _evaluate(items, top_ks=top_ks, method=args.method) + + print(f"DB: {settings.database_url}") + print(json.dumps(report, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/backend/check_pg.py b/backend/check_pg.py new file mode 100644 index 00000000..2e9690cb --- /dev/null +++ b/backend/check_pg.py @@ -0,0 +1,20 @@ +import psycopg +from core.config import settings + +def check_pg(): + try: + # Try to connect to the local PG on 5432 first as per original .env + conn_str = "postgresql://postgres:123456@localhost:5432/ctb" + with psycopg.connect(conn_str) as conn: + with conn.cursor() as cur: + cur.execute("SELECT * FROM pg_available_extensions WHERE name = 'vector';") + result = cur.fetchone() + if result: + print(f"FOUND: {result}") + else: + print("NOT FOUND: vector extension not available in pg_available_extensions") + except Exception as e: + print(f"ERROR: {e}") + +if __name__ == "__main__": + check_pg() diff --git a/backend/core/config.py b/backend/core/config.py index 9876d50b..3192cd66 100644 --- a/backend/core/config.py +++ b/backend/core/config.py @@ -21,7 +21,10 @@ _BACKEND_ROOT = Path(__file__).resolve().parent.parent # backend/core/ → backend/ _PROJECT_ROOT = _BACKEND_ROOT.parent -_ENV_FILE = _PROJECT_ROOT / ".env" +_ENV_FILES = ( + _BACKEND_ROOT / ".env", + _PROJECT_ROOT / ".env", +) # --------------------------------------------------------------------------- @@ -187,7 +190,7 @@ def create_llm(self, *, model: str, temperature: float): class Settings(BaseSettings): model_config = SettingsConfigDict( env_prefix="APP_", - env_file=_ENV_FILE, + env_file=_ENV_FILES, extra="ignore", ) @@ -200,6 +203,8 @@ class Settings(BaseSettings): # 错题库数据库路径,可通过 APP_DB_PATH 覆盖 db_path: Path | None = None + database_url: str = "" + database_echo: bool = False # 各类子目录(由 validator 从 runtime_dir 派生,可独立覆盖以便测试) upload_dir: Path | None = None @@ -220,6 +225,10 @@ class Settings(BaseSettings): True # 是否信任系统代理环境变量,Windows 下设为 False 可解 WinError 10054 ) + rag_embedding_model: str = "text-embedding-v3" + rag_embedding_api_key: str = "" + rag_embedding_base_url: str = "" + # 注册验证码邮件(SMTP,环境变量前缀仍为 APP_,如 APP_SMTP_HOST) smtp_host: str = "" smtp_port: int = 587 @@ -258,6 +267,8 @@ class Settings(BaseSettings): def _resolve_defaults(self): if self.db_path is None: self.db_path = self.runtime_dir / "error_book.db" + if not self.database_url: + self.database_url = f"sqlite:///{self.db_path}" if self.upload_dir is None: self.upload_dir = self.runtime_dir / "uploads" if self.pages_dir is None: diff --git a/backend/core/rag.py b/backend/core/rag.py new file mode 100644 index 00000000..10a1cf61 --- /dev/null +++ b/backend/core/rag.py @@ -0,0 +1,536 @@ +""" +RAG 核心模块 — 错题库语义检索 + +提供 embedding 生成、索引构建和向量检索能力。 +第一版使用 SQLite 存储向量,纯 Python 计算余弦相似度。 +""" + +import json +import logging +from typing import Optional + +from sqlalchemy import bindparam, text +from sqlalchemy.orm import Session + +from db import is_postgresql_backend +from db.models import Question, RagDocumentChunk, QuestionTagMapping, KnowledgeTag, UploadBatch + +logger = logging.getLogger(__name__) + + +def _serialize_vector(vector: Optional[list[float]]) -> Optional[str]: + if vector is None: + return None + return json.dumps(vector, ensure_ascii=False, separators=(",", ":")) + + +def _write_postgres_vector(db: Session, chunk_id: int, vector: Optional[list[float]]): + if not is_postgresql_backend(db.bind): + return + if vector is None: + db.execute( + text( + "UPDATE rag_document_chunks " + "SET embedding_vector = NULL " + "WHERE id = :chunk_id" + ), + {"chunk_id": chunk_id}, + ) + return + db.execute( + text( + "UPDATE rag_document_chunks " + "SET embedding_vector = CAST(:vector AS vector) " + "WHERE id = :chunk_id" + ), + { + "chunk_id": chunk_id, + "vector": _serialize_vector(vector), + }, + ) + + +# --------------------------------------------------------------------------- +# 文本构建 +# --------------------------------------------------------------------------- + +def _extract_text_from_blocks(content_json: str) -> str: + """从 content_json 提取纯文本""" + if not content_json: + return "" + try: + blocks = json.loads(content_json) if isinstance(content_json, str) else content_json + except (json.JSONDecodeError, TypeError): + return "" + parts = [] + for block in blocks: + if isinstance(block, dict): + text = block.get("content", "").strip() + if text: + parts.append(text) + elif isinstance(block, str): + parts.append(block.strip()) + return "\n".join(parts) + + +def _extract_options_text(options_json: str) -> str: + """从 options_json 提取选项文本""" + if not options_json: + return "" + try: + options = json.loads(options_json) if isinstance(options_json, str) else options_json + except (json.JSONDecodeError, TypeError): + return "" + if not options: + return "" + parts = [] + for opt in options: + if isinstance(opt, dict): + label = opt.get("label", "") + content = opt.get("content", "") + if label or content: + parts.append(f"{label}. {content}" if label else content) + elif isinstance(opt, str): + parts.append(opt) + return "\n".join(parts) + + +def build_question_chunks(question: Question) -> list[dict]: + """将 Question 转换为可索引的文本块列表 + + Args: + question: Question ORM 对象(需要已加载 batch 和 tags 关系) + + Returns: + [{"content": str, "metadata": dict, "content_hash": str}, ...] + """ + content_text = _extract_text_from_blocks(question.content_json) + options_text = _extract_options_text(question.options_json) + answer_text = (question.answer or "").strip() + user_answer_text = (question.user_answer or "").strip() + + subject = question.batch.subject if question.batch else "" + question_type = question.question_type or "" + tags = [] + if question.tags: + for mapping in question.tags: + if mapping.tag: + tags.append(mapping.tag.tag_name) + + parts = [] + if subject: + parts.append(f"科目:{subject}") + if question_type: + parts.append(f"题型:{question_type}") + if tags: + parts.append(f"知识点:{'、'.join(tags)}") + if content_text: + parts.append(f"题干:{content_text}") + if options_text: + parts.append(f"选项:{options_text}") + if answer_text: + parts.append(f"答案:{answer_text}") + if user_answer_text: + parts.append(f"用户作答:{user_answer_text}") + + chunk_content = "\n".join(parts) + + if len(chunk_content) > 8000: + chunk_content = chunk_content[:8000] + + metadata = { + "subject": subject, + "question_type": question_type, + "tags": tags, + } + + return [{ + "content": chunk_content, + "metadata": metadata, + "content_hash": question.content_hash or "", + }] + + +# --------------------------------------------------------------------------- +# Embedding 生成 +# --------------------------------------------------------------------------- + +def _get_embedding_client(): + """获取 OpenAI 兼容的 embedding 客户端 + + 优先使用专门的 RAG Embedding 配置,若未配置则尝试复用 OpenAI provider。 + """ + from core.config import settings + + # 1. 优先尝试独立配置的 RAG Embedding 凭据 + api_key = settings.rag_embedding_api_key + base_url = settings.rag_embedding_base_url + + # 2. 如果没有独立配置,尝试回退到默认的 OpenAI provider + if not api_key: + try: + provider = settings.get_provider("openai") + if provider.configured: + api_key = provider.api_key + # 仅在未显式指定时复用 provider 的 base_url + if not base_url: + base_url = provider.base_url + except (ValueError, KeyError): + pass + + if not api_key: + return None, None + + try: + from openai import OpenAI + import httpx + + kwargs = {"api_key": api_key, "timeout": 30} + if base_url: + # 兼容性处理:如果 base_url 指向 DeepSeek,则自动置空(DeepSeek 无 embedding 接口) + # 除非用户显式在 APP_RAG_EMBEDDING_BASE_URL 中指定了它 + if "deepseek.com" in base_url.lower() and not settings.rag_embedding_base_url: + logger.warning("检测到 OpenAI Base URL 指向 DeepSeek,已自动跳过其 Embedding 调用") + return None, None + kwargs["base_url"] = base_url + + if settings.trust_env: + kwargs["http_client"] = httpx.Client(trust_env=True) + + client = OpenAI(**kwargs) + return client, settings.rag_embedding_model + except Exception as e: + logger.warning("创建 embedding 客户端失败: %s", e) + return None, None + + +def embed_texts(texts: list[str], batch_size: int = 100) -> list[Optional[list[float]]]: + """批量生成 embedding 向量 + + Args: + texts: 文本列表 + batch_size: 每批处理数量(embedding API 通常限制 100) + + Returns: + 与 texts 等长的列表,每个元素是浮点数列表或 None(失败时) + """ + if not texts: + return [] + + client, model = _get_embedding_client() + if client is None: + logger.warning("embedding 客户端未配置,跳过向量生成") + return [None] * len(texts) + + results: list[Optional[list[float]]] = [None] * len(texts) + + for start in range(0, len(texts), batch_size): + batch = texts[start:start + batch_size] + try: + response = client.embeddings.create(model=model, input=batch) + for i, item in enumerate(response.data): + results[start + i] = item.embedding + except Exception as e: + logger.error("embedding API 调用失败 (batch %d-%d): %s", start, start + len(batch), e) + + return results + + +# --------------------------------------------------------------------------- +# 余弦相似度 +# --------------------------------------------------------------------------- + +def cosine_similarity(a: list[float], b: list[float]) -> float: + """计算两个向量的余弦相似度""" + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +# --------------------------------------------------------------------------- +# 索引操作 +# --------------------------------------------------------------------------- + +def index_question(db: Session, question_id: int) -> bool: + """为单道错题建立或刷新 RAG 索引 + + Args: + db: 数据库会话 + question_id: 题目 ID + + Returns: + True 如果成功索引(或已是最新),False 如果失败 + """ + from sqlalchemy.orm import joinedload + + question = ( + db.query(Question) + .options( + joinedload(Question.batch), + joinedload(Question.tags).joinedload(QuestionTagMapping.tag), + ) + .filter(Question.id == question_id) + .first() + ) + if not question: + return False + + chunks = build_question_chunks(question) + if not chunks: + return False + + chunk_data = chunks[0] + + existing = ( + db.query(RagDocumentChunk) + .filter( + RagDocumentChunk.source_type == "question", + RagDocumentChunk.source_id == question_id, + RagDocumentChunk.chunk_index == 0, + ) + .first() + ) + + # 检查是否已有相同内容的索引 + has_vector = False + if existing: + # 检查 PostgreSQL 向量列是否真的有值 + if is_postgresql_backend(db.bind): + res = db.execute( + text("SELECT 1 FROM rag_document_chunks WHERE id = :id AND embedding_vector IS NOT NULL"), + {"id": existing.id} + ).first() + has_vector = bool(res) + else: + has_vector = bool(existing.vector_json) + + if existing and has_vector and existing.content_hash == chunk_data["content_hash"] and existing.content == chunk_data["content"]: + return True + + vectors = embed_texts([chunk_data["content"]]) + vector = vectors[0] if vectors else None + + from core.config import settings + + vector_json = _serialize_vector(vector) + + if existing: + existing.content = chunk_data["content"] + existing.metadata_json = json.dumps(chunk_data["metadata"], ensure_ascii=False) + existing.content_hash = chunk_data["content_hash"] + existing.embedding_model = settings.rag_embedding_model if vector else None + existing.vector_json = vector_json + else: + chunk = RagDocumentChunk( + user_id=question.user_id, + project_id=question.project_id, + source_type="question", + source_id=question_id, + chunk_index=0, + content=chunk_data["content"], + metadata_json=json.dumps(chunk_data["metadata"], ensure_ascii=False), + content_hash=chunk_data["content_hash"], + embedding_model=settings.rag_embedding_model if vector else None, + vector_json=vector_json, + ) + db.add(chunk) + + try: + db.flush() + target_chunk = existing or chunk + _write_postgres_vector(db, target_chunk.id, vector) + db.commit() + return True + except Exception as e: + db.rollback() + logger.error("索引题目 %d 失败: %s", question_id, e) + return False + + +def delete_question_chunks(db: Session, question_id: int) -> int: + """删除错题关联的所有 RAG chunk""" + count = ( + db.query(RagDocumentChunk) + .filter( + RagDocumentChunk.source_type == "question", + RagDocumentChunk.source_id == question_id, + ) + .delete(synchronize_session=False) + ) + try: + db.commit() + except Exception as e: + db.rollback() + logger.error("删除题目 %d 的 chunk 失败: %s", question_id, e) + return 0 + return count + + +# --------------------------------------------------------------------------- +# 语义检索 +# --------------------------------------------------------------------------- + +def retrieve_context( + db: Session, + query: str, + user_id: Optional[int], + project_id: Optional[int] = None, + subject: Optional[str] = None, + question_type: Optional[str] = None, + knowledge_tag: Optional[str] = None, + top_k: int = 6, +) -> list[dict]: + """语义检索错题上下文""" + vectors = embed_texts([query]) + query_vector = vectors[0] if vectors else None + + if query_vector is None: + logger.warning("查询 embedding 生成失败,无法执行语义检索") + return [] + + allowed_ids = None + if subject or question_type or knowledge_tag: + question_query = db.query(Question.id).filter(Question.user_id == user_id) + if user_id is None: + question_query = db.query(Question.id) + if project_id is not None: + question_query = question_query.filter(Question.project_id == project_id) + if subject: + question_query = question_query.join(UploadBatch).filter(UploadBatch.subject == subject) + if question_type: + question_query = question_query.filter(Question.question_type == question_type) + if knowledge_tag: + from db.crud.tags import _parse_tag_list + tag_list = _parse_tag_list(knowledge_tag) + if tag_list: + question_query = ( + question_query + .join(QuestionTagMapping) + .join(KnowledgeTag) + .filter(KnowledgeTag.tag_name.in_(tag_list)) + ) + allowed_ids = [row[0] for row in question_query.distinct().all()] + if not allowed_ids: + return [] + + if is_postgresql_backend(db.bind): + return _retrieve_context_postgresql( + db, + query_vector=query_vector, + user_id=user_id, + project_id=project_id, + allowed_ids=allowed_ids, + top_k=top_k, + ) + + q = ( + db.query(RagDocumentChunk) + .filter( + RagDocumentChunk.source_type == "question", + RagDocumentChunk.vector_json.isnot(None), + ) + ) + if user_id is not None: + q = q.filter(RagDocumentChunk.user_id == user_id) + if project_id is not None: + q = q.filter(RagDocumentChunk.project_id == project_id) + + chunks = q.all() + if allowed_ids is not None: + allowed_set = set(allowed_ids) + chunks = [c for c in chunks if c.source_id in allowed_set] + if not chunks: + return [] + + scored = [] + for chunk in chunks: + try: + chunk_vector = json.loads(chunk.vector_json) + score = cosine_similarity(query_vector, chunk_vector) + except (json.JSONDecodeError, TypeError, ValueError): + continue + scored.append((chunk, score)) + + scored.sort(key=lambda x: x[1], reverse=True) + top_results = scored[:top_k] + + results = [] + for chunk, score in top_results: + metadata = {} + if chunk.metadata_json: + try: + metadata = json.loads(chunk.metadata_json) + except (json.JSONDecodeError, TypeError): + pass + results.append({ + "chunk_id": chunk.id, + "source_id": chunk.source_id, + "content": chunk.content, + "metadata": metadata, + "score": round(score, 4), + }) + + return results + + +def _retrieve_context_postgresql( + db: Session, + *, + query_vector: list[float], + user_id: Optional[int], + project_id: Optional[int], + allowed_ids: Optional[list[int]], + top_k: int, +) -> list[dict]: + where_clauses = [ + "source_type = 'question'", + "embedding_vector IS NOT NULL", + ] + params = { + "query_vector": _serialize_vector(query_vector), + "limit": top_k, + } + + if user_id is not None: + where_clauses.append("user_id = :user_id") + params["user_id"] = user_id + if project_id is not None: + where_clauses.append("project_id = :project_id") + params["project_id"] = project_id + if allowed_ids is not None: + where_clauses.append("source_id IN :allowed_ids") + params["allowed_ids"] = allowed_ids + + sql = text( + "SELECT id, source_id, content, metadata_json, " + "1 - (embedding_vector <=> CAST(:query_vector AS vector)) AS score " + "FROM rag_document_chunks " + f"WHERE {' AND '.join(where_clauses)} " + "ORDER BY embedding_vector <=> CAST(:query_vector AS vector) " + "LIMIT :limit" + ) + if allowed_ids is not None: + sql = sql.bindparams(bindparam("allowed_ids", expanding=True)) + + rows = db.execute(sql, params).mappings().all() + results = [] + for row in rows: + metadata = {} + if row["metadata_json"]: + try: + metadata = json.loads(row["metadata_json"]) + except (json.JSONDecodeError, TypeError): + metadata = {} + results.append( + { + "chunk_id": row["id"], + "source_id": row["source_id"], + "content": row["content"], + "metadata": metadata, + "score": round(float(row["score"] or 0.0), 4), + } + ) + return results diff --git a/backend/db/__init__.py b/backend/db/__init__.py index 10e116a2..3397a280 100644 --- a/backend/db/__init__.py +++ b/backend/db/__init__.py @@ -17,21 +17,31 @@ db_dir.mkdir(parents=True, exist_ok=True) # 创建引擎 -engine = create_engine(f"sqlite:///{settings.db_path}", echo=False) +engine = create_engine(settings.database_url, echo=settings.database_echo) -# 启用 SQLite 外键约束 -@event.listens_for(engine, "connect") -def set_sqlite_pragma(dbapi_connection, connection_record): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() +# 启用 SQLite 外键约束 (仅针对 SQLite) +if settings.database_url.startswith("sqlite"): + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() # Session 工厂 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +def is_postgresql_backend(bind) -> bool: + dialect = getattr(bind, "dialect", None) + if dialect is not None: + return getattr(dialect, "name", "") == "postgresql" + return False + + def _migrate_schema(): """轻量级自动迁移:为已有表补充新列""" + if not settings.database_url.startswith("sqlite"): + return import sqlite3 import uuid conn = sqlite3.connect(str(settings.db_path)) diff --git a/backend/db/crud/questions.py b/backend/db/crud/questions.py index 2d67c6b9..0fbbbda7 100644 --- a/backend/db/crud/questions.py +++ b/backend/db/crud/questions.py @@ -422,7 +422,15 @@ def save_questions_to_db( question.tags = db.query(QuestionTagMapping).filter( QuestionTagMapping.question_id == question.id ).all() + + # 同时更新本地 Hash 向量和 RAG 语义索引 ensure_question_embedding(db, question) + try: + from core.rag import index_question + index_question(db, question.id) + except Exception as e: + logger.warning("Failed to auto-index question %d: %s", question.id, e) + created += 1 db.commit() @@ -585,7 +593,8 @@ def query_questions( page_size: int = 20, user_id=None, project_id=None, -) -> Tuple[List[Question], int]: + include_grand_total: bool = False, +) -> Tuple[List[Question], int, Optional[int]]: """ 统一查询题目(合并 get_history_questions 和 search_questions 的能力) @@ -599,7 +608,9 @@ def query_questions( query = query.filter(Question.project_id == project_id) # 未筛选的总收录数(仅按用户隔离) - grand_total = query.distinct().count() + grand_total = 0 + if include_grand_total: + grand_total = query.distinct().count() subject_list = _parse_filter_list(subject) if subject_list: @@ -642,7 +653,9 @@ def query_questions( .all() ) - return questions, total, grand_total + if include_grand_total: + return questions, total, grand_total + return questions, total def get_questions_by_ids(db: Session, question_ids: List[int], user_id=None) -> List[Question]: diff --git a/backend/db/models.py b/backend/db/models.py index 23c5566b..3b3b9d79 100644 --- a/backend/db/models.py +++ b/backend/db/models.py @@ -313,3 +313,22 @@ class EmailVerification(Base): last_sent_at = Column(DateTime, nullable=True) attempts = Column(Integer, default=0) created_at = Column(DateTime, default=datetime.utcnow) + + +class RagDocumentChunk(Base): + """RAG 文档分块(用于错题库语义检索)""" + __tablename__ = "rag_document_chunks" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=True, index=True) + project_id = Column(Integer, ForeignKey("projects.id"), nullable=True, index=True) + source_type = Column(String(20), nullable=False, index=True) + source_id = Column(Integer, nullable=False, index=True) + chunk_index = Column(Integer, default=0, nullable=False) + content = Column(Text, nullable=False) + metadata_json = Column(Text, default="") + content_hash = Column(String(64), default="", index=True) + embedding_model = Column(String(100), nullable=True) + vector_json = Column(Text, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow, index=True) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) diff --git a/backend/models/weight/best.pth b/backend/models/weight/best.pth deleted file mode 100644 index 61723552..00000000 --- a/backend/models/weight/best.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5ce399f0447fdc016bbe3957559264735a5b9e5de30725ca2ed4d66699fd3f47 -size 362841171 diff --git a/backend/routes/questions.py b/backend/routes/questions.py index ee4ad4fa..f89d7843 100644 --- a/backend/routes/questions.py +++ b/backend/routes/questions.py @@ -329,6 +329,7 @@ def get_error_bank(): page=page, page_size=page_size, project_id=project_id, + include_grand_total=True, ) total_pages = (total + page_size - 1) // page_size @@ -353,20 +354,64 @@ def get_error_bank(): @bp.route('/error-bank/find', methods=['GET']) def find_error_bank_questions(): - """Use a natural language description to find likely questions.""" + """Use a natural language description to find likely questions. + + 优先使用真实 embedding(RAG),降级到 hash 向量。 + Query参数: + q / query: 自然语言描述 + limit: 返回数量(默认8,最大20) + project_id: 项目ID + """ try: query_text = (request.args.get('q') or request.args.get('query') or '').strip() if not query_text: return jsonify({'success': False, 'error': '请输入要查找的题目描述'}), 400 limit = min(20, max(1, request.args.get('limit', 8, type=int))) project_id = _project_id_arg() - + user_id = _effective_user_id() + + search_mode = "hash" # 默认降级模式 + + # 尝试真实 embedding 语义检索 + try: + from core.rag import retrieve_context + with SessionLocal() as db: + rag_results = retrieve_context( + db, + query=query_text, + user_id=user_id, + project_id=project_id, + top_k=limit, + ) + if rag_results: + source_ids = [r["source_id"] for r in rag_results] + questions = crud.get_questions_by_ids(db, source_ids, user_id=user_id) + q_map = {q.id: q for q in questions} + items = [] + for r in rag_results: + q = q_map.get(r["source_id"]) + if q: + payload = _serialize_question_detail(q) + payload['match_score'] = round(r["score"] * 100) + payload['match_reasons'] = ["向量语义相似"] + items.append(payload) + if items: + return jsonify({ + 'success': True, + 'items': items, + 'total': len(items), + 'search_mode': 'pgvector' if db.bind and db.bind.dialect.name.startswith('postgresql') else 'embedding', + }) + except Exception as e: + logger.debug("RAG 语义检索不可用,降级到 hash: %s", e) + + # 降级到 hash 向量检索 with SessionLocal() as db: matches = crud.find_questions_by_natural_language( db, query_text=query_text, limit=limit, - user_id=_effective_user_id(), + user_id=user_id, project_id=project_id, ) items = [] @@ -380,12 +425,58 @@ def find_error_bank_questions(): 'success': True, 'items': items, 'total': len(items), + 'search_mode': search_mode, }) except Exception: logger.exception("AI find questions failed") return jsonify({'success': False, 'error': '找题失败,请稍后重试'}), 500 +@bp.route('/rag/reindex', methods=['POST']) +def rag_reindex(): + """重建当前用户的 RAG 索引""" + try: + from core.rag import index_question + + user_id = session.get('user_id') + project_id = _project_id_body(request.get_json(silent=True) or {}) + + with SessionLocal() as db: + query = db.query(Question.id).filter(Question.user_id == user_id) + if project_id is not None: + query = query.filter(Question.project_id == project_id) + question_ids = [row[0] for row in query.all()] + + indexed = 0 + skipped = 0 + errors = 0 + + for qid in question_ids: + try: + with SessionLocal() as db: + success = index_question(db, qid) + if success: + indexed += 1 + else: + skipped += 1 + except Exception as e: + logger.warning("索引题目 %d 失败: %s", qid, e) + errors += 1 + + return jsonify({ + 'success': True, + 'message': f'索引完成:成功 {indexed},跳过 {skipped},失败 {errors}', + 'indexed': indexed, + 'skipped': skipped, + 'errors': errors, + 'total': len(question_ids), + }) + + except Exception as e: + logger.exception("重建 RAG 索引失败") + return jsonify({'success': False, 'error': '重建索引失败,请稍后重试'}), 500 + + @bp.route('/subjects', methods=['GET']) def get_subjects(): """获取所有科目列表""" @@ -432,10 +523,19 @@ def update_question(question_id): question.answer = str(data['answer'])[:10000] if data['answer'] else None question.updated_at = datetime.utcnow() db.commit() - return jsonify({'success': True}) except Exception: db.rollback() raise + + # 内容更新后重新建立 RAG 索引 + try: + from core.rag import index_question + with SessionLocal() as db: + index_question(db, question_id) + except Exception as e: + logger.warning(f"更新题目 {question_id} 的 RAG 索引失败: {e}") + + return jsonify({'success': True}) except Exception: logger.exception("编辑题目失败") return jsonify({'success': False, 'error': '保存失败,请稍后重试'}), 500 @@ -457,12 +557,19 @@ def update_question_answer(question_id): if not question: return jsonify({'success': False, 'error': '题目不存在'}), 404 - return jsonify({ - 'success': True, - 'message': '答案已保存', - 'user_answer': question.user_answer, - 'updated_at': question.updated_at.isoformat() if question.updated_at else None, - }) + try: + from core.rag import index_question + with SessionLocal() as db: + index_question(db, question_id) + except Exception as e: + logger.warning(f"更新题目 {question_id} 的用户作答索引失败: {e}") + + return jsonify({ + 'success': True, + 'message': '答案已保存', + 'user_answer': question.user_answer, + 'updated_at': question.updated_at.isoformat() if question.updated_at else None, + }) except Exception as e: logger.exception("保存答案失败") @@ -553,9 +660,14 @@ def save_to_db(): with SessionLocal() as db: try: project_id = ( - crud.require_project_id(db, project_id, user_id=session.get('user_id'), project_type="question") + crud.require_project_id( + db, + project_id, + user_id=session.get('user_id'), + project_type="question", + ) if project_id - else crud.resolve_project_id(db, project_id, user_id=session.get('user_id'), project_type="question") + else None ) except ValueError as exc: if str(exc) == "PROJECT_REQUIRED": @@ -744,11 +856,19 @@ def save_question_answer(question_id): if not question: return jsonify({'success': False, 'error': '题目不存在'}), 404 - return jsonify({ - 'success': True, - 'message': '答案已保存', - 'answer': question.answer, - }) + # 答案更新后重新建立 RAG 索引 + try: + from core.rag import index_question + with SessionLocal() as db: + index_question(db, question_id) + except Exception as e: + logger.warning(f"更新题目 {question_id} 的 RAG 索引失败: {e}") + + return jsonify({ + 'success': True, + 'message': '答案已保存', + 'answer': question.answer, + }) except Exception as e: logger.exception("保存答案失败") diff --git a/backend/src/paddleocr_client.py b/backend/src/paddleocr_client.py index ce53b012..a51dc822 100644 --- a/backend/src/paddleocr_client.py +++ b/backend/src/paddleocr_client.py @@ -12,6 +12,7 @@ import requests import aiohttp from rich.console import Console +from requests.utils import get_environ_proxies, should_bypass_proxies console = Console() @@ -60,7 +61,14 @@ def _request_kwargs(self) -> Dict[str, Any]: from core.config import settings kwargs = {"headers": self._headers} if not settings.trust_env: - kwargs["proxies"] = {"http": None, "https": None} + kwargs["proxies"] = {} + return kwargs + try: + proxies = get_environ_proxies(self.api_url) or {} + if proxies and not should_bypass_proxies(self.api_url, proxies): + kwargs["proxies"] = proxies + except Exception: + pass return kwargs @property @@ -68,7 +76,8 @@ def _public_download_kwargs(self) -> Dict[str, Any]: from core.config import settings kwargs: Dict[str, Any] = {} if not settings.trust_env: - kwargs["proxies"] = {"http": None, "https": None} + kwargs["proxies"] = {} + return kwargs return kwargs @staticmethod diff --git a/backend/web_app.py b/backend/web_app.py index f9cb4a39..f788b5d4 100644 --- a/backend/web_app.py +++ b/backend/web_app.py @@ -14,6 +14,23 @@ import os import sys import logging +from dotenv import load_dotenv + +# ============================================================ +# Monkeypatch: 修复 langgraph 依赖冲突 +# ============================================================ +# 部分版本的 langchain.agents 依赖 langgraph.runtime.ExecutionInfo / ServerInfo, +# 若当前环境的 langgraph 版本缺失这些类,手动补齐以防止导入阶段崩溃。 +try: + import langgraph.runtime + if not hasattr(langgraph.runtime, 'ExecutionInfo'): + class ExecutionInfo: pass + langgraph.runtime.ExecutionInfo = ExecutionInfo + if not hasattr(langgraph.runtime, 'ServerInfo'): + class ServerInfo: pass + langgraph.runtime.ServerInfo = ServerInfo +except ImportError: + pass # 无论从项目根目录执行 `python backend/web_app.py` 还是在 `backend` 下执行 `python web_app.py`, # 都把 backend 目录加入 sys.path,保证 `core`、`routes`、`db` 等包解析一致。 @@ -21,9 +38,11 @@ if _BACKEND_ROOT not in sys.path: sys.path.insert(0, _BACKEND_ROOT) +# 加载 backend/.env(无论从哪个目录启动都指向同一文件) +load_dotenv(os.path.join(_BACKEND_ROOT, ".env")) + from flask import Flask, request, jsonify, send_file, session from flask_cors import CORS -from dotenv import load_dotenv from core.config import settings from core import workflow_run_store as run_store @@ -31,9 +50,6 @@ from db import crud from routes import register_routes -# 加载 backend/.env(无论从哪个目录启动都指向同一文件) -load_dotenv(os.path.join(_BACKEND_ROOT, ".env")) - # 模块级日志记录器,日志名称为 'web_app' logger = logging.getLogger(__name__) diff --git a/example_uploads/notes/test.jpg b/example_uploads/notes/test.jpg deleted file mode 100644 index 367ba5e8..00000000 Binary files a/example_uploads/notes/test.jpg and /dev/null differ diff --git a/frontend/.vscode/extensions.json b/frontend/.vscode/extensions.json deleted file mode 100644 index a7cea0b0..00000000 --- a/frontend/.vscode/extensions.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "recommendations": ["Vue.volar"] -} diff --git a/rag-development-plan.md b/rag-development-plan.md new file mode 100644 index 00000000..2acb3322 --- /dev/null +++ b/rag-development-plan.md @@ -0,0 +1,309 @@ +# 错题库与笔记 RAG 开发方案 + +## 目标 + +为现有错题库检索和 AI 对话能力加入真正的 RAG 能力,使系统可以根据用户问题自动召回相关错题、笔记和知识点上下文,并把可追溯的引用内容提供给大模型生成回答。 + +当前项目已有两个可复用基础: + +- 错题库已有轻量向量雏形:`backend/db/crud/questions.py` 中的 `local-hash-v1` 和 `QuestionEmbedding`。 +- AI 对话已有手动引用上下文能力:`backend/routes/chat.py` 中的 `_build_project_context()` 可将用户选中的错题拼入 prompt。 + +本方案的核心改造是:将“手动引用 + 本地 hash 检索”升级为“真实 embedding + 自动召回 + 可评估的上下文注入”。 + +## 建设范围 + +### 1. 错题库检索模块 + +将现有 `/api/error-bank/find` 从本地 hash 相似度升级为语义检索。 + +需要索引的错题字段: + +- 题干文本 +- 选项 +- 答案与解析 +- 学科 +- 题型 +- 知识点标签 +- 用户作答或错因记录,若后续有该字段 + +建议第一版继续使用 SQLite 存储向量,避免一开始引入复杂向量库。数据量超过单用户 5 万 chunk 或检索 P95 延迟超过验收阈值后,再切换 FAISS、Chroma 或 Qdrant。 + +### 2. 笔记模块 + +给笔记建立可检索索引,使 AI 回答可以结合用户自己的学习笔记。 + +需要索引的笔记字段: + +- 笔记标题 +- `content_markdown` +- `ocr_text` +- 学科 +- 知识点标签 +- 所属项目 + +笔记建议按段落切块,每块 300-800 中文字符,保留标题和标签作为 metadata。 + +## 推荐数据模型 + +优先新增统一索引表,而不是分别维护 `QuestionEmbedding` 和 `NoteEmbedding` 两套逻辑。 + +建议表名:`rag_document_chunks` + +核心字段: + +- `id` +- `user_id` +- `project_id` +- `source_type`: `question` 或 `note` +- `source_id`: 原始错题或笔记 ID +- `chunk_index` +- `content`: 用于 embedding 和注入 prompt 的文本 +- `metadata_json`: 学科、题型、知识点、标题等 +- `content_hash`: 判断是否需要重建索引 +- `embedding_model` +- `vector_json`: 第一版可直接存 JSON 数组 +- `created_at` +- `updated_at` + +保留现有 `QuestionEmbedding` 可作为过渡,但新检索服务应优先读取统一 chunk 表。 + +## 后端模块设计 + +建议新增: + +`backend/core/rag.py` + +职责: + +- `build_question_chunks(question)`: 将错题转换为可索引文本块 +- `build_note_chunks(note)`: 将笔记切块 +- `embed_texts(texts, provider=None)`: 批量生成 embedding +- `index_question(db, question_id)`: 建立或刷新错题索引 +- `index_note(db, note_id)`: 建立或刷新笔记索引 +- `retrieve_context(db, query, user_id, project_id=None, source_types=None, top_k=6)`: 统一召回 +- `format_rag_context(results)`: 转成 AI 对话可注入的引用上下文 + +建议新增或调整接口: + +- `POST /api/rag/reindex`: 管理员或开发调试用,重建当前用户/项目索引 +- `GET /api/error-bank/find`: 复用 RAG 检索能力,替换 hash 检索 +- `POST /api/chat//stream`: 在用户未手动引用资料时,自动调用 `retrieve_context` + +## 索引触发点 + +错题索引: + +- `save_questions_to_db()` 成功保存题目后触发 +- 题目编辑、答案更新、知识点更新后触发 +- 删除题目时删除对应 chunk + +笔记索引: + +- 新建或更新笔记后触发 +- OCR 文本更新后触发 +- 删除笔记时删除对应 chunk + +第一版可以同步生成索引;如果保存接口变慢,再改为后台任务。 + +## AI 对话接入方式 + +当前流程: + +用户手动选择错题 -> 前端发送 `context_refs` -> 后端按 ID 查内容 -> 注入 prompt + +目标流程: + +1. 用户发送问题。 +2. 后端检查是否有手动 `context_refs`。 +3. 如果没有手动引用,调用 `retrieve_context()` 自动召回错题和笔记。 +4. 将召回结果放入 ``。 +5. 大模型回答时要求优先基于引用内容,并在必要时说明引用来源。 + +建议策略: + +- 手动引用优先级最高。 +- 自动 RAG 只补充上下文,不覆盖手动引用。 +- 召回结果必须按用户和项目权限过滤。 +- 每条引用保留 `source_type/source_id/title/score`,便于前端展示来源。 + +## 验收指标 + +### 检索效果指标 + +需要准备一套离线评测集,建议不少于 200 条查询。 + +评测集来源: + +- 用户自然语言描述:“我想找那道关于二次函数顶点式的题” +- 题干改写 +- 知识点查询 +- 错因描述 +- 笔记主题查询 + +每条查询标注 1-3 条正确目标错题或笔记。 + +验收线: + +- `Recall@5 >= 75%`: 前 5 条召回中包含至少一条标注正确结果。 +- `Recall@10 >= 85%` +- `MRR@10 >= 0.55`: 正确结果越靠前越好。 +- `Context Precision@5 >= 55%`: 前 5 条中相关内容比例不低于 55%。 +- 无权限数据泄漏率 `0%`。 + +### 对话回答指标 + +抽样不少于 100 条 AI 对话问题,其中 50 条需要依赖错题库或笔记上下文。 + +验收线: + +- RAG 命中率 `>= 80%`: 需要上下文的问题中,系统成功召回至少一条相关资料。 +- 引用有效率 `>= 85%`: AI 实际使用的引用内容与问题相关。 +- 回答可溯源率 `>= 90%`: 需要引用时,回答能对应到具体错题或笔记来源。 +- 幻觉引用率 `<= 3%`: 不允许编造不存在的题目、笔记或知识点。 +- 人工评分平均分 `>= 4/5`: 从相关性、准确性、解释清晰度三项打分。 + +### 性能指标 + +第一版 SQLite 向量检索验收线: + +- 单用户 1 万 chunk 内,检索 P95 延迟 `<= 800ms`。 +- 单用户 5 万 chunk 内,检索 P95 延迟 `<= 2000ms`。 +- AI 对话因 RAG 增加的额外 P95 延迟 `<= 1500ms`。 +- 索引生成成功率 `>= 99%`。 +- 题目或笔记更新后,索引可用时间 `<= 5s`。 + +如果超过以上延迟,应评估迁移到专用向量库。 + +### 数据一致性指标 + +- 新增题目后,对应 chunk 覆盖率 `100%`。 +- 更新题目或笔记后,旧 content hash 不应继续作为有效索引。 +- 删除题目或笔记后,关联 chunk 删除率 `100%`。 +- 不同用户、不同项目之间检索隔离测试全部通过。 + +## 分阶段实施 + +### 第一期:错题库 RAG + +目标: + +- 新增统一 chunk 表。 +- 实现真实 embedding 生成和 SQLite top-k 检索。 +- 将 `/api/error-bank/find` 接入 RAG 检索。 + +验收: + +- `Recall@5 >= 75%` +- `Recall@10 >= 85%` +- 错题新增、更新、删除时索引同步正确 + +### 第二期:AI 对话自动召回 + +目标: + +- `stream_chat()` 自动根据用户消息召回错题上下文。 +- 保留手动引用优先级。 +- 回答中可展示引用来源。 + +验收: + +- RAG 命中率 `>= 80%` +- 回答可溯源率 `>= 90%` +- 幻觉引用率 `<= 3%` + +### 第三期:笔记 RAG + +目标: + +- 笔记切块与索引。 +- 对话召回同时支持错题和笔记。 +- 错题解析时可结合笔记内容。 + +验收: + +- 笔记查询 `Recall@5 >= 70%` +- 混合查询中错题和笔记均可被召回 +- 笔记更新后索引 5 秒内刷新 + +### 第四期:混合检索与排序优化 + +目标: + +- 结合关键词、知识点过滤、embedding 相似度和 recency。 +- 支持按项目、学科、题型过滤。 +- 需要时加入 reranker。 + +验收: + +- `MRR@10 >= 0.65` +- `Context Precision@5 >= 65%` +- 检索 P95 延迟仍满足性能指标 + +## 风险与约束 + +- embedding provider 未配置时,需要降级到现有 `local-hash-v1`,但前端和日志应标识为“降级检索”。 +- 向量 JSON 存 SQLite 适合第一版,数据量增大后性能会下降。 +- RAG 召回内容可能增加 token 成本,需要限制 `top_k` 和总字符数。 +- 用户数据隔离必须优先于召回效果,所有检索必须带 `user_id/project_id` 过滤。 + +## 建议默认参数 + +- `chunk_size`: 500 +- `chunk_overlap`: 50 +- `top_k`: 6 +- `embedding_model`: `text-embedding-3-small` (OpenAI) 或 `bge-small-zh-v1.5` (Local) +- 单条 chunk 长度:300-800 中文字符 +- 注入 prompt 的总上下文上限:8000-12000 字符 +- 相似度阈值:先用 `0.25` 起步,按评测集调参 +- 混合排序初始权重:embedding `0.70`,关键词 `0.20`,知识点/学科过滤 `0.10` + +## 向量数据库扩展:pgvector + +> **发布时间**:2026-04-16 +> **分类**:pgsql + +由于插件 `pgvector` 的存在,PostgreSQL 已经成为目前 RAG(检索增强生成)架构中最主流的选择之一。你不需要为了向量检索专门去购买 Pinecone 或 Milvus,用 PG 就能实现“一库多用”。 + +### 部署与启动 (Docker) + +```bash +docker pull pgvector/pgvector:pg16 + +docker run -d \ + --name pg-vector \ + -e POSTGRES_PASSWORD=root \ + -p 15432:5432 \ + pgvector/pgvector:pg16 +``` + +### 核心操作示例 (SQL) + +```sql +-- 1. 开启扩展 +CREATE EXTENSION IF NOT EXISTS vector; + +-- 2. 验证:尝试创建一个 3 维向量字段 +CREATE TABLE test_vector ( + id serial PRIMARY KEY, + embedding vector(3) +); + +-- 3. 插入一个向量数据 +INSERT INTO test_vector (embedding) VALUES ('[1, 2, 3]'), ('[4, 5, 6]'); + +-- 4. 计算余弦相似度查询 +SELECT + id, embedding, 1 - (embedding <=> '[0, 2, 1]') AS similarity_score +FROM test_vector +ORDER BY similarity_score DESC; -- 相似度越高,排在越前面 +``` + +### 三大核心操作符 + +| 操作符 | 计算方式 | 适用场景 | +| :--- | :--- | :--- | +| `<=>` | 余弦距离 (Cosine) | RAG 最常用。只关注方向,不关注长度(适合文本语义匹配)。 | +| `<->` | L2 距离 (欧氏距离) | 适合图像检索或需要考虑数值绝对大小的场景。 | +| `<#>` | 内积 (Inner Product) | 适合推荐系统,或者 Embedding 已经过归一化的场景。 | + diff --git a/reindex_all.py b/reindex_all.py new file mode 100644 index 00000000..1df6fafc --- /dev/null +++ b/reindex_all.py @@ -0,0 +1,52 @@ + +import sys +import os +import logging +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv(os.path.join(os.path.dirname(__file__), 'backend', '.env')) + +# 将 backend 目录加入路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'backend'))) + +from db import SessionLocal +from db.models import Question +from core.rag import index_question + +logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') +logger = logging.getLogger(__name__) + +def reindex_all(): + """为数据库中所有现有的题目重建 RAG 索引""" + logger.info("开始为所有现有题目重建 RAG 索引...") + + with SessionLocal() as db: + # 获取所有题目 + questions = db.query(Question).all() + total = len(questions) + logger.info(f"找到 {total} 道题目需要处理") + + success_count = 0 + fail_count = 0 + + for i, q in enumerate(questions): + try: + # 传入 db session 处理单道题目 + # index_question 内部会处理向量生成和数据库写入 + if index_question(db, q.id): + success_count += 1 + else: + fail_count += 1 + + if (i + 1) % 5 == 0 or (i + 1) == total: + logger.info(f"进度: {i + 1}/{total} (成功: {success_count}, 失败: {fail_count})") + + except Exception as e: + fail_count += 1 + logger.error(f"索引题目 ID={q.id} 时发生错误: {e}") + + logger.info(f"索引重建完成!成功: {success_count}, 失败: {fail_count}") + +if __name__ == "__main__": + reindex_all() diff --git a/requirements.txt b/requirements.txt index 7a5cfbfe..ce1d9bb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,8 @@ deepagents==0.3.5 langchain==1.2.3 langchain-core==1.2.7 langgraph==1.0.5 +# 注:若遇到 'ExecutionInfo' 导入错误,是因为 langgraph-prebuilt 与 langgraph 版本不匹配。 +# 已在 web_app.py 中通过 Monkeypatch 修复,或可尝试升级到 langgraph>=1.2.0 langgraph-cli[inmem]==0.4.12 # --- 模型提供商 / 适配器 --- @@ -39,6 +41,8 @@ pydantic-settings>=2.0.0 # --- 数据库 --- sqlalchemy>=2.0 +psycopg[binary]>=3.2.0 +pgvector>=0.3.6 # --- Web 框架 --- flask==3.1.2