-
Notifications
You must be signed in to change notification settings - Fork 179
Expand file tree
/
Copy pathlanggraph_agent.py
More file actions
126 lines (103 loc) · 4.11 KB
/
langgraph_agent.py
File metadata and controls
126 lines (103 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import azure.identity
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
# Setup the client to use either Azure OpenAI or GitHub Models
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
if API_HOST == "azure":
token_provider = azure.identity.get_bearer_token_provider(azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
model = ChatOpenAI(
model=os.environ.get("AZURE_OPENAI_CHAT_DEPLOYMENT"),
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
api_key=token_provider,
)
elif API_HOST == "github":
model = ChatOpenAI(
model=os.getenv("GITHUB_MODEL", "gpt-4o"),
base_url="https://models.inference.ai.azure.com",
api_key=os.environ["GITHUB_TOKEN"],
)
elif API_HOST == "ollama":
model = ChatOpenAI(
model=os.environ["OLLAMA_MODEL"],
base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"),
api_key="none",
)
@tool
def play_song_on_spotify(song: str):
"""Play a song on Spotify"""
# Call the spotify API ...
return f"Successfully played {song} on Spotify!"
@tool
def play_song_on_apple(song: str):
"""Play a song on Apple Music"""
# Call the apple music API ...
return f"Successfully played {song} on Apple Music!"
tools = [play_song_on_apple, play_song_on_spotify]
tool_node = ToolNode(tools)
model = model.bind_tools(tools, parallel_tool_calls=False)
# Define the function that determines whether to continue or not
def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define the function that calls the model
def call_model(state):
messages = state["messages"]
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# Define a new graph
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.add_edge(START, "agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Set up memory
memory = MemorySaver()
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
# We add in `interrupt_before=["action"]`
# This will add a breakpoint before the `action` node is called
app = workflow.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "1"}}
input_message = HumanMessage(content="Can you play Taylor Swift's most popular song?")
for event in app.stream({"messages": [input_message]}, config, stream_mode="values"):
event["messages"][-1].pretty_print()