-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path5-SummaryInputGraph.py
More file actions
129 lines (85 loc) · 3.57 KB
/
5-SummaryInputGraph.py
File metadata and controls
129 lines (85 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
from utils.graph_img_generation import save_and_show_graph
from config.secret_keys import OPENAI_API_KEY
from config.config import get_llm
# defining the LLM
llm = get_llm()
class State(MessagesState):
summary: str
# Define the logic to call the model
def call_model(state: State):
# Get summary if it exists
summary = state.get("summary", "")
# If there is summary, then we add it
if summary:
# Add summary to system message
system_message = f"Summary of conversation earlier: {summary}"
# Append summary to any newer messages
messages = [SystemMessage(content=system_message)] + state["messages"]
else:
messages = state["messages"]
response = llm.invoke(messages)
return {"messages": response}
# Creating a summary of the message history
def summarize_conversation(state: State):
# First, we get any existing summary
summary = state.get("summary", "")
# Create our summarization prompt
if summary:
# A summary already exists
summary_message = (
f"This is summary of the conversation to date: {summary}\n\n"
"Extend the summary by taking into account the new messages above:"
)
else:
summary_message = "Create a summary of the conversation above:"
# Add prompt to our history
messages = state["messages"] + [HumanMessage(content=summary_message)]
response = llm.invoke(messages)
# Delete all but the 2 most recent messages
delete_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]
return {"summary": response.content, "messages": delete_messages}
# Determine whether to end or summarize the conversation
def should_continue(state: State):
"""Return the next node to execute."""
messages = state["messages"]
# If there are more than six messages, then we summarize the conversation
if len(messages) > 6:
return "summarize_conversation"
# Otherwise we can just end
return "__end__"
# Define a new graph
workflow = StateGraph(State)
workflow.add_node("conversation", call_model)
workflow.add_node(summarize_conversation)
# Set the entrypoint as conversation
workflow.add_edge(START, "conversation")
workflow.add_conditional_edges("conversation", should_continue,
{
"summarize_conversation": "summarize_conversation",
"__end__": END
})
workflow.add_edge("summarize_conversation", END)
# Compile
memory = MemorySaver()
summarize_conversation_graph = workflow.compile(checkpointer=memory)
# Use the utility function to save and optionally show the graph
save_and_show_graph(summarize_conversation_graph, filename="5-SummaryInputGraph", show_image=False)
# Specify a thread AKA session
config = {"configurable": {"thread_id": "1"}}
# Start the conversation loop
while True:
# Take user input
user_msg = input("Enter your message (or type 'exit' to quit): ")
# Break the loop if the user types 'exit'
if user_msg.lower() == 'exit':
print("Exiting the session...")
break
messages = [HumanMessage(content=user_msg)]
messages = summarize_conversation_graph.invoke({"messages": messages}, config)
for m in messages['messages'][-1:]:
m.pretty_print()