diff --git a/.gitignore b/.gitignore index b80f310..d2ff10f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ my_examples +logs/ .idea +.env .venv .DS_Store select_ai.egg-info diff --git a/tests/conftest.py b/tests/conftest.py index f9c54f0..1dbea20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,35 @@ PYSAI_TEST_USER = "PYSAI_TEST_USER" PYSAI_OCI_CREDENTIAL_NAME = f"PYSAI_OCI_CREDENTIAL_{uuid.uuid4().hex.upper()}" +_BASIC_SCHEMA_PRIVILEGES = ( + "CREATE SESSION", + "CREATE TABLE", + "UNLIMITED TABLESPACE", +) + + +def _ensure_test_user_exists(username: str, password: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + cr.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if cr.fetchone(): + return + escaped_password = password.replace('"', '""') + cr.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) + select_ai.db.get_connection().commit() + + +def _grant_basic_schema_privileges(username: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + for privilege in _BASIC_SCHEMA_PRIVILEGES: + cr.execute(f"GRANT {privilege} TO {username_upper}") + select_ai.db.get_connection().commit() def get_env_value(name, default_value=None, required=False): @@ -93,6 +122,11 @@ def test_env(pytestconfig): @pytest.fixture(autouse=True, scope="session") def setup_test_user(test_env): select_ai.connect(**test_env.connect_params(admin=True)) + _ensure_test_user_exists( + username=test_env.test_user, + password=test_env.test_user_password, + ) + _grant_basic_schema_privileges(username=test_env.test_user) select_ai.grant_privileges(users=[test_env.test_user]) select_ai.grant_http_access( users=[test_env.test_user], diff --git a/tests/gsd/conftest.py b/tests/gsd/conftest.py new file mode 100644 index 0000000..f5df91f --- /dev/null +++ b/tests/gsd/conftest.py @@ -0,0 +1,56 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +from pathlib import Path + +import pytest +import select_ai + +LOG_FORMAT = "%(levelname)s: [%(name)s] %(message)s" + + +def _configure_logger(logger: logging.Logger, module_file: str) -> None: + logger.setLevel(logging.DEBUG) + log_dir = Path(__file__).resolve().parents[2] / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / f"tkex_{Path(module_file).stem}.log" + + formatter = logging.Formatter(fmt=LOG_FORMAT) + + file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + console_handler.setFormatter(formatter) + + logger.handlers.clear() + logger.propagate = False + logger.addHandler(file_handler) + logger.addHandler(console_handler) + logger.info("Configured logging for module") + + +@pytest.fixture(scope="module", autouse=True) +def configure_module_logging(request): + module = request.module + logger = logging.getLogger(module.__name__) + _configure_logger(logger, module.__file__) + yield + for handler in logger.handlers: + handler.close() + logger.handlers.clear() + + +@pytest.fixture(autouse=True) +def log_test_case(request, configure_module_logging): + logger = logging.getLogger(request.module.__name__) + logger.info("Starting test %s", request.node.name) + yield + logger.info("Finished test %s", request.node.name) diff --git a/tests/gsd/test_2000_synthetic_data.py b/tests/gsd/test_2000_synthetic_data.py new file mode 100644 index 0000000..89b753e --- /dev/null +++ b/tests/gsd/test_2000_synthetic_data.py @@ -0,0 +1,191 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +2000 - Synthetic data generation tests +""" + +import logging +import uuid + +import pytest +import select_ai +from oracledb import DatabaseError +from select_ai import ( + Profile, + ProfileAttributes, + SyntheticDataAttributes, + SyntheticDataParams, +) + +logger = logging.getLogger(__name__) + +PROFILE_PREFIX = f"PYSAI_2000_{uuid.uuid4().hex.upper()}" + + +def _build_attributes(record_count=1, **kwargs): + logger.debug( + "Building synthetic data attributes with record_count=%s and extras=%s", + record_count, + kwargs, + ) + return SyntheticDataAttributes( + object_name="people", + record_count=record_count, + **kwargs, + ) + + +@pytest.fixture(scope="module") +def synthetic_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def synthetic_profile_attributes(oci_credential, synthetic_provider): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=synthetic_provider, + ) + + +@pytest.fixture(scope="module") +def synthetic_profile(synthetic_profile_attributes): + logger.info("Creating synthetic profile %s", f"{PROFILE_PREFIX}_SYNC") + profile = Profile( + profile_name=f"{PROFILE_PREFIX}_SYNC", + attributes=synthetic_profile_attributes, + description="Synthetic data test profile", + replace=True, + ) + yield profile + try: + logger.info("Deleting synthetic profile %s", profile.profile_name) + profile.delete(force=True) + except Exception: + logger.warning("Failed to delete profile %s", profile.profile_name) + pass + + +def test_2000_generate_with_full_params(synthetic_profile): + """Generate synthetic data with full parameter set""" + logger.info( + "Generating synthetic data with full params for profile %s", + synthetic_profile.profile_name, + ) + params = SyntheticDataParams(sample_rows=10, priority="HIGH") + attributes = _build_attributes( + record_count=5, + params=params, + user_prompt="age must be greater than 20", + ) + logger.info("Attributes = %s", attributes) + assert attributes.record_count is 5 + result = synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +def test_2001_generate_minimum_fields(synthetic_profile): + """Generate synthetic data with minimum fields""" + logger.info("Generating synthetic data with minimum fields") + attributes = _build_attributes() + logger.info("Attributes = %s", attributes) + result = synthetic_profile.generate_synthetic_data(attributes) + logger.info("Result = %s", result) + assert result is None + + +def test_2002_generate_zero_sample_rows(synthetic_profile): + """Generate synthetic data with zero sample rows""" + logger.info("Generating synthetic data with zero sample rows") + params = SyntheticDataParams(sample_rows=0, priority="HIGH") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + assert attributes.params.sample_rows is 0 + result = synthetic_profile.generate_synthetic_data(attributes) + logger.info("Result = %s", result) + assert result is None + + +def test_2003_generate_single_sample_row(synthetic_profile): + """Generate synthetic data with single sample row""" + logger.info("Generating synthetic data with single sample row") + params = SyntheticDataParams(sample_rows=1, priority="HIGH") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + assert attributes.params.sample_rows is 1 + result = synthetic_profile.generate_synthetic_data(attributes) + logger.info("Result = %s", result) + assert result is None + + +def test_2004_generate_low_priority(synthetic_profile): + """Generate synthetic data with low priority""" + logger.info("Generating synthetic data with low priority") + params = SyntheticDataParams(sample_rows=1, priority="LOW") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + assert attributes.params.sample_rows is 1 + assert attributes.params.priority is "LOW" + result = synthetic_profile.generate_synthetic_data(attributes) + logger.info("Result = %s", result) + assert result is None + + +def test_2005_generate_missing_object_name(synthetic_profile): + """Missing object_name raises error""" + logger.info("Validating missing object_name raises ValueError") + attributes = SyntheticDataAttributes(record_count=1) + with pytest.raises( + ValueError, match="One of object_name and object_list should be set" + ): + synthetic_profile.generate_synthetic_data(attributes) + + +def test_2006_generate_invalid_priority(synthetic_profile): + """Invalid priority raises error""" + logger.info("Validating invalid priority raises error") + params = SyntheticDataParams(sample_rows=1, priority="CRITICAL") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + with pytest.raises(DatabaseError) as exc_info: + synthetic_profile.generate_synthetic_data(attributes) + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20000 + assert "Invalid value for priority" in error.message + + +def test_2007_generate_negative_record_count(synthetic_profile): + """Negative record count raises error""" + logger.info("Validating negative record count raises error") + attributes = _build_attributes(record_count=-5) + logger.info("Attributes = %s", attributes) + with pytest.raises(DatabaseError) as exc_info: + synthetic_profile.generate_synthetic_data(attributes) + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20000 + assert "record_count cannot be smaller than 0" in error.message + + +def test_2008_generate_with_none_attributes(synthetic_profile): + """Passing None as attributes raises error""" + logger.info("Validating None attributes raise ValueError") + with pytest.raises( + ValueError, match="'synthetic_data_attributes' cannot be None" + ): + synthetic_profile.generate_synthetic_data(None) diff --git a/tests/gsd/test_2100_synthetic_data_async.py b/tests/gsd/test_2100_synthetic_data_async.py new file mode 100644 index 0000000..1c40878 --- /dev/null +++ b/tests/gsd/test_2100_synthetic_data_async.py @@ -0,0 +1,205 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +2100 - Synthetic data generation tests (async) +""" + +import logging +import uuid + +import pytest +import select_ai +from oracledb import DatabaseError +from select_ai import ( + AsyncProfile, + ProfileAttributes, + SyntheticDataAttributes, + SyntheticDataParams, +) + +logger = logging.getLogger(__name__) + +PROFILE_PREFIX = f"PYSAI_2100_{uuid.uuid4().hex.upper()}" + + +def _build_attributes(record_count=1, **kwargs): + logger.debug( + "Building async synthetic data attributes with record_count=%s and extras=%s", + record_count, + kwargs, + ) + return SyntheticDataAttributes( + object_name="people", + record_count=record_count, + **kwargs, + ) + + +@pytest.fixture(scope="module") +def async_synthetic_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def async_synthetic_profile_attributes( + oci_credential, async_synthetic_provider +): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=async_synthetic_provider, + ) + + +@pytest.fixture(scope="module") +async def async_synthetic_profile(async_synthetic_profile_attributes): + logger.info( + "Creating async synthetic profile %s", f"{PROFILE_PREFIX}_ASYNC" + ) + profile = await AsyncProfile( + profile_name=f"{PROFILE_PREFIX}_ASYNC", + attributes=async_synthetic_profile_attributes, + description="Synthetic data async test profile", + replace=True, + ) + yield profile + try: + logger.info( + "Deleting async synthetic profile %s", profile.profile_name + ) + await profile.delete(force=True) + except Exception: + logger.warning( + "Failed to delete async synthetic profile %s", profile.profile_name + ) + pass + + +@pytest.mark.anyio +async def test_2100_generate_with_full_params(async_synthetic_profile): + """Generate synthetic data with full parameter set""" + logger.info( + "Generating async synthetic data with full params for profile %s", + async_synthetic_profile.profile_name, + ) + params = SyntheticDataParams(sample_rows=10, priority="HIGH") + attributes = _build_attributes( + record_count=5, + params=params, + user_prompt="age must be greater than 20", + ) + logger.info("Attributes = %s", attributes) + assert attributes.record_count is 5 + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2101_generate_minimum_fields(async_synthetic_profile): + """Generate synthetic data with minimum fields""" + logger.info("Generating async synthetic data with minimum fields") + attributes = _build_attributes() + logger.info("Attributes = %s", attributes) + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2102_generate_zero_sample_rows(async_synthetic_profile): + """Generate synthetic data with zero sample rows""" + logger.info("Generating async synthetic data with zero sample rows") + params = SyntheticDataParams(sample_rows=0, priority="HIGH") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + assert attributes.params.sample_rows is 0 + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2103_generate_single_sample_row(async_synthetic_profile): + """Generate synthetic data with single sample row""" + logger.info("Generating async synthetic data with single sample row") + params = SyntheticDataParams(sample_rows=1, priority="HIGH") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + assert attributes.params.sample_rows is 1 + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2104_generate_low_priority(async_synthetic_profile): + """Generate synthetic data with low priority""" + logger.info("Generating async synthetic data with low priority") + params = SyntheticDataParams(sample_rows=1, priority="LOW") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + assert attributes.params.sample_rows is 1 + assert attributes.params.priority is "LOW" + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2105_generate_missing_object_name(async_synthetic_profile): + """Missing object_name raises error""" + logger.info("Validating async missing object_name raises error") + attributes = SyntheticDataAttributes(record_count=1) + logger.info("Attributes = %s", attributes) + with pytest.raises( + ValueError, match="One of object_name and object_list should be set" + ): + await async_synthetic_profile.generate_synthetic_data(attributes) + + +@pytest.mark.anyio +async def test_2106_generate_invalid_priority(async_synthetic_profile): + """Invalid priority raises error""" + logger.info("Validating async invalid priority raises error") + params = SyntheticDataParams(sample_rows=1, priority="CRITICAL") + attributes = _build_attributes(params=params) + logger.info("Attributes = %s", attributes) + with pytest.raises(DatabaseError) as exc_info: + await async_synthetic_profile.generate_synthetic_data(attributes) + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20000 + assert "Invalid value for priority" in error.message + + +@pytest.mark.anyio +async def test_2107_generate_negative_record_count(async_synthetic_profile): + """Negative record count raises error""" + logger.info("Validating async negative record count raises error") + attributes = _build_attributes(record_count=-5) + logger.info("Attributes = %s", attributes) + with pytest.raises(DatabaseError) as exc_info: + await async_synthetic_profile.generate_synthetic_data(attributes) + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20000 + assert "record_count cannot be smaller than 0" in error.message + + +@pytest.mark.anyio +async def test_2108_generate_with_none_attributes(async_synthetic_profile): + """Passing None as attributes raises error""" + logger.info("Validating async None attributes raise error") + with pytest.raises( + ValueError, match="'synthetic_data_attributes' cannot be None" + ): + await async_synthetic_profile.generate_synthetic_data(None) diff --git a/tests/profiles/conftest.py b/tests/profiles/conftest.py index 63ea821..601cc2f 100644 --- a/tests/profiles/conftest.py +++ b/tests/profiles/conftest.py @@ -5,9 +5,56 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- +import logging +from pathlib import Path + import pytest import select_ai +LOG_FORMAT = "%(levelname)s: [%(name)s] %(message)s" + + +def _configure_logger(logger: logging.Logger, module_file: str) -> None: + logger.setLevel(logging.DEBUG) + log_dir = Path(__file__).resolve().parents[2] / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / f"tkex_{Path(module_file).stem}.log" + + formatter = logging.Formatter(fmt=LOG_FORMAT) + + file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + console_handler.setFormatter(formatter) + + logger.handlers.clear() + logger.propagate = False + logger.addHandler(file_handler) + logger.addHandler(console_handler) + logger.info("Configured logging for module") + + +@pytest.fixture(scope="module", autouse=True) +def configure_module_logging(request): + module = request.module + logger = logging.getLogger(module.__name__) + _configure_logger(logger, module.__file__) + yield + for handler in logger.handlers: + handler.close() + logger.handlers.clear() + + +@pytest.fixture(autouse=True) +def log_test_case(request, configure_module_logging): + logger = logging.getLogger(request.module.__name__) + logger.info("Starting test %s", request.node.name) + yield + logger.info("Finished test %s", request.node.name) + @pytest.fixture(scope="module") def provider(): diff --git a/tests/profiles/test_1200_profile.py b/tests/profiles/test_1200_profile.py index 5226af7..e4d327e 100644 --- a/tests/profiles/test_1200_profile.py +++ b/tests/profiles/test_1200_profile.py @@ -9,6 +9,7 @@ 1200 - Module for testing the Profile proxy object """ import collections +import logging import uuid import oracledb @@ -16,6 +17,8 @@ import select_ai from select_ai import Profile, ProfileAttributes +logger = logging.getLogger(__name__) + PYSAI_1200_PROFILE = f"PYSAI_1200_{uuid.uuid4().hex.upper()}" PYSAI_1200_PROFILE_2 = f"PYSAI_1200_2_{uuid.uuid4().hex.upper()}" PYSAI_1200_MIN_ATTR_PROFILE = f"PYSAI_1200_MIN_{uuid.uuid4().hex.upper()}" @@ -24,45 +27,59 @@ @pytest.fixture(scope="module") def python_gen_ai_profile(profile_attributes): + logger.info("Creating profile %s", PYSAI_1200_PROFILE) profile = select_ai.Profile( profile_name=PYSAI_1200_PROFILE, description="OCI GENAI Profile", attributes=profile_attributes, ) yield profile + logger.info("Deleting profile %s", profile.profile_name) profile.delete(force=True) @pytest.fixture(scope="module") def python_gen_ai_profile_2(profile_attributes): + logger.info("Creating profile %s", PYSAI_1200_PROFILE_2) profile = select_ai.Profile( profile_name=PYSAI_1200_PROFILE_2, description="OCI GENAI Profile 2", attributes=profile_attributes, ) profile.create(replace=True) + logger.debug("Profile = \n %s", profile) yield profile + logger.info("Deleting profile %s", profile.profile_name) profile.delete(force=True) @pytest.fixture(scope="module") def python_gen_ai_min_attr_profile(min_profile_attributes): + logger.info( + "Creating profile with minimum attributes %s", + PYSAI_1200_MIN_ATTR_PROFILE, + ) profile = select_ai.Profile( profile_name=PYSAI_1200_MIN_ATTR_PROFILE, attributes=min_profile_attributes, description=None, ) + logger.debug("Profile = \n %s", profile) yield profile + logger.info("Deleting minimum attributes profile %s", profile.profile_name) profile.delete(force=True) @pytest.fixture def python_gen_ai_duplicate_profile(min_profile_attributes): + logger.info("Creating duplicate profile %s", PYSAI_1200_DUP_PROFILE) profile = Profile( profile_name=PYSAI_1200_DUP_PROFILE, attributes=min_profile_attributes, ) + logger.debug("Profile = \n %s", profile) yield profile + logger.info("Deleting duplicate profile %s", profile.profile_name) profile.delete(force=True) @@ -206,7 +223,7 @@ def test_1207(): assert profile.attributes.provider.model == "meta.llama-3.1-70b-instruct" -def test_1208(oci_credential): +def test_1208(oci_credential, oci_compartment_id): """Set multiple attributes for a Profile""" profile = Profile(PYSAI_1200_PROFILE) profile_attrs = ProfileAttributes( @@ -214,6 +231,7 @@ def test_1208(oci_credential): provider=select_ai.OCIGenAIProvider( model="meta.llama-4-maverick-17b-128e-instruct-fp8", region="us-chicago-1", + oci_compartment_id=oci_compartment_id, oci_apiformat="GENERIC", ), object_list=[{"owner": "ADMIN", "name": "gymnasts"}], diff --git a/tests/profiles/test_1300_profile_async.py b/tests/profiles/test_1300_profile_async.py index 4f43d81..7626f74 100644 --- a/tests/profiles/test_1300_profile_async.py +++ b/tests/profiles/test_1300_profile_async.py @@ -9,6 +9,7 @@ 1300 - Module for testing the AsyncProfile proxy object """ import collections +import logging import uuid import oracledb @@ -16,6 +17,7 @@ import select_ai from select_ai import AsyncProfile, ProfileAttributes +logger = logging.getLogger(__name__) PYSAI_ASYNC_1300_PROFILE = f"PYSAI_ASYNC_1300_{uuid.uuid4().hex.upper()}" PYSAI_ASYNC_1300_PROFILE_2 = f"PYSAI_ASYNC_1300_2_{uuid.uuid4().hex.upper()}" PYSAI_ASYNC_1300_MIN_ATTR_PROFILE = ( @@ -28,45 +30,62 @@ @pytest.fixture(scope="module") async def python_gen_ai_profile(profile_attributes): + logger.info("Creating async profile %s", PYSAI_ASYNC_1300_PROFILE) profile = await AsyncProfile( profile_name=PYSAI_ASYNC_1300_PROFILE, description="OCI GENAI Profile", attributes=profile_attributes, ) + logger.debug("AsyncProfile = \n %s", profile) yield profile + logger.info("Deleting async profile %s", profile.profile_name) await profile.delete(force=True) @pytest.fixture(scope="module") async def python_gen_ai_profile_2(profile_attributes): + logger.info("Creating async profile %s", PYSAI_ASYNC_1300_PROFILE_2) profile = await AsyncProfile( profile_name=PYSAI_ASYNC_1300_PROFILE_2, description="OCI GENAI Profile 2", attributes=profile_attributes, ) await profile.create(replace=True) + logger.debug("AsyncProfile = \n %s", profile) yield profile + logger.info("Deleting async profile %s", profile.profile_name) await profile.delete(force=True) @pytest.fixture(scope="module") async def python_gen_ai_min_attr_profile(min_profile_attributes): + logger.info( + "Creating async profile with minimum attributes %s", + PYSAI_ASYNC_1300_MIN_ATTR_PROFILE, + ) profile = await AsyncProfile( profile_name=PYSAI_ASYNC_1300_MIN_ATTR_PROFILE, attributes=min_profile_attributes, description=None, ) + logger.debug("AsyncProfile = \n %s", profile) yield profile + logger.info("Deleting async profile %s", profile.profile_name) await profile.delete(force=True) @pytest.fixture async def python_gen_ai_duplicate_profile(min_profile_attributes): + logger.info( + "Creating duplicate async profile %s", PYSAI_ASYNC_1300_DUP_PROFILE + ) profile = await AsyncProfile( profile_name=PYSAI_ASYNC_1300_DUP_PROFILE, attributes=min_profile_attributes, ) + logger.debug("AsyncProfile = \n %s", profile) yield profile + logger.info("Cleaning up duplicate async profile %s", profile.profile_name) await profile.delete(force=True) @@ -94,6 +113,10 @@ async def python_gen_ai_neg_feedback(async_cursor, python_gen_ai_profile): await async_cursor.execute(sql_text) feedback_response = "SELECT * from gymnast" feedback_content = "print in ascending order of total_points" + logger.info( + "Adding negative feedback for async profile %s", + python_gen_ai_profile.profile_name, + ) await python_gen_ai_profile.add_negative_feedback( prompt_spec=(prompt, action), response=feedback_response, @@ -106,6 +129,10 @@ async def python_gen_ai_neg_feedback(async_cursor, python_gen_ai_profile): feedback_content=feedback_content, sql_text=sql_text, ) + logger.info( + "Removing negative feedback for async profile %s", + python_gen_ai_profile.profile_name, + ) await python_gen_ai_profile.delete_feedback(prompt_spec=(prompt, action)) @@ -125,6 +152,10 @@ async def python_gen_ai_pos_feedback(async_cursor, python_gen_ai_profile): action = select_ai.Action.SHOWSQL sql_text = f"select ai {action.value} {prompt}" await async_cursor.execute(sql_text) + logger.info( + "Adding positive feedback for async profile %s", + python_gen_ai_profile.profile_name, + ) await python_gen_ai_profile.add_positive_feedback( prompt_spec=(prompt, action), ) @@ -133,11 +164,18 @@ async def python_gen_ai_pos_feedback(async_cursor, python_gen_ai_profile): action=action, sql_text=sql_text, ) + logger.info( + "Removing positive feedback for async profile %s", + python_gen_ai_profile.profile_name, + ) await python_gen_ai_profile.delete_feedback(prompt_spec=(prompt, action)) def test_1300(python_gen_ai_profile, profile_attributes): """Create basic Profile""" + logger.info( + "Validating async profile %s", python_gen_ai_profile.profile_name + ) assert python_gen_ai_profile.profile_name == PYSAI_ASYNC_1300_PROFILE assert python_gen_ai_profile.attributes == profile_attributes assert python_gen_ai_profile.description == "OCI GENAI Profile" @@ -145,6 +183,10 @@ def test_1300(python_gen_ai_profile, profile_attributes): def test_1301(python_gen_ai_profile_2, profile_attributes): """Create Profile using create method""" + logger.info( + "Validating async profile created via create %s", + python_gen_ai_profile_2.profile_name, + ) assert python_gen_ai_profile_2.profile_name == PYSAI_ASYNC_1300_PROFILE_2 assert python_gen_ai_profile_2.attributes == profile_attributes assert python_gen_ai_profile_2.description == "OCI GENAI Profile 2" @@ -152,11 +194,15 @@ def test_1301(python_gen_ai_profile_2, profile_attributes): async def test_1302(profile_attributes): """Create duplicate profile with replace=True""" + logger.info("Creating duplicate async profile with replace=True") duplicate = await AsyncProfile( profile_name=PYSAI_ASYNC_1300_PROFILE, attributes=profile_attributes, replace=True, ) + logger.info( + "Validating duplicate async profile %s", duplicate.profile_name + ) assert duplicate.profile_name == PYSAI_ASYNC_1300_PROFILE assert duplicate.attributes == profile_attributes assert duplicate.description is None @@ -164,6 +210,10 @@ async def test_1302(profile_attributes): def test_1303(python_gen_ai_min_attr_profile, min_profile_attributes): """Create Profile with minimum required attributes""" + logger.info( + "Validating async minimum attribute profile %s", + python_gen_ai_min_attr_profile.profile_name, + ) assert ( python_gen_ai_min_attr_profile.profile_name == PYSAI_ASYNC_1300_MIN_ATTR_PROFILE @@ -174,6 +224,7 @@ def test_1303(python_gen_ai_min_attr_profile, min_profile_attributes): async def test_1304(): """List profiles without regex""" + logger.info("Listing async profiles without regex") profile_list = [profile async for profile in AsyncProfile.list()] profile_names = set(profile.profile_name for profile in profile_list) descriptions = set(profile.description for profile in profile_list) @@ -185,6 +236,7 @@ async def test_1304(): async def test_1305(): """List profiles with regex""" + logger.info("Listing async profiles with regex pattern") profile_list = [ profile async for profile in AsyncProfile.list( @@ -201,6 +253,9 @@ async def test_1305(): async def test_1306(profile_attributes): """Get attributes for a Profile""" + logger.info( + "Fetching attributes for async profile %s", PYSAI_ASYNC_1300_PROFILE + ) profile = await AsyncProfile(PYSAI_ASYNC_1300_PROFILE) fetched_attributes = await profile.get_attributes() assert fetched_attributes == profile_attributes @@ -208,6 +263,10 @@ async def test_1306(profile_attributes): async def test_1307(): """Set attributes for a Profile""" + logger.info( + "Setting single attribute on async profile %s", + PYSAI_ASYNC_1300_PROFILE, + ) profile = await AsyncProfile(PYSAI_ASYNC_1300_PROFILE) assert profile.attributes.provider.model is None await profile.set_attribute( @@ -218,6 +277,10 @@ async def test_1307(): async def test_1308(oci_credential): """Set multiple attributes for a Profile""" + logger.info( + "Setting multiple attributes for async profile %s", + PYSAI_ASYNC_1300_PROFILE, + ) profile = await AsyncProfile(PYSAI_ASYNC_1300_PROFILE) profile_attrs = ProfileAttributes( credential_name=oci_credential["credential_name"], @@ -235,12 +298,19 @@ async def test_1308(oci_credential): ] assert profile.attributes.comments is True fetched_attributes = await profile.get_attributes() + logger.debug( + "Fetched async provider attributes: %s", fetched_attributes.provider + ) assert fetched_attributes == profile_attrs async def test_1309(python_gen_ai_duplicate_profile): """Create duplicate profile without replace""" # expected - ProfileExistsError + logger.info( + "Expecting ProfileExistsError for duplicate async profile %s", + python_gen_ai_duplicate_profile.profile_name, + ) with pytest.raises(select_ai.errors.ProfileExistsError): await AsyncProfile( profile_name=python_gen_ai_duplicate_profile.profile_name, @@ -251,6 +321,10 @@ async def test_1309(python_gen_ai_duplicate_profile): async def test_1310(python_gen_ai_duplicate_profile): """Create duplicate profile with replace=False""" # expected - select_ai.ProfileExistsError + logger.info( + "Expecting ProfileExistsError with replace=False for async profile %s", + python_gen_ai_duplicate_profile.profile_name, + ) with pytest.raises(select_ai.errors.ProfileExistsError): await AsyncProfile( profile_name=python_gen_ai_duplicate_profile.profile_name, @@ -270,6 +344,9 @@ async def test_1310(python_gen_ai_duplicate_profile): async def test_1311(invalid_provider): """Create Profile with invalid providers""" # expected - ValueError + logger.info( + "Validating async invalid provider handling: %s", invalid_provider + ) with pytest.raises(ValueError): await AsyncProfile( profile_name="PYTHON_INVALID_PROFILE", @@ -282,6 +359,7 @@ async def test_1311(invalid_provider): async def test_1312(): # provider=None # expected - ORA-20047: Either provider or provider_endpoint must be specified + logger.info("Validating async provider=None raises DatabaseError") with pytest.raises(oracledb.DatabaseError): await AsyncProfile( profile_name="PYTHON_INVALID_PROFILE", @@ -301,6 +379,10 @@ async def test_1312(): async def test_1313(invalid_profile_name, min_profile_attributes): """Create Profile with empty profile_name""" # expected - ValueError + logger.info( + "Validating async empty profile name handling: %s", + invalid_profile_name, + ) with pytest.raises(ValueError): await AsyncProfile( profile_name=invalid_profile_name, @@ -311,6 +393,7 @@ async def test_1313(invalid_profile_name, min_profile_attributes): async def test_1314(): """List Profile with invalid regex""" # expected - ORA-12726: unmatched bracket in regular expression + logger.info("Validating async invalid regex handling during list") with pytest.raises(oracledb.DatabaseError): profiles = [ await profile @@ -322,6 +405,7 @@ async def test_1314(): async def test_1315(profile_attributes): """Test AsyncProfile.fetch""" + logger.info("Fetching async profile %s", PYSAI_ASYNC_1300_PROFILE_2) async_profile = await AsyncProfile.fetch( profile_name=PYSAI_ASYNC_1300_PROFILE_2 ) @@ -334,6 +418,10 @@ async def test_1316( async_cursor, python_gen_ai_profile, python_gen_ai_neg_feedback ): """Test profile negative feedback""" + logger.info( + "Validating negative feedback persistence for async profile %s", + python_gen_ai_profile.profile_name, + ) await async_cursor.execute( f"select CONTENT, ATTRIBUTES " f"from {python_gen_ai_profile.profile_name.upper()}_FEEDBACK_VECINDEX$VECTAB " @@ -356,6 +444,10 @@ async def test_1317( async_cursor, python_gen_ai_profile, python_gen_ai_pos_feedback ): """Test profile positive feedback""" + logger.info( + "Validating positive feedback persistence for async profile %s", + python_gen_ai_profile.profile_name, + ) await async_cursor.execute( f"select CONTENT, ATTRIBUTES " f"from {python_gen_ai_profile.profile_name.upper()}_FEEDBACK_VECINDEX$VECTAB " @@ -372,6 +464,10 @@ async def test_1317( async def test_1318(python_gen_ai_profile): """Test translate""" + logger.info( + "Testing translate for async profile %s", + python_gen_ai_profile.profile_name, + ) response = await python_gen_ai_profile.translate( text="Thank you", source_language="en", target_language="de" ) diff --git a/tests/profiles/test_1400_conversation.py b/tests/profiles/test_1400_conversation.py new file mode 100644 index 0000000..1cc132c --- /dev/null +++ b/tests/profiles/test_1400_conversation.py @@ -0,0 +1,278 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1400 - Conversation API tests +""" + +import logging +import uuid + +import pytest +import select_ai +from oracledb import DatabaseError +from select_ai import Conversation, ConversationAttributes + +logger = logging.getLogger(__name__) + +CONVERSATION_PREFIX = f"PYSAI_1400_{uuid.uuid4().hex.upper()}" + + +@pytest.fixture +def conversation_factory(): + created = [] + + def _create(**kwargs): + logger.info("Creating conversation with params %s", kwargs) + attributes = ConversationAttributes(**kwargs) + conv = Conversation(attributes=attributes) + conv.create() + created.append(conv) + return conv + + yield _create + + for conv in created: + logger.info("Deleting conversation %s", conv.conversation_id) + conv.delete(force=True) + + +@pytest.fixture +def conversation(conversation_factory): + logger.info("Creating default conversation instance") + return conversation_factory(title=f"{CONVERSATION_PREFIX}_ACTIVE") + + +def test_1400_create_with_title(conversation): + """Create a conversation with title""" + logger.info("Validating conversation creation with title") + logger.info("Conversation = %s", conversation) + assert conversation.conversation_id + + +def test_1401_create_with_description(conversation_factory): + """Create a conversation with title and description""" + logger.info("Creating conversation with title and description") + conv = conversation_factory( + title=f"{CONVERSATION_PREFIX}_HISTORY", + description="LLM's understanding of history of science", + ) + logger.info("Conversation = %s", conv) + attrs = conv.get_attributes() + logger.debug("Fetched attributes: %s", attrs) + assert attrs.title == f"{CONVERSATION_PREFIX}_HISTORY" + assert attrs.description == "LLM's understanding of history of science" + + +def test_1402_create_without_title(conversation_factory): + """Create a conversation without providing a title""" + logger.info("Creating conversation without explicit title") + conv = conversation_factory() + logger.info("Conversation = %s", conv) + attrs = conv.get_attributes() + logger.debug("Fetched attributes: %s", attrs) + assert attrs.title == "New Conversation" + + +def test_1403_create_with_missing_attributes(): + """Missing attributes raise AttributeError""" + logger.info("Validating missing attributes raise AttributeError") + conv = Conversation(attributes=None) + logger.info("Conversation = %s", conv) + with pytest.raises( + AttributeError, match="'NoneType' object has no attribute 'json'" + ): + conv.create() + + +def test_1404_get_attributes(conversation): + """Fetch conversation attributes""" + logger.info( + "Fetching attributes for conversation %s", conversation.conversation_id + ) + attrs = conversation.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_ACTIVE" + assert attrs.description is None + + +def test_1405_set_attributes(conversation): + """Update conversation attributes""" + logger.info( + "Updating conversation attributes for %s", conversation.conversation_id + ) + updated = ConversationAttributes( + title=f"{CONVERSATION_PREFIX}_UPDATED", + description="Updated Description", + ) + conversation.set_attributes(updated) + attrs = conversation.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_UPDATED" + assert attrs.description == "Updated Description" + + +def test_1406_set_attributes_with_none(conversation): + """Setting empty attributes raises AttributeError""" + logger.info( + "Validating setting None attributes raises AttributeError for %s", + conversation.conversation_id, + ) + with pytest.raises( + AttributeError, match="'NoneType' object has no attribute 'json'" + ): + conversation.set_attributes(None) + + +def test_1407_delete_conversation(conversation_factory): + """Delete conversation and validate removal""" + logger.info("Creating conversation to validate deletion") + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_DELETE") + conv.delete(force=True) + with pytest.raises(select_ai.errors.ConversationNotFoundError): + conv.get_attributes() + + +def test_1408_delete_twice(conversation_factory): + """Deleting an already deleted conversation raises DatabaseError""" + logger.info("Validating double deletion raises DatabaseError") + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_DELETE_TWICE") + conv.delete(force=True) + with pytest.raises(DatabaseError) as exc_info: + conv.delete() + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert "does not exist" in error.message + + +def test_1409_list_contains_created_conversation(conversation): + """Conversation list contains the created conversation""" + logger.info("Ensuring conversation list includes created conversation") + conversation_ids = {item.conversation_id for item in Conversation.list()} + assert conversation.conversation_id in conversation_ids + + +def test_1410_multiple_conversations_have_unique_ids(conversation_factory): + """Multiple conversations produce unique identifiers""" + logger.info("Creating multiple conversations to verify unique IDs") + titles = [ + f"{CONVERSATION_PREFIX}_AI", + f"{CONVERSATION_PREFIX}_DB", + f"{CONVERSATION_PREFIX}_MATH", + ] + conversations = [conversation_factory(title=title) for title in titles] + ids = {conv.conversation_id for conv in conversations} + assert len(ids) == len(titles) + + +def test_1411_create_with_long_values(): + """Creating conversation with overly long values fails""" + logger.info("Validating long attribute values trigger failure") + conv = Conversation( + attributes=ConversationAttributes( + title="A" * 255, + description="B" * 1000, + ) + ) + with pytest.raises(DatabaseError) as exc_info: + conv.create() + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert ( + "Value is too long for conversation attribute - title" in error.message + ) + + +def test_1412_set_attributes_with_invalid_id(): + """Updating conversation with invalid id raises DatabaseError""" + logger.info( + "Validating set_attributes with invalid ID raises DatabaseError" + ) + conv = Conversation(conversation_id="fake_id") + with pytest.raises(DatabaseError) as exc_info: + conv.set_attributes(ConversationAttributes(title="Invalid")) + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert "Invalid value for conversation id" in error.message + + +def test_1413_delete_with_invalid_id(): + """Deleting conversation with invalid id raises DatabaseError""" + logger.info("Validating delete with invalid ID raises DatabaseError") + conv = Conversation(conversation_id="fake_id") + with pytest.raises(DatabaseError) as exc_info: + conv.delete() + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert "Invalid value for conversation id" in error.message + + +def test_1414_get_attributes_with_invalid_id(): + """Fetching attributes for invalid conversation raises ConversationNotFound""" + logger.info( + "Validating get_attributes with invalid ID raises ConversationNotFound" + ) + conv = Conversation(conversation_id="invalid") + with pytest.raises( + select_ai.errors.ConversationNotFoundError, match="not found" + ): + conv.get_attributes() + + +def test_1415_get_attributes_for_deleted_conversation(conversation_factory): + """Fetching attributes after deletion raises ConversationNotFound""" + logger.info("Validating get_attributes after deletion raises error") + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_TO_DELETE") + conv.delete(force=True) + with pytest.raises( + select_ai.errors.ConversationNotFoundError, match="not found" + ): + conv.get_attributes() + + +def test_1416_list_contains_new_conversation(conversation_factory): + """List reflects newly created conversation""" + logger.info("Ensuring list reflects newly created conversation") + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_LIST") + listed = list(Conversation.list()) + logger.info("List = %s", listed) + assert any(item.conversation_id == conv.conversation_id for item in listed) + + +def test_1417_list_returns_conversation_instances(): + """List returns Conversation objects""" + logger.info("Validating Conversation.list returns Conversation instances") + listed = list(Conversation.list()) + logger.info("List = %s", listed) + assert all(isinstance(item, Conversation) for item in listed) + + +def test_1418_get_attributes_without_description(conversation_factory): + """Conversation created without description has None description""" + logger.info("Creating conversation without description") + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_NO_DESC") + attrs = conv.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_NO_DESC" + assert attrs.description is None + + +def test_1419_create_with_description_none(conversation_factory): + """Explicitly setting description to None is allowed""" + logger.info("Creating conversation with description explicitly None") + conv = conversation_factory( + title=f"{CONVERSATION_PREFIX}_NONE_DESC", + description=None, + ) + attrs = conv.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_NONE_DESC" + assert attrs.description is None diff --git a/tests/profiles/test_1500_conversation_async.py b/tests/profiles/test_1500_conversation_async.py new file mode 100644 index 0000000..cca8f84 --- /dev/null +++ b/tests/profiles/test_1500_conversation_async.py @@ -0,0 +1,327 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1500 - AsyncConversation API tests +""" + +import logging +import uuid + +import pytest +import select_ai +from oracledb import DatabaseError +from select_ai import AsyncConversation, ConversationAttributes + +logger = logging.getLogger(__name__) + +CONVERSATION_PREFIX = f"PYSAI_1500_{uuid.uuid4().hex.upper()}" + + +@pytest.fixture +async def async_conversation_factory(): + created = [] + + async def _create(**kwargs): + logger.info("Creating async conversation with params %s", kwargs) + attributes = ConversationAttributes(**kwargs) + conversation = AsyncConversation(attributes=attributes) + await conversation.create() + created.append(conversation) + return conversation + + yield _create + + for conversation in created: + logger.info( + "Deleting async conversation %s", conversation.conversation_id + ) + await conversation.delete(force=True) + + +@pytest.fixture +async def async_conversation(async_conversation_factory): + logger.info("Creating default async conversation instance") + return await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_ACTIVE" + ) + + +@pytest.mark.anyio +async def test_1500_create_with_title(async_conversation): + """Create an async conversation with title""" + logger.info("Validating async conversation creation with title") + assert async_conversation.conversation_id + + +@pytest.mark.anyio +async def test_1501_create_with_description(async_conversation_factory): + """Create an async conversation with title and description""" + logger.info("Creating async conversation with title and description") + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_HISTORY", + description="LLM's understanding of history of science", + ) + attributes = await conversation.get_attributes() + logger.debug("Fetched async conversation attributes: %s", attributes) + assert attributes.title == f"{CONVERSATION_PREFIX}_HISTORY" + assert ( + attributes.description == "LLM's understanding of history of science" + ) + + +@pytest.mark.anyio +async def test_1502_create_without_title(async_conversation_factory): + """Create an async conversation without providing a title""" + logger.info("Creating async conversation without explicit title") + conversation = await async_conversation_factory() + attributes = await conversation.get_attributes() + assert attributes.title == "New Conversation" + + +@pytest.mark.anyio +async def test_1503_create_with_missing_attributes(): + """Missing attributes raise AttributeError""" + logger.info("Validating missing async attributes raise AttributeError") + conversation = AsyncConversation(attributes=None) + with pytest.raises(AttributeError): + await conversation.create() + + +@pytest.mark.anyio +async def test_1504_get_attributes(async_conversation): + """Fetch async conversation attributes""" + logger.info( + "Fetching attributes for async conversation %s", + async_conversation.conversation_id, + ) + attributes = await async_conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_ACTIVE" + assert attributes.description is None + + +@pytest.mark.anyio +async def test_1505_set_attributes(async_conversation): + """Update async conversation attributes""" + logger.info( + "Updating async conversation attributes for %s", + async_conversation.conversation_id, + ) + updated = ConversationAttributes( + title=f"{CONVERSATION_PREFIX}_UPDATED", + description="Updated Description", + ) + await async_conversation.set_attributes(updated) + attributes = await async_conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_UPDATED" + assert attributes.description == "Updated Description" + + +@pytest.mark.anyio +async def test_1506_set_attributes_with_none(async_conversation): + """Setting empty attributes raises AttributeError""" + logger.info( + "Validating async set_attributes(None) raises AttributeError for %s", + async_conversation.conversation_id, + ) + with pytest.raises( + AttributeError, match="'NoneType' object has no attribute 'json'" + ): + await async_conversation.set_attributes(None) + + +@pytest.mark.anyio +async def test_1507_delete_conversation(async_conversation_factory): + """Delete async conversation and validate removal""" + logger.info("Creating async conversation to validate deletion") + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_DELETE" + ) + await conversation.delete(force=True) + with pytest.raises(select_ai.errors.ConversationNotFoundError): + await conversation.get_attributes() + + +@pytest.mark.anyio +async def test_1508_delete_twice(async_conversation_factory): + """Deleting an already deleted async conversation raises DatabaseError""" + logger.info( + "Validating double deletion raises DatabaseError for async conversation" + ) + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_DELETE_TWICE" + ) + await conversation.delete(force=True) + with pytest.raises(DatabaseError) as exc_info: + await conversation.delete() + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert "does not exist" in error.message + + +@pytest.mark.anyio +async def test_1509_list_contains_created_conversation(async_conversation): + """Async conversation list contains the created conversation""" + logger.info( + "Ensuring async conversation list includes created conversation" + ) + ids = {item.conversation_id async for item in AsyncConversation.list()} + assert async_conversation.conversation_id in ids + + +@pytest.mark.anyio +async def test_1510_multiple_conversations_have_unique_ids( + async_conversation_factory, +): + """Multiple async conversations produce unique identifiers""" + logger.info("Creating multiple async conversations to verify unique IDs") + titles = [ + f"{CONVERSATION_PREFIX}_AI", + f"{CONVERSATION_PREFIX}_DB", + f"{CONVERSATION_PREFIX}_MATH", + ] + conversations = [ + await async_conversation_factory(title=title) for title in titles + ] + ids = {conversation.conversation_id for conversation in conversations} + assert len(ids) == len(titles) + + +@pytest.mark.anyio +async def test_1511_create_with_long_values(): + """Creating async conversation with overly long values fails""" + logger.info("Validating long attribute values trigger async failure") + conversation = AsyncConversation( + attributes=ConversationAttributes( + title="A" * 255, + description="B" * 1000, + ) + ) + with pytest.raises(DatabaseError) as exc_info: + await conversation.create() + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert ( + "Value is too long for conversation attribute - title" in error.message + ) + + +@pytest.mark.anyio +async def test_1512_set_attributes_with_invalid_id(): + """Updating async conversation with invalid id raises DatabaseError""" + logger.info( + "Validating async set_attributes invalid ID raises DatabaseError" + ) + conversation = AsyncConversation(conversation_id="fake_id") + with pytest.raises(DatabaseError) as exc_info: + await conversation.set_attributes( + ConversationAttributes(title="Invalid") + ) + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert "Invalid value for conversation id" in error.message + + +@pytest.mark.anyio +async def test_1513_delete_with_invalid_id(): + """Deleting async conversation with invalid id raises DatabaseError""" + logger.info("Validating async delete invalid ID raises DatabaseError") + conversation = AsyncConversation(conversation_id="fake_id") + with pytest.raises(DatabaseError) as exc_info: + await conversation.delete() + (error,) = exc_info.value.args + logger.debug("Error code: %s", error.code) + logger.debug("Error message:\n%s", error.message) + assert error.code == 20050 + assert "Invalid value for conversation id" in error.message + + +@pytest.mark.anyio +async def test_1514_get_attributes_with_invalid_id(): + """Fetching attributes for invalid async conversation raises ConversationNotFound""" + logger.info( + "Validating async get_attributes with invalid ID raises ConversationNotFound" + ) + conversation = AsyncConversation(conversation_id="invalid") + with pytest.raises( + select_ai.errors.ConversationNotFoundError, match="not found" + ): + await conversation.get_attributes() + + +@pytest.mark.anyio +async def test_1515_get_attributes_for_deleted_conversation( + async_conversation_factory, +): + """Fetching attributes after deletion raises ConversationNotFound""" + logger.info("Validating async get_attributes after deletion raises error") + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_TO_DELETE" + ) + await conversation.delete(force=True) + with pytest.raises( + select_ai.errors.ConversationNotFoundError, match="not found" + ): + await conversation.get_attributes() + + +@pytest.mark.anyio +async def test_1516_list_contains_new_conversation(async_conversation_factory): + """List reflects newly created async conversation""" + logger.info("Ensuring async list reflects newly created conversation") + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_LIST" + ) + listed = [item async for item in AsyncConversation.list()] + logger.info("List = %s", listed) + assert any( + item.conversation_id == conversation.conversation_id for item in listed + ) + + +@pytest.mark.anyio +async def test_1517_list_returns_async_conversation_instances(): + """List returns AsyncConversation objects""" + logger.info( + "Validating AsyncConversation.list returns AsyncConversation instances" + ) + listed = [item async for item in AsyncConversation.list()] + logger.info("List = %s", listed) + assert all(isinstance(item, AsyncConversation) for item in listed) + + +@pytest.mark.anyio +async def test_1518_get_attributes_without_description( + async_conversation_factory, +): + """Async conversation created without description has None description""" + logger.info("Creating async conversation without description") + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_NO_DESC" + ) + attributes = await conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_NO_DESC" + assert attributes.description is None + + +@pytest.mark.anyio +async def test_1519_create_with_description_none(async_conversation_factory): + """Explicitly setting description to None is allowed""" + logger.info("Creating async conversation with description explicitly None") + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_NONE_DESC", + description=None, + ) + attributes = await conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_NONE_DESC" + assert attributes.description is None diff --git a/tests/profiles/test_1600_generate.py b/tests/profiles/test_1600_generate.py new file mode 100644 index 0000000..b130805 --- /dev/null +++ b/tests/profiles/test_1600_generate.py @@ -0,0 +1,327 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1600 - Profile generate API tests +""" + +import json +import logging +import uuid + +import oracledb +import pandas as pd +import pytest +import select_ai +from select_ai import ( + Conversation, + ConversationAttributes, + Profile, + ProfileAttributes, +) +from select_ai.profile import Action + +logger = logging.getLogger(__name__) + +PROFILE_PREFIX = f"PYSAI_1600_{uuid.uuid4().hex.upper()}" + +PROMPTS = [ + "What is a database?", + "How many gymnasts in database?", + "How many people are there in the database?", +] + + +@pytest.fixture(scope="module") +def generate_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def generate_profile_attributes(test_env, oci_credential, generate_provider): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ], + provider=generate_provider, + ) + + +@pytest.fixture(scope="module") +def generate_profile(generate_profile_attributes): + logger.info("Creating generate profile %s", f"{PROFILE_PREFIX}_POSITIVE") + profile = Profile( + profile_name=f"{PROFILE_PREFIX}_POSITIVE", + attributes=generate_profile_attributes, + description="Generate Calls Test Profile", + replace=True, + ) + profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + logger.info("Deleting generate profile %s", profile.profile_name) + profile.delete(force=True) + + +@pytest.fixture +def negative_profile(test_env, oci_credential, generate_provider): + logger.info("Creating negative generate profile") + profile_name = f"{PROFILE_PREFIX}_NEG_{uuid.uuid4().hex.upper()}" + attributes = ProfileAttributes( + credential_name=oci_credential["credential_name"], + provider=generate_provider, + ) + profile = Profile( + profile_name=profile_name, + attributes=attributes, + description="Generate Calls Negative Test Profile", + replace=True, + ) + profile.set_attribute( + attribute_name="object_list", + attribute_value=json.dumps( + [ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ] + ), + ) + profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + logger.info("Deleting negative generate profile %s", profile.profile_name) + profile.delete(force=True) + + +def test_1600_action_enum_members(): + """Validate Action enum exposes expected members""" + logger.info("Validating Action enum exposes expected members") + for member in [ + "RUNSQL", + "SHOWSQL", + "EXPLAINSQL", + "NARRATE", + "CHAT", + "SHOWPROMPT", + ]: + assert hasattr(Action, member) + + +def test_1601_action_enum_values(): + """Validate Action enum values""" + logger.info("Validating Action enum values") + assert Action.RUNSQL.value == "runsql" + assert Action.SHOWSQL.value == "showsql" + assert Action.EXPLAINSQL.value == "explainsql" + assert Action.NARRATE.value == "narrate" + assert Action.CHAT.value == "chat" + + +def test_1602_action_from_string(): + """Validate Action enum construction from string""" + logger.info("Validating Action enum from string conversions") + assert Action("runsql") is Action.RUNSQL + assert Action("chat") is Action.CHAT + assert Action("explainsql") is Action.EXPLAINSQL + assert Action("narrate") is Action.NARRATE + assert Action("showsql") is Action.SHOWSQL + + +def test_1603_action_invalid_string(): + """Invalid enum string raises ValueError""" + logger.info("Validating invalid Action string raises ValueError") + with pytest.raises(ValueError): + Action("invalid_action") + + +def test_1604_show_sql(generate_profile): + """show_sql returns SQL text""" + logger.info("Validating show_sql returns SQL text") + for prompt in PROMPTS[1:]: + show_sql = generate_profile.show_sql(prompt=prompt) + logger.debug("Response = %s", show_sql) + assert isinstance(show_sql, str) + assert "SELECT" in show_sql.upper() + + +def test_1605_show_prompt(generate_profile): + """show_prompt returns prompt text""" + logger.info("Validating show_prompt returns text") + for prompt in PROMPTS: + show_prompt = generate_profile.show_prompt(prompt=prompt) + logger.debug("Response = %s", show_prompt) + assert isinstance(show_prompt, str) + assert len(show_prompt) > 0 + assert '"type" : "TEXT"' in show_prompt + + +def test_1606_run_sql(generate_profile): + """run_sql returns DataFrame""" + logger.info("Validating run_sql returns DataFrame") + df = generate_profile.run_sql(prompt=PROMPTS[1]) + logger.debug("Response = %s", df) + assert isinstance(df, pd.DataFrame) + assert len(df.columns) > 0 + + +def test_1607_chat(generate_profile): + """chat returns text response""" + logger.info("Validating chat returns text response") + response = generate_profile.chat(prompt="What is OCI ?") + logger.debug("Response = %s", response) + assert isinstance(response, str) + assert len(response) > 0 + assert "Oracle Cloud Infrastructure" in response + + +def test_1608_narrate(generate_profile): + """narrate returns narrative text""" + logger.info("Validating narrate returns narrative text") + for prompt in PROMPTS[1:]: + narration = generate_profile.narrate(prompt=prompt) + logger.info("Response = %s", narration) + assert isinstance(narration, str) + assert len(narration) > 0 + assert "in the database" in narration + + +def test_1609_chat_session(generate_profile): + """chat_session provides a session context""" + logger.info("Validating chat_session context manager") + conversation = Conversation(attributes=ConversationAttributes()) + with generate_profile.chat_session( + conversation=conversation, delete=True + ) as session: + assert session is not None + + +def test_1610_explain_sql(generate_profile): + """explain_sql returns explanation text""" + logger.info("Validating explain_sql returns explanation text") + for prompt in PROMPTS: + explain_sql = generate_profile.explain_sql(prompt=prompt) + logger.debug("Response = %s", explain_sql) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +def test_1611_generate_runsql(generate_profile): + """generate with RUNSQL returns DataFrame""" + logger.info("Validating generate with RUNSQL returns DataFrame") + df = generate_profile.generate(prompt=PROMPTS[1], action=Action.RUNSQL) + logger.debug("Response = %s", df) + assert isinstance(df, pd.DataFrame) + + +def test_1612_generate_showsql(generate_profile): + """generate with SHOWSQL returns SQL""" + logger.info("Validating generate with SHOWSQL returns SQL") + sql = generate_profile.generate(prompt=PROMPTS[1], action=Action.SHOWSQL) + logger.debug("Response = %s", sql) + assert isinstance(sql, str) + assert "SELECT" in sql.upper() + + +def test_1613_generate_chat(generate_profile): + """generate with CHAT returns response""" + logger.info("Validating generate with CHAT returns response") + chat_resp = generate_profile.generate( + prompt="Tell me about OCI", action=Action.CHAT + ) + logger.debug("Response = %s", chat_resp) + assert isinstance(chat_resp, str) + assert len(chat_resp) > 0 + assert "Oracle Cloud Infrastructure" in chat_resp + + +def test_1614_generate_narrate(generate_profile): + """generate with NARRATE returns response""" + logger.info("Validating generate with NARRATE returns response") + narrate_resp = generate_profile.generate( + prompt=PROMPTS[1], action=Action.NARRATE + ) + logger.debug("Response = %s", narrate_resp) + assert isinstance(narrate_resp, str) + assert len(narrate_resp) > 0 + assert "in the database" in narrate_resp + + +def test_1615_generate_explainsql(generate_profile): + """generate with EXPLAINSQL returns explanation""" + logger.info("Validating generate with EXPLAINSQL returns explanation") + for prompt in PROMPTS: + explain_sql = generate_profile.generate( + prompt=prompt, action=Action.EXPLAINSQL + ) + logger.debug("Response = %s", explain_sql) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +def test_1616_empty_prompt_raises_value_error(negative_profile): + """Empty prompts raise ValueError for profile methods""" + logger.info( + "Validating empty prompts raise ValueError for generate methods" + ) + with pytest.raises(ValueError): + negative_profile.chat(prompt="") + with pytest.raises(ValueError): + negative_profile.narrate(prompt="") + with pytest.raises(ValueError): + negative_profile.show_sql(prompt="") + with pytest.raises(ValueError): + negative_profile.show_prompt(prompt="") + with pytest.raises(ValueError): + negative_profile.run_sql(prompt="") + with pytest.raises(ValueError): + negative_profile.explain_sql(prompt="") + + +def test_1617_none_prompt_raises_value_error(negative_profile): + """None prompts raise ValueError for profile methods""" + logger.info( + "Validating None prompts raise ValueError for generate methods" + ) + with pytest.raises(ValueError): + negative_profile.chat(prompt=None) + with pytest.raises(ValueError): + negative_profile.narrate(prompt=None) + with pytest.raises(ValueError): + negative_profile.show_sql(prompt=None) + with pytest.raises(ValueError): + negative_profile.show_prompt(prompt=None) + with pytest.raises(ValueError): + negative_profile.run_sql(prompt=None) + with pytest.raises(ValueError): + negative_profile.explain_sql(prompt=None) + + +# def test_1618_run_sql_with_ambiguous_prompt(negative_profile): +# """Ambiguous prompt raises DatabaseError for run_sql""" +# with pytest.raises(oracledb.DatabaseError): +# negative_profile.run_sql(prompt="delete data from user") + + +# def test_1619_run_sql_with_invalid_object_list(negative_profile, test_env): +# """run_sql with non existent table raises DatabaseError""" +# negative_profile.set_attribute( +# attribute_name="object_list", +# attribute_value=json.dumps( +# [{"owner": test_env.test_user, "name": "non_existent_table"}] +# ), +# ) +# with pytest.raises(oracledb.DatabaseError): +# negative_profile.run_sql(prompt="How many entries in the table") diff --git a/tests/profiles/test_1700_generate_async.py b/tests/profiles/test_1700_generate_async.py new file mode 100644 index 0000000..2f2b9e0 --- /dev/null +++ b/tests/profiles/test_1700_generate_async.py @@ -0,0 +1,357 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1700 - AsyncProfile generate API tests +""" + +import json +import logging +import uuid + +import oracledb +import pandas as pd +import pytest +import select_ai +from select_ai import ( + AsyncConversation, + AsyncProfile, + ConversationAttributes, + ProfileAttributes, +) +from select_ai.profile import Action + +logger = logging.getLogger(__name__) + +PROFILE_PREFIX = f"PYSAI_1700_{uuid.uuid4().hex.upper()}" + +PROMPTS = [ + "What is a database?", + "How many gymnasts in database?", + "How many people are in the database?", +] + + +@pytest.fixture(scope="module") +def async_generate_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def async_generate_profile_attributes( + oci_credential, async_generate_provider, test_env +): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ], + provider=async_generate_provider, + ) + + +@pytest.fixture(scope="module") +async def async_generate_profile(async_generate_profile_attributes): + logger.info( + "Creating async generate profile %s", f"{PROFILE_PREFIX}_POSITIVE" + ) + profile = await AsyncProfile( + profile_name=f"{PROFILE_PREFIX}_POSITIVE", + attributes=async_generate_profile_attributes, + description="Async generate calls test profile", + replace=True, + ) + await profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + logger.info("Deleting async generate profile %s", profile.profile_name) + await profile.delete(force=True) + + +@pytest.fixture +async def async_negative_profile( + oci_credential, async_generate_provider, test_env +): + logger.info("Creating async negative generate profile") + profile_name = f"{PROFILE_PREFIX}_NEG_{uuid.uuid4().hex.upper()}" + attributes = ProfileAttributes( + credential_name=oci_credential["credential_name"], + provider=async_generate_provider, + ) + profile = await AsyncProfile( + profile_name=profile_name, + attributes=attributes, + description="Async generate calls negative test profile", + replace=True, + ) + await profile.set_attribute( + attribute_name="object_list", + attribute_value=json.dumps( + [ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ] + ), + ) + await profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + logger.info( + "Deleting async negative generate profile %s", profile.profile_name + ) + await profile.delete(force=True) + + +@pytest.mark.anyio +async def test_1700_action_enum_members(): + """Validate Action enum exposes expected members""" + logger.info("Validating async Action enum exposes expected members") + for member in [ + "RUNSQL", + "SHOWSQL", + "EXPLAINSQL", + "NARRATE", + "CHAT", + "SHOWPROMPT", + ]: + assert hasattr(Action, member) + + +@pytest.mark.anyio +async def test_1701_action_enum_values(): + """Validate Action enum values""" + logger.info("Validating async Action enum values") + assert Action.RUNSQL.value == "runsql" + assert Action.SHOWSQL.value == "showsql" + assert Action.EXPLAINSQL.value == "explainsql" + assert Action.NARRATE.value == "narrate" + assert Action.CHAT.value == "chat" + + +@pytest.mark.anyio +async def test_1702_action_from_string(): + """Validate Action enum construction from string""" + logger.info("Validating async Action enum from string conversions") + assert Action("runsql") is Action.RUNSQL + assert Action("chat") is Action.CHAT + assert Action("explainsql") is Action.EXPLAINSQL + assert Action("narrate") is Action.NARRATE + assert Action("showsql") is Action.SHOWSQL + + +@pytest.mark.anyio +async def test_1703_action_invalid_string(): + """Invalid enum string raises ValueError""" + logger.info("Validating async invalid Action string raises ValueError") + with pytest.raises(ValueError): + Action("invalid_action") + + +@pytest.mark.anyio +async def test_1704_show_sql(async_generate_profile): + """show_sql returns SQL text""" + logger.info("Validating async show_sql returns SQL text") + for prompt in PROMPTS[1:]: + show_sql = await async_generate_profile.show_sql(prompt=prompt) + logger.debug("Response = %s", show_sql) + assert isinstance(show_sql, str) + assert "SELECT" in show_sql.upper() + + +@pytest.mark.anyio +async def test_1705_show_prompt(async_generate_profile): + """show_prompt returns prompt text""" + logger.info("Validating async show_prompt returns text") + for prompt in PROMPTS: + show_prompt = await async_generate_profile.show_prompt(prompt=prompt) + logger.debug("Response = %s", show_prompt) + assert isinstance(show_prompt, str) + assert len(show_prompt) > 0 + assert '"type" : "TEXT"' in show_prompt + + +@pytest.mark.anyio +async def test_1706_run_sql(async_generate_profile): + """run_sql returns DataFrame""" + logger.info("Validating async run_sql returns DataFrame") + dataframe = await async_generate_profile.run_sql(prompt=PROMPTS[1]) + logger.debug("Response = %s", dataframe) + assert isinstance(dataframe, pd.DataFrame) + assert len(dataframe.columns) > 0 + + +@pytest.mark.anyio +async def test_1707_chat(async_generate_profile): + """chat returns text response""" + logger.info("Validating async chat returns text response") + response = await async_generate_profile.chat(prompt="What is OCI ?") + logger.debug("Response = %s", response) + assert isinstance(response, str) + assert len(response) > 0 + assert "Oracle Cloud Infrastructure" in response + + +@pytest.mark.anyio +async def test_1708_narrate(async_generate_profile): + """narrate returns narrative text""" + logger.info("Validating async narrate returns narrative text") + for prompt in PROMPTS[1:0]: + narration = await async_generate_profile.narrate(prompt=prompt) + logger.info("Response = %s", narration) + assert isinstance(narration, str) + assert len(narration) > 0 + assert "in the database" in narration + + +@pytest.mark.anyio +async def test_1709_chat_session(async_generate_profile): + """chat_session provides a session context""" + logger.info("Validating async chat_session context manager") + conversation = AsyncConversation(attributes=ConversationAttributes()) + async with async_generate_profile.chat_session( + conversation=conversation, delete=True + ) as session: + assert session is not None + + +@pytest.mark.anyio +async def test_1710_explain_sql(async_generate_profile): + """explain_sql returns explanation text""" + logger.info("Validating async explain_sql returns explanation text") + for prompt in PROMPTS: + explain_sql = await async_generate_profile.explain_sql(prompt=prompt) + logger.debug("Response = %s", explain_sql) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +@pytest.mark.anyio +async def test_1711_generate_runsql(async_generate_profile): + """generate with RUNSQL returns DataFrame""" + logger.info("Validating async generate with RUNSQL returns DataFrame") + dataframe = await async_generate_profile.generate( + prompt=PROMPTS[1], action=Action.RUNSQL + ) + logger.debug("Response = %s", dataframe) + assert isinstance(dataframe, pd.DataFrame) + + +@pytest.mark.anyio +async def test_1712_generate_showsql(async_generate_profile): + """generate with SHOWSQL returns SQL""" + logger.info("Validating async generate with SHOWSQL returns SQL") + sql = await async_generate_profile.generate( + prompt=PROMPTS[1], action=Action.SHOWSQL + ) + logger.debug("Response = %s", sql) + assert isinstance(sql, str) + assert "SELECT" in sql.upper() + + +@pytest.mark.anyio +async def test_1713_generate_chat(async_generate_profile): + """generate with CHAT returns response""" + logger.info("Validating async generate with CHAT returns response") + chat_response = await async_generate_profile.generate( + prompt="Tell me about OCI", action=Action.CHAT + ) + logger.debug("Response = %s", chat_response) + assert isinstance(chat_response, str) + assert len(chat_response) > 0 + assert "Oracle Cloud Infrastructure" in chat_response + + +@pytest.mark.anyio +async def test_1714_generate_narrate(async_generate_profile): + """generate with NARRATE returns response""" + logger.info("Validating async generate with NARRATE returns response") + narrate_response = await async_generate_profile.generate( + prompt=PROMPTS[1], action=Action.NARRATE + ) + logger.debug("Response = %s", narrate_response) + assert isinstance(narrate_response, str) + assert len(narrate_response) > 0 + assert "in the database" in narrate_response + + +@pytest.mark.anyio +async def test_1715_generate_explainsql(async_generate_profile): + """generate with EXPLAINSQL returns explanation""" + logger.info( + "Validating async generate with EXPLAINSQL returns explanation" + ) + for prompt in PROMPTS: + explain_sql = await async_generate_profile.generate( + prompt=prompt, action=Action.EXPLAINSQL + ) + logger.debug("Response = %s", explain_sql) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +@pytest.mark.anyio +async def test_1716_empty_prompt_raises_value_error(async_negative_profile): + """Empty prompts raise ValueError for async profile methods""" + logger.info("Validating async empty prompts raise ValueError") + with pytest.raises(ValueError): + await async_negative_profile.chat(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.narrate(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.show_sql(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.show_prompt(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.run_sql(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.explain_sql(prompt="") + + +@pytest.mark.anyio +async def test_1717_none_prompt_raises_value_error(async_negative_profile): + """None prompts raise ValueError for async profile methods""" + logger.info("Validating async None prompts raise ValueError") + with pytest.raises(ValueError): + await async_negative_profile.chat(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.narrate(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.show_sql(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.show_prompt(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.run_sql(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.explain_sql(prompt=None) + + +# @pytest.mark.anyio +# async def test_1718_run_sql_with_ambiguous_prompt(async_negative_profile): +# """Ambiguous prompt raises DatabaseError for run_sql""" +# with pytest.raises(oracledb.DatabaseError): +# await async_negative_profile.run_sql(prompt="select from user") + + +# @pytest.mark.anyio +# async def test_1719_run_sql_with_invalid_object_list(async_negative_profile, test_env): +# """run_sql with non existent table raises DatabaseError""" +# await async_negative_profile.set_attribute( +# attribute_name="object_list", +# attribute_value=json.dumps( +# [{"owner": test_env.test_user, "name": "non_existent_table"}] +# ), +# ) +# with pytest.raises(oracledb.DatabaseError): +# await async_negative_profile.run_sql(prompt="How many entries in the table") diff --git a/tests/profiles/test_1800_chat_session.py b/tests/profiles/test_1800_chat_session.py new file mode 100644 index 0000000..4fcbadd --- /dev/null +++ b/tests/profiles/test_1800_chat_session.py @@ -0,0 +1,279 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1800 - Chat session API tests +""" + +import logging +import uuid + +import pytest +import select_ai +from select_ai import ( + Conversation, + ConversationAttributes, + Profile, + ProfileAttributes, +) + +logger = logging.getLogger(__name__) + +PROFILE_PREFIX = f"PYSAI_1800_{uuid.uuid4().hex.upper()}" + +CATEGORY_PROMPTS = { + "database": [ + ("What is a database?", "database"), + ("Explain the difference between SQL and NoSQL.", "sql"), + ("Give me an example of a SQL SELECT query.", "select"), + ("How do transactions ensure consistency?", "transaction"), + ("What are indexes and why are they used?", "index"), + ], + "cloud": [ + ("What is cloud computing?", "cloud"), + ("Explain IaaS, PaaS, and SaaS briefly.", "iaas"), + ("What is the benefit of auto-scaling?", "scaling"), + ("How do cloud regions and availability zones differ?", "region"), + ("What is serverless computing?", "serverless"), + ], + "ai": [ + ("What is artificial intelligence?", "intelligence"), + ("Explain supervised vs unsupervised learning.", "supervised"), + ("What are neural networks?", "neural"), + ("How does reinforcement learning work?", "reinforcement"), + ("Give me a real-world use case of AI.", "ai"), + ], + "physics": [ + ("What is Newton's first law?", "newton"), + ("Explain the concept of gravity.", "gravity"), + ("How does friction affect motion?", "friction"), + ("What is the difference between speed and velocity?", "velocity"), + ("Explain kinetic and potential energy with examples.", "energy"), + ], + "general": [ + ("What is the capital of Japan?", "tokyo"), + ("Tell me a fun fact about space.", "space"), + ("Who invented the telephone?", "telephone"), + ("What is the fastest land animal?", "cheetah"), + ("Explain why the sky looks blue.", "sky"), + ], +} + + +@pytest.fixture(scope="module") +def chat_session_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def chat_session_profile(oci_credential, chat_session_provider): + logger.info( + "Creating chat session profile %s", f"{PROFILE_PREFIX}_PROFILE" + ) + profile = Profile( + profile_name=f"{PROFILE_PREFIX}_PROFILE", + attributes=ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=chat_session_provider, + ), + description="Chat session test profile", + replace=True, + ) + profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + logger.info("Deleting chat session profile %s", profile.profile_name) + profile.delete(force=True) + + +@pytest.fixture +def conversation_factory(): + conversations = [] + + def _create(**kwargs): + logger.info("Creating conversation with params %s", kwargs) + conversation = Conversation( + attributes=ConversationAttributes(**kwargs) + ) + conversation.create() + conversations.append(conversation) + return conversation + + yield _create + + for conversation in conversations: + logger.info("Deleting conversation %s", conversation.conversation_id) + conversation.delete(force=True) + + +def _assert_keywords(session, prompts): + for prompt, keyword in prompts: + response = session.chat(prompt=prompt) + logger.debug("Received response for prompt '%s': %s", prompt, response) + assert keyword.lower() in response.lower() + + +def test_1800_database_chat_session( + chat_session_profile, conversation_factory +): + """Chat session processes database prompts""" + logger.info("Starting database chat session test") + conversation = conversation_factory( + title="Database", + description="LLM's understanding of databases", + ) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session: + logger.info( + "Chat session started with conversation %s", + conversation.conversation_id, + ) + assert session is not None + _assert_keywords(session, CATEGORY_PROMPTS["database"]) + + +def test_1801_physics_chat_session_delete_true( + chat_session_profile, conversation_factory +): + """Chat session deletes conversation when delete=True""" + logger.info("Starting physics chat session with delete=True") + conversation = conversation_factory(title="Physics") + with chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + logger.info( + "Chat session started for conversation %s with delete=True", + conversation.conversation_id, + ) + _assert_keywords(session, CATEGORY_PROMPTS["physics"]) + with pytest.raises(Exception): + conversation.delete() + + +def test_1802_multiple_sessions_same_conversation( + chat_session_profile, conversation_factory +): + """Same conversation supports multiple chat sessions""" + logger.info("Validating multiple sessions for same conversation") + conversation = conversation_factory( + title="Cloud Two Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + with chat_session_profile.chat_session( + conversation=conversation + ) as session_one: + logger.info( + "First session started for conversation %s", + conversation.conversation_id, + ) + _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + with chat_session_profile.chat_session( + conversation=conversation + ) as session_two: + logger.info( + "Second session started for conversation %s", + conversation.conversation_id, + ) + _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + + +def test_1803_many_sessions_same_conversation( + chat_session_profile, conversation_factory +): + """Conversation reused across several sessions""" + logger.info("Validating many sessions for same conversation") + conversation = conversation_factory( + title="Multi Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_one: + logger.info( + "Session one started for conversation %s", + conversation.conversation_id, + ) + _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_two: + logger.info( + "Session two started for conversation %s", + conversation.conversation_id, + ) + _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_three: + logger.info( + "Session three started for conversation %s", + conversation.conversation_id, + ) + _assert_keywords(session_three, CATEGORY_PROMPTS["ai"][:3]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_four: + logger.info( + "Session four started for conversation %s", + conversation.conversation_id, + ) + _assert_keywords(session_four, CATEGORY_PROMPTS["ai"][3:]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_five: + logger.info( + "Session five started for conversation %s", + conversation.conversation_id, + ) + _assert_keywords(session_five, CATEGORY_PROMPTS["general"]) + + +def test_1804_special_characters(chat_session_profile, conversation_factory): + """Chat session handles special characters""" + logger.info("Validating special character handling in chat session") + conversation = conversation_factory( + title="Special Character Test ✨😊你", + description="♥️✨你好", + ) + with chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + logger.info( + "Chat session started for special character conversation %s", + conversation.conversation_id, + ) + response = session.chat( + prompt="Tell me something with lot of emojis and special characters 🚀🔥" + ) + assert isinstance(response, str) + assert "error" not in response.lower() + + +def test_1805_invalid_conversation_object(chat_session_profile): + """Passing non conversation object raises error""" + logger.info("Validating invalid conversation object handling") + with pytest.raises(Exception): + with chat_session_profile.chat_session(conversation="fake-object"): + pass + + +# def test_1806_missing_conversation_attributes(chat_session_profile): +# """Conversation without attributes raises error""" +# conversation = Conversation(attributes=None) +# with pytest.raises(Exception): +# with chat_session_profile.chat_session(conversation=conversation): +# _assert_keywords(chat_session_profile, [("Hello World", "hello")]) diff --git a/tests/profiles/test_1900_chat_session_async.py b/tests/profiles/test_1900_chat_session_async.py new file mode 100644 index 0000000..1450eac --- /dev/null +++ b/tests/profiles/test_1900_chat_session_async.py @@ -0,0 +1,297 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1900 - Async chat session API tests +""" + +import logging +import uuid + +import pytest +import select_ai +from select_ai import ( + AsyncConversation, + AsyncProfile, + ConversationAttributes, + ProfileAttributes, +) + +logger = logging.getLogger(__name__) + +PROFILE_PREFIX = f"PYSAI_1900_{uuid.uuid4().hex.upper()}" + +CATEGORY_PROMPTS = { + "database": [ + ("What is a database?", "database"), + ("Explain the difference between SQL and NoSQL.", "sql"), + ("Give me an example of a SQL SELECT query.", "select"), + ("How do transactions ensure consistency?", "transaction"), + ("What are indexes and why are they used?", "index"), + ], + "cloud": [ + ("What is cloud computing?", "cloud"), + ("Explain IaaS, PaaS, and SaaS briefly.", "iaas"), + ("What is the benefit of auto-scaling?", "scaling"), + ("How do cloud regions and availability zones differ?", "region"), + ("What is serverless computing?", "serverless"), + ], + "ai": [ + ("What is artificial intelligence?", "intelligence"), + ("Explain supervised vs unsupervised learning.", "supervised"), + ("What are neural networks?", "neural"), + ("How does reinforcement learning work?", "reinforcement"), + ("Give me a real-world use case of AI.", "ai"), + ], + "physics": [ + ("What is Newton's first law?", "newton"), + ("Explain the concept of gravity.", "gravity"), + ("How does friction affect motion?", "friction"), + ("What is the difference between speed and velocity?", "velocity"), + ("Explain kinetic and potential energy with examples.", "energy"), + ], + "general": [ + ("What is the capital of Japan?", "tokyo"), + ("Tell me a fun fact about space.", "space"), + ("Who invented the telephone?", "telephone"), + ("What is the fastest land animal?", "cheetah"), + ("Explain why the sky looks blue.", "sky"), + ], +} + + +@pytest.fixture(scope="module") +def async_chat_session_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +async def async_chat_session_profile( + oci_credential, async_chat_session_provider +): + logger.info( + "Creating async chat session profile %s", f"{PROFILE_PREFIX}_PROFILE" + ) + profile = await AsyncProfile( + profile_name=f"{PROFILE_PREFIX}_PROFILE", + attributes=ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=async_chat_session_provider, + ), + description="Async chat session test profile", + replace=True, + ) + await profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + logger.info("Deleting async chat session profile %s", profile.profile_name) + await profile.delete(force=True) + + +@pytest.fixture +async def async_conversation_factory(): + conversations = [] + + async def _create(**kwargs): + logger.info("Creating async conversation with params %s", kwargs) + conversation = AsyncConversation( + attributes=ConversationAttributes(**kwargs) + ) + await conversation.create() + conversations.append(conversation) + return conversation + + yield _create + + for conversation in conversations: + logger.info( + "Deleting async conversation %s", conversation.conversation_id + ) + await conversation.delete(force=True) + + +async def _assert_keywords(session, prompts): + for prompt, keyword in prompts: + response = await session.chat(prompt=prompt) + logger.debug("Async response for prompt '%s': %s", prompt, response) + assert keyword.lower() in response.lower() + + +@pytest.mark.anyio +async def test_1900_database_chat_session( + async_chat_session_profile, async_conversation_factory +): + """Async chat session processes database prompts""" + logger.info("Starting async database chat session test") + conversation = await async_conversation_factory( + title="Database", + description="LLM's understanding of databases", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session: + logger.info( + "Async chat session started with conversation %s", + conversation.conversation_id, + ) + assert session is not None + await _assert_keywords(session, CATEGORY_PROMPTS["database"]) + + +@pytest.mark.anyio +async def test_1901_physics_chat_session_delete_true( + async_chat_session_profile, async_conversation_factory +): + """Async chat session deletes conversation when delete=True""" + logger.info("Starting async physics chat session with delete=True") + conversation = await async_conversation_factory(title="Physics") + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + logger.info( + "Async chat session started for conversation %s with delete=True", + conversation.conversation_id, + ) + await _assert_keywords(session, CATEGORY_PROMPTS["physics"]) + with pytest.raises(Exception): + await conversation.delete() + + +@pytest.mark.anyio +async def test_1902_multiple_sessions_same_conversation( + async_chat_session_profile, async_conversation_factory +): + """Same async conversation supports multiple chat sessions""" + logger.info("Validating multiple async sessions for same conversation") + conversation = await async_conversation_factory( + title="Cloud Two Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation + ) as session_one: + logger.info( + "Async session one started for conversation %s", + conversation.conversation_id, + ) + await _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + async with async_chat_session_profile.chat_session( + conversation=conversation + ) as session_two: + logger.info( + "Async session two started for conversation %s", + conversation.conversation_id, + ) + await _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + + +@pytest.mark.anyio +async def test_1903_many_sessions_same_conversation( + async_chat_session_profile, async_conversation_factory +): + """Conversation reused across several async sessions""" + logger.info("Validating many async sessions for same conversation") + conversation = await async_conversation_factory( + title="Multi Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_one: + logger.info( + "Async session one started for conversation %s", + conversation.conversation_id, + ) + await _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_two: + logger.info( + "Async session two started for conversation %s", + conversation.conversation_id, + ) + await _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_three: + logger.info( + "Async session three started for conversation %s", + conversation.conversation_id, + ) + await _assert_keywords(session_three, CATEGORY_PROMPTS["ai"][:3]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_four: + logger.info( + "Async session four started for conversation %s", + conversation.conversation_id, + ) + await _assert_keywords(session_four, CATEGORY_PROMPTS["ai"][3:]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_five: + logger.info( + "Async session five started for conversation %s", + conversation.conversation_id, + ) + await _assert_keywords(session_five, CATEGORY_PROMPTS["general"]) + + +@pytest.mark.anyio +async def test_1904_special_characters( + async_chat_session_profile, async_conversation_factory +): + """Async chat session handles special characters""" + logger.info("Validating async special character handling in chat session") + conversation = await async_conversation_factory( + title="Special Character Test ✨😊你", + description="♥️✨你好", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + logger.info( + "Async chat session started for special character conversation %s", + conversation.conversation_id, + ) + response = await session.chat( + prompt="Tell me something with lot of emojis and special characters 🚀🔥" + ) + assert isinstance(response, str) + assert "error" not in response.lower() + + +@pytest.mark.anyio +async def test_1905_invalid_conversation_object(async_chat_session_profile): + """Passing non conversation object raises error""" + with pytest.raises(Exception): + async with async_chat_session_profile.chat_session( + conversation="fake-object" + ): + pass + + +@pytest.mark.anyio +async def test_1906_missing_conversation_attributes( + async_chat_session_profile, +): + """Conversation without attributes raises error""" + conversation = AsyncConversation(attributes=None) + with pytest.raises(Exception): + async with async_chat_session_profile.chat_session( + conversation=conversation + ): + await conversation.chat(prompt="Hello World")