-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp.py
More file actions
286 lines (235 loc) · 10.4 KB
/
app.py
File metadata and controls
286 lines (235 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
from time import sleep
from typing import Dict, Any
from langgraph.types import Command
from data import DatabaseManager, LlmConfigManager
from data.knowledge_manager import KnowledgeManager
from sql_copilot import get_sql_copilot
from utils.exceptions import (
LLMServiceError, WorkflowExecutionError,
ConfigurationError, InvalidQueryError
)
from utils.logger import setup_logger, get_logger
# Setup application logger
setup_logger()
logger = get_logger(__name__)
def init(database_id: str) -> None:
"""
Initialize database by storing target database table column information to system tables.
Args:
database_id: Target database ID
Raises:
ConfigurationError: If database configuration is not found
DatabaseConnectionError: If database connection fails
WorkflowExecutionError: If initialization process fails
"""
try:
logger.info(f"Starting database initialization for {database_id}")
database_manager = DatabaseManager()
# Get database configuration
db_config = database_manager.databases.get(database_id)
if not db_config:
raise ConfigurationError(f"Database configuration not found: {database_id}")
logger.info(f"Initializing database {database_id} ({db_config.name})")
# Save table structure information to system database
database_manager.save_table_schemas_to_system(database_id)
# Initialize KnowledgeManager
knowledge_manager = KnowledgeManager(database_id)
df = database_manager.get_column(database_id)
if not df.empty:
knowledge_manager.add_column(df[['id', 'column_comment']].to_dict('records'))
logger.info(f"Database {database_id} initialization completed successfully")
except ConfigurationError:
logger.error(f"Database configuration error for {database_id}")
raise
except Exception as e:
logger.error(f"Database initialization failed for {database_id}: {str(e)}")
raise WorkflowExecutionError(f"Database initialization failed: {str(e)}") from e
# Global checkpointer for state persistence
from langgraph.checkpoint.memory import InMemorySaver
checkpointer = InMemorySaver()
def process_chat(query: str, database_id: str = "chenjie", session_id: str = "default", resume_input: str = None) -> Dict[str, Any]:
"""
Process a chat message or resume from interruption without blocking.
Args:
query: Natural language query (used if starting new query)
database_id: Target database ID
session_id: Session ID for conversation tracking
resume_input: User input to resume execution (used if resuming)
Returns:
Dictionary containing result, interruption info, or error
"""
try:
# Create graph with shared checkpointer
graph = get_sql_copilot(database_id, checkpointer=checkpointer)
config = {"configurable": {"thread_id": session_id}}
if resume_input is not None:
# Resume execution with user input
logger.info(f"Resuming execution for session {session_id} with input: {resume_input}")
result = graph.invoke(Command(resume=resume_input), config=config)
else:
# Start new execution
if not query or not query.strip():
raise InvalidQueryError("Query cannot be empty")
logger.info(f"Processing query: '{query}' for database: {database_id}")
initial_state = {
"session_id": session_id,
"user_input": query,
"database": database_id,
"intent_ambiguous": {},
"messages": [],
"intent_type": "other",
"schema_context": "",
"final_sql": "",
"answer": "",
"cnt": 0,
"actual_query": "",
"tables": []
}
result = graph.invoke(initial_state, config=config)
# Check result state
if "answer" in result and result["answer"]:
return {
"success": True,
"answer": result["answer"],
"sql": result.get("final_sql"),
"query_data": result.get("query_data"),
"session_id": session_id,
"status": "completed"
}
if "__interrupt__" in result and result["__interrupt__"]:
question = result["__interrupt__"][0].value["content"]
return {
"success": True,
"question": question,
"session_id": session_id,
"status": "interrupted"
}
# Fallback for unexpected state
return {
"success": False,
"error": "Unexpected workflow state",
"status": "error"
}
except Exception as e:
logger.error(f"Error in process_chat: {str(e)}")
raise
def stream_chat(query: str, database_id: str = "chenjie", session_id: str = "default", resume_input: str = None):
"""
Stream chat events from the LangGraph workflow.
Yields:
Dict containing event data (node name, content, etc.)
"""
try:
# Create graph with shared checkpointer
graph = get_sql_copilot(database_id, checkpointer=checkpointer)
config = {"configurable": {"thread_id": session_id}, "recursion_limit": 100}
if resume_input is not None:
# Resume execution with user input
logger.info(f"Resuming execution for session {session_id} with input: {resume_input}")
input_data = Command(resume=resume_input)
else:
# Start new execution
if not query or not query.strip():
raise InvalidQueryError("Query cannot be empty")
logger.info(f"Processing query: '{query}' for database: {database_id}")
input_data = {
"session_id": session_id,
"user_input": query,
"database": database_id,
"intent_ambiguous": {},
"messages": [],
"intent_type": "other",
"schema_context": "",
"final_sql": "",
"answer": "",
"cnt": 0,
"actual_query": "",
"tables": []
}
# Stream events from the graph
# stream_mode="updates" gives us the state updates from each node
final_state = None
for event in graph.stream(input_data, config=config, stream_mode="updates" ):
for node_name, state_update in event.items():
final_state = state_update
if node_name == "__interrupt__":
# Handle interruption (ambiguity)
question = state_update[0].value["content"]
yield {
"type": "interrupt",
"node": "system",
"content": question
}
else:
# Determine content based on node type and state update
content = ""
# This logic might need adjustment based on what each node actually returns in state_update
# For now, we try to extract meaningful info
if "messages" in state_update and state_update["messages"]:
last_msg = state_update["messages"][-1]
node_name = last_msg.role
if hasattr(last_msg, "content"):
content = last_msg.content
else:
content = str(last_msg)
elif "answer" in state_update:
content = state_update["answer"]
elif "final_sql" in state_update:
content = f"Generated SQL: \n```sql\n{state_update['final_sql']}\n```"
elif "question" in state_update: # For semantic router or ambiguity
content = f"Clarification needed: {state_update['question']}"
# Yield the event
if content:
yield {
"type": "chunk",
"node": node_name,
"content": content,
}
# If we have a final answer, yield a completion event
if "answer" in final_state and final_state["answer"]:
yield {
"type": "complete",
"node": "system",
"content": final_state["answer"],
"sql": final_state.get("final_sql")
}
except Exception as e:
logger.error(f"Error in stream_chat: {str(e)}", e)
yield {
"type": "error",
"node": "system",
"content": str(e)
}
def ask(query: str, database_id: str = "chenjie", session_id: str = "default") -> Dict[str, Any]:
"""
CLI wrapper for process_chat that handles interruptions interactively.
"""
# Initial call
result = process_chat(query, database_id, session_id)
# Handle interruptions loop
while result.get("status") == "interrupted":
question = result.get("question")
logger.info(f"Requesting user clarification: {question}")
print("input----")
user_input = input(question)
# Resume execution
result = process_chat(None, database_id, session_id, resume_input=user_input)
return result
def main():
"""Main entry point for the SQL Copilot application."""
logger.info("Starting SQL Copilot application")
# Initialize database (commented out for demo)
# logger.info("Initializing database...")
init("chenjie")
# Example query
# query = "在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司的年营业额"
# # query = "分析下每个品牌收入"
# logger.info(f"Executing example query: {query}")
#
# result = ask(query)
#
# logger.info(result)
# print(result["answer"])
return 0
if __name__ == '__main__':
exit(main())