-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsqlagent.py
More file actions
100 lines (88 loc) · 3.54 KB
/
Copy pathsqlagent.py
File metadata and controls
100 lines (88 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Use this section to suppress warnings generated by your code:
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')
import os
import sys
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
from ibm_watsonx_ai.foundation_models import ModelInference
from langchain.agents import AgentType
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from ibm_watsonx_ai.foundation_models.utils.enums import ModelTypes
from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
import argparse
# Set up argument parser
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, help="The prompt to send to the SQL agent")
parser.add_argument("--db-uri", type=str, help="Optional full SQLAlchemy database URI to override env vars")
args = parser.parse_args()
# Model configuration (can be overridden via env vars)
model_id = os.getenv("MODEL_ID", "your model of choosing")
parameters = {
GenParams.MAX_NEW_TOKENS: int(os.getenv("MAX_NEW_TOKENS", "1024")),
GenParams.TEMPERATURE: float(os.getenv("TEMPERATURE", "0.2")),
GenParams.TOP_P: float(os.getenv("TOP_P", "0.95")),
GenParams.REPETITION_PENALTY: float(os.getenv("REPETITION_PENALTY", "1.2")),
}
# Watsonx credentials: read from env. Do NOT hardcode credentials in source.
watsonx_url = os.getenv("WATSONX_URL")
watsonx_apikey = os.getenv("WATSONX_APIKEY")
project_id = os.getenv("WATSONX_PROJECT_ID")
credentials = {}
if watsonx_url:
credentials["url"] = watsonx_url
if watsonx_apikey:
credentials["apikey"] = watsonx_apikey
if not credentials:
print("Error: Watsonx credentials not found. Set WATSONX_URL and/or WATSONX_APIKEY in your environment or .env file.")
print("See .env.example for variable names. Exiting.")
sys.exit(1)
try:
model = ModelInference(
model_id=model_id,
params=parameters,
credentials=credentials,
project_id=project_id
)
llm = WatsonxLLM(model=model)
except Exception as e:
print("Failed to initialize Watsonx model:", e)
print("Ensure your Watsonx credentials and SDK are configured correctly.")
sys.exit(1)
# Database connection: prefer --db-uri, then MYSQL_URI, then component env vars.
def build_mysql_uri_from_env():
mysql_uri = os.getenv("MYSQL_URI")
if mysql_uri:
return mysql_uri
user = os.getenv("MYSQL_USERNAME")
password = os.getenv("MYSQL_PASSWORD")
host = os.getenv("MYSQL_HOST")
port = os.getenv("MYSQL_PORT", "3306")
dbname = os.getenv("MYSQL_DATABASE")
if all([user, password, host, dbname]):
return f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{dbname}"
return None
mysql_uri = args.db_uri or build_mysql_uri_from_env()
if not mysql_uri:
print("Error: No database URI provided. Set MYSQL_URI or MYSQL_USERNAME, MYSQL_PASSWORD, MYSQL_HOST, and MYSQL_DATABASE.")
print("See .env.example for variable names. Exiting.")
sys.exit(1)
try:
db = SQLDatabase.from_uri(mysql_uri)
except Exception as e:
print("Failed to connect to database:", e)
sys.exit(1)
agent_executor = create_sql_agent(llm=llm, db=db, verbose=True, handle_parsing_errors=True, agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION)
# Use the prompt from command line argument
if args.prompt:
agent_executor.invoke(args.prompt)
else:
print("Please provide a prompt using --prompt argument")