From 6c8c47a35ae8de102bc38f56181ecef5bdcd4d52 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Wed, 11 Feb 2026 23:49:01 -0800 Subject: [PATCH 1/6] Added support for connection pools --- src/select_ai/__init__.py | 2 + src/select_ai/agent/tool.py | 20 +-- src/select_ai/db.py | 263 +++++++++++++++++++++++++++++++----- src/select_ai/version.py | 2 +- tests/conftest.py | 22 ++- 5 files changed, 261 insertions(+), 48 deletions(-) diff --git a/src/select_ai/__init__.py b/src/select_ai/__init__.py index 82cf4f3..56073cd 100644 --- a/src/select_ai/__init__.py +++ b/src/select_ai/__init__.py @@ -25,6 +25,8 @@ async_disconnect, async_is_connected, connect, + create_pool, + create_pool_async, cursor, disconnect, is_connected, diff --git a/src/select_ai/agent/tool.py b/src/select_ai/agent/tool.py index 0918500..92e15e7 100644 --- a/src/select_ai/agent/tool.py +++ b/src/select_ai/agent/tool.py @@ -73,7 +73,7 @@ class ToolParams(SelectAIDataClass): :param str sender: Sender used for EMAIL notification - :param str slack_channel: Slack channel to use + :param str channel: Slack channel to use :param str smtp_host: SMTP host to use for EMAIL notification @@ -87,7 +87,7 @@ class ToolParams(SelectAIDataClass): profile_name: Optional[str] = None recipient: Optional[str] = None sender: Optional[str] = None - slack_channel: Optional[str] = None + channel: Optional[str] = None smtp_host: Optional[str] = None def __post_init__(self): @@ -119,7 +119,7 @@ def keys(cls): "profile_name", "recipient", "sender", - "slack_channel", + "channel", "smtp_host", } @@ -145,7 +145,7 @@ class NotificationToolParams(ToolParams): @dataclass class SlackNotificationToolParams(NotificationToolParams): - _REQUIRED_FIELDS = ["credential_name", "slack_channel"] + _REQUIRED_FIELDS = ["credential_name", "channel"] notification_type: NotificationType = NotificationType.SLACK @@ -523,7 +523,7 @@ def create_slack_notification_tool( cls, tool_name: str, credential_name: str, - slack_channel: str, + channel: str, description: Optional[str] = None, replace: bool = False, ) -> "Tool": @@ -532,7 +532,7 @@ def create_slack_notification_tool( :param str tool_name: The name of the Slack notification tool :param str credential_name: The name of the Slack credential - :param str slack_channel: The name of the Slack channel + :param str channel: The name of the Slack channel :param str description: The description of the Slack notification tool :param bool replace: Whether to replace existing tool. Default value is False @@ -540,7 +540,7 @@ def create_slack_notification_tool( """ slack_notification_tool_params = SlackNotificationToolParams( credential_name=credential_name, - slack_channel=slack_channel, + channel=channel, ) return cls.create_built_in_tool( tool_name=tool_name, @@ -968,7 +968,7 @@ async def create_slack_notification_tool( cls, tool_name: str, credential_name: str, - slack_channel: str, + channel: str, description: Optional[str] = None, replace: bool = False, ) -> "AsyncTool": @@ -977,7 +977,7 @@ async def create_slack_notification_tool( :param str tool_name: The name of the Slack notification tool :param str credential_name: The name of the Slack credential - :param str slack_channel: The name of the Slack channel + :param str channel: The name of the Slack channel :param str description: The description of the Slack notification tool :param bool replace: Whether to replace existing tool. Default value is False @@ -985,7 +985,7 @@ async def create_slack_notification_tool( """ slack_notification_tool_params = SlackNotificationToolParams( credential_name=credential_name, - slack_channel=slack_channel, + channel=channel, ) return await cls.create_built_in_tool( tool_name=tool_name, diff --git a/src/select_ai/db.py b/src/select_ai/db.py index 19da695..51f7845 100644 --- a/src/select_ai/db.py +++ b/src/select_ai/db.py @@ -8,17 +8,23 @@ import contextlib import os from threading import get_ident -from typing import Dict, Hashable +from typing import Any, Dict, Generator, Hashable, Optional import oracledb +from oracledb import Connection from select_ai.errors import DatabaseNotConnectedError __conn__: Dict[Hashable, oracledb.Connection] = {} __async_conn__: Dict[Hashable, oracledb.AsyncConnection] = {} +__pool__: Dict[Hashable, oracledb.ConnectionPool] = {} +__async_pool__: Dict[Hashable, oracledb.AsyncConnectionPool] = {} + __all__ = [ "connect", + "create_pool", + "create_pool_async", "async_connect", "is_connected", "async_is_connected", @@ -50,6 +56,54 @@ def connect(user: str, password: str, dsn: str, *args, **kwargs): _set_connection(conn=conn) +def create_pool( + user: str, + password: str, + dsn: str, + min_size: Optional[int] = 1, + max_size: Optional[int] = 1, + increment: Optional[int] = 1, + *args, + **kwargs, +): + pool = oracledb.create_pool( + user=user, + password=password, + dsn=dsn, + min=min_size, + max=max_size, + increment=increment, + connection_id_prefix="python-select-ai", + *args, + **kwargs, + ) + _set_connection_pool(pool=pool) + + +def create_pool_async( + user: str, + password: str, + dsn: str, + min_size: Optional[int] = 1, + max_size: Optional[int] = 1, + increment: Optional[int] = 1, + *args, + **kwargs, +): + async_pool = oracledb.create_pool_async( + user=user, + password=password, + dsn=dsn, + min=min_size, + max=max_size, + increment=increment, + connection_id_prefix="async-python-select-ai", + *args, + **kwargs, + ) + _set_connection_pool(async_pool=async_pool) + + async def async_connect(user: str, password: str, dsn: str, *args, **kwargs): """Creates an oracledb.AsyncConnection object and saves it global dictionary __async_conn__ @@ -115,22 +169,36 @@ def _set_connection( __async_conn__[key] = async_conn +def _set_connection_pool( + pool: Optional[oracledb.ConnectionPool] = None, + async_pool: Optional[oracledb.AsyncConnectionPool] = None, +): + """Set existing connection pool for select_ai Python API to reuse + + :param pool: python-oracledb ConnectionPool object + :param async_pool: python-oracledb AsyncConnectionPool object + + :return: None + """ + key = (os.getpid(), get_ident()) + if pool: + global __pool__ + __pool__[key] = pool + if async_pool: + global __async_pool__ + __async_pool__[key] = async_pool + + def get_connection() -> oracledb.Connection: """Returns the connection object if connection is healthy""" - if not is_connected(): - raise DatabaseNotConnectedError() - global __conn__ - key = (os.getpid(), get_ident()) - return __conn__[key] + with ConnectionManager().get_connection() as conn: + return conn async def async_get_connection() -> oracledb.AsyncConnection: """Returns the AsyncConnection object if connection is healthy""" - if not await async_is_connected(): - raise DatabaseNotConnectedError() - global __async_conn__ - key = (os.getpid(), get_ident()) - return __async_conn__[key] + async with ConnectionManager().get_connection() as conn: + return conn @contextlib.contextmanager @@ -147,11 +215,12 @@ def cursor(): of whether an exception occurred """ - cr = get_connection().cursor() - try: - yield cr - finally: - cr.close() + with ConnectionManager().get_connection() as conn: + cr = conn.cursor() + try: + yield cr + finally: + cr.close() @contextlib.asynccontextmanager @@ -165,27 +234,153 @@ async def async_cursor(): await cr.execute() :return: """ - conn = await async_get_connection() - cr = conn.cursor() - try: - yield cr - finally: - cr.close() + async with AsyncConnectionManager().get_connection() as conn: + cr = conn.cursor() + try: + yield cr + finally: + cr.close() def disconnect(): - try: - conn = get_connection() - except DatabaseNotConnectedError: - pass - else: - conn.close() + connection_manager = ConnectionManager() + connection_manager.disconnect() async def async_disconnect(): - try: - conn = await async_get_connection() - except DatabaseNotConnectedError: - pass - else: - await conn.close() + connection_manager = AsyncConnectionManager() + await connection_manager.disconnect() + + +class ConnectionManager: + """ + Manages standalone connections and connection pools + """ + + def __init__(self): + global __conn__, __pool__ + self.key = (os.getpid(), get_ident()) + self.conn = __conn__.get(self.key) + self.pool = __pool__.get(self.key) + if self.conn and self.pool: + raise ValueError( + "Use either a standalone connection " "or a connection pool" + ) + + @property + def is_standalone(self): + return self.conn is not None + + @property + def is_pool(self): + return self.pool is not None + + @contextlib.contextmanager + def get_connection(self) -> Generator[Connection, Any, None]: + if self.is_pool: + with self.connection_from_pool() as conn: + yield conn + else: + with self.standalone_connection() as conn: + yield conn + + @contextlib.contextmanager + def connection_from_pool(self) -> Generator[Connection, Any, None]: + if self.is_pool: + try: + conn = self.pool.acquire() + except (oracledb.DatabaseError, oracledb.InterfaceError): + raise DatabaseNotConnectedError() + else: + raise DatabaseNotConnectedError() + try: + yield conn + finally: + self.pool.release(conn) + + @contextlib.contextmanager + def standalone_connection(self) -> Generator[Connection, Any, None]: + if self.is_standalone: + try: + self.conn.ping() + except (oracledb.DatabaseError, oracledb.InterfaceError): + raise DatabaseNotConnectedError() + yield self.conn + else: + raise DatabaseNotConnectedError() + + def disconnect(self, force=False): + global __pool__, __conn__ + if self.is_pool: + self.pool.close(force=force) + __pool__.pop(self.key, None) + elif self.is_standalone: + self.conn.close() + __conn__.pop(self.key, None) + + +class AsyncConnectionManager: + """ + Manages async standalone connections and connection pools + """ + + def __init__(self): + global __async_conn__, __async_pool__ + self.key = (os.getpid(), get_ident()) + self.conn = __async_conn__.get(self.key) + self.pool = __async_pool__.get(self.key) + if self.conn and self.pool: + raise ValueError( + "Use either a standalone connection " "or a connection pool" + ) + + @property + def is_standalone(self): + return self.conn is not None + + @property + def is_pool(self): + return self.pool is not None + + @contextlib.asynccontextmanager + async def get_connection(self): + if self.is_pool: + async with self.connection_from_pool() as conn: + yield conn + else: + async with self.standalone_connection() as conn: + yield conn + + @contextlib.asynccontextmanager + async def connection_from_pool(self): + if self.is_pool: + try: + conn = await self.pool.acquire() + except (oracledb.DatabaseError, oracledb.InterfaceError): + raise DatabaseNotConnectedError() + else: + raise DatabaseNotConnectedError() + try: + yield conn + finally: + await self.pool.release(conn) + + @contextlib.asynccontextmanager + async def standalone_connection(self): + if self.is_standalone: + try: + await self.conn.ping() + except (oracledb.DatabaseError, oracledb.InterfaceError): + raise DatabaseNotConnectedError() + yield self.conn + else: + raise DatabaseNotConnectedError() + + async def disconnect(self, force=False): + global __async_conn__, __async_pool__ + if self.is_pool: + await self.pool.close(force=force) + __async_pool__.pop(self.key, None) + elif self.is_standalone: + await self.conn.close() + __async_conn__.pop(self.key, None) diff --git a/src/select_ai/version.py b/src/select_ai/version.py index 8e5d932..fa9fe9d 100644 --- a/src/select_ai/version.py +++ b/src/select_ai/version.py @@ -5,4 +5,4 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -__version__ = "1.2.2" +__version__ = "1.3.0" diff --git a/tests/conftest.py b/tests/conftest.py index 1dbea20..90a22b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,9 @@ # PYSAI_TEST_CONNECT_STRING: connect string for test suite # PYSAI_TEST_WALLET_LOCATION: location of wallet file (thin mode, mTLS) # PYSAI_TEST_WALLET_PASSWORD: password for wallet file (thin mode, mTLS) +# PYSAI_TEST_MIN_POOL_SIZE: Minimum number of connections in the pool +# PYSAI_TEST_MAX_POOL_SIZE: Maximum number of connections in the pool +# PYSAI_TEST_POOL_INCREMENT # # OCI Gen AI # PYSAI_TEST_OCI_USER_OCID @@ -90,8 +93,17 @@ def __init__(self): self.admin_password = get_env_value("ADMIN_PASSWORD") self.wallet_location = get_env_value("WALLET_LOCATION") self.wallet_password = get_env_value("WALLET_PASSWORD") + self.min_pool_size = int( + get_env_value("MIN_POOL_SIZE", default_value=2) + ) + self.max_pool_size = int( + get_env_value("MAX_POOL_SIZE", default_value=4) + ) + self.pool_increment = int( + get_env_value("POOL_INCREMENT", default_value=1) + ) - def connect_params(self, admin: bool = False): + def connect_params(self, admin: bool = False, use_pool: bool = False): """ Returns connect params """ @@ -105,6 +117,10 @@ def connect_params(self, admin: bool = False): "wallet_password": self.wallet_password, "config_dir": self.wallet_location, } + if use_pool: + connect_params["min_size"] = self.min_pool_size + connect_params["max_size"] = self.max_pool_size + connect_params["increment"] = self.pool_increment return connect_params @@ -137,14 +153,14 @@ def setup_test_user(test_env): @pytest.fixture(autouse=True, scope="session") def connect(setup_test_user, test_env): - select_ai.connect(**test_env.connect_params()) + select_ai.create_pool(**test_env.connect_params(use_pool=True)) yield select_ai.disconnect() @pytest.fixture(autouse=True, scope="session") async def async_connect(setup_test_user, test_env, anyio_backend): - await select_ai.async_connect(**test_env.connect_params()) + select_ai.create_pool_async(**test_env.connect_params(use_pool=True)) yield await select_ai.async_disconnect() From 9de7f3ee42422f69cbf37b8ccdfa275a113b4987 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Thu, 12 Feb 2026 13:39:00 -0800 Subject: [PATCH 2/6] Fix test fixtures and refactored Select AI agent code --- src/select_ai/agent/core.py | 39 +++++++++++++----------------------- src/select_ai/agent/task.py | 35 ++++++++++++++++++++------------ src/select_ai/agent/team.py | 33 ++++++++++++++++++++---------- src/select_ai/agent/tool.py | 23 ++++++++++----------- src/select_ai/db.py | 17 +++++++++------- src/select_ai/errors.py | 40 +++++++++++++++++++++++++++++++++++++ tests/conftest.py | 12 +++++++---- 7 files changed, 126 insertions(+), 73 deletions(-) diff --git a/src/select_ai/agent/core.py b/src/select_ai/agent/core.py index 02590eb..bed1edc 100644 --- a/src/select_ai/agent/core.py +++ b/src/select_ai/agent/core.py @@ -24,7 +24,7 @@ LIST_USER_AI_AGENTS, ) from select_ai.db import async_cursor, cursor -from select_ai.errors import AgentNotFoundError +from select_ai.errors import AgentAttributesEmptyError, AgentNotFoundError @dataclass @@ -97,7 +97,7 @@ def _get_attributes(agent_name: str) -> AgentAttributes: post_processed_attributes[k] = v return AgentAttributes(**post_processed_attributes) else: - raise AgentNotFoundError(agent_name=agent_name) + raise AgentAttributesEmptyError(agent_name=agent_name) @staticmethod def _get_description(agent_name: str) -> Union[str, None]: @@ -223,7 +223,10 @@ def fetch(cls, agent_name: str) -> "Agent": If the AI Agent is not found """ - attributes = cls._get_attributes(agent_name=agent_name) + try: + attributes = cls._get_attributes(agent_name=agent_name) + except AgentAttributesEmptyError: + attributes = None description = cls._get_description(agent_name=agent_name) return cls( agent_name=agent_name, @@ -251,16 +254,7 @@ def list( ) for row in cr.fetchall(): agent_name = row[0] - if row[1]: - description = row[1].read() # Oracle.LOB - else: - description = None - attributes = cls._get_attributes(agent_name=agent_name) - yield cls( - agent_name=agent_name, - description=description, - attributes=attributes, - ) + yield cls.fetch(agent_name=agent_name) def set_attributes(self, attributes: AgentAttributes) -> None: """ @@ -326,7 +320,7 @@ async def _get_attributes(agent_name: str) -> AgentAttributes: post_processed_attributes[k] = v return AgentAttributes(**post_processed_attributes) else: - raise AgentNotFoundError(agent_name=agent_name) + raise AgentAttributesEmptyError(agent_name=agent_name) @staticmethod async def _get_description(agent_name: str) -> Union[str, None]: @@ -454,7 +448,10 @@ async def fetch(cls, agent_name: str) -> "AsyncAgent": If the AI Agent is not found """ - attributes = await cls._get_attributes(agent_name=agent_name) + try: + attributes = await cls._get_attributes(agent_name=agent_name) + except AgentAttributesEmptyError: + attributes = None description = await cls._get_description(agent_name=agent_name) return cls( agent_name=agent_name, @@ -483,16 +480,8 @@ async def list( rows = await cr.fetchall() for row in rows: agent_name = row[0] - if row[1]: - description = await row[1].read() # Oracle.AsyncLOB - else: - description = None - attributes = await cls._get_attributes(agent_name=agent_name) - yield cls( - agent_name=agent_name, - description=description, - attributes=attributes, - ) + agent = await cls.fetch(agent_name=agent_name) + yield agent async def set_attributes(self, attributes: AgentAttributes) -> None: """ diff --git a/src/select_ai/agent/task.py b/src/select_ai/agent/task.py index d0fe70d..598c337 100644 --- a/src/select_ai/agent/task.py +++ b/src/select_ai/agent/task.py @@ -5,7 +5,6 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -import json from abc import ABC from dataclasses import dataclass from typing import ( @@ -13,25 +12,23 @@ AsyncGenerator, Iterator, List, - Mapping, Optional, Union, ) import oracledb -from select_ai import BaseProfile from select_ai._abc import SelectAIDataClass -from select_ai._enums import StrEnum from select_ai.agent.sql import ( GET_USER_AI_AGENT_TASK, GET_USER_AI_AGENT_TASK_ATTRIBUTES, LIST_USER_AI_AGENT_TASKS, ) -from select_ai.async_profile import AsyncProfile from select_ai.db import async_cursor, cursor -from select_ai.errors import AgentTaskNotFoundError -from select_ai.profile import Profile +from select_ai.errors import ( + AgentTaskAttributesEmptyError, + AgentTaskNotFoundError, +) @dataclass @@ -111,7 +108,7 @@ def _get_attributes(task_name: str) -> TaskAttributes: post_processed_attributes[k] = v return TaskAttributes(**post_processed_attributes) else: - raise AgentTaskNotFoundError(task_name=task_name) + raise AgentTaskAttributesEmptyError(task_name=task_name) @staticmethod def _get_description(task_name: str) -> Union[str, None]: @@ -244,7 +241,10 @@ def list(cls, task_name_pattern: Optional[str] = ".*") -> Iterator["Task"]: description = row[1].read() # Oracle.LOB else: description = None - attributes = cls._get_attributes(task_name=task_name) + try: + attributes = cls._get_attributes(task_name=task_name) + except AgentTaskAttributesEmptyError: + attributes = None yield cls( task_name=task_name, description=description, @@ -264,7 +264,10 @@ def fetch(cls, task_name: str) -> "Task": :raises select_ai.errors.AgentTaskNotFoundError: If the AI Task is not found """ - attributes = cls._get_attributes(task_name=task_name) + try: + attributes = cls._get_attributes(task_name=task_name) + except AgentTaskAttributesEmptyError: + attributes = None description = cls._get_description(task_name=task_name) return cls( task_name=task_name, @@ -338,7 +341,7 @@ async def _get_attributes(task_name: str) -> TaskAttributes: post_processed_attributes[k] = v return TaskAttributes(**post_processed_attributes) else: - raise AgentTaskNotFoundError(task_name=task_name) + raise AgentTaskAttributesEmptyError(task_name=task_name) @staticmethod async def _get_description(task_name: str) -> Union[str, None]: @@ -476,7 +479,10 @@ async def list( description = await row[1].read() # Oracle.AsyncLOB else: description = None - attributes = await cls._get_attributes(task_name=task_name) + try: + attributes = await cls._get_attributes(task_name=task_name) + except AgentTaskAttributesEmptyError: + attributes = None yield cls( task_name=task_name, description=description, @@ -496,7 +502,10 @@ async def fetch(cls, task_name: str) -> "AsyncTask": :raises select_ai.errors.AgentTaskNotFoundError: If the AI Task is not found """ - attributes = await cls._get_attributes(task_name=task_name) + try: + attributes = await cls._get_attributes(task_name=task_name) + except AgentTaskAttributesEmptyError: + attributes = None description = await cls._get_description(task_name=task_name) return cls( task_name=task_name, diff --git a/src/select_ai/agent/team.py b/src/select_ai/agent/team.py index 1392218..2a7ae6b 100644 --- a/src/select_ai/agent/team.py +++ b/src/select_ai/agent/team.py @@ -20,18 +20,17 @@ import oracledb -from select_ai import BaseProfile from select_ai._abc import SelectAIDataClass -from select_ai._enums import StrEnum from select_ai.agent.sql import ( GET_USER_AI_AGENT_TEAM, GET_USER_AI_AGENT_TEAM_ATTRIBUTES, LIST_USER_AI_AGENT_TEAMS, ) -from select_ai.async_profile import AsyncProfile from select_ai.db import async_cursor, cursor -from select_ai.errors import AgentTeamNotFoundError -from select_ai.profile import Profile +from select_ai.errors import ( + AgentTeamAttributesEmptyError, + AgentTeamNotFoundError, +) @dataclass @@ -105,7 +104,7 @@ def _get_attributes(team_name: str) -> TeamAttributes: post_processed_attributes[k] = v return TeamAttributes(**post_processed_attributes) else: - raise AgentTeamNotFoundError(team_name=team_name) + raise AgentTeamAttributesEmptyError(team_name=team_name) @staticmethod def _get_description(team_name: str) -> Union[str, None]: @@ -228,7 +227,10 @@ def fetch(cls, team_name: str) -> "Team": :raises select_ai.errors.AgentTeamNotFoundError: If the AI Team is not found """ - attributes = cls._get_attributes(team_name) + try: + attributes = cls._get_attributes(team_name) + except AgentTeamAttributesEmptyError: + attributes = None description = cls._get_description(team_name) return cls( team_name=team_name, @@ -259,7 +261,10 @@ def list(cls, team_name_pattern: Optional[str] = ".*") -> Iterator["Team"]: description = row[1].read() # Oracle.LOB else: description = None - attributes = cls._get_attributes(team_name=team_name) + try: + attributes = cls._get_attributes(team_name=team_name) + except AgentTeamAttributesEmptyError: + attributes = None yield cls( team_name=team_name, description=description, @@ -369,7 +374,7 @@ async def _get_attributes(team_name: str) -> TeamAttributes: post_processed_attributes[k] = v return TeamAttributes(**post_processed_attributes) else: - raise AgentTeamNotFoundError(team_name=team_name) + raise AgentTeamAttributesEmptyError(team_name=team_name) @staticmethod async def _get_description(team_name: str) -> Union[str, None]: @@ -494,7 +499,10 @@ async def fetch(cls, team_name: str) -> "AsyncTeam": :raises select_ai.errors.AgentTeamNotFoundError: If the AI Team is not found """ - attributes = await cls._get_attributes(team_name) + try: + attributes = await cls._get_attributes(team_name) + except AgentTeamAttributesEmptyError: + attributes = None description = await cls._get_description(team_name) return cls( team_name=team_name, @@ -528,7 +536,10 @@ async def list( description = await row[1].read() # Oracle.AsyncLOB else: description = None - attributes = await cls._get_attributes(team_name=team_name) + try: + attributes = await cls._get_attributes(team_name=team_name) + except AgentTeamAttributesEmptyError: + attributes = None yield cls( team_name=team_name, description=description, diff --git a/src/select_ai/agent/tool.py b/src/select_ai/agent/tool.py index 92e15e7..4189822 100644 --- a/src/select_ai/agent/tool.py +++ b/src/select_ai/agent/tool.py @@ -30,7 +30,10 @@ ) from select_ai.async_profile import AsyncProfile from select_ai.db import async_cursor, cursor -from select_ai.errors import AgentToolNotFoundError +from select_ai.errors import ( + AgentToolAttributesEmptyError, + AgentToolNotFoundError, +) from select_ai.profile import Profile @@ -291,7 +294,7 @@ def _get_attributes(tool_name: str) -> ToolAttributes: post_processed_attributes[k] = v return ToolAttributes.create(**post_processed_attributes) else: - raise AgentToolNotFoundError(tool_name=tool_name) + raise AgentToolAttributesEmptyError(tool_name=tool_name) @staticmethod def _get_description(tool_name: str) -> Union[str, None]: @@ -644,7 +647,10 @@ def fetch(cls, tool_name: str) -> "Tool": If the AI Tool is not found """ - attributes = cls._get_attributes(tool_name) + try: + attributes = cls._get_attributes(tool_name) + except AgentToolAttributesEmptyError: + attributes = None description = cls._get_description(tool_name) return cls( tool_name=tool_name, attributes=attributes, description=description @@ -667,16 +673,7 @@ def list(cls, tool_name_pattern: str = ".*") -> Iterator["Tool"]: ) for row in cr.fetchall(): tool_name = row[0] - if row[1]: - description = row[1].read() # Oracle.LOB - else: - description = None - attributes = cls._get_attributes(tool_name=tool_name) - yield cls( - tool_name=tool_name, - description=description, - attributes=attributes, - ) + yield cls.fetch(tool_name=tool_name) def set_attributes(self, attributes: ToolAttributes) -> None: """ diff --git a/src/select_ai/db.py b/src/select_ai/db.py index 51f7845..0c2f168 100644 --- a/src/select_ai/db.py +++ b/src/select_ai/db.py @@ -8,7 +8,7 @@ import contextlib import os from threading import get_ident -from typing import Any, Dict, Generator, Hashable, Optional +from typing import Any, AsyncGenerator, Dict, Generator, Hashable, Optional import oracledb from oracledb import Connection @@ -61,7 +61,7 @@ def create_pool( password: str, dsn: str, min_size: Optional[int] = 1, - max_size: Optional[int] = 1, + max_size: Optional[int] = 2, increment: Optional[int] = 1, *args, **kwargs, @@ -74,6 +74,7 @@ def create_pool( max=max_size, increment=increment, connection_id_prefix="python-select-ai", + getmode=oracledb.POOL_GETMODE_NOWAIT, *args, **kwargs, ) @@ -189,16 +190,18 @@ def _set_connection_pool( __async_pool__[key] = async_pool -def get_connection() -> oracledb.Connection: +@contextlib.contextmanager +def get_connection() -> Generator[Connection, Any, None]: """Returns the connection object if connection is healthy""" with ConnectionManager().get_connection() as conn: - return conn + yield conn -async def async_get_connection() -> oracledb.AsyncConnection: +@contextlib.asynccontextmanager +async def async_get_connection() -> AsyncGenerator[Any, Any]: """Returns the AsyncConnection object if connection is healthy""" async with ConnectionManager().get_connection() as conn: - return conn + yield conn @contextlib.contextmanager @@ -309,7 +312,7 @@ def standalone_connection(self) -> Generator[Connection, Any, None]: else: raise DatabaseNotConnectedError() - def disconnect(self, force=False): + def disconnect(self, force=True): global __pool__, __conn__ if self.is_pool: self.pool.close(force=force) diff --git a/src/select_ai/errors.py b/src/select_ai/errors.py index 718c22d..6682479 100644 --- a/src/select_ai/errors.py +++ b/src/select_ai/errors.py @@ -95,6 +95,16 @@ def __str__(self): return f"Agent {self.agent_name} not found" +class AgentAttributesEmptyError(SelectAIError): + """Agent attributes not found in the database""" + + def __init__(self, agent_name: str): + self.agent_name = agent_name + + def __str__(self): + return f"Agent {self.agent_name} attributes empty in the database." + + class AgentTaskNotFoundError(SelectAIError): """Agent task not found in the database""" @@ -105,6 +115,16 @@ def __str__(self): return f"Agent Task {self.task_name} not found" +class AgentTaskAttributesEmptyError(SelectAIError): + """Agent task attributes not found in the database""" + + def __init__(self, task_name: str): + self.task_name = task_name + + def __str__(self): + return f"Agent Task {self.task_name} attributes empty in the database." + + class AgentToolNotFoundError(SelectAIError): """Agent tool not found in the database""" @@ -115,6 +135,16 @@ def __str__(self): return f"Agent Tool {self.tool_name} not found" +class AgentToolAttributesEmptyError(SelectAIError): + """Agent team attributes empty in the database""" + + def __init__(self, tool_name: str): + self.tool_name = tool_name + + def __str__(self): + return f"Agent tool {self.tool_name} attributes empty in the database." + + class AgentTeamNotFoundError(SelectAIError): """Agent team not found in the database""" @@ -125,6 +155,16 @@ def __str__(self): return f"Agent Team {self.team_name} not found" +class AgentTeamAttributesEmptyError(SelectAIError): + """Agent team attributes empty in the database""" + + def __init__(self, team_name: str): + self.team_name = team_name + + def __str__(self): + return f"Agent team {self.team_name} attributes empty in the database." + + class InvalidSQLError(SelectAIError): """Invalid SQL generated""" diff --git a/tests/conftest.py b/tests/conftest.py index 90a22b6..f7e8a73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,7 +56,8 @@ def _ensure_test_user_exists(username: str, password: str): cr.execute( f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' ) - select_ai.db.get_connection().commit() + with select_ai.db.get_connection() as conn: + conn.commit() def _grant_basic_schema_privileges(username: str): @@ -64,7 +65,8 @@ def _grant_basic_schema_privileges(username: str): with select_ai.cursor() as cr: for privilege in _BASIC_SCHEMA_PRIVILEGES: cr.execute(f"GRANT {privilege} TO {username_upper}") - select_ai.db.get_connection().commit() + with select_ai.db.get_connection() as conn: + conn.commit() def get_env_value(name, default_value=None, required=False): @@ -167,12 +169,14 @@ async def async_connect(setup_test_user, test_env, anyio_backend): @pytest.fixture def connection(): - return select_ai.db.get_connection() + with select_ai.db.get_connection() as conn: + yield conn @pytest.fixture def async_connection(): - return select_ai.db.async_get_connection() + with select_ai.db.async_get_connection() as conn: + yield conn @pytest.fixture(scope="module") From 551679d483aa96527bebff070b63bf715a920e79 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Thu, 5 Mar 2026 13:40:37 -0800 Subject: [PATCH 3/6] Added instruction to Tool create API and bugfix in EMAIL tool --- src/select_ai/__init__.py | 2 +- src/select_ai/_abc.py | 2 +- src/select_ai/_enums.py | 2 +- src/select_ai/_validations.py | 2 +- src/select_ai/action.py | 2 +- src/select_ai/agent/tool.py | 105 ++++++++++++++++++++++++++++++-- src/select_ai/async_profile.py | 2 +- src/select_ai/base_profile.py | 2 +- src/select_ai/conversation.py | 2 +- src/select_ai/credential.py | 2 +- src/select_ai/db.py | 49 +++++++-------- src/select_ai/errors.py | 2 +- src/select_ai/feedback.py | 2 +- src/select_ai/privilege.py | 2 +- src/select_ai/profile.py | 2 +- src/select_ai/provider.py | 2 +- src/select_ai/sql.py | 25 +++++++- src/select_ai/summary.py | 2 +- src/select_ai/synthetic_data.py | 2 +- src/select_ai/vector_index.py | 2 +- src/select_ai/version.py | 2 +- 21 files changed, 164 insertions(+), 51 deletions(-) diff --git a/src/select_ai/__init__.py b/src/select_ai/__init__.py index 56073cd..fa13d31 100644 --- a/src/select_ai/__init__.py +++ b/src/select_ai/__init__.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/_abc.py b/src/select_ai/_abc.py index 08c9957..44621fc 100644 --- a/src/select_ai/_abc.py +++ b/src/select_ai/_abc.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/_enums.py b/src/select_ai/_enums.py index 8007185..3f1ea9c 100644 --- a/src/select_ai/_enums.py +++ b/src/select_ai/_enums.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/_validations.py b/src/select_ai/_validations.py index 70bf1ec..2de68f3 100644 --- a/src/select_ai/_validations.py +++ b/src/select_ai/_validations.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/action.py b/src/select_ai/action.py index 5735d30..755238b 100644 --- a/src/select_ai/action.py +++ b/src/select_ai/action.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/agent/tool.py b/src/select_ai/agent/tool.py index 4189822..19a29e4 100644 --- a/src/select_ai/agent/tool.py +++ b/src/select_ai/agent/tool.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -80,6 +80,8 @@ class ToolParams(SelectAIDataClass): :param str smtp_host: SMTP host to use for EMAIL notification + :param str subject: Email subject to use + """ _REQUIRED_FIELDS: Optional[List] = None @@ -92,6 +94,7 @@ class ToolParams(SelectAIDataClass): sender: Optional[str] = None channel: Optional[str] = None smtp_host: Optional[str] = None + subject: Optional[str] = None def __post_init__(self): super().__post_init__() @@ -124,6 +127,7 @@ def keys(cls): "sender", "channel", "smtp_host", + "subject", } @@ -352,6 +356,7 @@ def create_built_in_tool( tool_type: ToolType, description: Optional[str] = None, replace: Optional[bool] = False, + instruction: Optional[str] = None, ) -> "Tool": """ Register a built-in tool @@ -363,6 +368,9 @@ def create_built_in_tool( :param str description: Description of the tool :param bool replace: Whether to replace the existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. :return: select_ai.agent.Tool """ @@ -372,7 +380,9 @@ def create_built_in_tool( "type select_ai.agent.ToolParams" ) attributes = ToolAttributes( - tool_params=tool_params, tool_type=tool_type + tool_params=tool_params, + tool_type=tool_type, + instruction=instruction, ) tool = cls( tool_name=tool_name, attributes=attributes, description=description @@ -389,7 +399,9 @@ def create_email_notification_tool( sender: str, smtp_host: str, description: Optional[str], + subject: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "Tool": """ Register an email notification tool @@ -400,8 +412,12 @@ def create_email_notification_tool( :param str sender: The sender of the email :param str smtp_host: The SMTP host of the email server :param str description: The description of the tool + :param str subject: Subject of the email. :param bool replace: Whether to replace the existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. :return: select_ai.agent.Tool @@ -411,6 +427,7 @@ def create_email_notification_tool( recipient=recipient, sender=sender, smtp_host=smtp_host, + subject=subject, ) return cls.create_built_in_tool( tool_name=tool_name, @@ -418,6 +435,7 @@ def create_email_notification_tool( tool_params=email_notification_tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -428,6 +446,7 @@ def create_http_tool( endpoint: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "Tool": http_tool_params = HTTPToolParams( credential_name=credential_name, endpoint=endpoint @@ -438,6 +457,7 @@ def create_http_tool( tool_params=http_tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -447,6 +467,7 @@ def create_pl_sql_tool( function: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "Tool": """ Create a custom tool to invoke PL/SQL procedure or function @@ -456,9 +477,14 @@ def create_pl_sql_tool( :param str description: The description of the tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ - tool_attributes = ToolAttributes(function=function) + tool_attributes = ToolAttributes( + function=function, instruction=instruction + ) tool = cls( tool_name=tool_name, attributes=tool_attributes, @@ -474,6 +500,7 @@ def create_rag_tool( profile_name: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "Tool": """ Register a RAG tool, which will use a VectorIndex linked AI Profile @@ -484,6 +511,9 @@ def create_rag_tool( :param str description: The description of the tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ tool_params = RAGToolParams(profile_name=profile_name) return cls.create_built_in_tool( @@ -492,6 +522,7 @@ def create_rag_tool( tool_params=tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -501,6 +532,7 @@ def create_sql_tool( profile_name: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "Tool": """ Register a SQL tool to perform natural language to SQL translation @@ -511,6 +543,9 @@ def create_sql_tool( :param str description: The description of the tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ tool_params = SQLToolParams(profile_name=profile_name) return cls.create_built_in_tool( @@ -519,6 +554,7 @@ def create_sql_tool( tool_params=tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -529,6 +565,7 @@ def create_slack_notification_tool( channel: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "Tool": """ Register a Slack notification tool @@ -539,6 +576,9 @@ def create_slack_notification_tool( :param str description: The description of the Slack notification tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ slack_notification_tool_params = SlackNotificationToolParams( @@ -551,6 +591,7 @@ def create_slack_notification_tool( tool_params=slack_notification_tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -560,6 +601,7 @@ def create_websearch_tool( credential_name: str, description: Optional[str], replace: bool = False, + instruction: Optional[str] = None, ) -> "Tool": """ Register a built-in websearch tool to search information @@ -570,6 +612,9 @@ def create_websearch_tool( storing OpenAI credentials :param str description: The description of the tool :param bool replace: Whether to replace the existing tool + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ web_search_tool_params = WebSearchToolParams( @@ -581,6 +626,7 @@ def create_websearch_tool( tool_params=web_search_tool_params, description=description, replace=replace, + instruction=instruction, ) def delete(self, force: bool = False): @@ -751,6 +797,13 @@ async def _get_description(tool_name: str) -> Union[str, None]: async def create( self, enabled: Optional[bool] = True, replace: Optional[bool] = False ): + """ + Create an AI Tool in the database + :param Optional[bool] enabled: Whether the tool should be enabled. + Default: True + :param Optional[bool] replace: Whether the tool should be replaced. + Default: False + """ if self.tool_name is None: raise AttributeError("Tool must have a name") if self.attributes is None: @@ -791,6 +844,7 @@ async def create_built_in_tool( tool_type: ToolType, description: Optional[str] = None, replace: Optional[bool] = False, + instruction: Optional[str] = None, ) -> "AsyncTool": """ Register a built-in tool @@ -802,6 +856,9 @@ async def create_built_in_tool( :param str description: Description of the tool :param bool replace: Whether to replace the existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. :return: select_ai.agent.Tool """ @@ -811,7 +868,9 @@ async def create_built_in_tool( "type select_ai.agent.ToolParams" ) attributes = ToolAttributes( - tool_params=tool_params, tool_type=tool_type + tool_params=tool_params, + tool_type=tool_type, + instruction=instruction, ) tool = cls( tool_name=tool_name, attributes=attributes, description=description @@ -828,7 +887,9 @@ async def create_email_notification_tool( sender: str, smtp_host: str, description: Optional[str], + subject: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "AsyncTool": """ Register an email notification tool @@ -839,8 +900,12 @@ async def create_email_notification_tool( :param str sender: The sender of the email :param str smtp_host: The SMTP host of the email server :param str description: The description of the tool + :param str subject: Subject of the email. :param bool replace: Whether to replace the existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. :return: select_ai.agent.Tool @@ -850,6 +915,7 @@ async def create_email_notification_tool( recipient=recipient, sender=sender, smtp_host=smtp_host, + subject=subject, ) return await cls.create_built_in_tool( tool_name=tool_name, @@ -857,6 +923,7 @@ async def create_email_notification_tool( tool_params=email_notification_tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -867,6 +934,7 @@ async def create_http_tool( endpoint: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "AsyncTool": http_tool_params = HTTPToolParams( credential_name=credential_name, endpoint=endpoint @@ -877,6 +945,7 @@ async def create_http_tool( tool_params=http_tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -886,6 +955,7 @@ async def create_pl_sql_tool( function: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "AsyncTool": """ Create a custom tool to invoke PL/SQL procedure or function @@ -895,9 +965,14 @@ async def create_pl_sql_tool( :param str description: The description of the tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ - tool_attributes = ToolAttributes(function=function) + tool_attributes = ToolAttributes( + function=function, instruction=instruction + ) tool = cls( tool_name=tool_name, attributes=tool_attributes, @@ -913,6 +988,7 @@ async def create_rag_tool( profile_name: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "AsyncTool": """ Register a RAG tool, which will use a VectorIndex linked AI Profile @@ -923,6 +999,9 @@ async def create_rag_tool( :param str description: The description of the tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ tool_params = RAGToolParams(profile_name=profile_name) return await cls.create_built_in_tool( @@ -931,6 +1010,7 @@ async def create_rag_tool( tool_params=tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -940,6 +1020,7 @@ async def create_sql_tool( profile_name: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "AsyncTool": """ Register a SQL tool to perform natural language to SQL translation @@ -950,6 +1031,9 @@ async def create_sql_tool( :param str description: The description of the tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ tool_params = SQLToolParams(profile_name=profile_name) return await cls.create_built_in_tool( @@ -958,6 +1042,7 @@ async def create_sql_tool( tool_params=tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -968,6 +1053,7 @@ async def create_slack_notification_tool( channel: str, description: Optional[str] = None, replace: bool = False, + instruction: Optional[str] = None, ) -> "AsyncTool": """ Register a Slack notification tool @@ -978,6 +1064,9 @@ async def create_slack_notification_tool( :param str description: The description of the Slack notification tool :param bool replace: Whether to replace existing tool. Default value is False + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ slack_notification_tool_params = SlackNotificationToolParams( @@ -990,6 +1079,7 @@ async def create_slack_notification_tool( tool_params=slack_notification_tool_params, description=description, replace=replace, + instruction=instruction, ) @classmethod @@ -999,6 +1089,7 @@ async def create_websearch_tool( credential_name: str, description: Optional[str], replace: bool = False, + instruction: Optional[str] = None, ) -> "AsyncTool": """ Register a built-in websearch tool to search information @@ -1009,6 +1100,9 @@ async def create_websearch_tool( storing OpenAI credentials :param str description: The description of the tool :param bool replace: Whether to replace the existing tool + :param str instruction: A clear, concise statement that describes + what the tool should accomplish and how to do it. This + text is included in the prompt sent to the LLM. """ web_search_tool_params = WebSearchToolParams( @@ -1020,6 +1114,7 @@ async def create_websearch_tool( tool_params=web_search_tool_params, description=description, replace=replace, + instruction=instruction, ) async def delete(self, force: bool = False): diff --git a/src/select_ai/async_profile.py b/src/select_ai/async_profile.py index c980a96..aa1e3c9 100644 --- a/src/select_ai/async_profile.py +++ b/src/select_ai/async_profile.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/base_profile.py b/src/select_ai/base_profile.py index 558406f..370d1cb 100644 --- a/src/select_ai/base_profile.py +++ b/src/select_ai/base_profile.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/conversation.py b/src/select_ai/conversation.py index e108df6..64cf99a 100644 --- a/src/select_ai/conversation.py +++ b/src/select_ai/conversation.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/credential.py b/src/select_ai/credential.py index 17df49b..d36f940 100644 --- a/src/select_ai/credential.py +++ b/src/select_ai/credential.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/db.py b/src/select_ai/db.py index 0c2f168..92a8d36 100644 --- a/src/select_ai/db.py +++ b/src/select_ai/db.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -126,28 +126,21 @@ async def async_connect(user: str, password: str, dsn: str, *args, **kwargs): def is_connected() -> bool: """Checks if database connection is open and healthy""" - global __conn__ - key = (os.getpid(), get_ident()) - conn = __conn__.get(key) - if conn is None: - return False try: - return conn.ping() is None - except (oracledb.DatabaseError, oracledb.InterfaceError): + with ConnectionManager().get_connection() as conn: + pass + return True + except DatabaseNotConnectedError: return False async def async_is_connected() -> bool: """Asynchronously checks if database connection is open and healthy""" - - global __async_conn__ - key = (os.getpid(), get_ident()) - conn = __async_conn__.get(key) - if conn is None: - return False try: - return await conn.ping() is None - except (oracledb.DatabaseError, oracledb.InterfaceError): + async with AsyncConnectionManager().get_connection() as conn: + pass + return True + except DatabaseNotConnectedError: return False @@ -181,7 +174,7 @@ def _set_connection_pool( :return: None """ - key = (os.getpid(), get_ident()) + key = os.getpid() if pool: global __pool__ __pool__[key] = pool @@ -262,9 +255,10 @@ class ConnectionManager: def __init__(self): global __conn__, __pool__ - self.key = (os.getpid(), get_ident()) - self.conn = __conn__.get(self.key) - self.pool = __pool__.get(self.key) + self.conn_key = (os.getpid(), get_ident()) + self.pool_key = os.getpid() + self.conn = __conn__.get(self.conn_key) + self.pool = __pool__.get(self.pool_key) if self.conn and self.pool: raise ValueError( "Use either a standalone connection " "or a connection pool" @@ -316,10 +310,10 @@ def disconnect(self, force=True): global __pool__, __conn__ if self.is_pool: self.pool.close(force=force) - __pool__.pop(self.key, None) + __pool__.pop(self.pool_key, None) elif self.is_standalone: self.conn.close() - __conn__.pop(self.key, None) + __conn__.pop(self.conn_key, None) class AsyncConnectionManager: @@ -329,9 +323,10 @@ class AsyncConnectionManager: def __init__(self): global __async_conn__, __async_pool__ - self.key = (os.getpid(), get_ident()) - self.conn = __async_conn__.get(self.key) - self.pool = __async_pool__.get(self.key) + self.conn_key = (os.getpid(), get_ident()) + self.pool_key = os.getpid() + self.conn = __async_conn__.get(self.conn_key) + self.pool = __async_pool__.get(self.pool_key) if self.conn and self.pool: raise ValueError( "Use either a standalone connection " "or a connection pool" @@ -383,7 +378,7 @@ async def disconnect(self, force=False): global __async_conn__, __async_pool__ if self.is_pool: await self.pool.close(force=force) - __async_pool__.pop(self.key, None) + __async_pool__.pop(self.pool_key, None) elif self.is_standalone: await self.conn.close() - __async_conn__.pop(self.key, None) + __async_conn__.pop(self.conn_key, None) diff --git a/src/select_ai/errors.py b/src/select_ai/errors.py index 6682479..6b67292 100644 --- a/src/select_ai/errors.py +++ b/src/select_ai/errors.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/feedback.py b/src/select_ai/feedback.py index 039d765..3e92e25 100644 --- a/src/select_ai/feedback.py +++ b/src/select_ai/feedback.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/privilege.py b/src/select_ai/privilege.py index ebf1410..e2fcaef 100644 --- a/src/select_ai/privilege.py +++ b/src/select_ai/privilege.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/profile.py b/src/select_ai/profile.py index ec777b9..7f46716 100644 --- a/src/select_ai/profile.py +++ b/src/select_ai/profile.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/provider.py b/src/select_ai/provider.py index a023abe..dd00cf6 100644 --- a/src/select_ai/provider.py +++ b/src/select_ai/provider.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/sql.py b/src/select_ai/sql.py index d105a6d..481ba61 100644 --- a/src/select_ai/sql.py +++ b/src/select_ai/sql.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -115,3 +115,26 @@ from USER_CLOUD_AI_CONVERSATIONS WHERE conversation_id = :conversation_id """ + +CHECK_USER_PRIVILEGES = """ +WITH direct_privs AS ( + SELECT table_name + FROM all_tab_privs + WHERE grantee = USER + AND privilege = 'EXECUTE' + AND table_name IN ({placeholders}) + ), + role_privs AS ( + SELECT rtp.table_name + FROM session_roles sr + JOIN role_tab_privs rtp + ON rtp.role = sr.role + WHERE rtp.privilege = 'EXECUTE' + AND rtp.table_name IN ({placeholders}) + ) + SELECT DISTINCT table_name + FROM direct_privs + UNION + SELECT DISTINCT table_name + FROM role_privs +""" diff --git a/src/select_ai/summary.py b/src/select_ai/summary.py index 0fc53d8..f9fa4ed 100644 --- a/src/select_ai/summary.py +++ b/src/select_ai/summary.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/synthetic_data.py b/src/select_ai/synthetic_data.py index 0daa5ba..34dcebe 100644 --- a/src/select_ai/synthetic_data.py +++ b/src/select_ai/synthetic_data.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/vector_index.py b/src/select_ai/vector_index.py index dd5166a..f5d60ca 100644 --- a/src/select_ai/vector_index.py +++ b/src/select_ai/vector_index.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. diff --git a/src/select_ai/version.py b/src/select_ai/version.py index fa9fe9d..499fd8c 100644 --- a/src/select_ai/version.py +++ b/src/select_ai/version.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. From efc6bc17d351b89d1c88081982c006473f31189a Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Sat, 7 Mar 2026 18:10:14 -0800 Subject: [PATCH 4/6] Enhancements and bug fixes to support connection pool --- src/select_ai/async_profile.py | 123 ++++++++++++++++++++++++++++----- src/select_ai/db.py | 5 +- src/select_ai/profile.py | 111 +++++++++++++++++++++++++---- tests/conftest.py | 6 +- 4 files changed, 210 insertions(+), 35 deletions(-) diff --git a/src/select_ai/async_profile.py b/src/select_ai/async_profile.py index aa1e3c9..0881b6b 100644 --- a/src/select_ai/async_profile.py +++ b/src/select_ai/async_profile.py @@ -28,7 +28,11 @@ validate_params_for_summary, ) from select_ai.conversation import AsyncConversation -from select_ai.db import async_cursor, async_get_connection +from select_ai.db import ( + AsyncConnectionManager, + async_cursor, + async_get_connection, +) from select_ai.errors import ( ProfileAttributesEmptyError, ProfileNotFoundError, @@ -406,8 +410,12 @@ async def list( raise_error_on_empty_attributes=False, ) - async def generate( - self, prompt: str, action=Action.SHOWSQL, params: Mapping = None + async def _generate_with_cursor( + self, + cr, + prompt: str, + action=Action.SHOWSQL, + params: Mapping = None, ) -> Union[pandas.DataFrame, str, None]: """Asynchronously perform AI translation using this profile @@ -429,12 +437,11 @@ async def generate( if params: parameters["params"] = json.dumps(params) - async with async_cursor() as cr: - data = await cr.callfunc( - "DBMS_CLOUD_AI.GENERATE", - oracledb.DB_TYPE_CLOB, - keyword_parameters=parameters, - ) + data = await cr.callfunc( + "DBMS_CLOUD_AI.GENERATE", + oracledb.DB_TYPE_CLOB, + keyword_parameters=parameters, + ) if data is not None: result = await data.read() else: @@ -444,6 +451,22 @@ async def generate( else: return result + async def generate( + self, prompt: str, action=Action.SHOWSQL, params: Mapping = None + ) -> Union[pandas.DataFrame, str, None]: + """Asynchronously perform AI translation using this profile + + :param str prompt: Natural language prompt to translate + :param select_ai.profile.Action action: + :param params: Parameters to include in the LLM request. For e.g. + conversation_id for context-aware chats + :return: Union[pandas.DataFrame, str] + """ + async with async_cursor() as cr: + return await self._generate_with_cursor( + cr, prompt=prompt, action=action, params=params + ) + async def chat(self, prompt, params: Mapping = None) -> str: """Asynchronously chat with the LLM @@ -471,8 +494,10 @@ async def chat_session( ): await conversation.create() params = {"conversation_id": conversation.conversation_id} - async_session = AsyncSession(async_profile=self, params=params) - yield async_session + async with AsyncSession( + async_profile=self, params=params + ) as async_session: + yield async_session finally: if delete: await conversation.delete() @@ -623,10 +648,10 @@ async def run_pipeline( return_type=oracledb.DB_TYPE_CLOB, keyword_parameters=parameters, ) - async_connection = await async_get_connection() - pipeline_results = await async_connection.run_pipeline( - pipeline, continue_on_error=continue_on_error - ) + async with async_get_connection() as async_connection: + pipeline_results = await async_connection.run_pipeline( + pipeline, continue_on_error=continue_on_error + ) responses = [] for result in pipeline_results: if not result.error: @@ -679,12 +704,76 @@ def __init__(self, async_profile: AsyncProfile, params: Mapping): """ self.params = params self.async_profile = async_profile + self._conn = None + self._conn_cm = None + self._cursor = None async def chat(self, prompt: str): - return await self.async_profile.chat(prompt=prompt, params=self.params) + return await self.async_profile._generate_with_cursor( + self._cursor, prompt=prompt, action=Action.CHAT, params=self.params + ) + + async def narrate(self, prompt) -> str: + """Narrate the result of the SQL + + :param str prompt: Natural language prompt + :return: str + """ + return await self.async_profile._generate_with_cursor( + self._cursor, prompt, action=Action.NARRATE, params=self.params + ) + + async def explain_sql(self, prompt: str) -> str: + """Explain the generated SQL + + :param str prompt: Natural language prompt + :return: str + """ + return await self.async_profile._generate_with_cursor( + self._cursor, prompt, action=Action.EXPLAINSQL, params=self.params + ) + + async def run_sql(self, prompt: str) -> pandas.DataFrame: + """Explain the generated SQL + + :param str prompt: Natural language prompt + :return: pandas.DataFrame + """ + return await self.async_profile._generate_with_cursor( + self._cursor, prompt, action=Action.RUNSQL, params=self.params + ) + + async def show_sql(self, prompt) -> str: + """Show the generated SQL + + :param str prompt: Natural language prompt + :return: str + """ + return await self.async_profile._generate_with_cursor( + self._cursor, prompt, action=Action.SHOWSQL, params=self.params + ) + + async def show_prompt(self, prompt: str) -> str: + """Show the prompt sent to LLM + + :param str prompt: Natural language prompt + :return: str + """ + return await self.async_profile._generate_with_cursor( + self._cursor, prompt, action=Action.SHOWPROMPT, params=self.params + ) async def __aenter__(self): + self._conn_cm = AsyncConnectionManager().get_connection() + self._conn = await self._conn_cm.__aenter__() + self._cursor = self._conn.cursor() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - pass + if self._cursor is not None: + self._cursor.close() + if self._conn_cm is not None: + await self._conn_cm.__aexit__(exc_type, exc_val, exc_tb) + self._conn = None + self._conn_cm = None + self._cursor = None diff --git a/src/select_ai/db.py b/src/select_ai/db.py index 92a8d36..3a9e10f 100644 --- a/src/select_ai/db.py +++ b/src/select_ai/db.py @@ -86,7 +86,7 @@ def create_pool_async( password: str, dsn: str, min_size: Optional[int] = 1, - max_size: Optional[int] = 1, + max_size: Optional[int] = 2, increment: Optional[int] = 1, *args, **kwargs, @@ -99,6 +99,7 @@ def create_pool_async( max=max_size, increment=increment, connection_id_prefix="async-python-select-ai", + getmode=oracledb.POOL_GETMODE_NOWAIT, *args, **kwargs, ) @@ -193,7 +194,7 @@ def get_connection() -> Generator[Connection, Any, None]: @contextlib.asynccontextmanager async def async_get_connection() -> AsyncGenerator[Any, Any]: """Returns the AsyncConnection object if connection is healthy""" - async with ConnectionManager().get_connection() as conn: + async with AsyncConnectionManager().get_connection() as conn: yield conn diff --git a/src/select_ai/profile.py b/src/select_ai/profile.py index 7f46716..8b4a0d2 100644 --- a/src/select_ai/profile.py +++ b/src/select_ai/profile.py @@ -21,7 +21,7 @@ validate_params_for_feedback, validate_params_for_summary, ) -from select_ai.db import cursor +from select_ai.db import ConnectionManager, cursor from select_ai.errors import ( ProfileAttributesEmptyError, ProfileNotFoundError, @@ -379,8 +379,9 @@ def list( raise_error_on_empty_attributes=False, ) - def generate( + def _generate_with_cursor( self, + cr, prompt: str, action: Optional[Action] = Action.RUNSQL, params: Mapping = None, @@ -403,12 +404,11 @@ def generate( } if params: parameters["params"] = json.dumps(params) - with cursor() as cr: - data = cr.callfunc( - "DBMS_CLOUD_AI.GENERATE", - oracledb.DB_TYPE_CLOB, - keyword_parameters=parameters, - ) + data = cr.callfunc( + "DBMS_CLOUD_AI.GENERATE", + oracledb.DB_TYPE_CLOB, + keyword_parameters=parameters, + ) if data is not None: result = data.read() else: @@ -418,6 +418,25 @@ def generate( else: return result + def generate( + self, + prompt: str, + action: Optional[Action] = Action.RUNSQL, + params: Mapping = None, + ) -> Union[pandas.DataFrame, str, None]: + """Perform AI translation using this profile + + :param str prompt: Natural language prompt to translate + :param select_ai.profile.Action action: + :param params: Parameters to include in the LLM request. For e.g. + conversation_id for context-aware chats + :return: Union[pandas.DataFrame, str] + """ + with cursor() as cr: + return self._generate_with_cursor( + cr, prompt=prompt, action=action, params=params + ) + def chat(self, prompt: str, params: Mapping = None) -> str: """Chat with the LLM @@ -444,8 +463,8 @@ def chat_session(self, conversation: Conversation, delete: bool = False): ): conversation.create() params = {"conversation_id": conversation.conversation_id} - session = Session(profile=self, params=params) - yield session + with Session(profile=self, params=params) as session: + yield session finally: if delete: conversation.delete() @@ -604,13 +623,79 @@ def __init__(self, profile: Profile, params: Mapping): """ self.params = params self.profile = profile + self._conn = None + self._conn_cm = None + self._cursor = None def chat(self, prompt: str): - # params = {"conversation_id": self.conversation_id} - return self.profile.chat(prompt=prompt, params=self.params) + return self.profile._generate_with_cursor( + self._cursor, prompt=prompt, action=Action.CHAT, params=self.params + ) + + def narrate(self, prompt: str) -> str: + """Narrate the result of the SQL + + :param str prompt: Natural language prompt + :return: str + """ + return self.profile._generate_with_cursor( + self._cursor, prompt, action=Action.NARRATE, params=self.params + ) + + def explain_sql(self, prompt: str) -> str: + """Explain the generated SQL + + :param str prompt: Natural language prompt + :return: str + """ + return self.profile._generate_with_cursor( + self._cursor, prompt, action=Action.EXPLAINSQL, params=self.params + ) + + def run_sql(self, prompt: str) -> pandas.DataFrame: + """Run the generate SQL statement and return a pandas Dataframe built + using the result set + + :param str prompt: Natural language prompt + :return: pandas.DataFrame + """ + return self.profile._generate_with_cursor( + self._cursor, prompt, action=Action.RUNSQL, params=self.params + ) + + def show_sql(self, prompt: str) -> str: + """Show the generated SQL + + :param str prompt: Natural language prompt + :param params: Parameters to include in the LLM request + :return: str + """ + return self.profile._generate_with_cursor( + self._cursor, prompt, action=Action.SHOWSQL, params=self.params + ) + + def show_prompt(self, prompt: str) -> str: + """Show the prompt sent to LLM + + :param str prompt: Natural language prompt + :param params: Parameters to include in the LLM request + :return: str + """ + return self.profile._generate_with_cursor( + self._cursor, prompt, action=Action.SHOWPROMPT, params=self.params + ) def __enter__(self): + self._conn_cm = ConnectionManager().get_connection() + self._conn = self._conn_cm.__enter__() + self._cursor = self._conn.cursor() return self def __exit__(self, exc_type, exc_val, exc_tb): - pass + if self._cursor is not None: + self._cursor.close() + if self._conn_cm is not None: + self._conn_cm.__exit__(exc_type, exc_val, exc_tb) + self._conn = None + self._conn_cm = None + self._cursor = None diff --git a/tests/conftest.py b/tests/conftest.py index f7e8a73..bb6cc1b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -153,14 +153,14 @@ def setup_test_user(test_env): select_ai.disconnect() -@pytest.fixture(autouse=True, scope="session") +@pytest.fixture(autouse=True, scope="module") def connect(setup_test_user, test_env): select_ai.create_pool(**test_env.connect_params(use_pool=True)) yield select_ai.disconnect() -@pytest.fixture(autouse=True, scope="session") +@pytest.fixture(autouse=True, scope="module") async def async_connect(setup_test_user, test_env, anyio_backend): select_ai.create_pool_async(**test_env.connect_params(use_pool=True)) yield @@ -191,7 +191,7 @@ async def async_cursor(): yield cr -@pytest.fixture(autouse=True, scope="session") +@pytest.fixture(autouse=True, scope="module") def oci_credential(connect, test_env): credential = { "credential_name": PYSAI_OCI_CREDENTIAL_NAME, From bf1d5a540206128e000d14262bae6e081feb95f9 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Tue, 10 Mar 2026 10:38:18 -0700 Subject: [PATCH 5/6] Added missing params for VectorIndex --- src/select_ai/db.py | 10 ++++++++-- src/select_ai/vector_index.py | 19 +++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/select_ai/db.py b/src/select_ai/db.py index 3a9e10f..f690bb1 100644 --- a/src/select_ai/db.py +++ b/src/select_ai/db.py @@ -63,6 +63,8 @@ def create_pool( min_size: Optional[int] = 1, max_size: Optional[int] = 2, increment: Optional[int] = 1, + getmode=None, + wait_timeout: Optional[int] = None, *args, **kwargs, ): @@ -74,7 +76,8 @@ def create_pool( max=max_size, increment=increment, connection_id_prefix="python-select-ai", - getmode=oracledb.POOL_GETMODE_NOWAIT, + getmode=getmode, + wait_timeout=wait_timeout, *args, **kwargs, ) @@ -88,6 +91,8 @@ def create_pool_async( min_size: Optional[int] = 1, max_size: Optional[int] = 2, increment: Optional[int] = 1, + getmode=None, + wait_timeout: Optional[int] = None, *args, **kwargs, ): @@ -99,7 +104,8 @@ def create_pool_async( max=max_size, increment=increment, connection_id_prefix="async-python-select-ai", - getmode=oracledb.POOL_GETMODE_NOWAIT, + getmode=getmode, + wait_timeout=wait_timeout, *args, **kwargs, ) diff --git a/src/select_ai/vector_index.py b/src/select_ai/vector_index.py index f5d60ca..de936fc 100644 --- a/src/select_ai/vector_index.py +++ b/src/select_ai/vector_index.py @@ -48,6 +48,8 @@ class VectorIndexAttributes(SelectAIDataClass): :param int chunk_size: Text size of chunking the input data. :param int chunk_overlap: Specifies the amount of overlapping characters between adjacent chunks of text. + :param enable_sources: Provides document source links and filenames in RAG + output :param str location: Location of the object store. :param int match_limit: Specifies the maximum number of results to return in a vector search query @@ -74,6 +76,7 @@ class VectorIndexAttributes(SelectAIDataClass): chunk_size: Optional[int] = None chunk_overlap: Optional[int] = None + enable_sources: Optional[bool] = None location: Optional[str] = None match_limit: Optional[int] = None object_storage_credential_name: Optional[str] = None @@ -193,11 +196,16 @@ def _get_description(index_name) -> Union[str, None]: else: raise VectorIndexNotFoundError(index_name=index_name) - def create(self, replace: Optional[bool] = False): + def create( + self, + replace: Optional[bool] = False, + wait_for_completion: bool = False, + ): """Create a vector index in the database and populates the index with data from an object store bucket using an async scheduler job :param bool replace: Replace vector index if it exists + :param bool wait_for_completion: True to wait for index creation :return: None """ @@ -207,6 +215,7 @@ def create(self, replace: Optional[bool] = False): parameters = { "index_name": self.index_name, "attributes": self.attributes.json(), + "wait_for_completion": wait_for_completion, } if self.description: @@ -493,11 +502,16 @@ async def _get_description(index_name) -> Union[str, None]: else: raise VectorIndexNotFoundError(index_name=index_name) - async def create(self, replace: Optional[bool] = False) -> None: + async def create( + self, + replace: Optional[bool] = False, + wait_for_completion: Optional[bool] = False, + ) -> None: """Create a vector index in the database and populates it with data from an object store bucket using an async scheduler job :param bool replace: True to replace existing vector index + :param bool wait_for_completion: True to wait for index creation """ @@ -506,6 +520,7 @@ async def create(self, replace: Optional[bool] = False) -> None: parameters = { "index_name": self.index_name, "attributes": self.attributes.json(), + "wait_for_completion": wait_for_completion, } if self.description: parameters["description"] = self.description From 83539286df326b4f8dca5439db8b52be8d2c48c8 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Fri, 13 Mar 2026 18:23:38 -0700 Subject: [PATCH 6/6] Added vector_index next_refresh_timestamp() --- samples/async/vector_index_create.py | 2 +- samples/async/vector_index_fetch.py | 1 + samples/vector_index_create.py | 2 +- samples/vector_index_fetch.py | 1 + src/select_ai/sql.py | 26 +++----------- src/select_ai/vector_index.py | 54 ++++++++++++++++++++++++++++ 6 files changed, 63 insertions(+), 23 deletions(-) diff --git a/samples/async/vector_index_create.py b/samples/async/vector_index_create.py index 34f8a2d..ccae29a 100644 --- a/samples/async/vector_index_create.py +++ b/samples/async/vector_index_create.py @@ -51,7 +51,7 @@ async def main(): description="Vector index for conda environments", profile=async_profile, ) - await async_vector_index.create(replace=True) + await async_vector_index.create(replace=True, wait_for_completion=True) print("Created vector index: test_vector_index") diff --git a/samples/async/vector_index_fetch.py b/samples/async/vector_index_fetch.py index 27aba54..da1d7ab 100644 --- a/samples/async/vector_index_fetch.py +++ b/samples/async/vector_index_fetch.py @@ -28,6 +28,7 @@ async def main(): ) print(async_vector_index.attributes) print(async_vector_index.profile) + print(await async_vector_index.get_next_refresh_timestamp()) asyncio.run(main()) diff --git a/samples/vector_index_create.py b/samples/vector_index_create.py index c29c21b..839283c 100644 --- a/samples/vector_index_create.py +++ b/samples/vector_index_create.py @@ -56,5 +56,5 @@ description="Test vector index", profile=profile, ) -vector_index.create(replace=True) +vector_index.create(replace=True, wait_for_completion=True) print("Created vector index: test_vector_index") diff --git a/samples/vector_index_fetch.py b/samples/vector_index_fetch.py index 7530064..e43076b 100644 --- a/samples/vector_index_fetch.py +++ b/samples/vector_index_fetch.py @@ -24,3 +24,4 @@ vector_index = select_ai.VectorIndex.fetch(index_name="test_vector_index") print(vector_index.attributes) print(vector_index.profile) +print(vector_index.get_next_refresh_timestamp()) diff --git a/src/select_ai/sql.py b/src/select_ai/sql.py index 481ba61..7aeab78 100644 --- a/src/select_ai/sql.py +++ b/src/select_ai/sql.py @@ -116,25 +116,9 @@ WHERE conversation_id = :conversation_id """ -CHECK_USER_PRIVILEGES = """ -WITH direct_privs AS ( - SELECT table_name - FROM all_tab_privs - WHERE grantee = USER - AND privilege = 'EXECUTE' - AND table_name IN ({placeholders}) - ), - role_privs AS ( - SELECT rtp.table_name - FROM session_roles sr - JOIN role_tab_privs rtp - ON rtp.role = sr.role - WHERE rtp.privilege = 'EXECUTE' - AND rtp.table_name IN ({placeholders}) - ) - SELECT DISTINCT table_name - FROM direct_privs - UNION - SELECT DISTINCT table_name - FROM role_privs + +GET_VECTOR_PIPELINE_LAST_EXECUTION = """ +SELECT CAST(last_execution AT TIME ZONE 'UTC' AS TIMESTAMP) +FROM USER_CLOUD_PIPELINES +WHERE pipeline_name = :pipeline_name """ diff --git a/src/select_ai/vector_index.py b/src/select_ai/vector_index.py index de936fc..1685c67 100644 --- a/src/select_ai/vector_index.py +++ b/src/select_ai/vector_index.py @@ -8,6 +8,7 @@ import json from abc import ABC from dataclasses import dataclass +from datetime import datetime, timedelta, timezone from typing import AsyncGenerator, Iterator, Optional, Union import oracledb @@ -22,6 +23,7 @@ from select_ai.sql import ( GET_USER_VECTOR_INDEX, GET_USER_VECTOR_INDEX_ATTRIBUTES, + GET_VECTOR_PIPELINE_LAST_EXECUTION, LIST_USER_VECTOR_INDEXES, ) @@ -411,6 +413,33 @@ def get_attributes(self) -> VectorIndexAttributes: """ return self._get_attributes(self.index_name) + def get_next_refresh_timestamp(self) -> Optional[datetime]: + """ + Returns the UTC timestamp of the next scheduled refresh + """ + if not self.index_name: + raise AttributeError("'index_name' is required") + attributes = self.attributes or self.get_attributes() + self.attributes = attributes + refresh_rate = attributes.refresh_rate + if refresh_rate is None: + return None + pipeline_name = ( + attributes.pipeline_name + or f"{self.index_name.upper()}$VECPIPELINE" + ) + with cursor() as cr: + cr.execute( + GET_VECTOR_PIPELINE_LAST_EXECUTION, + pipeline_name=pipeline_name, + ) + row = cr.fetchone() + if not row or row[0] is None: + return None + last_execution = row[0] + naive_ts = last_execution + timedelta(minutes=int(refresh_rate)) + return naive_ts.astimezone(timezone.utc) + def get_profile(self) -> Profile: """Get Profile object linked to this vector index @@ -710,6 +739,31 @@ async def get_attributes(self) -> VectorIndexAttributes: """ return await self._get_attributes(index_name=self.index_name) + async def get_next_refresh_timestamp(self) -> Optional[datetime]: + """Return the UTC timestamp for the next scheduled refresh.""" + if not self.index_name: + raise AttributeError("'index_name' is required") + attributes = self.attributes or await self.get_attributes() + self.attributes = attributes + refresh_rate = attributes.refresh_rate + if refresh_rate is None: + return None + pipeline_name = ( + attributes.pipeline_name + or f"{self.index_name.upper()}$VECPIPELINE" + ) + async with async_cursor() as cr: + await cr.execute( + GET_VECTOR_PIPELINE_LAST_EXECUTION, + pipeline_name=pipeline_name, + ) + row = await cr.fetchone() + if not row or row[0] is None: + return None + last_execution = row[0] + naive_ts = last_execution + timedelta(minutes=int(refresh_rate)) + return naive_ts.astimezone(timezone.utc) + async def get_profile(self) -> AsyncProfile: """Get AsyncProfile object linked to this vector index