-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcustom_code_agent.py
More file actions
591 lines (515 loc) Β· 32.3 KB
/
custom_code_agent.py
File metadata and controls
591 lines (515 loc) Β· 32.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
# custom_code_agent.py
import streamlit as st
import os
import sys
import time
from io import StringIO
import base64
from pathlib import Path
import re
import json
import pandas as pd
from datetime import datetime
# Import your agent executor and the specific tool instance
from agent.core import get_agent_executor
from tools.python_executor import PythonCodeExecutorTool
from utils.callbacks import StreamlitCodeExecutionCallbackHandler, InterceptToolCall
from tools.task_completed import TaskCompletedTool
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
# --- Streamlit UI Setup ---
st.set_page_config(page_title="AI Code Assistant", page_icon="π€", layout="wide")
st.title("π€ AI Code Assistant (Groq)")
st.caption("A ReAct agent powered by Groq (Llama3) for Python code generation and execution.")
# --- Session State Initialization and Reset Function ---
def initialize_session_state():
"""Initializes or resets all relevant session state variables for a new chat."""
if "conversations" not in st.session_state:
st.session_state.conversations = {}
if "current_conversation_id" not in st.session_state:
first_conv_id = f"chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
st.session_state.conversations[first_conv_id] = {
"messages": [{"role": "assistant", "content": "Hello! I am your AI Code Assistant. How can I help you with Python today?"}],
"display_name": "New Conversation"
}
st.session_state.current_conversation_id = first_conv_id
if "llm_model_name" not in st.session_state:
st.session_state.llm_model_name = "llama3-70b-8192"
if "llm_temperature" not in st.session_state:
st.session_state.llm_temperature = 0.05
if "agent_executor" not in st.session_state or st.session_state.get("needs_reinitialization", False):
try:
agent_exec = get_agent_executor()
python_executor_tool = next(tool for tool in agent_exec.tools if tool.name == "python_code_executor")
task_completed_tool = next(tool for tool in agent_exec.tools if tool.name == "task_completed")
callback_logs_buffer = StringIO()
callback_handler = StreamlitCodeExecutionCallbackHandler(python_executor_tool, task_completed_tool, callback_logs_buffer)
agent_exec.callbacks = [callback_handler]
st.session_state.agent_executor = agent_exec
st.session_state.python_executor_tool = python_executor_tool
st.session_state.task_completed_tool = task_completed_tool
st.session_state.callback_handler = callback_handler
st.session_state.callback_logs_buffer = callback_logs_buffer
st.session_state.needs_reinitialization = False
except (ValueError, RuntimeError) as e:
st.error(f"Failed to initialize Agent: {e}. Please ensure your Groq API key is correct and valid. Check Groq's model deprecation notices.")
st.stop()
if "pending_action" not in st.session_state:
st.session_state.pending_action = None
if "pending_final_answer" not in st.session_state:
st.session_state.pending_final_answer = None
if "last_processed_observation" not in st.session_state:
st.session_state.last_processed_observation = None
if "agent_continuation_needed" not in st.session_state:
st.session_state.agent_continuation_needed = False
if "last_agent_action_log_entry" not in st.session_state:
st.session_state.last_agent_action_log_entry = None
if "current_agent_chain_user_prompt" not in st.session_state:
st.session_state.current_agent_chain_user_prompt = None
if "last_executed_code" not in st.session_state:
st.session_state.last_executed_code = None
if "last_successful_output" not in st.session_state:
st.session_state.last_successful_output = None
if "last_generated_chart_data" not in st.session_state:
st.session_state.last_generated_chart_data = None
if "last_generated_plot_file" not in st.session_state:
st.session_state.last_generated_plot_file = None
if "execution_count" not in st.session_state:
st.session_state.execution_count = 0
if "last_user_prompt" not in st.session_state:
st.session_state.last_user_prompt = None
if "hil_prompt_rendered" not in st.session_state: # Flag to prevent re-rendering HIL
st.session_state.hil_prompt_rendered = False
if "last_agent_turn_processed" not in st.session_state: # Flag to show feedback prompt
st.session_state.last_agent_turn_processed = False
if "start_time" not in st.session_state:
st.session_state.start_time = None
if "feedback_given" not in st.session_state:
st.session_state.feedback_given = False
initialize_session_state()
def get_current_messages():
conv_id = st.session_state.current_conversation_id
return st.session_state.conversations[conv_id]["messages"]
def start_new_chat():
"""Creates a new, separate conversation and resets agent state."""
st.session_state.needs_reinitialization = True
st.session_state.pending_action = None
st.session_state.pending_final_answer = None
st.session_state.agent_continuation_needed = False
st.session_state.last_processed_observation = None
st.session_state.last_agent_action_log_entry = None
st.session_state.current_agent_chain_user_prompt = None
st.session_state.last_user_prompt = None
st.session_state.last_executed_code = None
st.session_state.last_successful_output = None
st.session_state.last_generated_chart_data = None
st.session_state.last_generated_plot_file = None
st.session_state.execution_count = 0
st.session_state.feedback_given = False # Reset feedback status
st.session_state.hil_prompt_rendered = False # Reset HIL flag
st.session_state.last_agent_turn_processed = False # Reset feedback flag
st.session_state.start_time = None # Reset start time
new_conv_id = f"chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
st.session_state.conversations[new_conv_id] = {
"messages": [{"role": "assistant", "content": "Hello! I am your AI Code Assistant. How can I help you with Python today?"}],
"display_name": "New Conversation"
}
st.session_state.current_conversation_id = new_conv_id
st.rerun()
# --- Sidebar for Configuration and Conversation Management ---
with st.sidebar:
st.header("Configuration")
st.subheader("Agent Settings")
st.markdown("Adjust the agent's model and creativity. Changes will apply to the next conversation.")
selected_model = st.selectbox(
"LLM Model",
options=["llama3-70b-8192", "llama3-8b-8192"],
index=0 if st.session_state.llm_model_name == "llama3-70b-8192" else 1,
help="Choose the underlying Large Language Model. `70b` is more powerful but slower, while `8b` is much faster."
)
if selected_model != st.session_state.llm_model_name:
st.session_state.llm_model_name = selected_model
st.session_state.needs_reinitialization = True
st.toast(f"Model changed to {selected_model}. Starting a new chat to apply.")
start_new_chat()
selected_temp = st.slider(
"Temperature",
min_value=0.0, max_value=1.0,
value=st.session_state.llm_temperature, step=0.05,
help="Controls randomness. Lower values are more deterministic and factual, higher values are more creative."
)
if selected_temp != st.session_state.llm_temperature:
st.session_state.llm_temperature = selected_temp
st.session_state.needs_reinitialization = True
st.toast(f"Temperature set to {selected_temp}. Starting a new chat to apply.")
start_new_chat()
st.markdown("---")
st.header("API Key")
groq_api_key_env = os.getenv("GROQ_API_KEY")
if not groq_api_key_env:
st.warning("GROQ_API_KEY not found in environment. Please enter it below.")
groq_api_key_input = st.text_input("Enter your Groq API Key:", type="password", key="api_key_input")
if groq_api_key_input:
os.environ["GROQ_API_KEY"] = groq_api_key_input
st.session_state.needs_reinitialization = True
st.rerun()
else:
st.success("Groq API Key detected.")
st.markdown("---")
if st.button("New Chat", on_click=start_new_chat, use_container_width=True, help="Start a fresh conversation and clear agent memory."):
pass
# Conversation Selector
st.markdown("---")
st.header("Conversations")
display_names = [conv["display_name"] for conv in st.session_state.conversations.values()]
conv_ids = list(st.session_state.conversations.keys())
try:
current_index = conv_ids.index(st.session_state.current_conversation_id)
except ValueError:
current_index = 0
selected_display_name = st.selectbox(
"Select Conversation",
options=display_names,
index=current_index,
label_visibility="collapsed"
)
selected_id = conv_ids[display_names.index(selected_display_name)]
if st.session_state.current_conversation_id != selected_id:
st.session_state.current_conversation_id = selected_id
st.rerun()
st.markdown("---")
st.markdown("Developed by [Manaswi Kamble](https://github.com/man-swi).")
# --- Display Chat Messages ---
for message in get_current_messages():
with st.chat_message(message["role"]):
if isinstance(message["content"], dict):
if message["content"].get("type") == "chart":
st.subheader(message["content"].get("title", "Generated Chart"))
try:
df = message["content"]["data"]
x_label = message["content"].get("x_label")
y_label = message["content"].get("y_label")
if x_label and y_label and x_label in df.columns and y_label in df.columns:
st.line_chart(df, x=x_label, y=y_label, use_container_width=True)
else:
if len(df.columns) >= 2:
x_col, y_col = df.columns[0], df.columns[1]
st.line_chart(df, x=x_col, y=y_col, use_container_width=True)
st.info(f"Chart labels not specified. Using '{x_col}' (x-axis) and '{y_col}' (y-axis).")
else:
st.error("Chart data must have at least two columns.")
except Exception as chart_error:
st.error(f"Error rendering chart: {chart_error}")
st.caption("Generated visualization. Always verify results.")
elif message["content"].get("type") == "file_display":
if message["content"]["mime"].startswith("image/"):
st.image(message["content"]["data"], caption=message["content"].get("caption", "Generated Image"), use_container_width=True)
else:
st.info(f"File created: {message['content']['file_name']}")
if message["content"].get("download_label"):
st.download_button(
label=message["content"].get("download_label"),
data=message["content"].get("data"),
file_name=message["content"].get("file_name", "download.bin"),
mime=message["content"].get("mime", "application/octet-stream")
)
elif "content_text" in message["content"]: # For displaying text-based messages like thoughts
st.markdown(message["content"]["content_text"])
else:
st.markdown(message["content"])
# --- Human-in-the-Loop (HIL) Section for Code Execution ---
# This block only renders if there's a pending action AND the HIL prompt hasn't been rendered for this action yet.
if st.session_state.pending_action and not st.session_state.get("hil_prompt_rendered", False):
with st.chat_message("assistant"):
# Read thought and tool_input from session_state.pending_action
thought_text = st.session_state.pending_action.get("thought", "Agent's Plan could not be determined.")
proposed_code_raw = st.session_state.pending_action.get("tool_input", "# No code was proposed.")
# --- Enhanced Cleaning for Proposed Code ---
# Clean up proposed_code by removing 'undefined', leading/trailing whitespace, and empty lines
if proposed_code_raw.strip().lower() == "undefined" or proposed_code_raw.strip() == "":
proposed_code = "# No valid code was proposed."
else:
# Remove leading text like "Thought:", "Proposed Code:", "undefined" if they are part of the string
# This regex aims to be robust against variations in how the agent might prefix the code
cleaned_code = re.sub(r"^.*?(?:Thought:|Proposed Code:|\n)+", "", proposed_code_raw, flags=re.DOTALL | re.IGNORECASE)
# Ensure only actual code lines are kept and remove empty lines
proposed_code = "\n".join([line for line in cleaned_code.splitlines() if line.strip()])
if not proposed_code:
proposed_code = "# No valid code was proposed."
# Additional cleanup: remove any standalone "undefined" instances case-insensitively
# This is a direct replacement for any lingering "undefined" that might not be caught by the above
proposed_code = re.sub(r'\bundefined\b', '', proposed_code, flags=re.IGNORECASE).strip()
# Re-check if code is now empty after cleaning 'undefined'
if not proposed_code:
proposed_code = "# No valid code was proposed."
# --- End Enhanced Cleaning ---
# Display Agent's Plan
st.markdown(thought_text)
# Display the actual code only if it's not a placeholder
if proposed_code != "# No code was proposed." and proposed_code.strip():
st.code(proposed_code, language="python")
# Buttons for approval/cancellation
col1, col2, col3 = st.columns([1,1,2])
with col1:
if st.button("β
Approve & Execute", key="approve_code"):
# Append thought and code to chat history BEFORE execution
get_current_messages().append({"role": "assistant", "content": {"type": "content_text", "text": thought_text}})
if proposed_code != "# No code was proposed." and proposed_code.strip():
get_current_messages().append({"role": "assistant", "content": {"type": "content_text", "content_text": f"**Executed Code:**\n```python\n{proposed_code}\n```"}})
st.session_state.last_executed_code = proposed_code
st.session_state.last_successful_output = None # Reset for new execution
with st.spinner("Executing proposed code..."):
execution_result = st.session_state.python_executor_tool.execute_code_after_approval(proposed_code)
# Process and display execution results
result_display_content = f"**Code Execution Result:**\n```\n{execution_result}\n```"
get_current_messages().append({"role": "assistant", "content": result_display_content})
st.session_state.execution_count += 1
st.session_state.last_successful_output = execution_result if "Standard Error:" not in execution_result else None
# Reset chart/file states
st.session_state.last_generated_chart_data = None
st.session_state.last_generated_plot_file = None
# --- Process Output for PLOT_DATA_JSON ---
plot_data_match = re.search(r"PLOT_DATA_JSON_START:(.*):PLOT_DATA_JSON_END", execution_result, re.DOTALL)
if plot_data_match:
try:
plot_data_json = plot_data_match.group(1).strip()
parsed_plot_data = json.loads(plot_data_json)
if not isinstance(parsed_plot_data, dict) or not all(isinstance(v, list) for v in parsed_plot_data.values()):
raise ValueError("Plot data must be a dictionary of lists.")
df_chart = pd.DataFrame(parsed_plot_data)
chart_title = "Generated Chart"
thought_match_for_title = re.search(r"Thought:\s*(.*?)(?=\nAction:)", st.session_state.last_agent_action_log_entry, re.DOTALL)
if thought_match_for_title:
thought_text_for_title = thought_match_for_title.group(1).strip()
title_match = re.search(r"(?:plot|chart|visualize)\s+(.*?)(?:\.|,|\n|$)", thought_text_for_title, re.IGNORECASE)
if title_match: chart_title = title_match.group(1).strip().capitalize()
x_key, y_key = (df_chart.columns[0], df_chart.columns[1]) if len(df_chart.columns) >= 2 else (None, None)
st.session_state.last_generated_chart_data = {"type": "chart", "data": df_chart, "title": chart_title, "x_label": x_key, "y_label": y_key}
execution_result = execution_result.replace(f"PLOT_DATA_JSON_START:{plot_data_json}:PLOT_DATA_JSON_END", "").strip()
except (json.JSONDecodeError, ValueError) as e:
st.warning(f"Failed to parse plot data from agent output: {e}")
# --- Process Output for Generated Files ---
file_creation_match = re.search(r"Files created during execution: (.*)", execution_result, re.DOTALL)
if file_creation_match:
filenames_str = file_creation_match.group(1)
filenames = [f.strip() for f in filenames_str.split(',')]
execution_result = execution_result.replace(f"Files created during execution: {filenames_str}", "").strip()
for filename in filenames:
file_path = Path(filename)
if file_path.is_file():
mime_map = {'.png': 'image/png', '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.gif': 'image/gif', '.csv': 'text/csv', '.txt': 'text/plain'}
mime_type = mime_map.get(file_path.suffix.lower(), "application/octet-stream")
try:
with open(file_path, "rb") as f: file_bytes = f.read()
file_display_entry = {"type": "file_display", "data": file_bytes, "caption": f"Generated: {filename}", "download_label": f"Download {filename}", "file_name": filename, "mime": mime_type}
if mime_type.startswith("image/"):
st.session_state.last_generated_plot_file = file_display_entry
else:
get_current_messages().append({"role": "assistant", "content": file_display_entry})
except Exception as e:
st.warning(f"Could not load file {filename}: {e}")
else:
st.warning(f"Agent mentioned file '{filename}' but it does not exist.")
st.session_state.last_processed_observation = execution_result.strip()
st.session_state.agent_continuation_needed = True
st.session_state.pending_action = None # Clear pending action after processing
st.session_state.hil_prompt_rendered = True # Mark HIL as rendered for this action
st.session_state.last_agent_turn_processed = True # Mark that an agent turn has been processed for feedback
# --- Metrics Logic Adjusted Here ---
if st.session_state.start_time:
end_time = time.time()
processing_time = end_time - st.session_state.start_time
st.info(f"π‘ **Total Processing Time:** {processing_time:.2f} seconds")
st.session_state.start_time = None # Reset timer
# --- End Metrics Logic ---
st.rerun()
with col2:
if st.button("β Cancel", key="cancel_code"):
cancellation_message = "Code execution CANCELED by user."
get_current_messages().append({"role": "assistant", "content": {"type": "content_text", "content_text": f"*{cancellation_message}*"}})
st.session_state.last_processed_observation = cancellation_message
st.session_state.agent_continuation_needed = True # Signal agent to process the cancellation message
st.session_state.pending_action = None # Clear pending action
st.session_state.hil_prompt_rendered = True # Mark HIL as rendered
st.session_state.last_agent_turn_processed = True # Mark that an agent turn has been processed for feedback
# --- Metrics Logic Adjusted Here ---
if st.session_state.start_time:
end_time = time.time()
processing_time = end_time - st.session_state.start_time
st.info(f"π‘ **Total Processing Time:** {processing_time:.2f} seconds")
st.session_state.start_time = None # Reset timer
# --- End Metrics Logic ---
st.rerun()
# --- Handle Task Completed interception ---
if st.session_state.pending_final_answer:
with st.chat_message("assistant"):
st.success("Task Completed!")
# --- Metrics Logic Moved Here for Final Answer ---
if st.session_state.start_time:
end_time = time.time()
processing_time = end_time - st.session_state.start_time
st.info(f"π‘ **Total Processing Time:** {processing_time:.2f} seconds")
st.session_state.start_time = None # Reset timer
# --- End Metrics Logic ---
if st.session_state.last_generated_chart_data:
chart_msg = st.session_state.last_generated_chart_data
st.subheader(chart_msg.get("title", "Generated Chart"))
st.line_chart(chart_msg["data"], x=chart_msg.get("x_label"), y=chart_msg.get("y_label"), use_container_width=True)
st.caption("Generated visualization. Always verify results.")
if st.session_state.last_generated_plot_file:
file_msg = st.session_state.last_generated_plot_file
st.image(file_msg["data"], caption=file_msg.get("caption", "Generated Plot"), use_container_width=True)
st.download_button(
label=file_msg.get("download_label"), data=file_msg.get("data"),
file_name=file_msg.get("file_name", "download.png"), mime=file_msg.get("mime", "image/png")
)
st.caption("Generated visualization. Always verify results.")
final_answer_content = st.session_state.pending_final_answer["final_answer"] if isinstance(st.session_state.pending_final_answer, dict) else st.session_state.pending_final_answer
st.markdown(final_answer_content)
get_current_messages().append({"role": "assistant", "content": final_answer_content})
if st.session_state.last_agent_action_log_entry:
with st.expander("Show Agent's Final Thought Process"):
st.code(st.session_state.last_agent_action_log_entry, language='ansi')
# Reset state after completion
st.session_state.pending_final_answer = None
st.session_state.agent_continuation_needed = False
st.session_state.current_agent_chain_user_prompt = None
st.session_state.last_user_prompt = None
st.session_state.last_agent_action_log_entry = None
st.session_state.last_executed_code = None
st.session_state.last_successful_output = None
st.session_state.last_generated_chart_data = None
st.session_state.last_generated_plot_file = None
st.session_state.execution_count = 0
st.session_state.hil_prompt_rendered = False # Ensure HIL is reset
st.session_state.last_agent_turn_processed = True # Mark that an agent turn has been processed for feedback
# --- Feedback Section ---
# This block is now placed to render after any agent output (HIL or final answer)
if st.session_state.get("last_agent_turn_processed", False) and not st.session_state.get("feedback_given", False):
with st.chat_message("assistant"):
st.markdown("---")
st.write("Was this response helpful?")
fb_col1, fb_col2, fb_col3 = st.columns([1,1,5])
with fb_col1:
if st.button("π", key="helpful"):
st.session_state.feedback_given = True
st.toast("Thank you for your feedback!", icon="β
")
st.session_state.last_agent_turn_processed = False # Reset for next turn
# Metrics: capture time after feedback is given
if st.session_state.start_time:
end_time = time.time()
processing_time = end_time - st.session_state.start_time
st.info(f"π‘ **Total Processing Time:** {processing_time:.2f} seconds")
st.session_state.start_time = None # Reset timer
st.rerun()
with fb_col2:
if st.button("π", key="unhelpful"):
st.session_state.feedback_given = True
st.toast("Thank you for your feedback! We'll use it to improve.", icon="π‘")
st.session_state.last_agent_turn_processed = False # Reset for next turn
# Metrics: capture time after feedback is given
if st.session_state.start_time:
end_time = time.time()
processing_time = end_time - st.session_state.start_time
st.info(f"π‘ **Total Processing Time:** {processing_time:.2f} seconds")
st.session_state.start_time = None # Reset timer
st.rerun()
with fb_col3:
st.empty() # Push buttons to the left
# --- Agent Invocation Logic ---
# This block handles invoking the agent when it needs to continue processing
if (st.session_state.agent_continuation_needed or st.session_state.get("last_user_prompt")) and \
not st.session_state.pending_action and not st.session_state.pending_final_answer:
agent_input = None
if st.session_state.agent_continuation_needed:
# Construct input for agent to continue from observation
base_scratchpad = st.session_state.last_agent_action_log_entry or ""
current_observation = st.session_state.last_processed_observation
agent_input = {
"input": st.session_state.current_agent_chain_user_prompt,
"agent_scratchpad": base_scratchpad + f"\nObservation: {current_observation}"
}
st.session_state.agent_continuation_needed = False
st.session_state.last_processed_observation = None
st.session_state.last_agent_action_log_entry = None
st.session_state.hil_prompt_rendered = False # Reset HIL flag as we are now processing the agent's next step
elif st.session_state.get("last_user_prompt"):
# This is the initial invocation for a new user prompt
agent_input = {"input": st.session_state.last_user_prompt}
st.session_state.current_agent_chain_user_prompt = st.session_state.last_user_prompt
st.session_state.last_user_prompt = None # Clear the direct user prompt
st.session_state.execution_count = 0 # Reset execution count for new chain
st.session_state.feedback_given = False # Reset feedback for new chain
st.session_state.hil_prompt_rendered = False # Reset HIL flag
st.session_state.last_agent_turn_processed = False # Reset feedback prompt flag
# Metrics: start timer when a new user prompt is submitted
st.session_state.start_time = time.time()
if agent_input:
with st.chat_message("assistant"):
thought_process_placeholder = st.empty()
try:
st.session_state.callback_logs_buffer.seek(0)
st.session_state.callback_logs_buffer.truncate(0)
with st.spinner("Agent is thinking..."):
response = st.session_state.agent_executor.invoke(agent_input)
agent_output = response["output"]
captured_logs = st.session_state.callback_logs_buffer.getvalue()
# Display thought process only if there are logs
if captured_logs:
with thought_process_placeholder.expander("Show Agent's Thought Process"):
st.code(captured_logs, language='ansi')
# Append the agent's actual output to the message history
get_current_messages().append({"role": "assistant", "content": agent_output})
# Prepare for potential continuation or feedback
st.session_state.last_user_prompt = None
st.session_state.agent_continuation_needed = False
st.session_state.current_agent_chain_user_prompt = None
st.session_state.last_agent_action_log_entry = None # Clear logs after displaying/processing
st.session_state.last_processed_observation = None # Clear observation after processing
st.session_state.last_agent_turn_processed = True # Mark that an agent turn has been processed for feedback
st.rerun()
except InterceptToolCall:
# This exception is caught by the callback handler when a tool is called.
# The HIL prompt or final answer processing will handle it.
pass
except Exception as e:
# Ensure stdout is restored if an error occurs mid-execution
if sys.stdout != st.session_state.callback_handler._original_stdout:
sys.stdout = st.session_state.callback_handler._original_stdout
captured_logs = st.session_state.callback_logs_buffer.getvalue()
error_message_str = str(e)
display_error_message = f"An unexpected error occurred: {error_message_str}"
st.error(display_error_message)
# Display logs if available
if captured_logs:
with thought_process_placeholder.expander("Agent's Process before error"):
st.code(captured_logs, language='ansi')
# Add error message to chat history
get_current_messages().append({"role": "assistant", "content": {"type": "content_text", "content_text": f"**Error:** {display_error_message}\n**Logs:**\n```ansi\n{captured_logs}\n```"}})
# Reset for a clean slate and inform user
st.warning("Agent session has encountered an error and will be reset. Please start a new chat.")
start_new_chat()
get_current_messages().append({"role": "assistant", "content": {"type": "content_text", "content_text": "*(Agent session has been reset due to an error. Please start a new chat.)*"}})
st.rerun()
# --- Chat Input ---
# Only show chat input if the agent is not waiting for HIL approval or a final answer
if not st.session_state.pending_action and \
not st.session_state.agent_continuation_needed and \
not st.session_state.pending_final_answer:
user_prompt = st.chat_input("Ask me to write or execute Python code...")
if user_prompt:
current_conv_id = st.session_state.current_conversation_id
current_conv = st.session_state.conversations[current_conv_id]
# Dynamically update conversation display name on first user message
if len(current_conv["messages"]) == 1 and "New Conversation" in current_conv["display_name"]:
new_display_name = user_prompt[:50] + "..." if len(user_prompt) > 50 else user_prompt
current_conv["display_name"] = new_display_name
get_current_messages().append({"role": "user", "content": user_prompt})
st.session_state.last_user_prompt = user_prompt # Store for agent invocation
# Reset states for a new user turn
st.session_state.feedback_given = False # Ensure feedback is reset for new user input
st.session_state.hil_prompt_rendered = False # Reset HIL flag for new user prompt
st.session_state.last_agent_turn_processed = False # Reset feedback prompt flag
st.session_state.start_time = time.time() # Start performance timer when a new user prompt is received
st.rerun()