-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathrun_agent.py
More file actions
132 lines (111 loc) · 4.63 KB
/
run_agent.py
File metadata and controls
132 lines (111 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
#
# Copyright (C) 2026 Tencent. All rights reserved.
#
# tRPC-Agent-Python is licensed under Apache-2.0.
"""Run the weather query agent demo"""
import asyncio
import os
from dotenv import load_dotenv
from trpc_agent_sdk.memory import MemoryServiceConfig
from trpc_agent_sdk.memory import SqlMemoryService
from trpc_agent_sdk.runners import Runner
from trpc_agent_sdk.sessions import InMemorySessionService
from trpc_agent_sdk.types import Content
from trpc_agent_sdk.types import Part
# Load environment variables from the .env file
load_dotenv()
def create_memory_service(is_async: bool = False):
"""Create session service"""
# DROP DATABASE IF EXISTS trpc_agent_memory;
# CREATE DATABASE trpc_agent_memory;
# USE trpc_agent_memory;
# SELECT * FROM mem_events;
# Build MySQL connection URL from environment variables
# Required driver: `pymysql` (install via `pip install pymysql`)
db_user = os.environ.get("MYSQL_USER", "root")
db_password = os.environ.get("MYSQL_PASSWORD", "")
db_host = os.environ.get("MYSQL_HOST", "127.0.0.1")
db_port = os.environ.get("MYSQL_PORT", "3306")
db_name = os.environ.get("MYSQL_DB", "trpc_agent_memory")
# Example: mysql+pymysql://user:pass@host:3306/dbname?charset=utf8mb4
db_url = f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}?charset=utf8mb4"
# db_url = f"mysql+aiomysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}?charset=utf8mb4"
memory_service_config = MemoryServiceConfig(
enabled=True,
ttl=MemoryServiceConfig.create_ttl_config(enable=True, ttl_seconds=20, cleanup_interval_seconds=10),
)
memory_service = SqlMemoryService(
memory_service_config=memory_service_config,
is_async=is_async,
db_url=db_url,
pool_pre_ping=True,
pool_recycle=3600,
)
return memory_service
async def run_weather_agent():
"""Run the weather query agent demo"""
app_name = "weather_agent_demo"
from agent.agent import root_agent
session_service = InMemorySessionService()
memory_service = create_memory_service()
runner = Runner(app_name=app_name, agent=root_agent, session_service=session_service, memory_service=memory_service)
user_id = "sql_memory_user"
current_session_id = "sql_memory_session"
# Demo query list
demo_queries = [
"Do you remember my name?",
"Do you remember my favorite color?",
"what is the weather like in paris?",
"Hello! My name is Alice. What's your name?",
"Do you remember my name?",
"Hello! My favorite color is blue. What's your favorite color?",
"Do you remember my favorite color?",
]
for index, query in enumerate(demo_queries):
# Use a new session for each query
user_content = Content(parts=[Part.from_text(text=query)])
print("🤖 Assistant: ", end="", flush=True)
async for event in runner.run_async(user_id=user_id,
session_id=f"{current_session_id}_{index}",
new_message=user_content):
# Check if event.content exists
if not event.content or not event.content.parts:
continue
if event.partial:
for part in event.content.parts:
if part.text:
print(part.text, end="", flush=True)
continue
for part in event.content.parts:
# Skip the reasoning part; the output is already generated when partial=True
if part.thought:
continue
if part.function_call:
print(f"\n🔧 [Invoke Tool: {part.function_call.name}({part.function_call.args})]")
elif part.function_response:
print(f"📊 [Tool Result: {part.function_response.response}]")
# Uncomment to get the full text output of the LLM
# elif part.text:
# print(f"\n✅ {part.text}")
print("\n" + "-" * 40)
async def main():
print("=" * 60)
print("First run")
print("=" * 60)
await run_weather_agent()
await asyncio.sleep(2)
print("=" * 60)
print("Second run")
print("=" * 60)
await run_weather_agent()
await asyncio.sleep(30)
print("=" * 60)
print("Third run")
print("=" * 60)
await run_weather_agent()
# wait for the memory to be expired
await asyncio.sleep(30)
if __name__ == "__main__":
asyncio.run(main())