-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
153 lines (125 loc) · 5.42 KB
/
main.py
File metadata and controls
153 lines (125 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import sqlite3
import time
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.types import Command, interrupt
from state.state_definition import AgentState
# Import your agents and evaluator
from agents.researcher import Researcher
from agents.critic import Critic
from tests.evals.evaluator import run_technical_audit
# --- 2. PERSISTENCE LAYER ---
conn = sqlite3.connect("agent_memory.db", check_same_thread=False)
memory = SqliteSaver(conn)
# --- 3. AGENT INITIALIZATION ---
researcher = Researcher()
critic = Critic()
# --- 4. NEW: AUDITOR NODE ---
def audit_node(state: AgentState):
"""
Automated Quality Gate: Rejects reports that are ungrounded or vague.
"""
print("\n⚖️ Node: Auditor is verifying report integrity...")
# Check if we have sources to verify against
if not state.get("sources_text"):
return {"critique": "Auditor failed: No source data found to verify.", "approved": False}
current_test = {
"query": state['query'],
"critical_facts": "Ensuring technical metrics and claims are grounded in provided sources."
}
# Run the DeepEval/Ragas audit
results = run_technical_audit(
agent_output=state['current_draft'],
source_context=state['sources_text'],
test_case=current_test
)
if not results['Pass']:
print(f"⚠️ AUDIT FAILED (Score: {results['Faithfulness']}): Sending back for revision.")
return {
"critique": f"Low Faithfulness Score ({results['Faithfulness']}). Data points in the report do not match search results. Please re-verify metrics.",
"approved": False,
"revision_count": state.get("revision_count", 0) + 1
}
print(f"✅ AUDIT PASSED (Score: {results['Faithfulness']})")
return {"approved": True} # This tells the graph the Auditor is satisfied
# --- 5. HUMAN APPROVAL NODE ---
def human_approval_node(state: AgentState):
print("\n--- ⏸️ WAITING FOR HUMAN REVIEW ---")
print(f"Auditor has cleared this draft. Turn {state.get('revision_count')}.")
user_feedback = interrupt("Do you want to (A)pprove, (R)equest specific changes, or (E)xit?")
if user_feedback.lower() == 'a':
return {"approved": True}
elif user_feedback.lower() == 'e':
return {"approved": True, "exit_requested": True, "critique": "__USER_EXIT__"}
else:
return {"critique": f"HUMAN FEEDBACK: {user_feedback}", "approved": False}
# --- 6. BUILD THE GRAPH ---
builder = StateGraph(AgentState)
builder.add_node("researcher", researcher.run)
builder.add_node("critic", critic.run)
builder.add_node("auditor", audit_node) # Added Auditor
builder.add_node("human_review", human_approval_node)
builder.set_entry_point("researcher")
builder.add_edge("researcher", "critic")
builder.add_edge("critic", "auditor") # Critic sends to Auditor
def audit_router(state):
# If Auditor failed (approved=False), go back to researcher
if not state.get("approved"):
return "revise"
return "human"
builder.add_conditional_edges("auditor", audit_router, {"revise": "researcher", "human": "human_review"})
def final_router(state):
# Only finish if human approves or we hit the revision cap
if (
state.get("approved")
or state.get("revision_count", 0) >= 3
or state.get("exit_requested")
):
return "end"
return "continue"
builder.add_conditional_edges("human_review", final_router, {"continue": "researcher", "end": END})
app = builder.compile(checkpointer=memory)
# --- 7. INTERACTIVE EXECUTION LOOP ---
def run_interactive():
default_thread_id = "main_user_session"
print("\n" + "="*50)
print("🤖 2026 IEEE RESEARCH AGENT (WITH AUTO-AUDIT)")
print("="*50)
resume_choice = input("\nResume previous session? [Y/n]: ").strip().lower()
if resume_choice in {"n", "no"}:
thread_id = f"session_{int(time.time())}"
print(f"🆕 Starting fresh session: {thread_id}")
else:
thread_id = default_thread_id
print(f"📌 Using session: {thread_id}")
config = {"configurable": {"thread_id": thread_id}}
while True:
state_snapshot = app.get_state(config)
if state_snapshot.next:
print("\n📬 RESUMING PREVIOUS TASK...")
user_input = input("Your decision/feedback: ")
inputs = Command(resume=user_input)
else:
user_query = input("\n🔍 What should I research today? ")
if user_query.lower() in ["exit", "quit"]: break
inputs = {
"query": user_query,
"revision_count": 0,
"approved": False,
"exit_requested": False,
"sources_text": [],
}
for event in app.stream(inputs, config=config):
for node, data in event.items():
if "__interrupt__" in str(event): continue
print(f" └─ ⚙️ {node.capitalize()} is processing...")
final_state = app.get_state(config)
if final_state.values.get("exit_requested"):
print("\n👋 Session ended by user.")
break
if not final_state.next and final_state.values.get("current_draft"):
print("\n" + "✨" + "-"*48 + "✨")
print(final_state.values.get("current_draft"))
print("-"*50)
if __name__ == "__main__":
run_interactive()