-
Notifications
You must be signed in to change notification settings - Fork 178
Expand file tree
/
Copy pathpydanticai_basic.py
More file actions
50 lines (41 loc) · 1.93 KB
/
pydanticai_basic.py
File metadata and controls
50 lines (41 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import asyncio
import os
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from openai import AsyncOpenAI
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openai import OpenAIProvider
# Setup the OpenAI client to use either Azure OpenAI or GitHub Models
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = AsyncOpenAI(
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
api_key=token_provider,
)
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
elif API_HOST == "github":
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
elif API_HOST == "ollama":
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
else:
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
agent: Agent[None, str] = Agent(
model,
system_prompt="You are a Spanish tutor. Help the user learn Spanish. ONLY respond in Spanish.",
output_type=str,
)
async def main():
result = await agent.run("oh hey how are you?")
print(result.output)
if async_credential:
await async_credential.close()
if __name__ == "__main__":
asyncio.run(main())