diff --git a/README.md b/README.md index 351da16..38a930e 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,18 @@ Select AI for Python enables you to leverage the broader Python ecosystem in com ## Installation -Run +Install the Python package: + ```bash python3 -m pip install select_ai ``` +Install the optional command line interface: + +```bash +python3 -m pip install 'select_ai[cli]' +``` + ## Documentation See [Select AI for Python documentation][documentation] @@ -21,6 +28,17 @@ See [Select AI for Python documentation][documentation] Examples can be found in the [/samples][samples] directory +## Command Line Interface + +The optional `select-ai` command provides an interactive chat REPL for Select AI +profiles: + +```bash +select-ai chat --profile OCI_AI_PROFILE +``` + +![Select AI CLI demo](doc/source/image/select_ai_cli_demo.gif) + ### Basic Example ```python diff --git a/doc/source/image/select_ai_cli_demo.gif b/doc/source/image/select_ai_cli_demo.gif new file mode 100644 index 0000000..4235309 Binary files /dev/null and b/doc/source/image/select_ai_cli_demo.gif differ diff --git a/doc/source/image/select_ai_cli_demo.png b/doc/source/image/select_ai_cli_demo.png new file mode 100644 index 0000000..260f184 Binary files /dev/null and b/doc/source/image/select_ai_cli_demo.png differ diff --git a/doc/source/index.rst b/doc/source/index.rst index 55b2e84..00ece82 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -78,6 +78,15 @@ Profile user_guide/profile.rst +Command Line Interface +====================== + +.. toctree:: + :numbered: + :maxdepth: 3 + + user_guide/cli.rst + Conversation ============ @@ -125,3 +134,12 @@ AI Agent :maxdepth: 3 user_guide/agent.rst + +Web Frameworks +============== + +.. toctree:: + :numbered: + :maxdepth: 3 + + user_guide/web_frameworks.rst diff --git a/doc/source/user_guide/agent.rst b/doc/source/user_guide/agent.rst index ad549e7..f2059da 100644 --- a/doc/source/user_guide/agent.rst +++ b/doc/source/user_guide/agent.rst @@ -45,11 +45,6 @@ Python layer and persist the tool in the Database using - ``recipient`` - ``sender`` - ``smtp_host`` - * - ``HTTP`` - - ``select_ai.agent.Tool.create_http_tool`` - - - ``tool_name`` - - ``credential_name`` - - ``endpoint`` * - ``SQL`` - ``select_ai.agent.Tool.create_sql_tool`` - - ``tool_name`` @@ -288,6 +283,58 @@ output:: .. latex:clearpage:: +Export and Import Team +++++++++++++++++++++++ + +Select AI agent teams can be exported into a portable specification and +imported into the same database, a different database, or another Select AI +service. The specification describes the team composition and the associated +agent, task, and tool definitions that are needed to recreate the team. + +``Team.export_team()`` returns the specification as a JSON string by default. +``Team.import_team()`` accepts either that JSON string or a Python mapping. On +import, ``profile_name`` identifies the Select AI profile to use in the target +database. ``team_name`` can be provided to create the imported team under a new +name; this is useful when importing into the same database as the source team. + +If imported object names conflict with existing agents, tasks, tools, or teams, +set ``force=True`` to let the database replace the conflicting objects. Use this +carefully when importing into a shared schema because conflicting components can +be dropped and recreated. + +.. literalinclude:: ../../../samples/agent/team_export_import.py + :language: python + :lines: 14- + +output:: + + Exported specification: + { + "name": "EXPORT_IMPORT_MOVIE_ANALYST", + "component_type": "Agent", + "task": { + "task_name": "EXPORT_IMPORT_MOVIE_TASK", + "instruction": "Help the user with movie questions. Question: {query}", + "task_attributes": { + "enable_human_tool": "false", + "tools": [] + } + }, + "llm_config": { + "name": "LLAMA_4_MAVERICK", + "component_type": "oci" + } + } + Imported team: Team(team_name=IMPORTED_MOVIE_AGENT_TEAM, ...) + +The same APIs can also read from or write to object storage by passing both +``object_storage_credential_name`` and ``location``. When exporting to object +storage, ``Team.export_team()`` writes the specification to the location and +returns ``None``. When importing from object storage, pass the same credential +and location instead of ``specification``. + +.. latex:clearpage:: + ***************** AI agent examples ***************** diff --git a/doc/source/user_guide/async_agent.rst b/doc/source/user_guide/async_agent.rst index 9d64737..3c2cf8a 100644 --- a/doc/source/user_guide/async_agent.rst +++ b/doc/source/user_guide/async_agent.rst @@ -19,11 +19,6 @@ - ``recipient`` - ``sender`` - ``smtp_host`` - * - ``HTTP`` - - ``select_ai.agent.AsyncTool.create_http_tool`` - - - ``tool_name`` - - ``credential_name`` - - ``endpoint`` * - ``SQL`` - ``select_ai.agent.AsyncTool.create_sql_tool`` - - ``tool_name`` @@ -227,6 +222,60 @@ output:: .. latex:clearpage:: +Export and Import Team +++++++++++++++++++++++ + +Select AI agent teams can be exported into a portable specification and +imported into the same database, a different database, or another Select AI +service. The specification describes the team composition and the associated +agent, task, and tool definitions that are needed to recreate the team. + +``AsyncTeam.export_team()`` returns the specification as a JSON string by +default. ``AsyncTeam.import_team()`` accepts either that JSON string or a Python +mapping. On import, ``profile_name`` identifies the Select AI profile to use in +the target database. ``team_name`` can be provided to create the imported team +under a new name; this is useful when importing into the same database as the +source team. + +If imported object names conflict with existing agents, tasks, tools, or teams, +set ``force=True`` to let the database replace the conflicting objects. Use this +carefully when importing into a shared schema because conflicting components can +be dropped and recreated. + +.. literalinclude:: ../../../samples/agent/async/team_export_import.py + :language: python + :lines: 14- + +output:: + + Exported specification: + { + "name": "EXPORT_IMPORT_MOVIE_ANALYST", + "component_type": "Agent", + "task": { + "task_name": "EXPORT_IMPORT_MOVIE_TASK", + "instruction": "Help the user with movie questions. Question: {query}", + "task_attributes": { + "enable_human_tool": "false", + "tools": [] + } + }, + "llm_config": { + "name": "LLAMA_4_MAVERICK", + "component_type": "oci" + } + } + Imported team: AsyncTeam(team_name=IMPORTED_MOVIE_AGENT_TEAM, ...) + +The same APIs can also read from or write to object storage by passing both +``object_storage_credential_name`` and ``location``. When exporting to object +storage, ``AsyncTeam.export_team()`` writes the specification to the location +and returns ``None``. When importing from object storage, pass the same +credential and location instead of ``specification``. + +.. latex:clearpage:: + + List Teams ++++++++++ diff --git a/doc/source/user_guide/async_profile.rst b/doc/source/user_guide/async_profile.rst index 8280857..b602f5e 100644 --- a/doc/source/user_guide/async_profile.rst +++ b/doc/source/user_guide/async_profile.rst @@ -177,6 +177,27 @@ output:: .. latex:clearpage:: +*********************** +Async streaming chat +*********************** + +.. literalinclude:: ../../../samples/async/profile_chat_stream.py + :language: python + :lines: 14- + +``stream=True`` lets callers consume generated CLOB responses chunk by chunk, +reducing memory pressure and making it easier to progressively forward output +to files, services, or user interfaces. Async streaming text APIs return an +async iterator of ``str`` chunks after the awaited method call. The +``chunk_size`` parameter controls the number of CLOB characters read per chunk; +it is not a byte count. + +Streaming is supported by ``generate()``, ``chat()``, ``narrate()``, +``explain_sql()``, ``show_sql()``, and ``show_prompt()``. It is not supported +for ``run_sql()``, which returns a ``pandas.DataFrame``. + +.. latex:clearpage:: + ************************** Summarize ************************** diff --git a/doc/source/user_guide/cli.rst b/doc/source/user_guide/cli.rst new file mode 100644 index 0000000..0f63ad5 --- /dev/null +++ b/doc/source/user_guide/cli.rst @@ -0,0 +1,94 @@ +.. _cli: + +************************** +Command line interface +************************** + +.. only:: html + + .. image:: /image/select_ai_cli_demo.gif + :alt: Select AI CLI demo + :width: 100% + +.. only:: latex + + .. image:: /image/select_ai_cli_demo.png + :alt: Select AI CLI demo + :width: 100% + +The package provides an optional ``select-ai`` command line tool. Install the +CLI extra to use it: + +.. code-block:: bash + + pip install 'select_ai[cli]' + +Set the database connection details as environment variables, or pass them as +command line options: + +.. code-block:: bash + + export SELECT_AI_USER= + export SELECT_AI_PASSWORD= + export SELECT_AI_DB_CONNECT_STRING= + +Interactive chat +================ + +The ``chat`` subcommand starts an interactive profile chat REPL. Pass an +existing Select AI profile with ``--profile``: + +.. code-block:: bash + + select-ai chat --profile OCI_AI_PROFILE + +The REPL uses ``Profile.chat_session()`` so prompts in the same terminal session +share conversation context. Responses stream by default. Use ``--no-stream`` to +print each response after it is fully generated. + +.. code-block:: text + + Connected to Select AI profile: OCI_AI_PROFILE + Type /help for commands. Type /exit to quit. + select_ai> What tables can I ask about? + ... + select_ai> /exit + +Useful options: + +- ``--user``, ``--password``, and ``--dsn`` override the environment values. +- ``--wallet-location`` and ``--wallet-password`` configure wallet connections. +- ``--chunk-size`` controls the number of CLOB characters read per stream chunk. +- ``--conversation-length`` controls how many prompts are retained in context. +- ``--keep-conversation`` keeps the database conversation after the REPL exits. + +SQL commands +============ + +SQL operations are one-shot subcommands instead of a REPL: + +.. code-block:: bash + + select-ai sql show --profile OCI_AI_PROFILE "count movies by genre" + select-ai sql run --profile OCI_AI_PROFILE "count movies by genre" + select-ai sql explain --profile OCI_AI_PROFILE "count movies by genre" + select-ai sql narrate --profile OCI_AI_PROFILE "count movies by genre" + +Profile commands +================ + +Summarize and translate are available under the ``profile`` command group: + +.. code-block:: bash + + select-ai profile list + select-ai profile list --pattern "OCI.*" + + select-ai profile summarize --profile OCI_AI_PROFILE "Text to summarize" + select-ai profile summarize --profile OCI_AI_PROFILE --file notes.txt + + select-ai profile translate \ + --profile OCI_AI_PROFILE \ + --source-language English \ + --target-language German \ + "Thank you" diff --git a/doc/source/user_guide/privileges.rst b/doc/source/user_guide/privileges.rst index b61d135..945bab9 100644 --- a/doc/source/user_guide/privileges.rst +++ b/doc/source/user_guide/privileges.rst @@ -37,6 +37,7 @@ output:: Granted privileges to: +.. latex:clearpage:: **************** Revoke privilege @@ -52,4 +53,55 @@ Similarly, to revoke use the method output:: - Granted privileges to: + Revoked privileges from: + +.. latex:clearpage:: + +*************************** +Grant network access +*************************** + +Connect as admin and run +``select_ai.grant_network_access(...)`` to add a network ACL entry for +host access. This wraps ``DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE`` and can be +used for hosts that require privileges such as ``connect``, ``http``, or +``smtp``. + +.. literalinclude:: ../../../samples/grant_network_access.py + :language: python + :lines: 14- + +output:: + + Granted network access to: + +The async API is ``select_ai.async_grant_network_access(...)``. + +.. literalinclude:: ../../../samples/async/grant_network_access.py + :language: python + :lines: 14- + +.. latex:clearpage:: + +*************************** +Revoke network access +*************************** + +Connect as admin and run +``select_ai.revoke_network_access(...)`` to remove a network ACL entry for +host access. This wraps ``DBMS_NETWORK_ACL_ADMIN.REMOVE_HOST_ACE`` and should +use the same host, privileges, and port range that were used to grant access. + +.. literalinclude:: ../../../samples/revoke_network_access.py + :language: python + :lines: 14- + +output:: + + Revoked network access from: + +The async API is ``select_ai.async_revoke_network_access(...)``. + +.. literalinclude:: ../../../samples/async/revoke_network_access.py + :language: python + :lines: 14- diff --git a/doc/source/user_guide/profile.rst b/doc/source/user_guide/profile.rst index bc0dbc6..1cdc1fb 100644 --- a/doc/source/user_guide/profile.rst +++ b/doc/source/user_guide/profile.rst @@ -139,6 +139,26 @@ output:: .. latex:clearpage:: +************************** +Streaming chat +************************** + +.. literalinclude:: ../../../samples/profile_chat_stream.py + :language: python + :lines: 14- + +``stream=True`` lets callers consume generated CLOB responses chunk by chunk, +reducing memory pressure and making it easier to progressively forward output +to files, services, or user interfaces. Streaming text APIs return an iterator +of ``str`` chunks. The ``chunk_size`` parameter controls the number of CLOB +characters read per chunk; it is not a byte count. + +Streaming is supported by ``generate()``, ``chat()``, ``narrate()``, +``explain_sql()``, ``show_sql()``, and ``show_prompt()``. It is not supported +for ``run_sql()``, which returns a ``pandas.DataFrame``. + +.. latex:clearpage:: + ************************** Summarize ************************** diff --git a/doc/source/user_guide/web_frameworks.rst b/doc/source/user_guide/web_frameworks.rst new file mode 100644 index 0000000..40f354e --- /dev/null +++ b/doc/source/user_guide/web_frameworks.rst @@ -0,0 +1,182 @@ +.. _web_frameworks: + +************************************************** +Using ``select_ai`` with Python web frameworks +************************************************** + +Python web applications should create a Select AI connection pool when the +application starts and close it when the application shuts down. A pool lets +concurrent requests share a bounded set of database connections instead of +creating standalone connections per request. + +This pattern works with Python WSGI and ASGI frameworks. FastAPI is used below +as a concrete example, but the same approach applies to frameworks such as +Flask, Django, Starlette, Sanic, and Quart: initialize the pool during +application startup, use ``select_ai`` APIs inside request handlers, and close +the pool during application shutdown. + +For background and concurrency measurements, see this +`connection pooling blog `__. + +Install dependencies +==================== + +Install ``select_ai`` and FastAPI server dependencies: + +.. code-block:: sh + + python -m pip install select_ai fastapi uvicorn + +For local development, set the database connection details as environment +variables: + +.. code-block:: sh + + export SELECT_AI_USER= + export SELECT_AI_PASSWORD= + export SELECT_AI_DB_CONNECT_STRING= + export SELECT_AI_POOL_MIN=5 + export SELECT_AI_POOL_MAX=10 + export SELECT_AI_POOL_INCREMENT=5 + +If you use an mTLS wallet, also set ``TNS_ADMIN`` or pass wallet parameters to +``select_ai.create_pool()`` / ``select_ai.create_pool_async()``. + +FastAPI synchronous endpoints +============================= + +Create a file named ``app.py``: + +.. code-block:: python + + import os + from contextlib import asynccontextmanager + + from fastapi import FastAPI + + import select_ai + + user = os.getenv("SELECT_AI_USER") + password = os.getenv("SELECT_AI_PASSWORD") + dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + + pool_min = int(os.getenv("SELECT_AI_POOL_MIN", "5")) + pool_max = int(os.getenv("SELECT_AI_POOL_MAX", "10")) + pool_increment = int(os.getenv("SELECT_AI_POOL_INCREMENT", "5")) + + + @asynccontextmanager + async def lifespan(app: FastAPI): + select_ai.create_pool( + user=user, + password=password, + dsn=dsn, + min_size=pool_min, + max_size=pool_max, + increment=pool_increment, + ) + yield + select_ai.disconnect() + + + app = FastAPI(lifespan=lifespan) + + + @app.get("/chat") + def chat(prompt: str): + profile = select_ai.Profile(profile_name="oci_ai_profile") + return {"response": profile.chat(prompt=prompt)} + + + @app.get("/show_sql") + def show_sql(prompt: str): + profile = select_ai.Profile(profile_name="oci_ai_profile") + return {"sql": profile.show_sql(prompt=prompt)} + +Start the server: + +.. code-block:: sh + + uvicorn app:app --host 0.0.0.0 --port 8000 + +Call the service: + +.. code-block:: sh + + curl "http://localhost:8000/chat?prompt=What%20is%20OCI%3F" + +Stop the server by pressing ``Ctrl+C`` in the terminal where ``uvicorn`` is +running. FastAPI runs the lifespan shutdown hook and ``select_ai.disconnect()`` +closes the pool. + +FastAPI asynchronous endpoints +============================== + +For async endpoints, initialize the async pool with +``select_ai.create_pool_async()`` and close it with +``select_ai.async_disconnect()``. + +.. code-block:: python + + import os + from contextlib import asynccontextmanager + + from fastapi import FastAPI + + import select_ai + + user = os.getenv("SELECT_AI_USER") + password = os.getenv("SELECT_AI_PASSWORD") + dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + + + @asynccontextmanager + async def lifespan(app: FastAPI): + select_ai.create_pool_async( + user=user, + password=password, + dsn=dsn, + min_size=5, + max_size=10, + increment=5, + ) + yield + await select_ai.async_disconnect() + + + app = FastAPI(lifespan=lifespan) + + + @app.get("/chat") + async def chat(prompt: str): + profile = await select_ai.AsyncProfile( + profile_name="async_oci_ai_profile" + ) + return {"response": await profile.chat(prompt=prompt)} + +Start and stop the server the same way: + +.. code-block:: sh + + uvicorn app:app --host 0.0.0.0 --port 8000 + +Press ``Ctrl+C`` to stop it. + +Pool sizing +=========== + +Use connection pooling for concurrent services such as API applications, +workloads with mixed fast and slow requests, and applications with tail-latency +requirements. Use standalone connections for simple scripts, command-line +tools, or low-concurrency batch jobs. + +Set pool sizing based on expected request concurrency and database capacity. +In multi-worker deployments, each worker process creates its own pool, so total +possible database connections are approximately: + +.. code-block:: text + + workers * SELECT_AI_POOL_MAX + +Choose pool sizes that leave capacity for other database clients and avoid +overwhelming small database deployments. diff --git a/pyproject.toml b/pyproject.toml index 7758495..23eb72a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,9 @@ dependencies = [ ] [project.optional-dependencies] +cli = [ + "click", +] test = [ "anyio", "pytest", @@ -56,6 +59,9 @@ Repository = "https://github.com/oracle/python-select-ai" Issues = "https://github.com/oracle/python-select-ai/issues" Documentation = "https://oracle.github.io/python-select-ai/" +[project.scripts] +select-ai = "select_ai.cli.main:cli" + [tool.setuptools.packages.find] where = ["src"] diff --git a/samples/agent/async/team_export_import.py b/samples/agent/async/team_export_import.py new file mode 100644 index 0000000..9172923 --- /dev/null +++ b/samples/agent/async/team_export_import.py @@ -0,0 +1,91 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# agent/async/team_export_import.py +# +# Export a team specification and import it as a new team. +# ----------------------------------------------------------------------------- + +import asyncio +import json +import os + +import select_ai +from select_ai.agent import ( + AgentAttributes, + AsyncAgent, + AsyncTask, + AsyncTeam, + TaskAttributes, + TeamAttributes, +) + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +profile_name = os.getenv("SELECT_AI_PROFILE_NAME", "LLAMA_4_MAVERICK") + + +async def main(): + await select_ai.async_connect(user=user, password=password, dsn=dsn) + + task = AsyncTask( + task_name="EXPORT_IMPORT_MOVIE_TASK", + description="Task used by the team export/import sample", + attributes=TaskAttributes( + instruction="Help the user with movie questions. Question: {query}", + tools=[], + enable_human_tool=False, + ), + ) + await task.create(replace=True) + + agent = AsyncAgent( + agent_name="EXPORT_IMPORT_MOVIE_ANALYST", + description="Agent used by the team export/import sample", + attributes=AgentAttributes( + profile_name=profile_name, + role="You are an AI Movie Analyst.", + enable_human_tool=False, + ), + ) + await agent.create(enabled=True, replace=True) + + source_team = AsyncTeam( + team_name="EXPORT_IMPORT_MOVIE_TEAM", + attributes=TeamAttributes( + agents=[ + { + "name": agent.agent_name, + "task": task.task_name, + } + ], + process="sequential", + ), + ) + await source_team.create(enabled=True, replace=True) + + specification = json.loads(await source_team.export()) + print("Exported specification:") + print(json.dumps(specification, indent=2)) + + specification["name"] = "IMPORTED_MOVIE_ANALYST" + specification["task"]["task_name"] = "IMPORTED_ANALYZE_MOVIE_TASK" + + await AsyncTeam.import_team( + profile_name=profile_name, + team_name="IMPORTED_MOVIE_AGENT_TEAM", + specification=specification, + force=True, + ) + + team = await AsyncTeam.fetch("IMPORTED_MOVIE_AGENT_TEAM") + print("Imported team:", team) + + +asyncio.run(main()) diff --git a/samples/agent/team_export_import.py b/samples/agent/team_export_import.py new file mode 100644 index 0000000..066c58f --- /dev/null +++ b/samples/agent/team_export_import.py @@ -0,0 +1,85 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# team_export_import.py +# +# Export a team specification and import it as a new team. +# ----------------------------------------------------------------------------- + +import json +import os + +import select_ai +from select_ai.agent import ( + Agent, + AgentAttributes, + Task, + TaskAttributes, + Team, + TeamAttributes, +) + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +profile_name = os.getenv("SELECT_AI_PROFILE_NAME", "LLAMA_4_MAVERICK") + +select_ai.connect(user=user, password=password, dsn=dsn) + +task = Task( + task_name="EXPORT_IMPORT_MOVIE_TASK", + description="Task used by the team export/import sample", + attributes=TaskAttributes( + instruction="Help the user with movie questions. Question: {query}", + tools=[], + enable_human_tool=False, + ), +) +task.create(replace=True) + +agent = Agent( + agent_name="EXPORT_IMPORT_MOVIE_ANALYST", + description="Agent used by the team export/import sample", + attributes=AgentAttributes( + profile_name=profile_name, + role="You are an AI Movie Analyst.", + enable_human_tool=False, + ), +) +agent.create(enabled=True, replace=True) + +source_team = Team( + team_name="EXPORT_IMPORT_MOVIE_TEAM", + attributes=TeamAttributes( + agents=[ + { + "name": agent.agent_name, + "task": task.task_name, + } + ], + process="sequential", + ), +) +source_team.create(enabled=True, replace=True) + +specification = json.loads(source_team.export()) +print("Exported specification:") +print(json.dumps(specification, indent=2)) + +specification["name"] = "IMPORTED_MOVIE_ANALYST" +specification["task"]["task_name"] = "IMPORTED_ANALYZE_MOVIE_TASK" + +Team.import_team( + profile_name=profile_name, + team_name="IMPORTED_MOVIE_AGENT_TEAM", + specification=specification, + force=True, +) + +team = Team.fetch("IMPORTED_MOVIE_AGENT_TEAM") +print("Imported team:", team) diff --git a/samples/async/grant_network_access.py b/samples/async/grant_network_access.py new file mode 100644 index 0000000..a406604 --- /dev/null +++ b/samples/async/grant_network_access.py @@ -0,0 +1,37 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/grant_network_access.py +# +# Add a network ACL entry for host access +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +admin_user = os.getenv("SELECT_AI_ADMIN_USER") +password = os.getenv("SELECT_AI_ADMIN_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +select_ai_user = os.getenv("SELECT_AI_USER") + + +async def main(): + await select_ai.async_connect(user=admin_user, password=password, dsn=dsn) + await select_ai.async_grant_network_access( + users=select_ai_user, + host="smtp.example.com", + privileges=["connect", "smtp"], + lower_port=587, + upper_port=587, + ) + print("Granted network access to: ", select_ai_user) + + +asyncio.run(main()) diff --git a/samples/async/profile_chat_stream.py b/samples/async/profile_chat_stream.py new file mode 100644 index 0000000..2a22bb2 --- /dev/null +++ b/samples/async/profile_chat_stream.py @@ -0,0 +1,38 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/profile_chat_stream.py +# +# Stream chat response chunks using an async AI Profile +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + + +async def main(): + await select_ai.async_connect(user=user, password=password, dsn=dsn) + async_profile = await select_ai.AsyncProfile( + profile_name="async_oci_ai_profile" + ) + + chunks = await async_profile.chat( + prompt="What is OCI ?", stream=True, chunk_size=4096 + ) + async for chunk in chunks: + print(chunk, end="") + print() + + +asyncio.run(main()) diff --git a/samples/async/revoke_network_access.py b/samples/async/revoke_network_access.py new file mode 100644 index 0000000..8177d7d --- /dev/null +++ b/samples/async/revoke_network_access.py @@ -0,0 +1,37 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/revoke_network_access.py +# +# Remove a network ACL entry for host access +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +admin_user = os.getenv("SELECT_AI_ADMIN_USER") +password = os.getenv("SELECT_AI_ADMIN_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +select_ai_user = os.getenv("SELECT_AI_USER") + + +async def main(): + await select_ai.async_connect(user=admin_user, password=password, dsn=dsn) + await select_ai.async_revoke_network_access( + users=select_ai_user, + host="smtp.example.com", + privileges=["connect", "smtp"], + lower_port=587, + upper_port=587, + ) + print("Revoked network access from: ", select_ai_user) + + +asyncio.run(main()) diff --git a/samples/grant_network_access.py b/samples/grant_network_access.py new file mode 100644 index 0000000..da1ada9 --- /dev/null +++ b/samples/grant_network_access.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# grant_network_access.py +# +# Add a network ACL entry for host access +# ----------------------------------------------------------------------------- + +import os + +import select_ai + +admin_user = os.getenv("SELECT_AI_ADMIN_USER") +password = os.getenv("SELECT_AI_ADMIN_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +select_ai_user = os.getenv("SELECT_AI_USER") + +select_ai.connect(user=admin_user, password=password, dsn=dsn) +select_ai.grant_network_access( + users=select_ai_user, + host="smtp.example.com", + privileges=["connect", "smtp"], + lower_port=587, + upper_port=587, +) +print("Granted network access to: ", select_ai_user) diff --git a/samples/profile_chat_stream.py b/samples/profile_chat_stream.py new file mode 100644 index 0000000..44acf44 --- /dev/null +++ b/samples/profile_chat_stream.py @@ -0,0 +1,29 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# profile_chat_stream.py +# +# Stream chat response chunks using an AI Profile +# ----------------------------------------------------------------------------- + +import os + +import select_ai + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + +select_ai.connect(user=user, password=password, dsn=dsn) +profile = select_ai.Profile(profile_name="oci_ai_profile") + +for chunk in profile.chat( + prompt="What is OCI ?", stream=True, chunk_size=4096 +): + print(chunk, end="") +print() diff --git a/samples/revoke_network_access.py b/samples/revoke_network_access.py new file mode 100644 index 0000000..0d5ea7a --- /dev/null +++ b/samples/revoke_network_access.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# revoke_network_access.py +# +# Remove a network ACL entry for host access +# ----------------------------------------------------------------------------- + +import os + +import select_ai + +admin_user = os.getenv("SELECT_AI_ADMIN_USER") +password = os.getenv("SELECT_AI_ADMIN_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +select_ai_user = os.getenv("SELECT_AI_USER") + +select_ai.connect(user=admin_user, password=password, dsn=dsn) +select_ai.revoke_network_access( + users=select_ai_user, + host="smtp.example.com", + privileges=["connect", "smtp"], + lower_port=587, + upper_port=587, +) +print("Revoked network access from: ", select_ai_user) diff --git a/src/select_ai/__init__.py b/src/select_ai/__init__.py index fa13d31..d1327a5 100644 --- a/src/select_ai/__init__.py +++ b/src/select_ai/__init__.py @@ -34,12 +34,16 @@ from .errors import * from .privilege import ( async_grant_http_access, + async_grant_network_access, async_grant_privileges, async_revoke_http_access, + async_revoke_network_access, async_revoke_privileges, grant_http_access, + grant_network_access, grant_privileges, revoke_http_access, + revoke_network_access, revoke_privileges, ) from .profile import Profile diff --git a/src/select_ai/agent/team.py b/src/select_ai/agent/team.py index 2a7ae6b..34ea8cb 100644 --- a/src/select_ai/agent/team.py +++ b/src/select_ai/agent/team.py @@ -1,5 +1,5 @@ # ------------------------------------------------------------------------------ -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -76,6 +76,27 @@ def __repr__(self): ) +def _json_or_none(value: Optional[Union[str, Mapping]]) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, Mapping): + return json.dumps(value) + raise TypeError("value must be a JSON string or mapping") + + +def _validate_object_storage_location( + object_storage_credential_name: Optional[str], + location: Optional[str], +) -> None: + if bool(object_storage_credential_name) != bool(location): + raise ValueError( + "object_storage_credential_name and location must be specified " + "together" + ) + + class Team(BaseTeam): """ A Team of AI agents work together to accomplish tasks @@ -313,6 +334,177 @@ def run(self, prompt: str = None, params: Mapping = None): result = None return result + @classmethod + def export_team( + cls, + team_name: str, + object_storage_credential_name: Optional[str] = None, + location: Optional[str] = None, + params: Optional[Union[str, Mapping]] = None, + ) -> Optional[str]: + """ + Export an AI agent team specification. + + If object storage details are provided, the specification is written to + the given location and None is returned. Otherwise, the specification is + returned as a string. + + :param str team_name: Name of the AI agent team to export. + + :param str object_storage_credential_name: Optional credential name + used to write the exported specification to object storage. Must be + specified together with ``location``. + + :param str location: Optional object storage URI where the exported + specification should be written. Must be specified together with + ``object_storage_credential_name``. + + :param params: Optional export parameters. May be a JSON string or a + Python mapping. + :type params: str or Mapping + + :return: Exported team specification as a JSON string when exporting + inline, or None when exporting to object storage. + :rtype: str or None + """ + _validate_object_storage_location( + object_storage_credential_name, location + ) + parameters = { + "team_name": team_name, + } + params_json = _json_or_none(params) if params is not None else "{}" + if params_json is not None: + parameters["params"] = params_json + + with cursor() as cr: + if object_storage_credential_name: + parameters["object_storage_credential_name"] = ( + object_storage_credential_name + ) + parameters["location"] = location + cr.callproc( + "DBMS_CLOUD_AI_AGENT.EXPORT_TEAM", + keyword_parameters=parameters, + ) + return None + + data = cr.callfunc( + "DBMS_CLOUD_AI_AGENT.EXPORT_TEAM", + oracledb.DB_TYPE_CLOB, + keyword_parameters=parameters, + ) + return data.read() if data is not None else None + + @classmethod + def import_team( + cls, + profile_name: str, + team_name: Optional[str] = None, + specification: Optional[Union[str, Mapping]] = None, + object_storage_credential_name: Optional[str] = None, + location: Optional[str] = None, + force: Optional[bool] = False, + params: Optional[Union[str, Mapping]] = None, + ) -> None: + """ + Import an AI agent team specification and create the associated team, + agents, tasks, and tools in the database. + + :param str profile_name: Name of the Select AI profile to use for the + imported team and agents in the target database. + + :param str team_name: Optional name for the imported team. If omitted, + the team name from the specification is used. + + :param specification: Team specification to import. May be a JSON + string or a Python mapping. Omit this when importing from object + storage. + :type specification: str or Mapping + + :param str object_storage_credential_name: Optional credential name + used to read the specification from object storage. Must be specified + together with ``location``. + + :param str location: Optional object storage URI of the specification + to import. Must be specified together with + ``object_storage_credential_name``. + + :param bool force: Whether to replace conflicting database objects + during import. Default value is False. + + :param params: Optional import parameters. May be a JSON string or a + Python mapping. + :type params: str or Mapping + """ + _validate_object_storage_location( + object_storage_credential_name, location + ) + if specification is None and object_storage_credential_name is None: + raise ValueError( + "specification or object storage location must be specified" + ) + + parameters = { + "profile_name": profile_name, + "force": force, + } + if team_name is not None: + parameters["team_name"] = team_name + specification_json = _json_or_none(specification) + if specification_json is not None: + parameters["specification"] = specification_json + if object_storage_credential_name is not None: + parameters["object_storage_credential_name"] = ( + object_storage_credential_name + ) + parameters["location"] = location + params_json = _json_or_none(params) + if params_json is not None: + parameters["params"] = params_json + + with cursor() as cr: + cr.callproc( + "DBMS_CLOUD_AI_AGENT.IMPORT_TEAM", + keyword_parameters=parameters, + ) + + def export( + self, + object_storage_credential_name: Optional[str] = None, + location: Optional[str] = None, + params: Optional[Union[str, Mapping]] = None, + ) -> Optional[str]: + """ + Export this AI agent team specification. + + If object storage details are provided, the specification is written to + the given location and None is returned. Otherwise, the specification is + returned as a string. + + :param str object_storage_credential_name: Optional credential name + used to write the exported specification to object storage. Must be + specified together with ``location``. + + :param str location: Optional object storage URI where the exported + specification should be written. Must be specified together with + ``object_storage_credential_name``. + + :param params: Optional export parameters. May be a JSON string or a + Python mapping. + :type params: str or Mapping + + :return: Exported team specification as a JSON string when exporting + inline, or None when exporting to object storage. + :rtype: str or None + """ + return self.export_team( + team_name=self.team_name, + object_storage_credential_name=object_storage_credential_name, + location=location, + params=params, + ) + def set_attributes(self, attributes: TeamAttributes) -> None: """ Set the attributes of the AI Agent team @@ -588,6 +780,177 @@ async def run(self, prompt: str = None, params: Mapping = None): result = None return result + @classmethod + async def export_team( + cls, + team_name: str, + object_storage_credential_name: Optional[str] = None, + location: Optional[str] = None, + params: Optional[Union[str, Mapping]] = None, + ) -> Optional[str]: + """ + Export an AI agent team specification. + + If object storage details are provided, the specification is written to + the given location and None is returned. Otherwise, the specification is + returned as a string. + + :param str team_name: Name of the AI agent team to export. + + :param str object_storage_credential_name: Optional credential name + used to write the exported specification to object storage. Must be + specified together with ``location``. + + :param str location: Optional object storage URI where the exported + specification should be written. Must be specified together with + ``object_storage_credential_name``. + + :param params: Optional export parameters. May be a JSON string or a + Python mapping. + :type params: str or Mapping + + :return: Exported team specification as a JSON string when exporting + inline, or None when exporting to object storage. + :rtype: str or None + """ + _validate_object_storage_location( + object_storage_credential_name, location + ) + parameters = { + "team_name": team_name, + } + params_json = _json_or_none(params) if params is not None else "{}" + if params_json is not None: + parameters["params"] = params_json + + async with async_cursor() as cr: + if object_storage_credential_name: + parameters["object_storage_credential_name"] = ( + object_storage_credential_name + ) + parameters["location"] = location + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.EXPORT_TEAM", + keyword_parameters=parameters, + ) + return None + + data = await cr.callfunc( + "DBMS_CLOUD_AI_AGENT.EXPORT_TEAM", + oracledb.DB_TYPE_CLOB, + keyword_parameters=parameters, + ) + return await data.read() if data is not None else None + + @classmethod + async def import_team( + cls, + profile_name: str, + team_name: Optional[str] = None, + specification: Optional[Union[str, Mapping]] = None, + object_storage_credential_name: Optional[str] = None, + location: Optional[str] = None, + force: Optional[bool] = False, + params: Optional[Union[str, Mapping]] = None, + ) -> None: + """ + Import an AI agent team specification and create the associated team, + agents, tasks, and tools in the database. + + :param str profile_name: Name of the Select AI profile to use for the + imported team and agents in the target database. + + :param str team_name: Optional name for the imported team. If omitted, + the team name from the specification is used. + + :param specification: Team specification to import. May be a JSON + string or a Python mapping. Omit this when importing from object + storage. + :type specification: str or Mapping + + :param str object_storage_credential_name: Optional credential name + used to read the specification from object storage. Must be specified + together with ``location``. + + :param str location: Optional object storage URI of the specification + to import. Must be specified together with + ``object_storage_credential_name``. + + :param bool force: Whether to replace conflicting database objects + during import. Default value is False. + + :param params: Optional import parameters. May be a JSON string or a + Python mapping. + :type params: str or Mapping + """ + _validate_object_storage_location( + object_storage_credential_name, location + ) + if specification is None and object_storage_credential_name is None: + raise ValueError( + "specification or object storage location must be specified" + ) + + parameters = { + "profile_name": profile_name, + "force": force, + } + if team_name is not None: + parameters["team_name"] = team_name + specification_json = _json_or_none(specification) + if specification_json is not None: + parameters["specification"] = specification_json + if object_storage_credential_name is not None: + parameters["object_storage_credential_name"] = ( + object_storage_credential_name + ) + parameters["location"] = location + params_json = _json_or_none(params) + if params_json is not None: + parameters["params"] = params_json + + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.IMPORT_TEAM", + keyword_parameters=parameters, + ) + + async def export( + self, + object_storage_credential_name: Optional[str] = None, + location: Optional[str] = None, + params: Optional[Union[str, Mapping]] = None, + ) -> Optional[str]: + """ + Export this AI agent team specification. + + If object storage details are provided, the specification is written to + the given location and None is returned. Otherwise, the specification is + returned as a string. + + :param str object_storage_credential_name: Optional credential name + used to write the exported specification to object storage. Must be + specified together with ``location``. + + :param str location: Optional object storage URI where the exported + specification should be written. Must be specified together with + ``object_storage_credential_name``. + + :param params: Optional export parameters. May be a JSON string or a + Python mapping. + :type params: str or Mapping + + :return: Exported team specification as a JSON string when exporting + inline, or None when exporting to object storage. + :rtype: str or None + """ + return await self.export_team( + team_name=self.team_name, + object_storage_credential_name=object_storage_credential_name, + location=location, + params=params, + ) + async def set_attributes(self, attributes: TeamAttributes) -> None: """ Set the attributes of the AI Agent team diff --git a/src/select_ai/agent/tool.py b/src/select_ai/agent/tool.py index 19a29e4..f2cd523 100644 --- a/src/select_ai/agent/tool.py +++ b/src/select_ai/agent/tool.py @@ -438,28 +438,6 @@ def create_email_notification_tool( instruction=instruction, ) - @classmethod - def create_http_tool( - cls, - tool_name: str, - credential_name: str, - endpoint: str, - description: Optional[str] = None, - replace: bool = False, - instruction: Optional[str] = None, - ) -> "Tool": - http_tool_params = HTTPToolParams( - credential_name=credential_name, endpoint=endpoint - ) - return cls.create_built_in_tool( - tool_name=tool_name, - tool_type=ToolType.HTTP, - tool_params=http_tool_params, - description=description, - replace=replace, - instruction=instruction, - ) - @classmethod def create_pl_sql_tool( cls, @@ -926,28 +904,6 @@ async def create_email_notification_tool( instruction=instruction, ) - @classmethod - async def create_http_tool( - cls, - tool_name: str, - credential_name: str, - endpoint: str, - description: Optional[str] = None, - replace: bool = False, - instruction: Optional[str] = None, - ) -> "AsyncTool": - http_tool_params = HTTPToolParams( - credential_name=credential_name, endpoint=endpoint - ) - return await cls.create_built_in_tool( - tool_name=tool_name, - tool_type=ToolType.HTTP, - tool_params=http_tool_params, - description=description, - replace=replace, - instruction=instruction, - ) - @classmethod async def create_pl_sql_tool( cls, diff --git a/src/select_ai/async_profile.py b/src/select_ai/async_profile.py index 0881b6b..43ce4cb 100644 --- a/src/select_ai/async_profile.py +++ b/src/select_ai/async_profile.py @@ -425,6 +425,28 @@ async def _generate_with_cursor( conversation_id for context-aware chats :return: Union[pandas.DataFrame, str] """ + parameters = self._generate_parameters(prompt, action, params) + + data = await cr.callfunc( + "DBMS_CLOUD_AI.GENERATE", + oracledb.DB_TYPE_CLOB, + keyword_parameters=parameters, + ) + if data is not None: + result = await data.read() + else: + result = None + if action == Action.RUNSQL: + return convert_json_rows_to_df(result) + else: + return result + + def _generate_parameters( + self, + prompt: str, + action, + params: Mapping = None, + ) -> Mapping: if not prompt: raise ValueError("prompt cannot be empty or None") @@ -436,45 +458,102 @@ async def _generate_with_cursor( } if params: parameters["params"] = json.dumps(params) + return parameters + async def _generate_stream( + self, + prompt: str, + action, + params: Mapping = None, + chunk_size: int = 8192, + ) -> AsyncGenerator[str, None]: + async with async_cursor() as cr: + async for chunk in self._generate_stream_with_cursor( + cr, + prompt=prompt, + action=action, + params=params, + chunk_size=chunk_size, + ): + yield chunk + + async def _generate_stream_with_cursor( + self, + cr, + prompt: str, + action, + params: Mapping = None, + chunk_size: int = 8192, + ) -> AsyncGenerator[str, None]: + if action == Action.RUNSQL: + raise ValueError("stream=True is not supported for run_sql") + if chunk_size <= 0: + raise ValueError("chunk_size must be greater than 0") + + parameters = self._generate_parameters(prompt, action, params) data = await cr.callfunc( "DBMS_CLOUD_AI.GENERATE", oracledb.DB_TYPE_CLOB, keyword_parameters=parameters, ) - if data is not None: - result = await data.read() - else: - result = None - if action == Action.RUNSQL: - return convert_json_rows_to_df(result) - else: - return result + if data is None: + return + + offset = 1 + while True: + chunk = await data.read(offset=offset, amount=chunk_size) + if not chunk: + break + yield chunk + offset += len(chunk) async def generate( - self, prompt: str, action=Action.SHOWSQL, params: Mapping = None - ) -> Union[pandas.DataFrame, str, None]: + self, + prompt: str, + action=Action.SHOWSQL, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[pandas.DataFrame, str, AsyncGenerator[str, None], None]: """Asynchronously perform AI translation using this profile :param str prompt: Natural language prompt to translate :param select_ai.profile.Action action: :param params: Parameters to include in the LLM request. For e.g. conversation_id for context-aware chats + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: Union[pandas.DataFrame, str] """ + if stream: + return self._generate_stream(prompt, action, params, chunk_size) async with async_cursor() as cr: return await self._generate_with_cursor( cr, prompt=prompt, action=action, params=params ) - async def chat(self, prompt, params: Mapping = None) -> str: + async def chat( + self, + prompt, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[str, AsyncGenerator[str, None]]: """Asynchronously chat with the LLM :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ - return await self.generate(prompt, action=Action.CHAT, params=params) + return await self.generate( + prompt, + action=Action.CHAT, + params=params, + stream=stream, + chunk_size=chunk_size, + ) @asynccontextmanager async def chat_session( @@ -502,26 +581,50 @@ async def chat_session( if delete: await conversation.delete() - async def narrate(self, prompt, params: Mapping = None) -> str: + async def narrate( + self, + prompt, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[str, AsyncGenerator[str, None]]: """Narrate the result of the SQL :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ return await self.generate( - prompt, action=Action.NARRATE, params=params + prompt, + action=Action.NARRATE, + params=params, + stream=stream, + chunk_size=chunk_size, ) - async def explain_sql(self, prompt: str, params: Mapping = None): + async def explain_sql( + self, + prompt: str, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ): """Explain the generated SQL :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ return await self.generate( - prompt, action=Action.EXPLAINSQL, params=params + prompt, + action=Action.EXPLAINSQL, + params=params, + stream=stream, + chunk_size=chunk_size, ) async def run_sql( @@ -535,26 +638,50 @@ async def run_sql( """ return await self.generate(prompt, action=Action.RUNSQL, params=params) - async def show_sql(self, prompt, params: Mapping = None): + async def show_sql( + self, + prompt, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ): """Show the generated SQL :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ return await self.generate( - prompt, action=Action.SHOWSQL, params=params + prompt, + action=Action.SHOWSQL, + params=params, + stream=stream, + chunk_size=chunk_size, ) - async def show_prompt(self, prompt: str, params: Mapping = None): + async def show_prompt( + self, + prompt: str, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ): """Show the prompt sent to LLM :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ return await self.generate( - prompt, action=Action.SHOWPROMPT, params=params + prompt, + action=Action.SHOWPROMPT, + params=params, + stream=stream, + chunk_size=chunk_size, ) async def summarize( @@ -708,27 +835,61 @@ def __init__(self, async_profile: AsyncProfile, params: Mapping): self._conn_cm = None self._cursor = None - async def chat(self, prompt: str): + async def chat( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, AsyncGenerator[str, None]]: + if stream: + return self.async_profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.CHAT, + params=self.params, + chunk_size=chunk_size, + ) return await self.async_profile._generate_with_cursor( self._cursor, prompt=prompt, action=Action.CHAT, params=self.params ) - async def narrate(self, prompt) -> str: + async def narrate( + self, prompt, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, AsyncGenerator[str, None]]: """Narrate the result of the SQL :param str prompt: Natural language prompt + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.async_profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.NARRATE, + params=self.params, + chunk_size=chunk_size, + ) return await self.async_profile._generate_with_cursor( self._cursor, prompt, action=Action.NARRATE, params=self.params ) - async def explain_sql(self, prompt: str) -> str: + async def explain_sql( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, AsyncGenerator[str, None]]: """Explain the generated SQL :param str prompt: Natural language prompt + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.async_profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.EXPLAINSQL, + params=self.params, + chunk_size=chunk_size, + ) return await self.async_profile._generate_with_cursor( self._cursor, prompt, action=Action.EXPLAINSQL, params=self.params ) @@ -743,22 +904,46 @@ async def run_sql(self, prompt: str) -> pandas.DataFrame: self._cursor, prompt, action=Action.RUNSQL, params=self.params ) - async def show_sql(self, prompt) -> str: + async def show_sql( + self, prompt, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, AsyncGenerator[str, None]]: """Show the generated SQL :param str prompt: Natural language prompt + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.async_profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.SHOWSQL, + params=self.params, + chunk_size=chunk_size, + ) return await self.async_profile._generate_with_cursor( self._cursor, prompt, action=Action.SHOWSQL, params=self.params ) - async def show_prompt(self, prompt: str) -> str: + async def show_prompt( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, AsyncGenerator[str, None]]: """Show the prompt sent to LLM :param str prompt: Natural language prompt + :param bool stream: Return an async iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.async_profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.SHOWPROMPT, + params=self.params, + chunk_size=chunk_size, + ) return await self.async_profile._generate_with_cursor( self._cursor, prompt, action=Action.SHOWPROMPT, params=self.params ) diff --git a/src/select_ai/base_profile.py b/src/select_ai/base_profile.py index 370d1cb..02103dd 100644 --- a/src/select_ai/base_profile.py +++ b/src/select_ai/base_profile.py @@ -132,7 +132,9 @@ async def async_create(cls, **kwargs): def set_attribute(self, key, value): if key in Provider.keys() and not isinstance(value, Provider): - setattr(self.provider, key, value) + if self.provider is None: + self.provider = Provider() + setattr(self.provider, Provider.key_alias(key), value) else: setattr(self, key, value) diff --git a/src/select_ai/cli/__init__.py b/src/select_ai/cli/__init__.py new file mode 100644 index 0000000..4a90a0e --- /dev/null +++ b/src/select_ai/cli/__init__.py @@ -0,0 +1,8 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +__all__ = [] diff --git a/src/select_ai/cli/chat.py b/src/select_ai/cli/chat.py new file mode 100644 index 0000000..a470ce4 --- /dev/null +++ b/src/select_ai/cli/chat.py @@ -0,0 +1,132 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import click + +import select_ai +from select_ai.cli.common import ( + connect, + connection_options, + echo_command, + echo_status, + print_chunks, + profile, +) + + +def _print_help() -> None: + click.secho("Commands:", fg="cyan") + echo_command("/help", "Show this help") + echo_command("/clear", "Start a fresh conversation") + echo_command("/exit", "Exit the chat session") + echo_command("/quit", "Exit the chat session") + + +def _conversation(profile_name: str, conversation_length: int): + return select_ai.Conversation( + attributes=select_ai.ConversationAttributes( + title=f"select-ai chat: {profile_name}", + conversation_length=conversation_length, + ) + ) + + +@click.command() +@click.option( + "--profile", + "profile_name", + required=True, + help="Select AI profile name to use for the chat session.", +) +@connection_options +@click.option( + "--no-stream", + is_flag=True, + help="Print each response after it is fully generated.", +) +@click.option( + "--chunk-size", + default=8192, + show_default=True, + help="Number of characters to read per streaming chunk.", +) +@click.option( + "--conversation-length", + default=10, + show_default=True, + help="Number of prompts retained in the conversation context.", +) +@click.option( + "--keep-conversation", + is_flag=True, + help="Keep the database conversation after the REPL exits.", +) +def chat( + profile_name, + user, + password, + dsn, + wallet_location, + wallet_password, + no_stream, + chunk_size, + conversation_length, + keep_conversation, +): + """Start a context-aware interactive chat REPL.""" + connect(user, password, dsn, wallet_location, wallet_password) + prof = profile(profile_name) + conversation = _conversation(profile_name, conversation_length) + + echo_status(f"Connected to Select AI profile: {profile_name}") + click.secho( + "Type /help for commands. Type /exit to quit.", fg="bright_black" + ) + + try: + while True: + with prof.chat_session( + conversation=conversation, + delete=not keep_conversation, + ) as session: + while True: + try: + prompt = click.prompt( + click.style("select_ai", fg="cyan"), + prompt_suffix="> ", + ) + except (EOFError, KeyboardInterrupt): + click.echo() + return + + prompt = prompt.strip() + if not prompt: + continue + if prompt in ("/exit", "/quit"): + return + if prompt == "/help": + _print_help() + continue + if prompt == "/clear": + conversation = _conversation( + profile_name, conversation_length + ) + echo_status("Started a fresh conversation.") + break + + if no_stream: + click.echo(session.chat(prompt)) + else: + print_chunks( + session.chat( + prompt, + stream=True, + chunk_size=chunk_size, + ) + ) + finally: + select_ai.disconnect() diff --git a/src/select_ai/cli/common.py b/src/select_ai/cli/common.py new file mode 100644 index 0000000..c6360f4 --- /dev/null +++ b/src/select_ai/cli/common.py @@ -0,0 +1,123 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import getpass +import os +import sys +from typing import Iterable, Optional + +import click + +import select_ai + + +def env(name: str) -> Optional[str]: + value = os.getenv(name) + return value if value else None + + +def connection_options(func): + options = [ + click.option( + "--wallet-password", + default=env("SELECT_AI_WALLET_PASSWORD"), + help="Wallet password. Defaults to SELECT_AI_WALLET_PASSWORD.", + ), + click.option( + "--wallet-location", + default=env("SELECT_AI_WALLET_LOCATION"), + help="Wallet location. Defaults to SELECT_AI_WALLET_LOCATION.", + ), + click.option( + "--dsn", + default=env("SELECT_AI_DB_CONNECT_STRING"), + help=( + "Database connect string. Defaults to " + "SELECT_AI_DB_CONNECT_STRING." + ), + ), + click.option( + "--password", + default=env("SELECT_AI_PASSWORD"), + help="Database password. Defaults to SELECT_AI_PASSWORD.", + ), + click.option( + "--user", + default=env("SELECT_AI_USER"), + help="Database user. Defaults to SELECT_AI_USER.", + ), + ] + for option in options: + func = option(func) + return func + + +def connect( + user: Optional[str], + password: Optional[str], + dsn: Optional[str], + wallet_location: Optional[str], + wallet_password: Optional[str], +) -> None: + missing = [] + if user is None: + missing.append("--user or SELECT_AI_USER") + if dsn is None: + missing.append("--dsn or SELECT_AI_DB_CONNECT_STRING") + if missing: + raise click.ClickException( + "Missing required connection values: " + ", ".join(missing) + ) + if password is None: + password = getpass.getpass("Database password: ") + + connect_args = { + "user": user, + "password": password, + "dsn": dsn, + } + if wallet_location: + connect_args["wallet_location"] = wallet_location + connect_args["config_dir"] = wallet_location + if wallet_password: + connect_args["wallet_password"] = wallet_password + select_ai.connect(**connect_args) + + +def profile(profile_name: str) -> select_ai.Profile: + return select_ai.Profile(profile_name=profile_name) + + +def echo_command(command: str, description: str) -> None: + click.echo( + f" {click.style(command, fg='green')} " + f"{click.style(description, fg='bright_black')}" + ) + + +def echo_profile(profile_name: str) -> None: + click.echo(click.style(profile_name, fg="cyan")) + + +def echo_status(message: str) -> None: + click.secho(message, fg="cyan") + + +def print_chunks(chunks: Iterable[str], color: Optional[str] = None) -> None: + for chunk in chunks: + if color: + click.echo(click.style(chunk, fg=color), nl=False) + else: + sys.stdout.write(chunk) + sys.stdout.flush() + click.echo() + sys.stdout.flush() + + +def print_text_result(result: object) -> None: + if result is not None: + click.echo(result) diff --git a/src/select_ai/cli/main.py b/src/select_ai/cli/main.py new file mode 100644 index 0000000..3016f79 --- /dev/null +++ b/src/select_ai/cli/main.py @@ -0,0 +1,33 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +try: + import click +except ImportError: + + def cli(): + raise SystemExit( + "The Select AI CLI requires the optional 'cli' extra. " + "Install it with: pip install 'select_ai[cli]'" + ) + +else: + from select_ai.cli.chat import chat + from select_ai.cli.profile import profile_group + from select_ai.cli.sql import sql + + @click.group(context_settings={"help_option_names": ["-h", "--help"]}) + def cli(): + """Command line tools for Select AI.""" + + cli.add_command(chat) + cli.add_command(sql) + cli.add_command(profile_group, "profile") + + +if __name__ == "__main__": + cli() diff --git a/src/select_ai/cli/profile.py b/src/select_ai/cli/profile.py new file mode 100644 index 0000000..0a604c0 --- /dev/null +++ b/src/select_ai/cli/profile.py @@ -0,0 +1,155 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import click + +import select_ai +from select_ai.cli.common import ( + connect, + connection_options, + echo_profile, + print_text_result, + profile, +) + + +@click.group() +def profile_group(): + """Run one-shot Select AI profile operations.""" + + +@profile_group.command("list") +@connection_options +@click.option( + "--pattern", + default=".*", + show_default=True, + help="Regular expression used to match profile names.", +) +def list_profiles( + user, + password, + dsn, + wallet_location, + wallet_password, + pattern, +): + """List Select AI profile names.""" + connect(user, password, dsn, wallet_location, wallet_password) + try: + for fetched_profile in select_ai.Profile.list(pattern): + echo_profile(fetched_profile.profile_name) + finally: + select_ai.disconnect() + + +@profile_group.command("translate") +@click.option( + "--profile", + "profile_name", + required=True, + help="Select AI profile name.", +) +@connection_options +@click.option( + "--source-language", + required=True, + help="Source language for the input text.", +) +@click.option( + "--target-language", + required=True, + help="Target language for the translated text.", +) +@click.argument("text") +def translate( + profile_name, + user, + password, + dsn, + wallet_location, + wallet_password, + source_language, + target_language, + text, +): + """Translate text using a Select AI profile.""" + connect(user, password, dsn, wallet_location, wallet_password) + try: + print_text_result( + profile(profile_name).translate( + text=text, + source_language=source_language, + target_language=target_language, + ) + ) + finally: + select_ai.disconnect() + + +@profile_group.command("summarize") +@click.option( + "--profile", + "profile_name", + required=True, + help="Select AI profile name.", +) +@connection_options +@click.option("--prompt", help="Optional prompt to guide the summary.") +@click.option( + "--location-uri", + help="URI or local file path containing content to summarize.", +) +@click.option( + "--file", + "file_path", + type=click.Path(exists=True, dir_okay=False, readable=True), + help="Read local file content and summarize it.", +) +@click.option( + "--credential-name", + help="Credential used to access object storage content.", +) +@click.argument("content", required=False) +def summarize( + profile_name, + user, + password, + dsn, + wallet_location, + wallet_password, + prompt, + location_uri, + file_path, + credential_name, + content, +): + """Summarize inline content or content from a location URI.""" + if file_path and content: + raise click.ClickException( + "Use either inline content or --file, not both." + ) + if file_path and location_uri: + raise click.ClickException( + "Use either --file or --location-uri, not both." + ) + if file_path: + with open(file_path, encoding="utf-8") as file: + content = file.read() + + connect(user, password, dsn, wallet_location, wallet_password) + try: + print_text_result( + profile(profile_name).summarize( + content=content, + prompt=prompt, + location_uri=location_uri, + credential_name=credential_name, + ) + ) + finally: + select_ai.disconnect() diff --git a/src/select_ai/cli/sql.py b/src/select_ai/cli/sql.py new file mode 100644 index 0000000..503b7fd --- /dev/null +++ b/src/select_ai/cli/sql.py @@ -0,0 +1,102 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import click + +import select_ai +from select_ai.cli.common import ( + connect, + connection_options, + print_chunks, + print_text_result, + profile, +) + + +def _text_action(action): + @click.command() + @click.option( + "--profile", + "profile_name", + required=True, + help="Select AI profile name.", + ) + @connection_options + @click.option( + "--no-stream", + is_flag=True, + help="Print the response after it is fully generated.", + ) + @click.option( + "--chunk-size", + default=8192, + show_default=True, + help="Number of characters to read per streaming chunk.", + ) + @click.argument("prompt") + def command( + profile_name, + user, + password, + dsn, + wallet_location, + wallet_password, + no_stream, + chunk_size, + prompt, + ): + connect(user, password, dsn, wallet_location, wallet_password) + try: + prof = profile(profile_name) + method = getattr(prof, action) + if no_stream: + print_text_result(method(prompt)) + else: + print_chunks( + method(prompt, stream=True, chunk_size=chunk_size) + ) + finally: + select_ai.disconnect() + + return command + + +@click.group() +def sql(): + """Generate, run, explain, and narrate SQL from natural language.""" + + +@sql.command("run") +@click.option( + "--profile", + "profile_name", + required=True, + help="Select AI profile name.", +) +@connection_options +@click.argument("prompt") +def run_sql( + profile_name, + user, + password, + dsn, + wallet_location, + wallet_password, + prompt, +): + """Generate SQL, run it, and print the result table.""" + connect(user, password, dsn, wallet_location, wallet_password) + try: + result = profile(profile_name).run_sql(prompt) + click.echo(result.to_string(index=False)) + finally: + select_ai.disconnect() + + +sql.add_command(_text_action("show_sql"), "show") +sql.add_command(_text_action("explain_sql"), "explain") +sql.add_command(_text_action("narrate"), "narrate") diff --git a/src/select_ai/privilege.py b/src/select_ai/privilege.py index 1bf4a56..8a445b3 100644 --- a/src/select_ai/privilege.py +++ b/src/select_ai/privilege.py @@ -4,7 +4,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -from typing import List, Union +from typing import List, Optional, Union from .db import async_cursor, cursor from .sql import ( @@ -22,6 +22,74 @@ def _normalize_schema_user(user: str) -> str: return user.upper() +def _as_list(value: Union[str, List[str]], name: str) -> List[str]: + if isinstance(value, str): + value = [value] + if not value: + raise ValueError(f"'{name}' cannot be empty") + return value + + +def _append_host_ace_statement(privileges: List[str]) -> str: + privilege_bind_names = [ + f"privilege_{idx}" for idx, _ in enumerate(privileges) + ] + privilege_list = ", ".join(f":{name}" for name in privilege_bind_names) + return f""" + BEGIN + DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( + host => :host, + lower_port => :lower_port, + upper_port => :upper_port, + ace => xs$ace_type( + privilege_list => xs$name_list({privilege_list}), + principal_name => :user, + principal_type => xs_acl.ptype_db + ) + ); + END; + """ + + +def _remove_host_ace_statement(privileges: List[str]) -> str: + privilege_bind_names = [ + f"privilege_{idx}" for idx, _ in enumerate(privileges) + ] + privilege_list = ", ".join(f":{name}" for name in privilege_bind_names) + return f""" + BEGIN + DBMS_NETWORK_ACL_ADMIN.REMOVE_HOST_ACE( + host => :host, + lower_port => :lower_port, + upper_port => :upper_port, + ace => xs$ace_type( + privilege_list => xs$name_list({privilege_list}), + principal_name => :user, + principal_type => xs_acl.ptype_db + ) + ); + END; + """ + + +def _append_host_ace_parameters( + user: str, + host: str, + privileges: List[str], + lower_port: Optional[int], + upper_port: Optional[int], +): + parameters = { + "host": host, + "user": _normalize_schema_user(user), + "lower_port": lower_port, + "upper_port": upper_port, + } + for idx, privilege in enumerate(privileges): + parameters[f"privilege_{idx}"] = privilege + return parameters + + async def async_grant_privileges(users: Union[str, List[str]]): """ This method grants execute privilege on the packages DBMS_CLOUD, @@ -71,6 +139,34 @@ async def async_grant_http_access( ) +async def async_grant_network_access( + users: Union[str, List[str]], + host: str, + privileges: Union[str, List[str]], + lower_port: Optional[int] = None, + upper_port: Optional[int] = None, +): + """ + Async method to add a network ACL entry for host access. + """ + users = _as_list(users, "users") + privileges = _as_list(privileges, "privileges") + statement = _append_host_ace_statement(privileges) + + async with async_cursor() as cr: + for user in users: + await cr.execute( + statement, + **_append_host_ace_parameters( + user=user, + host=host, + privileges=privileges, + lower_port=lower_port, + upper_port=upper_port, + ), + ) + + async def async_revoke_http_access( users: Union[str, List[str]], provider_endpoint: str, @@ -90,6 +186,34 @@ async def async_revoke_http_access( ) +async def async_revoke_network_access( + users: Union[str, List[str]], + host: str, + privileges: Union[str, List[str]], + lower_port: Optional[int] = None, + upper_port: Optional[int] = None, +): + """ + Async method to remove a network ACL entry for host access. + """ + users = _as_list(users, "users") + privileges = _as_list(privileges, "privileges") + statement = _remove_host_ace_statement(privileges) + + async with async_cursor() as cr: + for user in users: + await cr.execute( + statement, + **_append_host_ace_parameters( + user=user, + host=host, + privileges=privileges, + lower_port=lower_port, + upper_port=upper_port, + ), + ) + + def grant_privileges(users: Union[str, List[str]]): """ This method grants execute privilege on the packages DBMS_CLOUD, @@ -131,6 +255,34 @@ def grant_http_access(users: Union[str, List[str]], provider_endpoint: str): ) +def grant_network_access( + users: Union[str, List[str]], + host: str, + privileges: Union[str, List[str]], + lower_port: Optional[int] = None, + upper_port: Optional[int] = None, +): + """ + Adds a network ACL entry for host access. + """ + users = _as_list(users, "users") + privileges = _as_list(privileges, "privileges") + statement = _append_host_ace_statement(privileges) + + with cursor() as cr: + for user in users: + cr.execute( + statement, + **_append_host_ace_parameters( + user=user, + host=host, + privileges=privileges, + lower_port=lower_port, + upper_port=upper_port, + ), + ) + + def revoke_http_access(users: Union[str, List[str]], provider_endpoint: str): """ Removes ACL entry for HTTP access @@ -144,3 +296,31 @@ def revoke_http_access(users: Union[str, List[str]], provider_endpoint: str): user=user, host=provider_endpoint, ) + + +def revoke_network_access( + users: Union[str, List[str]], + host: str, + privileges: Union[str, List[str]], + lower_port: Optional[int] = None, + upper_port: Optional[int] = None, +): + """ + Removes a network ACL entry for host access. + """ + users = _as_list(users, "users") + privileges = _as_list(privileges, "privileges") + statement = _remove_host_ace_statement(privileges) + + with cursor() as cr: + for user in users: + cr.execute( + statement, + **_append_host_ace_parameters( + user=user, + host=host, + privileges=privileges, + lower_port=lower_port, + upper_port=upper_port, + ), + ) diff --git a/src/select_ai/profile.py b/src/select_ai/profile.py index 8b4a0d2..69a7c5b 100644 --- a/src/select_ai/profile.py +++ b/src/select_ai/profile.py @@ -394,6 +394,27 @@ def _generate_with_cursor( conversation_id for context-aware chats :return: Union[pandas.DataFrame, str] """ + parameters = self._generate_parameters(prompt, action, params) + data = cr.callfunc( + "DBMS_CLOUD_AI.GENERATE", + oracledb.DB_TYPE_CLOB, + keyword_parameters=parameters, + ) + if data is not None: + result = data.read() + else: + result = None + if action == Action.RUNSQL: + return convert_json_rows_to_df(result) + else: + return result + + def _generate_parameters( + self, + prompt: str, + action: Optional[Action], + params: Mapping = None, + ) -> Mapping: if not prompt: raise ValueError("prompt cannot be empty or None") parameters = { @@ -404,47 +425,101 @@ def _generate_with_cursor( } if params: parameters["params"] = json.dumps(params) + return parameters + + def _generate_stream( + self, + prompt: str, + action: Optional[Action], + params: Mapping = None, + chunk_size: int = 8192, + ) -> Generator[str, None, None]: + with cursor() as cr: + yield from self._generate_stream_with_cursor( + cr, + prompt=prompt, + action=action, + params=params, + chunk_size=chunk_size, + ) + + def _generate_stream_with_cursor( + self, + cr, + prompt: str, + action: Optional[Action], + params: Mapping = None, + chunk_size: int = 8192, + ) -> Generator[str, None, None]: + if action == Action.RUNSQL: + raise ValueError("stream=True is not supported for run_sql") + if chunk_size <= 0: + raise ValueError("chunk_size must be greater than 0") + + parameters = self._generate_parameters(prompt, action, params) data = cr.callfunc( "DBMS_CLOUD_AI.GENERATE", oracledb.DB_TYPE_CLOB, keyword_parameters=parameters, ) - if data is not None: - result = data.read() - else: - result = None - if action == Action.RUNSQL: - return convert_json_rows_to_df(result) - else: - return result + if data is None: + return + + offset = 1 + while True: + chunk = data.read(offset=offset, amount=chunk_size) + if not chunk: + break + yield chunk + offset += len(chunk) def generate( self, prompt: str, action: Optional[Action] = Action.RUNSQL, params: Mapping = None, - ) -> Union[pandas.DataFrame, str, None]: + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[pandas.DataFrame, str, Generator[str, None, None], None]: """Perform AI translation using this profile :param str prompt: Natural language prompt to translate :param select_ai.profile.Action action: :param params: Parameters to include in the LLM request. For e.g. conversation_id for context-aware chats + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: Union[pandas.DataFrame, str] """ + if stream: + return self._generate_stream(prompt, action, params, chunk_size) with cursor() as cr: return self._generate_with_cursor( cr, prompt=prompt, action=action, params=params ) - def chat(self, prompt: str, params: Mapping = None) -> str: + def chat( + self, + prompt: str, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[str, Generator[str, None, None]]: """Chat with the LLM :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ - return self.generate(prompt, action=Action.CHAT, params=params) + return self.generate( + prompt, + action=Action.CHAT, + params=params, + stream=stream, + chunk_size=chunk_size, + ) @contextmanager def chat_session(self, conversation: Conversation, delete: bool = False): @@ -469,23 +544,51 @@ def chat_session(self, conversation: Conversation, delete: bool = False): if delete: conversation.delete() - def narrate(self, prompt: str, params: Mapping = None) -> str: + def narrate( + self, + prompt: str, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[str, Generator[str, None, None]]: """Narrate the result of the SQL :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ - return self.generate(prompt, action=Action.NARRATE, params=params) + return self.generate( + prompt, + action=Action.NARRATE, + params=params, + stream=stream, + chunk_size=chunk_size, + ) - def explain_sql(self, prompt: str, params: Mapping = None) -> str: + def explain_sql( + self, + prompt: str, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[str, Generator[str, None, None]]: """Explain the generated SQL :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ - return self.generate(prompt, action=Action.EXPLAINSQL, params=params) + return self.generate( + prompt, + action=Action.EXPLAINSQL, + params=params, + stream=stream, + chunk_size=chunk_size, + ) def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame: """Run the generate SQL statement and return a pandas Dataframe built @@ -497,23 +600,51 @@ def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame: """ return self.generate(prompt, action=Action.RUNSQL, params=params) - def show_sql(self, prompt: str, params: Mapping = None) -> str: + def show_sql( + self, + prompt: str, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[str, Generator[str, None, None]]: """Show the generated SQL :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ - return self.generate(prompt, action=Action.SHOWSQL, params=params) + return self.generate( + prompt, + action=Action.SHOWSQL, + params=params, + stream=stream, + chunk_size=chunk_size, + ) - def show_prompt(self, prompt: str, params: Mapping = None) -> str: + def show_prompt( + self, + prompt: str, + params: Mapping = None, + stream: bool = False, + chunk_size: int = 8192, + ) -> Union[str, Generator[str, None, None]]: """Show the prompt sent to LLM :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ - return self.generate(prompt, action=Action.SHOWPROMPT, params=params) + return self.generate( + prompt, + action=Action.SHOWPROMPT, + params=params, + stream=stream, + chunk_size=chunk_size, + ) def summarize( self, @@ -627,27 +758,61 @@ def __init__(self, profile: Profile, params: Mapping): self._conn_cm = None self._cursor = None - def chat(self, prompt: str): + def chat( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, Generator[str, None, None]]: + if stream: + return self.profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.CHAT, + params=self.params, + chunk_size=chunk_size, + ) return self.profile._generate_with_cursor( self._cursor, prompt=prompt, action=Action.CHAT, params=self.params ) - def narrate(self, prompt: str) -> str: + def narrate( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, Generator[str, None, None]]: """Narrate the result of the SQL :param str prompt: Natural language prompt + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.NARRATE, + params=self.params, + chunk_size=chunk_size, + ) return self.profile._generate_with_cursor( self._cursor, prompt, action=Action.NARRATE, params=self.params ) - def explain_sql(self, prompt: str) -> str: + def explain_sql( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, Generator[str, None, None]]: """Explain the generated SQL :param str prompt: Natural language prompt + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.EXPLAINSQL, + params=self.params, + chunk_size=chunk_size, + ) return self.profile._generate_with_cursor( self._cursor, prompt, action=Action.EXPLAINSQL, params=self.params ) @@ -663,24 +828,48 @@ def run_sql(self, prompt: str) -> pandas.DataFrame: self._cursor, prompt, action=Action.RUNSQL, params=self.params ) - def show_sql(self, prompt: str) -> str: + def show_sql( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, Generator[str, None, None]]: """Show the generated SQL :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.SHOWSQL, + params=self.params, + chunk_size=chunk_size, + ) return self.profile._generate_with_cursor( self._cursor, prompt, action=Action.SHOWSQL, params=self.params ) - def show_prompt(self, prompt: str) -> str: + def show_prompt( + self, prompt: str, stream: bool = False, chunk_size: int = 8192 + ) -> Union[str, Generator[str, None, None]]: """Show the prompt sent to LLM :param str prompt: Natural language prompt :param params: Parameters to include in the LLM request + :param bool stream: Return an iterator of response chunks + :param int chunk_size: Number of characters to read per stream chunk :return: str """ + if stream: + return self.profile._generate_stream_with_cursor( + self._cursor, + prompt=prompt, + action=Action.SHOWPROMPT, + params=self.params, + chunk_size=chunk_size, + ) return self.profile._generate_with_cursor( self._cursor, prompt, action=Action.SHOWPROMPT, params=self.params ) diff --git a/src/select_ai/version.py b/src/select_ai/version.py index 472aff4..7eaa33f 100644 --- a/src/select_ai/version.py +++ b/src/select_ai/version.py @@ -5,4 +5,4 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -__version__ = "1.3.1" +__version__ = "1.4.0" diff --git a/tests/agents/test_3001_async_tools.py b/tests/agents/test_3001_async_tools.py index a5d4668..a1b9de5 100644 --- a/tests/agents/test_3001_async_tools.py +++ b/tests/agents/test_3001_async_tools.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -60,8 +60,6 @@ DISABLED_TOOL_NAME = f"PYSAI_3001_DISABLED_TOOL_{UUID}" DEFAULT_STATUS_TOOL_NAME = f"PYSAI_3001_DEFAULT_STATUS_TOOL_{UUID}" DROP_FORCE_MISSING_TOOL = f"PYSAI_3001_DROP_MISSING_{UUID}" -HTTP_TOOL_NAME = f"PYSAI_3001_HTTP_TOOL_{UUID}" -HTTP_ENDPOINT = "https://example.com/api/tool" EMAIL_TOOL_NAME = f"PYSAI_3001_EMAIL_TOOL_{UUID}" SLACK_TOOL_NAME = f"PYSAI_3001_SLACK_TOOL_{UUID}" @@ -796,34 +794,3 @@ async def test_3023_drop_tool_force_false_non_existent_raises(): with pytest.raises(oracledb.Error) as exc: await tool.delete(force=False) logger.info("Received expected drop error: %s", exc.value) - - -async def test_3024_http_tool_created(email_credential): - logger.info("Creating HTTP tool: %s", HTTP_TOOL_NAME) - try: - tool = await AsyncTool.create_http_tool( - tool_name=HTTP_TOOL_NAME, - credential_name=email_credential, - endpoint=HTTP_ENDPOINT, - description="HTTP Tool", - replace=True, - ) - except oracledb.DatabaseError as e: - if "ORA-20052" in str(e): - logger.info( - "HTTP tool creation failed with expected backend-side error: %s", - e, - ) - return - raise - try: - fetched = await AsyncTool.fetch(HTTP_TOOL_NAME) - assert fetched.tool_name == HTTP_TOOL_NAME - assert fetched.attributes.tool_type == select_ai.agent.ToolType.HTTP - assert ( - fetched.attributes.tool_params.credential_name == email_credential - ) - assert fetched.attributes.tool_params.endpoint == HTTP_ENDPOINT - finally: - logger.info("Deleting HTTP tool: %s", HTTP_TOOL_NAME) - await tool.delete(force=True) diff --git a/tests/agents/test_3001_tools.py b/tests/agents/test_3001_tools.py index f72eb30..e23b0ec 100644 --- a/tests/agents/test_3001_tools.py +++ b/tests/agents/test_3001_tools.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -10,17 +10,20 @@ (with logging for behavior visibility) """ -import uuid import logging -import pytest import os -import select_ai +import uuid + import oracledb +import pytest +import select_ai from select_ai.agent import Tool from select_ai.errors import AgentToolNotFoundError # Path -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3001_tools.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) @@ -41,6 +44,7 @@ # Per-test logging # ----------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def log_test_name(request): logger.info(f"--- Starting test: {request.function.__name__} ---") @@ -52,16 +56,21 @@ def log_test_name(request): # Helper Functions # ----------------------------------------------------------------------------- + def get_tool_status(tool_name): with select_ai.cursor() as cur: - cur.execute(""" + cur.execute( + """ SELECT status FROM USER_AI_AGENT_TOOLS WHERE tool_name = :tool_name - """, {"tool_name": tool_name}) + """, + {"tool_name": tool_name}, + ) row = cur.fetchone() return row[0] if row else None + # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- @@ -87,8 +96,6 @@ def get_tool_status(tool_name): DISABLED_TOOL_NAME = f"PYSAI_3001_DISABLED_TOOL_{UUID}" DEFAULT_STATUS_TOOL_NAME = f"PYSAI_3001_DEFAULT_STATUS_TOOL_{UUID}" DROP_FORCE_MISSING_TOOL = f"PYSAI_3001_DROP_MISSING_{UUID}" -HTTP_TOOL_NAME = f"PYSAI_3001_HTTP_TOOL_{UUID}" -HTTP_ENDPOINT = "https://example.com/api/tool" EMAIL_RECIPIENT = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") EMAIL_SENDER = os.getenv("PYSAI_TEST_EMAIL_SENDER") EMAIL_SMTP_HOST = os.getenv("PYSAI_TEST_EMAIL_SMTPHOST") @@ -97,6 +104,7 @@ def get_tool_status(tool_name): slack_username = os.getenv("PYSAI_TEST_SLACK_USERNAME") slack_password = os.getenv("PYSAI_TEST_SLACK_PASSWORD") + @pytest.fixture(scope="module") def email_credential(): cred_name = "EMAIL_CRED" @@ -107,7 +115,9 @@ def email_credential(): select_ai.delete_credential(cred_name) logger.info("Dropped existing EMAIL credential: %s", cred_name) except Exception as e: - logger.info("EMAIL credential did not exist or could not be dropped: %s", e) + logger.info( + "EMAIL credential did not exist or could not be dropped: %s", e + ) # Create fresh credential credential = { @@ -116,10 +126,7 @@ def email_credential(): "password": smtp_password, } - select_ai.create_credential( - credential=credential, - replace=True - ) + select_ai.create_credential(credential=credential, replace=True) logger.info("Created EMAIL credential: %s", cred_name) yield cred_name @@ -128,7 +135,10 @@ def email_credential(): try: select_ai.delete_credential(cred_name) except Exception as e: - logger.warning("Failed to delete EMAIL credential during teardown: %s", e) + logger.warning( + "Failed to delete EMAIL credential during teardown: %s", e + ) + @pytest.fixture(scope="module") def slack_credential(): @@ -140,7 +150,9 @@ def slack_credential(): select_ai.delete_credential(cred_name) logger.info("Dropped existing SLACK credential: %s", cred_name) except Exception as e: - logger.info("SLACK credential did not exist or could not be dropped: %s", e) + logger.info( + "SLACK credential did not exist or could not be dropped: %s", e + ) # Create fresh SLACK credential (backend-required fields) credential = { @@ -149,10 +161,7 @@ def slack_credential(): "password": slack_password, } - select_ai.create_credential( - credential=credential, - replace=True - ) + select_ai.create_credential(credential=credential, replace=True) logger.info("Created SLACK credential: %s", cred_name) yield cred_name @@ -161,7 +170,9 @@ def slack_credential(): try: select_ai.delete_credential(cred_name) except Exception as e: - logger.warning("Failed to delete SLACK credential during teardown: %s", e) + logger.warning( + "Failed to delete SLACK credential during teardown: %s", e + ) # ----------------------------------------------------------------------------- @@ -254,6 +265,7 @@ def plsql_tool(plsql_function): logger.info("Deleting PL/SQL tool") tool.delete(force=True) + @pytest.fixture(scope="module") def web_search_tool(): """Fixture for Web Search Tool positive case.""" @@ -264,11 +276,14 @@ def web_search_tool(): credential_name="OPENAI_CRED", replace=True, ) - logger.info("WEBSEARCH Tool created successfully: %s", WEB_SEARCH_TOOL_NAME) + logger.info( + "WEBSEARCH Tool created successfully: %s", WEB_SEARCH_TOOL_NAME + ) yield tool logger.info("Deleting Web Search tool") tool.delete(force=True) + @pytest.fixture(scope="module") def email_tool(email_credential): logger.info("Creating EMAIL tool: EMAIL_TOOL") @@ -286,6 +301,7 @@ def email_tool(email_credential): logger.info("Deleting EMAIL tool") tool.delete(force=True) + @pytest.fixture(scope="module") def slack_tool(slack_credential): logger.info("Creating SLACK tool: SLACK_TOOL") @@ -306,10 +322,11 @@ def slack_tool(slack_credential): else: raise e finally: - if 'tool' in locals(): + if "tool" in locals(): logger.info("Deleting SLACK tool") tool.delete(force=True) + @pytest.fixture(scope="module") def neg_sql_tool(): logger.info("Creating SQL tool with INVALID profile: NEG_SQL_TOOL") @@ -323,6 +340,7 @@ def neg_sql_tool(): logger.info("Deleting NEG_SQL_TOOL") tool.delete(force=True) + @pytest.fixture(scope="module") def neg_rag_tool(): logger.info("Creating RAG tool with INVALID profile: NEG_RAG_TOOL") @@ -350,10 +368,12 @@ def neg_plsql_tool(): logger.info("Deleting NEG_PLSQL_TOOL") tool.delete(force=True) + # ----------------------------------------------------------------------------- # POSITIVE TESTS # ----------------------------------------------------------------------------- + def test_3000_sql_tool_created(sql_tool): logger.info("Validating SQL tool creation") logger.info("SQL Tool created successfully: %s", SQL_TOOL_NAME) @@ -373,7 +393,9 @@ def test_3001_rag_tool_created(rag_tool): def test_3002_plsql_tool_created(plsql_tool): logger.info("Validating PL/SQL tool creation") logger.info("PL/SQL Tool created successfully: %s", PLSQL_TOOL_NAME) - logger.info("PL/SQL function created successfully: %s", PLSQL_FUNCTION_NAME) + logger.info( + "PL/SQL function created successfully: %s", PLSQL_FUNCTION_NAME + ) assert plsql_tool.tool_name == PLSQL_TOOL_NAME assert plsql_tool.attributes.function == PLSQL_FUNCTION_NAME @@ -443,10 +465,13 @@ def test_3009_slack_tool_created(slack_tool): # If the tool is None (because of expected ORA-20052 error), skip the assertion if slack_tool is None: - logger.info("SLACK tool creation failed with expected error ORA-20052, but continuing test.") + logger.info( + "SLACK tool creation failed with expected error ORA-20052, but continuing test." + ) else: assert slack_tool.tool_name == "SLACK_TOOL" + def test_3010_custom_tool_attributes_roundtrip(): logger.info( "Validating custom tool attribute roundtrip: instruction/tool_inputs/description" @@ -483,7 +508,10 @@ def test_3010_custom_tool_attributes_roundtrip(): ) assert isinstance(fetched.attributes.tool_inputs, list) assert fetched.attributes.tool_inputs[0]["name"] == "p_birth_date" - assert "birth date" in fetched.attributes.tool_inputs[0]["description"].lower() + assert ( + "birth date" + in fetched.attributes.tool_inputs[0]["description"].lower() + ) finally: tool.delete(force=True) @@ -522,7 +550,9 @@ def test_3012_custom_tool_with_tool_type_without_instruction(sql_profile): ) tool.create(replace=True) try: - fetched = select_ai.agent.Tool.fetch(CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME) + fetched = select_ai.agent.Tool.fetch( + CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME + ) logger.info( "Fetched custom tool | name=%s | type=%s | instruction=%s | profile=%s", fetched.tool_name, @@ -554,7 +584,9 @@ def test_3013_custom_tool_with_tool_type_and_instruction(sql_profile): ) tool.create(replace=True) try: - fetched = select_ai.agent.Tool.fetch(CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME) + fetched = select_ai.agent.Tool.fetch( + CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME + ) assert fetched.tool_name == CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL assert fetched.attributes.instruction is not None @@ -567,13 +599,19 @@ def test_3013_custom_tool_with_tool_type_and_instruction(sql_profile): def test_3014_sql_tool_with_invalid_profile_created(neg_sql_tool): logger.info("Validating SQL tool with invalid profile is stored") assert neg_sql_tool.tool_name == "NEG_SQL_TOOL" - assert neg_sql_tool.attributes.tool_params.profile_name == "NON_EXISTENT_PROFILE" + assert ( + neg_sql_tool.attributes.tool_params.profile_name + == "NON_EXISTENT_PROFILE" + ) def test_3015_rag_tool_with_invalid_profile_created(neg_rag_tool): logger.info("Validating RAG tool with invalid profile is stored") assert neg_rag_tool.tool_name == "NEG_RAG_TOOL" - assert neg_rag_tool.attributes.tool_params.profile_name == "NON_EXISTENT_RAG_PROFILE" + assert ( + neg_rag_tool.attributes.tool_params.profile_name + == "NON_EXISTENT_RAG_PROFILE" + ) def test_3016_plsql_tool_with_invalid_function_created(neg_plsql_tool): @@ -611,7 +649,9 @@ def test_3020_create_tool_default_status_enabled(sql_profile): tool = select_ai.agent.Tool.create_built_in_tool( tool_name=DEFAULT_STATUS_TOOL_NAME, tool_type=select_ai.agent.ToolType.SQL, - tool_params=select_ai.agent.SQLToolParams(profile_name=SQL_PROFILE_NAME), + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), ) try: status = get_tool_status(DEFAULT_STATUS_TOOL_NAME) @@ -661,32 +701,3 @@ def test_3023_drop_tool_force_false_non_existent_raises(): with pytest.raises(oracledb.Error) as exc: tool.delete(force=False) logger.info("Received expected drop error: %s", exc.value) - - -def test_3024_http_tool_created(email_credential): - logger.info("Creating HTTP tool: %s", HTTP_TOOL_NAME) - try: - tool = select_ai.agent.Tool.create_http_tool( - tool_name=HTTP_TOOL_NAME, - credential_name=email_credential, - endpoint=HTTP_ENDPOINT, - description="HTTP Tool", - replace=True, - ) - except oracledb.DatabaseError as e: - if "ORA-20052" in str(e): - logger.info( - "HTTP tool creation failed with expected backend-side error: %s", - e, - ) - return - raise - try: - fetched = select_ai.agent.Tool.fetch(HTTP_TOOL_NAME) - assert fetched.tool_name == HTTP_TOOL_NAME - assert fetched.attributes.tool_type == select_ai.agent.ToolType.HTTP - assert fetched.attributes.tool_params.credential_name == email_credential - assert fetched.attributes.tool_params.endpoint == HTTP_ENDPOINT - finally: - logger.info("Deleting HTTP tool: %s", HTTP_TOOL_NAME) - tool.delete(force=True) diff --git a/tests/agents/test_3301_async_teams.py b/tests/agents/test_3301_async_teams.py index 62b1384..8ae2547 100644 --- a/tests/agents/test_3301_async_teams.py +++ b/tests/agents/test_3301_async_teams.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -9,6 +9,7 @@ 3301 - Async contract, regression and corner-case tests for select_ai.agent.AsyncTeam """ +import json import logging import os import uuid @@ -93,6 +94,15 @@ async def expect_async_error(expected_code, coro_fn): pytest.fail(f"Expected error {expected_code} did not occur") +async def ignore_async_missing(coro_fn): + try: + await coro_fn() + except (AgentTeamNotFoundError, oracledb.DatabaseError) as exc: + msg = str(exc) + if "ORA-20053" not in msg and "ORA-20051" not in msg: + raise + + def log_team_details(context: str, team) -> None: attrs = getattr(team, "attributes", None) details = { @@ -172,6 +182,7 @@ async def python_gen_ai_profile(profile_attributes): def task_attributes(): return TaskAttributes( instruction="Help the user. Question: {query}", + tools=[], enable_human_tool=False, ) @@ -316,6 +327,53 @@ async def test_3308_fetch_non_existing(): await expect_async_error("NOT_FOUND", lambda: AsyncTeam.fetch(name)) +async def test_3309_export_team(team): + logger.info("Exporting team: %s", team.team_name) + specification = await team.export() + logger.info("Exported team specification: %s", specification) + spec = json.loads(specification) + assert isinstance(specification, str) + assert len(specification) > 0 + assert spec["name"] == PYSAI_TEAM_AGENT_NAME + assert spec["task"]["task_name"] == PYSAI_TEAM_TASK_NAME + + +async def test_3310_import_team(team, python_gen_ai_profile): + imported_team_name = f"PYSAI_IMPORTED_TEAM_{uuid.uuid4().hex.upper()}" + imported_agent_name = f"PYSAI_IMPORTED_AGENT_{uuid.uuid4().hex.upper()}" + imported_task_name = f"PYSAI_IMPORTED_TASK_{uuid.uuid4().hex.upper()}" + + logger.info("Exporting source team for import: %s", team.team_name) + spec = json.loads(await team.export()) + spec["name"] = imported_agent_name + spec["task"]["task_name"] = imported_task_name + + try: + logger.info("Importing team: %s", imported_team_name) + await AsyncTeam.import_team( + profile_name=PYSAI_TEAM_PROFILE_NAME, + team_name=imported_team_name, + specification=spec, + force=True, + ) + imported = await AsyncTeam.fetch(imported_team_name) + log_team_details("test_3310_import_team", imported) + assert imported.team_name == imported_team_name + assert imported.attributes.agents == [ + {"name": imported_agent_name, "task": imported_task_name} + ] + finally: + await ignore_async_missing( + lambda: AsyncTeam.delete_team(imported_team_name, force=True) + ) + await ignore_async_missing( + lambda: AsyncTask.delete_task(imported_task_name, force=True) + ) + await ignore_async_missing( + lambda: AsyncAgent.delete_agent(imported_agent_name, force=True) + ) + + async def test_3311_set_attribute_invalid_key(team): await expect_async_error( "ORA-20053", lambda: team.set_attribute("no_such_attr", "x") diff --git a/tests/agents/test_3301_teams.py b/tests/agents/test_3301_teams.py index 66398f9..9ac5664 100644 --- a/tests/agents/test_3301_teams.py +++ b/tests/agents/test_3301_teams.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -9,13 +9,14 @@ 3301 - Final contract, regression and corner-case tests for select_ai.agent.Team """ -import uuid +import json import logging import os +import uuid + +import oracledb import pytest import select_ai -import oracledb - from select_ai.agent import ( Agent, AgentAttributes, @@ -24,14 +25,15 @@ Team, TeamAttributes, ) - from select_ai.errors import AgentTeamNotFoundError # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_DIR = os.path.join(PROJECT_ROOT, "log") os.makedirs(LOG_DIR, exist_ok=True) @@ -50,11 +52,15 @@ LOGGER = logging.getLogger(__name__) LOGGER.setLevel(logging.INFO) + def log_step(msg): LOGGER.info("%s", msg) + def log_ok(msg): LOGGER.info("%s", msg) + + logger = LOGGER @@ -62,6 +68,7 @@ def log_ok(msg): # Per-test logging # ----------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def log_test_name(request): logger.info(f"--- Starting test: {request.function.__name__} ---") @@ -73,6 +80,7 @@ def log_test_name(request): # Strict error checker (LIKE 3101 / 3201) # ----------------------------------------------------------------------------- + def expect_error(expected_code, fn): """ expected_code: @@ -95,6 +103,16 @@ def expect_error(expected_code, fn): else: pytest.fail(f"Expected error {expected_code} did not occur") + +def ignore_missing(fn): + try: + fn() + except (AgentTeamNotFoundError, oracledb.DatabaseError) as exc: + msg = str(exc) + if "ORA-20053" not in msg and "ORA-20051" not in msg: + raise + + # ----------------------------------------------------------------------------- # Test constants # ----------------------------------------------------------------------------- @@ -122,6 +140,7 @@ def expect_error(expected_code, fn): # "model": "cohere.command-r-plus" # } + @pytest.fixture(scope="module") def python_gen_ai_profile(profile_attributes): log_step(f"Creating profile: {PYSAI_TEAM_PROFILE_NAME}") @@ -135,8 +154,7 @@ def python_gen_ai_profile(profile_attributes): # ---- STRICT TYPE CHECK ---- assert isinstance( - profile_attributes, - select_ai.ProfileAttributes + profile_attributes, select_ai.ProfileAttributes ), "profile_attributes must be ProfileAttributes object" profile = select_ai.Profile( @@ -157,9 +175,11 @@ def python_gen_ai_profile(profile_attributes): def task_attributes(): return TaskAttributes( instruction="Help the user. Question: {query}", + tools=[], enable_human_tool=False, ) + @pytest.fixture(scope="module") def task(task_attributes): log_step(f"Creating task: {PYSAI_TEAM_TASK_NAME}") @@ -173,6 +193,7 @@ def task(task_attributes): log_step(f"Deleting task: {PYSAI_TEAM_TASK_NAME}") task.delete(force=True) + @pytest.fixture(scope="module") def agent(python_gen_ai_profile): log_step(f"Creating agent: {PYSAI_TEAM_AGENT_NAME}") @@ -190,6 +211,7 @@ def agent(python_gen_ai_profile): log_step(f"Deleting agent: {PYSAI_TEAM_AGENT_NAME}") agent.delete(force=True) + @pytest.fixture(scope="module") def team_attributes(agent, task): return TeamAttributes( @@ -197,6 +219,7 @@ def team_attributes(agent, task): process="sequential", ) + @pytest.fixture(scope="module") def team(team_attributes): log_step(f"Creating team: {PYSAI_TEAM_NAME}") @@ -210,6 +233,7 @@ def team(team_attributes): log_step(f"Deleting team: {PYSAI_TEAM_NAME}") team.delete(force=True) + # ----------------------------------------------------------------------------- # Tests # ----------------------------------------------------------------------------- @@ -218,6 +242,7 @@ def team(team_attributes): # Logging-enhanced Team tests # ----------------------------------------------------------------------------- + def test_3300_create_and_identity(team, team_attributes): log_step("Validating team identity and attributes") log_step(f"Team name: {team.team_name}") @@ -268,7 +293,9 @@ def test_3304_disable_enable_contract(team): def test_3305_set_attribute_process(team): - log_step(f"Setting team attribute 'process' to 'sequential': {team.team_name}") + log_step( + f"Setting team attribute 'process' to 'sequential': {team.team_name}" + ) team.set_attribute("process", "sequential") fetched = Team.fetch(PYSAI_TEAM_NAME) log_step(f"Fetched attribute process: {fetched.attributes.process}") @@ -307,6 +334,57 @@ def test_3308_fetch_non_existing(): log_ok("Fetch non-existing confirmed error") +def test_3309_export_team(team): + log_step(f"Exporting team: {team.team_name}") + specification = team.export() + log_step(f"Exported team specification: {specification}") + spec = json.loads(specification) + assert isinstance(specification, str) + assert len(specification) > 0 + assert spec["name"] == PYSAI_TEAM_AGENT_NAME + assert spec["task"]["task_name"] == PYSAI_TEAM_TASK_NAME + log_ok("Export team OK") + + +def test_3310_import_team(team, python_gen_ai_profile): + imported_team_name = f"PYSAI_IMPORTED_TEAM_{uuid.uuid4().hex.upper()}" + imported_agent_name = f"PYSAI_IMPORTED_AGENT_{uuid.uuid4().hex.upper()}" + imported_task_name = f"PYSAI_IMPORTED_TASK_{uuid.uuid4().hex.upper()}" + + log_step(f"Exporting source team for import: {team.team_name}") + spec = json.loads(team.export()) + spec["name"] = imported_agent_name + spec["task"]["task_name"] = imported_task_name + + try: + log_step(f"Importing team: {imported_team_name}") + Team.import_team( + profile_name=PYSAI_TEAM_PROFILE_NAME, + team_name=imported_team_name, + specification=spec, + force=True, + ) + imported = Team.fetch(imported_team_name) + log_step(f"Imported team attributes: {imported.attributes}") + assert imported.team_name == imported_team_name + assert imported.attributes.agents == [ + {"name": imported_agent_name, "task": imported_task_name} + ] + finally: + log_step(f"Deleting imported team: {imported_team_name}") + ignore_missing( + lambda: Team.delete_team(imported_team_name, force=True) + ) + log_step(f"Deleting imported task: {imported_task_name}") + ignore_missing( + lambda: Task.delete_task(imported_task_name, force=True) + ) + log_step(f"Deleting imported agent: {imported_agent_name}") + ignore_missing( + lambda: Agent.delete_agent(imported_agent_name, force=True) + ) + + def test_3311_set_attribute_invalid_key(team): log_step(f"Setting invalid attribute key on team: {team.team_name}") expect_error("ORA-20053", lambda: team.set_attribute("no_such_attr", "x")) @@ -320,14 +398,21 @@ def test_3312_set_attribute_none(team): def test_3313_set_attribute_empty(team): - log_step(f"Setting team attribute 'process' to empty string: {team.team_name}") + log_step( + f"Setting team attribute 'process' to empty string: {team.team_name}" + ) expect_error("ORA-20053", lambda: team.set_attribute("process", "")) log_ok("Set attribute empty confirmed error") def test_3314_set_attribute_invalid_value(team): - log_step(f"Setting team attribute 'process' to invalid value: {team.team_name}") - expect_error("ORA-20053", lambda: team.set_attribute("process", "not_a_real_process")) + log_step( + f"Setting team attribute 'process' to invalid value: {team.team_name}" + ) + expect_error( + "ORA-20053", + lambda: team.set_attribute("process", "not_a_real_process"), + ) log_ok("Set attribute invalid value confirmed error") @@ -385,7 +470,10 @@ def test_3319_create_existing_without_replace(team_attributes): t1 = Team(name, team_attributes, "TMP1") t1.create(replace=False) log_step(f"Attempting to create existing team without replace: {name}") - expect_error("ORA-20053", lambda: Team(name, team_attributes, "TMP2").create(replace=False)) + expect_error( + "ORA-20053", + lambda: Team(name, team_attributes, "TMP2").create(replace=False), + ) t1.delete(force=True) log_ok("Create existing without replace confirmed error") diff --git a/tests/conftest.py b/tests/conftest.py index a191ef5..139dca1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -90,24 +90,6 @@ def _grant_http_access(cur, username: str, provider_endpoint: str): ) -def _append_host_ace(cur, host: str, privileges, username: str): - privilege_list = ",".join([f"'{p}'" for p in privileges]) - cur.execute( - f""" - BEGIN - DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( - host => '{host}', - ace => xs$ace_type( - privilege_list => xs$name_list({privilege_list}), - principal_name => '{username}', - principal_type => xs_acl.ptype_db - ) - ); - END; - """ - ) - - def get_env_value(name, default_value=None, required=False): """ Returns the value of the environment variable if it is present and the @@ -261,35 +243,42 @@ def allow_network_acl(test_env): email_smtp_host = get_env_value("EMAIL_SMTPHOST") http_hosts = ["api.openai.com", "a.co", "amazon.in"] - with oracledb.connect(**test_env.connect_params(admin=True)) as conn: - cur = conn.cursor() - try: - if email_smtp_host: - try: - _append_host_ace( - cur, email_smtp_host, ["connect", "smtp"], username - ) - except Exception as exc: - msg = str(exc) - if ( - "ORA-46212" not in msg - and "ORA-46313" not in msg - and "already exists" not in msg - ): - raise - - for host in http_hosts: - try: - _append_host_ace(cur, host, ["connect", "http"], username) - except Exception as exc: - msg = str(exc) - if ( - "ORA-46212" not in msg - and "ORA-46313" not in msg - and "already exists" not in msg - ): - raise - finally: - cur.close() + select_ai.disconnect() + select_ai.create_pool(**test_env.connect_params(admin=True, use_pool=True)) + try: + if email_smtp_host: + try: + select_ai.grant_network_access( + users=username, + host=email_smtp_host, + privileges=["connect", "smtp"], + ) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" not in msg + and "ORA-46313" not in msg + and "already exists" not in msg + ): + raise + + for host in http_hosts: + try: + select_ai.grant_network_access( + users=username, + host=host, + privileges=["connect", "http"], + ) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" not in msg + and "ORA-46313" not in msg + and "already exists" not in msg + ): + raise + finally: + select_ai.disconnect() + select_ai.create_pool(**test_env.connect_params(use_pool=True)) yield diff --git a/tests/profiles/test_1200_profile.py b/tests/profiles/test_1200_profile.py index 030d006..16d5626 100644 --- a/tests/profiles/test_1200_profile.py +++ b/tests/profiles/test_1200_profile.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -225,6 +225,14 @@ def test_1207(): assert profile.attributes.provider.model == "meta.llama-3.1-70b-instruct" +def test_1207_profile_attributes_set_provider_attribute_without_provider(): + """Set a provider attribute when provider is not specified""" + attributes = ProfileAttributes() + attributes.set_attribute("model", "meta.llama-3.1-70b-instruct") + + assert attributes.provider.model == "meta.llama-3.1-70b-instruct" + + def test_1208(oci_credential, oci_compartment_id): """Set multiple attributes for a Profile""" profile = Profile(PYSAI_1200_PROFILE) diff --git a/tests/profiles/test_1600_generate.py b/tests/profiles/test_1600_generate.py index b130805..d9661ff 100644 --- a/tests/profiles/test_1600_generate.py +++ b/tests/profiles/test_1600_generate.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -271,6 +271,19 @@ def test_1615_generate_explainsql(generate_profile): assert len(explain_sql) > 0 +def test_1616_chat_stream(generate_profile): + """chat with stream=True returns text chunks""" + logger.info("Validating chat with stream=True returns text chunks") + chunks = generate_profile.chat( + prompt="What is OCI ?", stream=True, chunk_size=1024 + ) + response = "".join(chunks) + logger.debug("Response = %s", response) + assert isinstance(response, str) + assert len(response) > 0 + assert "Oracle Cloud Infrastructure" in response + + def test_1616_empty_prompt_raises_value_error(negative_profile): """Empty prompts raise ValueError for profile methods""" logger.info( diff --git a/tests/profiles/test_1700_generate_async.py b/tests/profiles/test_1700_generate_async.py index 2f2b9e0..fdb1d9b 100644 --- a/tests/profiles/test_1700_generate_async.py +++ b/tests/profiles/test_1700_generate_async.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. +# Copyright (c) 2025, 2026, Oracle and/or its affiliates. # # Licensed under the Universal Permissive License v 1.0 as shown at # http://oss.oracle.com/licenses/upl. @@ -301,6 +301,23 @@ async def test_1715_generate_explainsql(async_generate_profile): assert len(explain_sql) > 0 +@pytest.mark.anyio +async def test_1716_chat_stream(async_generate_profile): + """chat with stream=True returns text chunks""" + logger.info("Validating async chat with stream=True returns text chunks") + chunks = await async_generate_profile.chat( + prompt="What is OCI ?", stream=True, chunk_size=1024 + ) + response_chunks = [] + async for chunk in chunks: + response_chunks.append(chunk) + response = "".join(response_chunks) + logger.debug("Response = %s", response) + assert isinstance(response, str) + assert len(response) > 0 + assert "Oracle Cloud Infrastructure" in response + + @pytest.mark.anyio async def test_1716_empty_prompt_raises_value_error(async_negative_profile): """Empty prompts raise ValueError for async profile methods""" diff --git a/tests/test_1150_privilege.py b/tests/test_1150_privilege.py new file mode 100644 index 0000000..2ab6827 --- /dev/null +++ b/tests/test_1150_privilege.py @@ -0,0 +1,153 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1150 - Privilege API tests +""" + +import uuid + +import pytest +import select_ai + + +def _network_ace_exists(cursor, host, principal, privilege, lower_port=None): + cursor.execute( + """ + SELECT COUNT(*) + FROM DBA_HOST_ACES + WHERE host = :host + AND principal = :principal + AND privilege = :privilege + AND (lower_port = :lower_port OR (lower_port IS NULL AND :lower_port IS NULL)) + """, + host=host, + principal=principal, + privilege=privilege, + lower_port=lower_port, + ) + return cursor.fetchone()[0] > 0 + + +@pytest.fixture +def admin_connect(test_env): + select_ai.disconnect() + select_ai.create_pool(**test_env.connect_params(admin=True, use_pool=True)) + yield + select_ai.disconnect() + select_ai.create_pool(**test_env.connect_params(use_pool=True)) + + +@pytest.fixture +async def async_admin_connect(test_env): + await select_ai.async_disconnect() + select_ai.create_pool_async( + **test_env.connect_params(admin=True, use_pool=True) + ) + yield + await select_ai.async_disconnect() + select_ai.create_pool_async(**test_env.connect_params(use_pool=True)) + + +def test_1150_grant_network_access(admin_connect, test_env): + host = f"pysai-{uuid.uuid4().hex}.example.com" + principal = test_env.test_user.upper() + + select_ai.grant_network_access( + users=test_env.test_user, + host=host, + privileges=["connect", "smtp"], + lower_port=587, + upper_port=587, + ) + + with select_ai.cursor() as cursor: + assert _network_ace_exists(cursor, host, principal, "CONNECT", 587) + assert _network_ace_exists(cursor, host, principal, "SMTP", 587) + + +def test_1151_revoke_network_access(admin_connect, test_env): + host = f"pysai-{uuid.uuid4().hex}.example.com" + principal = test_env.test_user.upper() + + select_ai.grant_network_access( + users=test_env.test_user, + host=host, + privileges=["connect", "smtp"], + lower_port=587, + upper_port=587, + ) + select_ai.revoke_network_access( + users=test_env.test_user, + host=host, + privileges=["connect", "smtp"], + lower_port=587, + upper_port=587, + ) + + with select_ai.cursor() as cursor: + assert not _network_ace_exists(cursor, host, principal, "CONNECT", 587) + assert not _network_ace_exists(cursor, host, principal, "SMTP", 587) + + +@pytest.mark.anyio +async def test_1152_async_grant_network_access(async_admin_connect, test_env): + host = f"pysai-{uuid.uuid4().hex}.example.com" + principal = test_env.test_user.upper() + + await select_ai.async_grant_network_access( + users=test_env.test_user, + host=host, + privileges="connect", + ) + + async with select_ai.async_cursor() as cursor: + await cursor.execute( + """ + SELECT COUNT(*) + FROM DBA_HOST_ACES + WHERE host = :host + AND principal = :principal + AND privilege = 'CONNECT' + """, + host=host, + principal=principal, + ) + count = await cursor.fetchone() + assert count[0] > 0 + + +@pytest.mark.anyio +async def test_1153_async_revoke_network_access(async_admin_connect, test_env): + host = f"pysai-{uuid.uuid4().hex}.example.com" + principal = test_env.test_user.upper() + + await select_ai.async_grant_network_access( + users=test_env.test_user, + host=host, + privileges="connect", + ) + await select_ai.async_revoke_network_access( + users=test_env.test_user, + host=host, + privileges="connect", + ) + + async with select_ai.async_cursor() as cursor: + await cursor.execute( + """ + SELECT COUNT(*) + FROM DBA_HOST_ACES + WHERE host = :host + AND principal = :principal + AND privilege = 'CONNECT' + """, + host=host, + principal=principal, + ) + count = await cursor.fetchone() + assert count[0] == 0