diff --git a/test/test_core.py b/test/test_core.py index f10c807..1940ec8 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -753,7 +753,8 @@ def test_ydb_credentials_bad(self, query_client_settings, driver_config_for_cred with pytest.raises(Exception) as excinfo: with engine.connect() as conn: conn.execute(sa.text("SELECT 1 as value")) - assert "Invalid password" in str(excinfo.value) + error_message = str(excinfo.value) + assert "Invalid password" in error_message or "StaticCredentials" in error_message class TestUpsert(TablesTest): diff --git a/test/test_inspect.py b/test/test_inspect.py index 0d4c9a7..ae32004 100644 --- a/test/test_inspect.py +++ b/test/test_inspect.py @@ -1,3 +1,6 @@ +import posixpath + +import pytest import sqlalchemy as sa from sqlalchemy import Column, Integer, Numeric, Table, Unicode from sqlalchemy.testing.fixtures import TablesTest @@ -14,6 +17,31 @@ def define_tables(cls, metadata): Column("num", Numeric(22, 9)), ) + @pytest.fixture + def test_view(self, connection): + raw_connection = connection.connection + driver_connection = getattr(raw_connection, "driver_connection", raw_connection) + view_name = "test_view" + table_path = posixpath.join(driver_connection.database, driver_connection.table_path_prefix, "test") + cursor = driver_connection.cursor() + try: + try: + cursor.execute_scheme(f"DROP VIEW `{view_name}`") + except Exception: + pass + + cursor.execute_scheme( + f"CREATE VIEW `{view_name}` WITH (security_invoker = TRUE) AS " + f"SELECT `id`, `value`, `num` FROM `{table_path}`" + ) + yield view_name + finally: + try: + cursor.execute_scheme(f"DROP VIEW `{view_name}`") + except Exception: + pass + cursor.close() + def test_get_columns(self, connection): inspect = sa.inspect(connection) @@ -32,3 +60,22 @@ def test_has_table(self, connection): assert inspect.has_table("test") assert not inspect.has_table("foo") + + def test_view_reflection(self, connection, test_view): + view_name = test_view + inspect = sa.inspect(connection) + + assert view_name in inspect.get_view_names() + assert inspect.has_table(view_name) + assert inspect.get_view_definition(view_name).startswith(f"CREATE VIEW `{view_name}`") + + columns = {column["name"]: column for column in inspect.get_columns(view_name)} + assert set(columns) == {"id", "value", "num"} + assert isinstance(columns["id"]["type"], sa.INTEGER) + assert columns["id"]["nullable"] is False + assert isinstance(columns["value"]["type"], sa.TEXT) + assert columns["value"]["nullable"] is True + assert isinstance(columns["num"]["type"], sa.DECIMAL) + assert columns["num"]["type"].precision == 22 + assert columns["num"]["type"].scale == 9 + assert columns["num"]["nullable"] is True diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index f5f8492..788ae18 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -5,6 +5,7 @@ import collections import collections.abc +import re from typing import Any, Mapping, Optional, Sequence, Tuple, Union import sqlalchemy as sa @@ -73,6 +74,13 @@ def upsert(table): ydb.PrimitiveType.DyNumber: sa.TEXT, } +DBAPI_COLUMN_TYPES = { + ydb_type.name: sa_type for ydb_type, sa_type in COLUMN_TYPES.items() if isinstance(ydb_type, ydb.PrimitiveType) +} + + +DECIMAL_DBAPI_TYPE_RE = re.compile(r"^Decimal\((\d+),\s*(\d+)\)$") + def _get_column_info(t): nullable = False @@ -86,6 +94,28 @@ def _get_column_info(t): return COLUMN_TYPES[t], nullable +def _get_column_info_from_dbapi_description(type_name): + nullable = type_name.endswith("?") + if nullable: + type_name = type_name[:-1] + + decimal_match = DECIMAL_DBAPI_TYPE_RE.match(type_name) + if decimal_match: + precision, scale = decimal_match.groups() + return sa.DECIMAL(precision=int(precision), scale=int(scale)), nullable + + return DBAPI_COLUMN_TYPES.get(type_name, sa.types.NullType), nullable + + +def _format_reflected_column(name, col_type, nullable): + return { + "name": name, + "type": col_type, + "nullable": nullable, + "default": None, + } + + class YdbRequestSettingsCharacteristic(characteristics.ConnectionCharacteristic): def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection) -> None: dialect.reset_ydb_request_settings(dbapi_connection) @@ -220,10 +250,13 @@ def __init__( self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars self._statement_prefixes = tuple(_statement_prefixes_list) if _statement_prefixes_list else () - def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription: - if schema is not None: + def _ensure_schema_unsupported(self, schema): + if schema: raise ydb_dbapi.NotSupportedError("unsupported on non empty schema") + def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription: + self._ensure_schema_unsupported(schema) + qt = table_name if isinstance(table_name, str) else table_name.name raw_conn = connection.connection try: @@ -231,8 +264,23 @@ def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescr except ydb_dbapi.DatabaseError as e: raise NoSuchTableError(qt) from e + @reflection.cache def get_view_names(self, connection, schema=None, **kw: Any): - return [] + self._ensure_schema_unsupported(schema) + + raw_conn = connection.connection + return raw_conn.get_view_names() + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw: Any): + self._ensure_schema_unsupported(schema) + + quoted_view_name = self.identifier_preparer.quote(view_name) + result = connection.execute(sa.text(f"SHOW CREATE VIEW {quoted_view_name}")) + row = result.fetchone() + if row is None: + return None + return row._mapping.get("CreateQuery") or row[0] @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -240,21 +288,20 @@ def get_columns(self, connection, table_name, schema=None, **kw): as_compatible = [] for column in table.columns: col_type, nullable = _get_column_info(column.type) - as_compatible.append( - { - "name": column.name, - "type": col_type, - "nullable": nullable, - "default": None, - } - ) + as_compatible.append(_format_reflected_column(column.name, col_type, nullable)) + + if not as_compatible: + quoted_table_name = self.identifier_preparer.quote(table_name) + result = connection.execute(sa.text(f"SELECT * FROM {quoted_table_name} LIMIT 0")) + for column in result.cursor.description or []: + col_type, nullable = _get_column_info_from_dbapi_description(column[1]) + as_compatible.append(_format_reflected_column(column[0], col_type, nullable)) return as_compatible @reflection.cache def get_table_names(self, connection, schema=None, **kw): - if schema: - raise ydb_dbapi.NotSupportedError("unsupported on non empty schema") + self._ensure_schema_unsupported(schema) raw_conn = connection.connection return raw_conn.get_table_names() diff --git a/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py b/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py index 068fd43..88178b0 100644 --- a/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py +++ b/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py @@ -71,6 +71,9 @@ def check_exists(self, table_path: str): def get_table_names(self): return await_only(self._connection.get_table_names()) + def get_view_names(self): + return await_only(self._connection.get_view_names()) + # TODO(vgvoleg): Migrate to AsyncAdapt_dbapi_cursor and AsyncAdapt_dbapi_connection class AdaptedAsyncCursor: