diff --git a/.gitignore b/.gitignore index c757ddf..01a1c90 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ examples/demo .sidekickvenv models/ +db/ +sdk_quick_tutorial.ipynb # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/examples/notebooks/sdk_quick_tutorial.ipynb b/examples/notebooks/sdk_quick_tutorial.ipynb index de888d8..0fb7fbe 100644 --- a/examples/notebooks/sdk_quick_tutorial.ipynb +++ b/examples/notebooks/sdk_quick_tutorial.ipynb @@ -2,18 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "60080b7e-2e80-4154-aa35-87c13b6ab371", "metadata": {}, "outputs": [], "source": [ "# https://github.com/h2oai/sql-sidekick/releases\n", - "#!python3 -m pip install --force-reinstall sql_sidekick-0.2.2-py3-none-any.whl" + "# !python -m pip uninstall sql_sidekick-0.2.4-py3-none-any.whl -y" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "f480e37a-4327-48da-8c84-aba0ac1eef23", "metadata": {}, "outputs": [], @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "c91887ce-c74a-432b-a3f9-120c8abc0003", "metadata": {}, "outputs": [], @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "9fc212c8-dc73-4330-a07f-7394fd198395", "metadata": {}, "outputs": [], @@ -51,18 +51,18 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "6421a995-f846-4a1e-8292-374bd7500382", "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", - "f = pd.read_csv(\"./sleep_health_and_lifestyle_dataset.csv\")" + "# import pandas as pd\n", + "# f = pd.read_csv(\"../demo/demo_data.csv\")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "eac0fa65-bb06-415a-aa87-1185789f878d", "metadata": {}, "outputs": [], @@ -71,76 +71,21 @@ "import os\n", "\n", "os.environ['OPENAI_API_KEY'] = \"\"\n", - "os.environ['H2OGPT_URL'] = 'http://38.128.233.247'\n", + "os.environ['H2OGPT_URL'] = \"\"\n", "os.environ['H2OGPT_API_TOKEN'] = \"\"\n", "# To get access to h2ogpte endpoint, reach out to cloud-feedback@h2o.ai\n", - "os.environ['H2OGPTE_URL'] = \"https://h2ogpte.genai.h2o.ai\" # e.g. https://<>.h2ogpte.h2o.ai\n", - "os.environ['H2OGPTE_API_TOKEN'] = \"\"" + "os.environ['H2OGPTE_URL'] = \"\" # e.g. https://<>.h2ogpte.h2o.ai\n", + "os.environ['H2OGPTE_API_TOKEN'] = \"\"\n", + "\n", + "os.environ['H2OGPT_BASE_URL'] = \"\"" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "62e23b39-caa8-4e2f-bf12-678dd586f0df", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Information supplied:\n", - " querydb, localhost, sqlite, abc, 5432\n", - "Database already exists!\n", - "Table name: sleep_health_eda\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-01-27 20:35:06.568\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36m_extract_schema_info\u001b[0m:\u001b[36m162\u001b[0m - \u001b[34m\u001b[1mUsing schema information from: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:06.572\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mcreate_table\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mSchema info used for creating table:\n", - " Person_ID NUMERIC,\n", - "Gender TEXT COLLATE NOCASE,\n", - "Age NUMERIC,\n", - "Occupation TEXT COLLATE NOCASE,\n", - "Sleep_Duration NUMERIC,\n", - "Quality_of_Sleep NUMERIC,\n", - "Physical_Activity_Level NUMERIC,\n", - "Stress_Level NUMERIC,\n", - "BMI_Category TEXT COLLATE NOCASE,\n", - "Blood_Pressure TEXT COLLATE NOCASE,\n", - "Heart_Rate NUMERIC,\n", - "Daily_Steps NUMERIC,\n", - "Sleep_Disorder TEXT COLLATE NOCASE\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:06.578\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mcreate_table\u001b[0m:\u001b[36m198\u001b[0m - \u001b[1mTable created: sleep_health_eda\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Checked table sleep_health_eda exists in the DB.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-01-27 20:35:06.586\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m222\u001b[0m - \u001b[34m\u001b[1mAdding sample values to table: ./sleep_health_and_lifestyle_dataset.csv\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:06.597\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m228\u001b[0m - \u001b[34m\u001b[1mInserting chunk: 0\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:06.755\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m233\u001b[0m - \u001b[1mData inserted into table: sleep_health_eda\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:06.759\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m238\u001b[0m - \u001b[1mNumber of rows inserted: 2618\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Created a Database querydb. Inserted sample values from ./sleep_health_and_lifestyle_dataset.csv into table sleep_health_eda, please ask questions!\n" - ] - } - ], + "outputs": [], "source": [ "HOST_NAME = \"localhost\"\n", "USER_NAME = \"sqlite\"\n", @@ -151,7 +96,7 @@ "\n", "# Given .csv file, auto-generate schema\n", "# Download dataset --> https://www.kaggle.com/datasets/uom190346a/sleep-health-and-lifestyle-dataset\n", - "data_path = \"./sleep_health_and_lifestyle_dataset.csv\"\n", + "data_path = \"examples/demo/demo_data.csv\"\n", "table_name = \"sleep_health_eda\"\n", "\n", "r, table_info_path = generate_schema(data_path=data_path, output_path=f\"{cache_path}/{table_name}_table_info.jsonl\")\n", @@ -171,29 +116,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "80dec22c-362e-41a0-8f34-0690465542e6", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['h2ogpt-sql-sqlcoder2-4bit',\n", - " 'h2ogpt-sql-sqlcoder-34b-alpha-4bit',\n", - " 'h2ogpt-sql-nsql-llama-2-7B-4bit',\n", - " 'h2ogpt-sql-sqlcoder2',\n", - " 'h2ogpt-sql-sqlcoder-34b-alpha',\n", - " 'h2ogpt-sql-nsql-llama-2-7B',\n", - " 'gpt-3.5-turbo',\n", - " 'gpt-4-8k',\n", - " 'gpt-4-1106-preview-128k']" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# List supported models\n", "list_models()" @@ -201,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "2b3db015-1d9e-46b0-ad58-2f5aac0c6e4c", "metadata": {}, "outputs": [], @@ -225,7 +151,7 @@ " sample_queries_path=sample_qna_path,\n", " table_name=table_name,\n", " is_command=False,\n", - " model_name=\"h2ogpt-sql-sqlcoder2-4bit\", #Other default model option: h2ogpt-sql-sqlcoder-34b-alpha\n", + " model_name=\"gpt-4o\", #Other default model option: h2ogpt-sql-sqlcoder-34b-alpha\n", " is_regenerate=regenerate,\n", " is_regen_with_options=regenerate_with_options,\n", " execute_query=False,\n", @@ -236,311 +162,21 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "523f1a88-eea8-414c-89b1-b7a2b3126535", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-01-27 20:35:33.226\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m500\u001b[0m - \u001b[1mTable in use: ['sleep_health_eda']\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:33.229\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m501\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:33.231\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m534\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:33.232\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m535\u001b[0m - \u001b[1mQuestion: What is the average sleep duration for each gender?\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:33.234\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m553\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:33.235\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m355\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:33.236\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m359\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.049\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m362\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.055\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m363\u001b[0m - \u001b[1mFree GPU memory: 20GB\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.057\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36m__new__\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mLoading local model: h2ogpt-sql-sqlcoder2-4bit\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.058\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mload_causal_lm_model\u001b[0m:\u001b[36m382\u001b[0m - \u001b[1mTotal GPUs: 1\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.059\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mFree GPU memory: 20GB\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.060\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m393\u001b[0m - \u001b[1mLoading model: defog/sqlcoder2 on device id: 0\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.062\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m394\u001b[0m - \u001b[34m\u001b[1mModel cache: .//models/\u001b[0m\n", - "\u001b[32m2024-01-27 20:35:35.063\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m432\u001b[0m - \u001b[34m\u001b[1mLoading in 4 bit mode: True with device {'': 0}\u001b[0m\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "dec7435d27704941a96dcdb9951ed10e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/4 [00:00> or ->\n", - "- Use prepared statements with parameterized queries to prevent SQL injection\n", - "\n", - "\n", - "### Input:\n", - "For SQL TABLE 'sleep_health_eda' with sample question/answer pairs,\n", - "(), create a valid SQL (dialect:sqlite) query to answer the following question:\n", - "What is the average sleep duration for each gender?.\n", - "This query will run on a database whose schema is represented in this string:\n", - "CREATE TABLE 'sleep_health_eda' (['Person_ID NUMERIC, Gender TEXT, Age NUMERIC, Occupation TEXT, Sleep_Duration NUMERIC, Quality_of_Sleep NUMERIC, Physical_Activity_Level NUMERIC, Stress_Level NUMERIC, BMI_Category TEXT, Blood_Pressure TEXT, Heart_Rate NUMERIC, Daily_Steps NUMERIC, Sleep_Disorder TEXT,']\n", - ");\n", - "\n", - "-- Table 'sleep_health_eda', , has sample values ({'sleep_health_eda': [\"'Gender' contains values similar to Male,Female.\", \"'Occupation' contains values similar to Lawyer,Teacher,Doctor,Software Engineer,Scientist,Sales Representative,Accountant,Salesperson,Manager,Nurse.\", \"'BMI_Category' contains values similar to Overweight,Normal,Obese,Normal Weight.\", \"'Sleep_Disorder' contains values similar to None,Sleep Apnea,Insomnia.\"]})\n", - "\n", - "### Response:\n", - "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `What is the average sleep duration for each gender?`:\n", - "```SELECT\u001b[0m\n", - "\u001b[32m2024-01-27 20:36:22.461\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mContext length: 743\u001b[0m\n", - "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n", - "\u001b[32m2024-01-27 20:36:30.891\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m608\u001b[0m - \u001b[1mInput query: What is the average sleep duration for each gender?\u001b[0m\n", - "\u001b[32m2024-01-27 20:36:30.895\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m609\u001b[0m - \u001b[1mGenerated response:\n", - "\n", - "SELECT \"gender\", AVG(\"sleep_duration\") AS \"average_sleep_duration\" FROM \"sleep_health_eda\" GROUP BY \"gender\" LIMIT 100\u001b[0m\n", - "\u001b[32m2024-01-27 20:36:30.905\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m635\u001b[0m - \u001b[1mAlternate responses:\n", - "\n", - "[]\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exiting...\n" - ] - } - ], + "outputs": [], "source": [ - "res = query(\"What is the average sleep duration for each gender?\", table_name=\"sleep_health_eda\", \n", + "res = query(\"What is the average sleep duration for each gender?\", table_name=\"sleep_health_eda\",\n", " table_info_path=table_info_path, sample_qna_path=None)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "b17e2b4f-8736-4d44-addc-db8d2be4ce51", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Question = **Generated response for question,**\n", - "What is the average sleep duration for each gender?\n", - "\n", - "----\n", - "Generated SQL = ``` sql\n", - "SELECT \"gender\",\n", - " AVG(\"sleep_duration\") AS \"average_sleep_duration\"\n", - "FROM \"sleep_health_eda\"\n", - "GROUP BY \"gender\"\n", - "LIMIT 100\n", - "```\n", - "\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(f\"Question = {res[0][0]}\")\n", "print(\"----\")\n", @@ -549,170 +185,23 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "03c5dfc0-c6f0-4573-b36d-56dc7bcbe8bc", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-01-27 20:39:50.016\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m500\u001b[0m - \u001b[1mTable in use: ['sleep_health_eda']\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.017\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m501\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.018\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m534\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.019\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m535\u001b[0m - \u001b[1mQuestion: What are the most common occupations among individuals in the dataset?\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.020\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m553\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.021\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m355\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.022\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m359\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.023\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m362\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.024\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m363\u001b[0m - \u001b[1mFree GPU memory: 8GB\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.038\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m469\u001b[0m - \u001b[1mUsing information info from path .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.039\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m603\u001b[0m - \u001b[1mComputing user request ...\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.043\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36msemantic_search\u001b[0m:\u001b[36m155\u001b[0m - \u001b[34m\u001b[1mInput questions: # query: what are the most common occupations among individuals in the dataset?\u001b[0m\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e4589b85d3514f2ea3c88a505f15698c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Batches: 0%| | 0/1 [00:00> or ->\n", - "- Use prepared statements with parameterized queries to prevent SQL injection\n", - "\n", - "\n", - "### Input:\n", - "For SQL TABLE 'sleep_health_eda' with sample question/answer pairs,\n", - "(), create a valid SQL (dialect:sqlite) query to answer the following question:\n", - "What are the most common occupations among individuals in the dataset?.\n", - "This query will run on a database whose schema is represented in this string:\n", - "CREATE TABLE 'sleep_health_eda' (['Person_ID NUMERIC, Gender TEXT, Age NUMERIC, Occupation TEXT, Sleep_Duration NUMERIC, Quality_of_Sleep NUMERIC, Physical_Activity_Level NUMERIC, Stress_Level NUMERIC, BMI_Category TEXT, Blood_Pressure TEXT, Heart_Rate NUMERIC, Daily_Steps NUMERIC, Sleep_Disorder TEXT,']\n", - ");\n", - "\n", - "-- Table 'sleep_health_eda', , has sample values ({'sleep_health_eda': [\"'Gender' contains values similar to Male,Female.\", \"'Occupation' contains values similar to Lawyer,Teacher,Doctor,Software Engineer,Scientist,Sales Representative,Accountant,Salesperson,Manager,Nurse.\", \"'BMI_Category' contains values similar to Overweight,Normal,Obese,Normal Weight.\", \"'Sleep_Disorder' contains values similar to None,Sleep Apnea,Insomnia.\"]})\n", - "\n", - "### Response:\n", - "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `What are the most common occupations among individuals in the dataset?`:\n", - "```SELECT\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.161\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mContext length: 749\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.162\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m651\u001b[0m - \u001b[1mRegeneration requested on previous query ...\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:50.163\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m652\u001b[0m - \u001b[34m\u001b[1mSelected temperature for fast regeneration : 0.8\u001b[0m\n", - "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n", - "\u001b[32m2024-01-27 20:39:52.499\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m675\u001b[0m - \u001b[34m\u001b[1mTemperature saved: 0.8\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:52.512\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m608\u001b[0m - \u001b[1mInput query: What are the most common occupations among individuals in the dataset?\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:52.513\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m609\u001b[0m - \u001b[1mGenerated response:\n", - "\n", - "SELECT \"occupation\", COUNT(1) AS \"COUNT\" FROM \"sleep_health_eda\" GROUP BY \"occupation\" ORDER BY \"COUNT\" DESC LIMIT 100\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:52.516\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m635\u001b[0m - \u001b[1mAlternate responses:\n", - "\n", - "[]\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exiting...\n" - ] - } - ], + "outputs": [], "source": [ - "# On using re-generation flag we toggle the temperature values between 0 and 1 alternating between low \n", + "# On using re-generation flag we toggle the temperature values between 0 and 1 alternating between low\n", "# (focus/conservative generation and high values (random/creative generation)\n", - "res = query(\"What are the most common occupations among individuals in the dataset?\", table_name=\"sleep_health_eda\", \n", + "res = query(\"What are the most common occupations among individuals in the dataset?\", table_name=\"sleep_health_eda\",\n", " table_info_path=table_info_path, sample_qna_path=None, regenerate=True)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "cf2fc33d-ea21-4ab2-9019-329f5bc2051d", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Question = **Generated response for question,**\n", - "What are the most common occupations among individuals in the dataset?\n", - "\n", - "----\n", - "Generated SQL = ``` sql\n", - "SELECT \"occupation\",\n", - " COUNT(1) AS \"COUNT\"\n", - "FROM \"sleep_health_eda\"\n", - "GROUP BY \"occupation\"\n", - "ORDER BY \"COUNT\" DESC\n", - "LIMIT 100\n", - "```\n", - "\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(f\"Question = {res[0][0]}\")\n", "print(\"----\")\n", @@ -721,389 +210,22 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "b47bef8d-c991-4581-a7fc-23a056911c3f", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-01-27 20:39:56.595\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m500\u001b[0m - \u001b[1mTable in use: ['sleep_health_eda']\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.597\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m501\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.598\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m534\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.599\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m535\u001b[0m - \u001b[1mQuestion: What is the average sleep duration for each gender?\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.601\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m553\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.602\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m355\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.604\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m359\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.605\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m362\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.607\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m363\u001b[0m - \u001b[1mFree GPU memory: 8GB\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.629\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m469\u001b[0m - \u001b[1mUsing information info from path .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.631\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m603\u001b[0m - \u001b[1mComputing user request ...\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.640\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36msemantic_search\u001b[0m:\u001b[36m155\u001b[0m - \u001b[34m\u001b[1mInput questions: # query: what is the average sleep duration for each gender?\u001b[0m\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5ed8c4c529c54952a30bfb4d99b7ec95", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Batches: 0%| | 0/1 [00:00> or ->\n", - "- Use prepared statements with parameterized queries to prevent SQL injection\n", - "\n", - "\n", - "### Input:\n", - "For SQL TABLE 'sleep_health_eda' with sample question/answer pairs,\n", - "(), create a valid SQL (dialect:sqlite) query to answer the following question:\n", - "What is the average sleep duration for each gender?.\n", - "This query will run on a database whose schema is represented in this string:\n", - "CREATE TABLE 'sleep_health_eda' (['Person_ID NUMERIC, Gender TEXT, Age NUMERIC, Occupation TEXT, Sleep_Duration NUMERIC, Quality_of_Sleep NUMERIC, Physical_Activity_Level NUMERIC, Stress_Level NUMERIC, BMI_Category TEXT, Blood_Pressure TEXT, Heart_Rate NUMERIC, Daily_Steps NUMERIC, Sleep_Disorder TEXT,']\n", - ");\n", - "\n", - "-- Table 'sleep_health_eda', , has sample values ({'sleep_health_eda': [\"'Gender' contains values similar to Male,Female.\", \"'Occupation' contains values similar to Lawyer,Teacher,Doctor,Software Engineer,Scientist,Sales Representative,Accountant,Salesperson,Manager,Nurse.\", \"'BMI_Category' contains values similar to Overweight,Normal,Obese,Normal Weight.\", \"'Sleep_Disorder' contains values similar to None,Sleep Apnea,Insomnia.\"]})\n", - "\n", - "### Response:\n", - "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `What is the average sleep duration for each gender?`:\n", - "```SELECT\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.785\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mContext length: 743\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.787\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m677\u001b[0m - \u001b[1mRegeneration with options requested on previous query ...\u001b[0m\n", - "\u001b[32m2024-01-27 20:39:56.788\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m692\u001b[0m - \u001b[34m\u001b[1mSelected temperature for diverse beam search: 0.4\u001b[0m\n", - "/home/pramit/.jupyterven/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.4` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", - " warnings.warn(\n", - "/home/pramit/.jupyterven/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `5` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", - " warnings.warn(\n", - "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n", - "\u001b[32m2024-01-27 20:42:01.434\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m727\u001b[0m - \u001b[1mGenerated options:\n", - "\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.440\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", - "Option 1: (_probability_: 0.381034255027771)\n", - "``` sql\n", - "SELECT gender,\n", - " AVG(sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.444\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", - "Option 2: (_probability_: 0.2624567449092865)\n", - "``` sql\n", - "SELECT AVG(sleep_duration) AS average_sleep_duration,\n", - " gender\n", - "FROM sleep_health_eda\n", - "GROUP BY gender\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.446\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", - "Option 3: (_probability_: 0.22498156130313873)\n", - "``` sql\n", - "SELECT Gender,\n", - " AVG(Sleep_Duration) AS average_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY Gender\n", - "ORDER BY average_duration DESC NULLS LAST\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.451\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", - "Option 4: (_probability_: 0.13085876405239105)\n", - "``` sql\n", - "SELECT 'Gender',\n", - " AVG('Sleep_Duration') AS average_sleep_duration\n", - "FROM'sleep_health_eda'\n", - "GROUP BY 'Gender'\n", - "ORDER BY average_sleep_duration DESC NULLS LAST\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.474\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", - "Option 5: (_probability_: 0.0006686743581667542)\n", - "``` sql\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.488\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m608\u001b[0m - \u001b[1mInput query: What is the average sleep duration for each gender?\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.489\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m609\u001b[0m - \u001b[1mGenerated response:\n", - "\n", - "SELECT \"gender\", AVG(\"sleep_duration\") AS \"average_sleep_duration\" FROM \"sleep_health_eda\" GROUP BY \"gender\" ORDER BY \"average_sleep_duration\" DESC LIMIT 100\u001b[0m\n", - "\u001b[32m2024-01-27 20:42:01.492\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m635\u001b[0m - \u001b[1mAlternate responses:\n", - "\n", - "['Option 1: (_probability_: 0.381034255027771)\\n``` sql\\nSELECT gender,\\n AVG(sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY gender\\nORDER BY average_sleep_duration DESC NULLS LAST\\nLIMIT 100;\\n```\\n\\n\\n', 'Option 2: (_probability_: 0.2624567449092865)\\n``` sql\\nSELECT AVG(sleep_duration) AS average_sleep_duration,\\n gender\\nFROM sleep_health_eda\\nGROUP BY gender\\nLIMIT 100;\\n```\\n\\n\\n', 'Option 3: (_probability_: 0.22498156130313873)\\n``` sql\\nSELECT Gender,\\n AVG(Sleep_Duration) AS average_duration\\nFROM sleep_health_eda\\nGROUP BY Gender\\nORDER BY average_duration DESC NULLS LAST\\nLIMIT 100;\\n```\\n\\n\\n', \"Option 4: (_probability_: 0.13085876405239105)\\n``` sql\\nSELECT 'Gender',\\n AVG('Sleep_Duration') AS average_sleep_duration\\nFROM'sleep_health_eda'\\nGROUP BY 'Gender'\\nORDER BY average_sleep_duration DESC NULLS LAST\\nLIMIT 100;\\n```\\n\\n\\n\", 'Option 5: (_probability_: 0.0006686743581667542)\\n``` sql\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda\\nLIMIT 100;\\n```\\n\\n\\n']\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exiting...\n" - ] - } - ], + "outputs": [], "source": [ "# Alternate options\n", - "res = query(\"What is the average sleep duration for each gender?\", table_name=\"sleep_health_eda\", \n", + "res = query(\"What is the average sleep duration for each gender?\", table_name=\"sleep_health_eda\",\n", " table_info_path=table_info_path, sample_qna_path=None, regenerate_with_options=True)" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "e415c0b9-466e-4417-ac1e-493914a83c36", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Question = **Generated response for question,**\n", - "What is the average sleep duration for each gender?\n", - "\n", - "----Options----\n", - "Option 1: (_probability_: 0.381034255027771)\n", - "``` sql\n", - "SELECT gender,\n", - " AVG(sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\n", - "Option 2: (_probability_: 0.2624567449092865)\n", - "``` sql\n", - "SELECT AVG(sleep_duration) AS average_sleep_duration,\n", - " gender\n", - "FROM sleep_health_eda\n", - "GROUP BY gender\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\n", - "Option 3: (_probability_: 0.22498156130313873)\n", - "``` sql\n", - "SELECT Gender,\n", - " AVG(Sleep_Duration) AS average_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY Gender\n", - "ORDER BY average_duration DESC NULLS LAST\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\n", - "Option 4: (_probability_: 0.13085876405239105)\n", - "``` sql\n", - "SELECT 'Gender',\n", - " AVG('Sleep_Duration') AS average_sleep_duration\n", - "FROM'sleep_health_eda'\n", - "GROUP BY 'Gender'\n", - "ORDER BY average_sleep_duration DESC NULLS LAST\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\n", - "Option 5: (_probability_: 0.0006686743581667542)\n", - "``` sql\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", - "FROM sleep_health_eda\n", - "GROUP BY sleep_health_eda.gender\n", - "ORDER BY average_sleep_duration DESC NULLS LAST;\n", - "\n", - "SELECT sleep_health_eda.gender,\n", - " AVG(sleep_health_eda\n", - "LIMIT 100;\n", - "```\n", - "\n", - "\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(f\"Question = {res[0][0]}\")\n", "print(\"----Options----\")\n", @@ -1136,7 +258,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.0" } }, "nbformat": 4, diff --git a/sidekick/query.py b/sidekick/query.py index 94974ce..d150154 100644 --- a/sidekick/query.py +++ b/sidekick/query.py @@ -223,6 +223,8 @@ def _query_tasks(self, question_str, data_info, sample_queries, table_name: list MODEL_CHOICE_MAP = MODEL_CHOICE_MAP_EVAL_MODE m_name = MODEL_CHOICE_MAP.get(self.model_name) + if m_name is None: + raise ValueError(f"Invalid model name {self.model_name}. Available models: {MODEL_CHOICE_MAP.keys()}") completion = self.openai_client.chat.completions.create( model=m_name, @@ -233,6 +235,8 @@ def _query_tasks(self, question_str, data_info, sample_queries, table_name: list ) res = completion.choices[0].message.content return res + except ValueError as ve: + raise ve except Exception as se: _, ex_value, _ = sys.exc_info() res = ex_value.statement if ex_value.statement else None @@ -245,7 +249,12 @@ def self_correction(self, error_msg, input_query, remote_url, client_key): user_prompt = DEBUGGING_PROMPT["user_prompt"].format(ex_traceback=error_msg, qry_txt=input_query).strip() _response = [] _res = input_query - self_correction_model = os.getenv("SELF_CORRECTION_MODEL", "h2oai/h2ogpt-4096-llama2-70b-chat") + if os.getenv("OPENAI_API_KEY", None): + default_correction_model = "gpt-4" + else: + default_correction_model = "h2ogpt-4096-llama2-70b-chat" + self_correction_model = os.getenv("SELF_CORRECTION_MODEL", default_correction_model) + logger.info(f"Using LLM model: {self_correction_model} for self-correction") if "h2ogpt-" in self_correction_model: if remote_url and client_key and remote_url != "" and client_key != "": from h2ogpte import H2OGPTE @@ -310,6 +319,7 @@ def generate_response( res = response.metadata["sql_query"] return res except Exception as se: + logger.info(f"Error in generating response: {se}") # Take the SQL and make an attempt for correction _, ex_value, ex_traceback = sys.exc_info() qry_txt = ex_value.statement @@ -337,6 +347,7 @@ def generate_response( res = qry_txt return res except Exception as se: + logger.info(f"Error in generate_response, self correction: {se}") # Another exception occurred, return the original SQL res = qry_txt return res diff --git a/sidekick/utils.py b/sidekick/utils.py index 148133b..04c6a89 100644 --- a/sidekick/utils.py +++ b/sidekick/utils.py @@ -25,7 +25,10 @@ REMOTE_LLMS = ["h2ogpt-sql-sqlcoder-34b-alpha", "h2ogpt-sql-sqlcoder2", "h2ogpt-sql-nsql-llama-2-7B", - "h2ogpt-sql-sqlcoder-7b-2", "gpt-3.5-turbo", "gpt-4-8k", "gpt-4-1106-preview-128k"] + "h2ogpt-sql-sqlcoder-7b-2", "gpt-3.5-turbo", "gpt-4-8k", "gpt-4-1106-preview-128k", + "gpt-4o", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-0125-preview", + "gpt-4-vision-preview", "gpt-4-1106-vision-preview", "gpt-4", "gpt-4-0613", "gpt-4-32k", + "gpt-4-32k-0613"] # clone of models from https://huggingface.co/models # suffix `h2ogpt-sql-` is added to avoid conflict with the original models (we haven't done any changes to the original models yet) @@ -38,9 +41,19 @@ "h2ogpt-sql-sqlcoder-34b-alpha": "defog/sqlcoder-34b-alpha", "h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B", "gpt-3.5-turbo": "gpt-3.5-turbo-1106", - "gpt-4-8k": "gpt-4", - "gpt-4-1106-preview-128k": "gpt-4-1106-preview" - + "gpt-4-8k": "gpt-4", # leaving this for backward compatibility + "gpt-4-1106-preview-128k": "gpt-4-1106-preview", # leaving this for backward compatibility + "gpt-4o": "gpt-4o", + "gpt-4-turbo": "gpt-4-turbo", + "gpt-4-turbo-2024-04-09": "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview": "gpt-4-turbo-preview", + "gpt-4-0125-preview": "gpt-4-0125-preview", + "gpt-4-vision-preview": "gpt-4o", # legacy to be deprecated + "gpt-4-1106-vision-preview": "gpt-4o", # legacy to be deprecated + "gpt-4": "gpt-4", + "gpt-4-0613": "gpt-4-0613", + "gpt-4-32k": "gpt-4-turbo", # legacy to be deprecated + "gpt-4-32k-0613": "gpt-4-turbo", # legacy to be deprecated } MODEL_CHOICE_MAP_DEFAULT = { @@ -573,7 +586,12 @@ def check_vulnerability(input_query: str): _user_prompt = GUARDRAIL_PROMPT["user_prompt"].format(query_txt=input_query, schema=output_schema).strip() temp_result = None try: - llm_scanner = os.getenv("VULNERABILITY_SCANNER", "h2oai/h2ogpt-4096-llama2-70b-chat") + if os.getenv("OPENAI_API_KEY", None): + default_scanner_model = "gpt-4" + else: + default_scanner_model = "h2ogpt-4096-llama2-70b-chat" + llm_scanner = os.getenv("VULNERABILITY_SCANNER", default_scanner_model) + logger.info(f"Using LLM model: {llm_scanner} for vulnerability scan") if "h2ogpt-" in llm_scanner and h2ogpte_client_url !='' and h2ogpte_client_url and h2ogpte_client_key != '' and h2ogpte_client_key: from h2ogpte import H2OGPTE client = H2OGPTE(address=h2ogpte_client_url, api_key=h2ogpte_client_key) @@ -636,7 +654,12 @@ def generate_suggestions(remote_url, client_key:str, column_names: list, n_qs: i _user_prompt = RECOMMENDATION_PROMPT.format(data_schema=column_info, n_questions=n_qs ) - recommender_model = os.getenv("RECOMMENDATION_MODEL", "h2oai/h2ogpt-4096-llama2-70b-chat") + if os.getenv("OPENAI_API_KEY", None): + default_recommendation_model = "gpt-4" + else: + default_recommendation_model = "h2ogpt-4096-llama2-70b-chat" + recommender_model = os.getenv("RECOMMENDATION_MODEL", default_recommendation_model) + logger.info(f"Using LLM model: {recommender_model} for recommendation") if "h2ogpt-" in recommender_model: try: client = H2OGPTE(address=remote_url, api_key=client_key)