Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 68 additions & 35 deletions backend/agents/note/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
支持:
- 重试机制(指数退避,最多 3 次)
- 多页分批处理(超过阈值时分批整理再合并)
- 多模式结构化输出(function_calling → prompt JSON fallback)
"""

import json
import logging
import re
import time

from langchain_core.messages import SystemMessage, HumanMessage
Expand All @@ -22,49 +25,79 @@
BATCH_CHAR_LIMIT = 6000
MAX_RETRIES = 3

_messages = lambda ocr_text: [
SystemMessage(content=NOTE_ORGANIZE_PROMPT),
HumanMessage(
content=f"以下是 OCR 识别出的课堂笔记原始文本,请整理为结构化笔记:\n\n{ocr_text}"
),
]


def _extract_json_from_text(text: str) -> dict | None:
"""从 LLM 输出文本中提取 JSON 对象(兼容 ```json ... ``` 包裹)"""
# 尝试 ```json ... ``` 代码块
m = re.search(r"```json\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
if m:
try:
return json.loads(m.group(1))
except json.JSONDecodeError:
pass
# 尝试裸 JSON 对象
m = re.search(r"\{[\s\S]*\}", text)
if m:
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
pass
return None


def _invoke_once(
model,
ocr_text: str,
provider: str = "openai",
supports_function_calling: bool = True,
) -> OrganizedNote:
"""单次 LLM 调用,返回结构化笔记"""
from core.config import settings

# 百度千帆等特殊平台/模型,即使 supports_function_calling=False
# 如果直接用 with_structured_output 仍可能报错 "暂不支持该模型"(如果 Langchain 默认使用了 JSON Schema mode)
# 为保证最大兼容性,我们通过 prompt 要求输出 JSON,并手动解析
if not supports_function_calling:
from langchain_core.output_parsers import PydanticOutputParser

parser = PydanticOutputParser(pydantic_object=OrganizedNote)
format_instructions = parser.get_format_instructions()

response = model.invoke(
[
SystemMessage(
content=NOTE_ORGANIZE_PROMPT
+ f"\n\n你必须以 JSON 格式输出,且遵循以下结构:\n{format_instructions}"
),
HumanMessage(
content=f"以下是 OCR 识别出的课堂笔记原始文本,请整理为结构化笔记:\n\n{ocr_text}"
),
]
)
"""单次 LLM 调用,返回结构化笔记。按优先级尝试多种输出模式。"""
messages = _messages(ocr_text)

# 模式 1: with_structured_output(function_calling / json_schema,取决于模型)
if supports_function_calling:
try:
structured_model = model.with_structured_output(OrganizedNote)
return structured_model.invoke(messages)
except Exception as exc:
logger.warning(
"笔记整理: with_structured_output 失败 (%s),fallback 到 prompt JSON", exc
)

# 模式 2: prompt JSON — 在 system prompt 中要求输出 JSON,并手动解析
from langchain_core.output_parsers import PydanticOutputParser

parser = PydanticOutputParser(pydantic_object=OrganizedNote)
format_instructions = parser.get_format_instructions()

prompt_messages = [
SystemMessage(
content=NOTE_ORGANIZE_PROMPT
+ f"\n\n你必须以 JSON 格式输出,且遵循以下结构:\n{format_instructions}"
),
messages[1], # HumanMessage
]
response = model.invoke(prompt_messages)

# 先尝试 PydanticOutputParser(依赖严格格式)
try:
return parser.parse(response.content)
else:
structured_model = model.with_structured_output(
OrganizedNote, method="function_calling"
)
return structured_model.invoke(
[
SystemMessage(content=NOTE_ORGANIZE_PROMPT),
HumanMessage(
content=f"以下是 OCR 识别出的课堂笔记原始文本,请整理为结构化笔记:\n\n{ocr_text}"
),
]
)
except Exception:
pass

# 再尝试从文本中提取 JSON 并手动构建
data = _extract_json_from_text(response.content)
if data:
return OrganizedNote.model_validate(data)

raise RuntimeError("LLM 未返回合法的 JSON 结构")


def _invoke_with_retry(
Expand Down
6 changes: 1 addition & 5 deletions backend/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,6 @@ def build_provider_config(
) -> LLMProviderConfig:
key = self._normalize_provider(name)

# 处理可能包含多个模型的字符串(取第一个作为默认)
if model_name and "," in model_name:
model_name = [m.strip() for m in model_name.split(",") if m.strip()][0]

if key == "openai":
return OpenAICompatibleConfig(
api_key=api_key,
Expand Down Expand Up @@ -508,7 +504,7 @@ def load_providers_from_db(
if owns_db:
db = SessionLocal()
try:
for category in [("openai"), ("anthropic")]:
for category in ("openai", "anthropic"):
provider = get_active_provider(db, user_id, category)
if provider and provider.api_key:
cfg = self.build_provider_config(
Expand Down
4 changes: 3 additions & 1 deletion backend/core/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class LLMSelectionError(Exception):


def split_models(model_name: str | None) -> list[str]:
return [item.strip() for item in (model_name or "").split(",") if item.strip()]
"""返回模型名称列表。现在只支持单个模型,但保留列表形式以兼容调用方。"""
name = (model_name or "").strip()
return [name] if name else []


def build_managed_provider_context(db):
Expand Down
25 changes: 16 additions & 9 deletions backend/db/crud/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from sqlalchemy.orm import Session

from db.models import Note, Project, Question, UploadBatch
from db.models import (
ChatSession, Note, NoteTagMapping, Project, Question,
QuestionEmbedding, QuestionTagMapping, UploadBatch,
)


VALID_PROJECT_TYPES = {"question", "note"}
Expand Down Expand Up @@ -135,14 +138,18 @@ def delete_project(db: Session, project_id: int, user_id=None) -> bool:
if project.is_default:
raise ValueError("DEFAULT_PROJECT_IMMUTABLE")

# 检查是否有题目或笔记
has_questions = db.query(Question.id).filter(Question.project_id == project.id).first()
has_notes = db.query(Note.id).filter(Note.project_id == project.id).first()

if has_questions or has_notes:
raise ValueError("PROJECT_NOT_EMPTY")

# 如果没有题目和笔记了,自动清理关联的空批次(UploadBatch)
# 先删除题目和笔记的关联子表,再删除题目/笔记本身
question_ids = [q.id for q in db.query(Question.id).filter(Question.project_id == project.id).all()]
if question_ids:
db.query(QuestionEmbedding).filter(QuestionEmbedding.question_id.in_(question_ids)).delete(synchronize_session=False)
db.query(ChatSession).filter(ChatSession.question_id.in_(question_ids)).delete(synchronize_session=False)
db.query(QuestionTagMapping).filter(QuestionTagMapping.question_id.in_(question_ids)).delete(synchronize_session=False)
db.query(Question).filter(Question.id.in_(question_ids)).delete(synchronize_session=False)

note_ids = [n.id for n in db.query(Note.id).filter(Note.project_id == project.id).all()]
if note_ids:
db.query(NoteTagMapping).filter(NoteTagMapping.note_id.in_(note_ids)).delete(synchronize_session=False)
db.query(Note).filter(Note.id.in_(note_ids)).delete(synchronize_session=False)
db.query(UploadBatch).filter(UploadBatch.project_id == project.id).delete()

db.delete(project)
Expand Down
2 changes: 0 additions & 2 deletions backend/routes/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def delete_project(project_id):
except ValueError as exc:
if str(exc) == "DEFAULT_PROJECT_IMMUTABLE":
return jsonify({"success": False, "error": "默认项目不能删除"}), 400
if str(exc) == "PROJECT_NOT_EMPTY":
return jsonify({"success": False, "error": "项目里还有内容,暂时不能删除"}), 400
raise
if not deleted:
return jsonify({"success": False, "error": "项目不存在"}), 404
Expand Down
56 changes: 37 additions & 19 deletions backend/routes/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,10 @@ def save_to_db():

Request Body:
{
"selected_ids": ["q_0", "q_1", ...] # 选中的题目 ID 列表
"selected_ids": ["q_0", "q_1", ...], # 选中的题目 ID 列表
"run_id": "xxx", # 可选,从 WorkflowRun 读取题目
"record_id": 123, # 可选,从 SplitRecord 读取题目(历史记录导入)
"project_id": 456 # 目标错题库 ID
}
"""
try:
Expand All @@ -516,14 +519,39 @@ def save_to_db():
return jsonify({'success': False, 'error': '请选择至少一道题目'}), 400

run_id = data.get('run_id')
if not run_id:
return jsonify({'success': False, 'code': 'MISSING_RUN_ID', 'error': '缺少 run_id,请重新分割题目'}), 400
record_id = data.get('record_id')
user_id = session.get('user_id')
with SessionLocal() as db:
run = run_store.get_split_run(db, run_id, user_id=user_id)
if not run or run.status != run_store.STATUS_SUCCEEDED:
return jsonify({'success': False, 'error': '请先分割题目'}), 400
questions = run_store.read_questions(run)

# 优先从 SplitRecord 读取(历史记录导入场景)
if record_id:
with SessionLocal() as db:
record = crud.get_split_record_by_id(db, record_id, user_id=user_id)
if not record:
return jsonify({'success': False, 'error': '分割记录不存在'}), 404
questions = json.loads(record.questions_json) if record.questions_json else []
subject = record.subject or ''
file_names = json.loads(record.file_names_json) if record.file_names_json else []
batch_info = {
"original_filename": ", ".join(file_names) or "Unknown",
"subject": subject,
"file_path": "",
}
# 从 WorkflowRun 读取(正常分割后导入场景)
elif run_id:
with SessionLocal() as db:
run = run_store.get_split_run(db, run_id, user_id=user_id)
if not run or run.status != run_store.STATUS_SUCCEEDED:
return jsonify({'success': False, 'error': '请先分割题目'}), 400
questions = run_store.read_questions(run)
subject = run_store.read_subject(run)
file_names = json.loads(run.file_names_json) if run.file_names_json else []
batch_info = {
"original_filename": ", ".join(file_names) or "Unknown",
"subject": subject,
"file_path": run.result_dir,
}
else:
return jsonify({'success': False, 'code': 'MISSING_SOURCE', 'error': '缺少 run_id 或 record_id'}), 400

uid_set = set(selected_uids)
selected_questions = [q for q in questions if q.get('uid') in uid_set]
Expand All @@ -540,16 +568,6 @@ def save_to_db():
if 'user_answer' in answers_map[uid]:
sq['user_answer'] = answers_map[uid]['user_answer']

# 读取科目信息
subject = run_store.read_subject(run)

file_names = json.loads(run.file_names_json) if run.file_names_json else []
batch_info = {
"original_filename": ", ".join(file_names) or "Unknown",
"subject": subject,
"file_path": run.result_dir,
}

with SessionLocal() as db:
try:
project_id = (
Expand All @@ -572,7 +590,7 @@ def save_to_db():
return jsonify({
'success': True,
'message': f'已导入 {result["created"]} 道题目(跳过 {result["duplicates"]} 道重复)',
'run_id': run.public_id,
'run_id': run_id,
'created': result['created'],
'duplicates': result['duplicates'],
})
Expand Down
12 changes: 4 additions & 8 deletions backend/routes/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ def get_status():
crud.get_active_provider(db, user_id, category) if user_id else None
)
if provider and provider.api_key:
models = (
[m.strip() for m in provider.model_name.split(",")]
if provider.model_name
else []
)
models = [provider.model_name] if provider.model_name else []
available_models.append(
{
"value": category,
Expand All @@ -128,7 +124,7 @@ def get_status():
else:
managed_cfg = managed_llm.get(category)
managed_models = (
[m.strip() for m in managed_cfg.model_name.split(",")]
[managed_cfg.model_name]
if managed_cfg
and managed_cfg.configured
and managed_cfg.model_name
Expand Down Expand Up @@ -302,7 +298,7 @@ def list_models():
else:
provider = crud.get_active_provider(db, user_id, provider_type)
if provider:
api_key = api_key or provider.api_key or ""
api_key = provider.api_key or ""
base_url = base_url or provider.base_url or ""
if not api_key and provider_type in ("openai", "anthropic"):
system_provider = crud.get_active_system_provider(db, provider_type)
Expand Down Expand Up @@ -413,7 +409,7 @@ def test_paddleocr():
else:
provider = crud.get_active_provider(db, user_id, "paddleocr")
if provider:
api_token = api_token or provider.api_key or ""
api_token = provider.api_key or ""
api_url = api_url or provider.base_url or ""
if not api_token or not api_url:
system_provider = crud.get_active_system_provider(db, "paddleocr")
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/api/upload.js
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ export async function exportQuestions(selectedIds, runId) {
}

/** 将选中的分割题目和答案保存到错题库。 */
export async function saveToDb(selectedIds, answers = [], runId, projectId) {
export async function saveToDb(selectedIds, answers = [], runId, projectId, recordId) {
const body = { selected_ids: selectedIds, answers }
if (runId) body.run_id = runId
if (recordId) body.record_id = recordId
if (projectId) body.project_id = projectId
const resp = await fetch('/api/save-to-db', {
method: 'POST',
Expand Down
4 changes: 3 additions & 1 deletion frontend/src/components/base/BaseModal.vue
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const props = defineProps({
bodyClass: { type: String, default: 'px-6 py-5' },
blurBackdrop: { type: Boolean, default: true },
sidebarOffset: { type: Number, default: null },
zIndex: { type: Number, default: null },
})

const emit = defineEmits(['close'])
Expand All @@ -28,7 +29,7 @@ const backdropStyle = computed(() => ({
<div
v-if="open"
class="dialog-backdrop fixed inset-0 z-[100] bg-black/40 transition-all duration-300"
:style="backdropStyle"
:style="{ ...backdropStyle, ...(zIndex ? { zIndex: zIndex - 1 } : {}) }"
@click="emit('close')"
></div>
</Transition>
Expand All @@ -37,6 +38,7 @@ const backdropStyle = computed(() => ({
<div
v-if="open"
class="fixed inset-0 z-[101] flex items-center justify-center p-4 transition-all duration-300"
:style="zIndex ? { zIndex } : undefined"
@click.self="emit('close')"
>
<div
Expand Down
Loading