-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path12-MemoryBetweenThreads.py
More file actions
330 lines (271 loc) Β· 12.1 KB
/
12-MemoryBetweenThreads.py
File metadata and controls
330 lines (271 loc) Β· 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# LangGraph Cross-Thread Memory Tutorial
# For documentation - https://langchain-ai.github.io/langgraph/how-tos/cross_thread_persistence/
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import MessagesState, START, END, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore
from typing import Annotated, Dict, List, Any
from typing_extensions import TypedDict
from langchain_core.runnables import RunnableConfig
from langgraph.store.base import BaseStore
import uuid
import time
from config.secret_keys import OPENAI_API_KEY
from config.config import get_llm, get_embeddings, EMBEDDING_DIMENSIONS
from utils.graph_img_generation import save_and_show_graph
# System prompt
SYSTEM_PROMPT = """
You're Jarvis, a helpful digital assistant with an excellent memory.
You remember details about users across different conversations.
You have two main users: Krishna and Shiv. Each has their own personality and preferences.
- When talking to Shiv, be respectful and use "ji" as an honorific.
- When talking to Krishna, be more casual and energetic.
When you learn something about a user, acknowledge it and use it in future conversations.
If you don't know something about a user, be honest about it.
"""
# Initialize OpenAI components
llm = get_llm()
# Create the embeddings model for semantic search
embeddings = get_embeddings()
# Set up the in-memory store with embedding capabilities
try:
memory_store = InMemoryStore(
index={
"embed": embeddings,
"dims": EMBEDDING_DIMENSIONS,
}
)
print("β
Memory Store: Successfully initialized")
except Exception as e:
print(f"β Memory Store Error: {e}")
raise SystemExit("Memory store initialization failed. Exiting...")
# Function to extract memory from user message
def extract_memory(message: str) -> str:
"""Extract memory from user message if it contains 'remember:'"""
message = message.lower()
if "remember:" in message:
# Get the part after "remember:"
memory_part = message.split("remember:", 1)[1].strip()
return memory_part
return None
# Assistant node with memory capabilities
def Assistant(state: MessagesState, config: RunnableConfig, *, store: BaseStore):
# Get user_id from config
user_id = config["configurable"]["user_id"]
namespace = ("memories", user_id)
# Get the latest user message
last_message = state["messages"][-1].content
# Check if we need to store a new memory
memory_to_store = extract_memory(last_message)
if memory_to_store:
# Store the new memory with a unique ID
memory_id = str(uuid.uuid4())
store.put(namespace, memory_id, {"data": memory_to_store})
print(f"π Stored new memory for user {user_id}: {memory_to_store}")
# Retrieve relevant memories using semantic search
retrieved_memories = store.search(namespace, query=last_message, limit=5)
# Format memories for the system prompt
user_info = ""
if retrieved_memories:
memory_texts = [f"- {d.value['data']}" for d in retrieved_memories]
user_info = "User information:\n" + "\n".join(memory_texts)
# Create enhanced system prompt with memories
enhanced_prompt = f"{SYSTEM_PROMPT}\n\n{user_info}"
# Generate the response
response = llm.invoke([SystemMessage(content=enhanced_prompt)] + state["messages"])
return {"messages": response}
# Initialize LangGraph components
builder = StateGraph(MessagesState)
# Configure graph nodes and edges
builder.add_node("Assistant", Assistant)
builder.add_edge(START, "Assistant")
builder.add_edge("Assistant", END)
# Compile the graph with in-memory checkpointer and store
try:
memory_graph = builder.compile(
checkpointer=MemorySaver(), # In-memory checkpointer for conversation history
store=memory_store # In-memory store for cross-thread memories
)
print("β
Graph compiled with cross-thread memory")
except Exception as e:
print(f"β Graph Compilation Error: {e}")
raise SystemExit("Graph compilation failed. Exiting...")
# Save and show the graph image
save_and_show_graph(memory_graph, filename="12-MemoryBetweenThreads", show_image=False)
# Function to run automated tests
def run_automated_test(graph, users):
"""Run an automated test to demonstrate cross-thread memory capabilities"""
test_results = []
thread_counters = {"Krishna": 1, "Shiv": 1}
print("\n" + "="*50)
print("π§ͺ STARTING AUTOMATED TEST")
print("="*50)
# Function to simulate conversation
def simulate_conversation(user, message, config):
thread_id = config["configurable"]["thread_id"]
print(f"\nπ Thread: {thread_id}")
print(f"π€ {user}: {message}")
response = graph.invoke({"messages": [HumanMessage(content=message)]}, config=config)
# Check if a memory was stored and display it
if "remember:" in message.lower():
memory_content = extract_memory(message)
if memory_content:
print(f"π Stored new memory for user {user}: {memory_content}")
ai_response = response['messages'][-1].content
print(f"π€ Jarvis: {ai_response}")
time.sleep(1) # Add short delay for readability
return response, ai_response
# Phase 1: Store memories for both users
for user in users:
print(f"\n{'-'*20} User Session: {user} {'-'*20}")
# Create config for this user
config = {
"configurable": {
"thread_id": f"{user}_thread_{thread_counters[user]}",
"user_id": user
}
}
# Store name
message = f"Remember: My name is {user}"
response, ai_response = simulate_conversation(user, message, config)
test_results.append(f"β Stored name for {user}")
# Store favorite food
if user == "Krishna":
food = "butter"
else:
food = "thandai"
message = f"Remember: My favorite food is {food}"
response, ai_response = simulate_conversation(user, message, config)
test_results.append(f"β Stored food preference for {user}")
# Store nickname
if user == "Krishna":
nickname = "Kanha"
else:
nickname = "Bholenath"
message = f"Remember: My nickname is {nickname}"
response, ai_response = simulate_conversation(user, message, config)
test_results.append(f"β Stored nickname for {user}")
# Phase 2: Test cross-thread memory (new thread, same user)
print("\n" + "="*50)
print("π TESTING CROSS-THREAD MEMORY")
print("="*50)
print("Starting new conversation threads, but memory should persist...")
time.sleep(2)
for user in users:
# Increment thread counter to simulate new conversation
thread_counters[user] += 1
config = {
"configurable": {
"thread_id": f"{user}_thread_{thread_counters[user]}",
"user_id": user
}
}
print(f"\n{'-'*20} New Thread for {user} {'-'*20}")
message = f"What is my name, favorite food, and nickname?"
response, ai_response = simulate_conversation(user, message, config)
test_results.append(f"β Memory persists across threads for {user}")
# Phase 3: Test user isolation (Krishna shouldn't know Shiv's details and vice versa)
print("\n" + "="*50)
print("π TESTING USER MEMORY ISOLATION")
print("="*50)
print("Verifying that memories are isolated between users...")
time.sleep(2)
for i, user in enumerate(users):
other_user = users[1-i] # Get the other user
config = {
"configurable": {
"thread_id": f"{user}_thread_{thread_counters[user]}",
"user_id": user
}
}
if other_user == "Krishna":
other_nickname = "Kanha"
else:
other_nickname = "Bholenath"
print(f"\n{'-'*20} Testing {user}'s Memory Isolation {'-'*20}")
message = f"Is my nickname {other_nickname}?"
response, ai_response = simulate_conversation(user, message, config)
if ("no" in ai_response.lower() or "not" in ai_response.lower()):
test_results.append(f"β Memory isolation confirmed for {user}")
else:
test_results.append(f"β Memory isolation failed for {user}")
# Print test summary
print("\n" + "="*50)
print("π§ͺ TEST RESULTS SUMMARY")
print("="*50)
for result in test_results:
print(result)
print("\n" + "="*50)
return thread_counters
def chat():
current_user = "Krishna" # Default user starts with Krishna
thread_counters = {"Krishna": 1, "Shiv": 1} # Track thread numbers per user
users = ["Krishna", "Shiv"] # Our two users
print("\n" + "="*50)
print("π€ Jarvis AI Chat | Cross-Thread Memory Tutorial")
print("="*50)
print("- Type 'new thread' to start a new conversation")
print("- Type 'remember: [information]' to store information")
print("- Type 'switch' to toggle between Krishna and Shiv")
print("- Type 'run test' to run an automated test case")
print("- Type 'help' to show these commands")
print("- Type 'exit' to end the conversation")
while True:
# Display current user and thread
print("\n" + "="*50)
print(f"π€ Current user: {current_user}")
print(f"π Thread: {current_user}_thread_{thread_counters[current_user]}")
# Get user input
user_msg = input("You: ")
# Handle commands
if user_msg.lower() == 'exit':
print("Ending the conversation")
break
# Command to show help
if user_msg.lower() == 'help':
print("Commands:")
print("- 'new thread': Start a new conversation")
print("- 'remember: [information]': Store information")
print("- 'switch': Toggle between Krishna and Shiv")
print("- 'run test': Run an automated test case")
print("- 'exit': End the conversation")
continue
# Command to switch between users
if user_msg.lower() == 'switch':
# Toggle between Krishna and Shiv
current_user = "Shiv" if current_user == "Krishna" else "Krishna"
print(f"β
Switched to user: {current_user}")
continue
# Command to start a new thread
if user_msg.lower() == 'new thread':
thread_counters[current_user] += 1
print(f"Starting new thread: {thread_counters[current_user]} for {current_user}")
continue
# Command to run automated test
if user_msg.lower() == 'run test':
updated_counters = run_automated_test(memory_graph, users)
thread_counters = updated_counters # Update thread counters after test
continue
# Regular message - prepare and invoke
try:
# Prepare the message and config
human_msg = [HumanMessage(content=user_msg)]
config = {
"configurable": {
"thread_id": f"{current_user}_thread_{thread_counters[current_user]}",
"user_id": current_user
}
}
# Check if memory is being stored
memory_to_store = extract_memory(user_msg)
if memory_to_store:
print(f"π Storing new memory for user {current_user}: {memory_to_store}")
# Invoke the graph
response = memory_graph.invoke({"messages": human_msg}, config=config)
print(f"π€ Jarvis: {response['messages'][-1].content}")
except Exception as e:
print(f"β Error: {e}")
# Run the chat function
if __name__ == "__main__":
chat()