Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,8 @@ def test_ydb_credentials_bad(self, query_client_settings, driver_config_for_cred
with pytest.raises(Exception) as excinfo:
with engine.connect() as conn:
conn.execute(sa.text("SELECT 1 as value"))
assert "Invalid password" in str(excinfo.value)
error_message = str(excinfo.value)
assert "Invalid password" in error_message or "StaticCredentials" in error_message


class TestUpsert(TablesTest):
Expand Down
47 changes: 47 additions & 0 deletions test/test_inspect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import posixpath

import pytest
import sqlalchemy as sa
from sqlalchemy import Column, Integer, Numeric, Table, Unicode
from sqlalchemy.testing.fixtures import TablesTest
Expand All @@ -14,6 +17,31 @@ def define_tables(cls, metadata):
Column("num", Numeric(22, 9)),
)

@pytest.fixture
def test_view(self, connection):
raw_connection = connection.connection
driver_connection = getattr(raw_connection, "driver_connection", raw_connection)
view_name = "test_view"
table_path = posixpath.join(driver_connection.database, driver_connection.table_path_prefix, "test")
cursor = driver_connection.cursor()
try:
try:
cursor.execute_scheme(f"DROP VIEW `{view_name}`")
except Exception:
pass

cursor.execute_scheme(
f"CREATE VIEW `{view_name}` WITH (security_invoker = TRUE) AS "
f"SELECT `id`, `value`, `num` FROM `{table_path}`"
)
yield view_name
finally:
try:
cursor.execute_scheme(f"DROP VIEW `{view_name}`")
except Exception:
pass
cursor.close()
Comment on lines +38 to +43
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (non-blocking): Move Try and Except part with creation of the view to a yield fixture.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


def test_get_columns(self, connection):
inspect = sa.inspect(connection)

Expand All @@ -32,3 +60,22 @@ def test_has_table(self, connection):

assert inspect.has_table("test")
assert not inspect.has_table("foo")

def test_view_reflection(self, connection, test_view):
view_name = test_view
inspect = sa.inspect(connection)

assert view_name in inspect.get_view_names()
assert inspect.has_table(view_name)
assert inspect.get_view_definition(view_name).startswith(f"CREATE VIEW `{view_name}`")

columns = {column["name"]: column for column in inspect.get_columns(view_name)}
assert set(columns) == {"id", "value", "num"}
assert isinstance(columns["id"]["type"], sa.INTEGER)
assert columns["id"]["nullable"] is False
assert isinstance(columns["value"]["type"], sa.TEXT)
assert columns["value"]["nullable"] is True
assert isinstance(columns["num"]["type"], sa.DECIMAL)
assert columns["num"]["type"].precision == 22
assert columns["num"]["type"].scale == 9
assert columns["num"]["nullable"] is True
73 changes: 60 additions & 13 deletions ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import collections
import collections.abc
import re
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import sqlalchemy as sa
Expand Down Expand Up @@ -73,6 +74,13 @@ def upsert(table):
ydb.PrimitiveType.DyNumber: sa.TEXT,
}

DBAPI_COLUMN_TYPES = {
ydb_type.name: sa_type for ydb_type, sa_type in COLUMN_TYPES.items() if isinstance(ydb_type, ydb.PrimitiveType)
}


DECIMAL_DBAPI_TYPE_RE = re.compile(r"^Decimal\((\d+),\s*(\d+)\)$")


def _get_column_info(t):
nullable = False
Expand All @@ -86,6 +94,28 @@ def _get_column_info(t):
return COLUMN_TYPES[t], nullable


def _get_column_info_from_dbapi_description(type_name):
nullable = type_name.endswith("?")
if nullable:
type_name = type_name[:-1]

decimal_match = DECIMAL_DBAPI_TYPE_RE.match(type_name)
if decimal_match:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: What about Array and other container types? They could be in View columns, don't they?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is same for tables, let's fix this outside of this pr

precision, scale = decimal_match.groups()
return sa.DECIMAL(precision=int(precision), scale=int(scale)), nullable

return DBAPI_COLUMN_TYPES.get(type_name, sa.types.NullType), nullable


def _format_reflected_column(name, col_type, nullable):
return {
"name": name,
"type": col_type,
"nullable": nullable,
"default": None,
}


class YdbRequestSettingsCharacteristic(characteristics.ConnectionCharacteristic):
def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection) -> None:
dialect.reset_ydb_request_settings(dbapi_connection)
Expand Down Expand Up @@ -220,41 +250,58 @@ def __init__(
self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars
self._statement_prefixes = tuple(_statement_prefixes_list) if _statement_prefixes_list else ()

def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription:
if schema is not None:
def _ensure_schema_unsupported(self, schema):
if schema:
raise ydb_dbapi.NotSupportedError("unsupported on non empty schema")

def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription:
self._ensure_schema_unsupported(schema)

qt = table_name if isinstance(table_name, str) else table_name.name
raw_conn = connection.connection
try:
return raw_conn.describe(qt)
except ydb_dbapi.DatabaseError as e:
raise NoSuchTableError(qt) from e

@reflection.cache
def get_view_names(self, connection, schema=None, **kw: Any):
return []
self._ensure_schema_unsupported(schema)

raw_conn = connection.connection
return raw_conn.get_view_names()

@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw: Any):
self._ensure_schema_unsupported(schema)

quoted_view_name = self.identifier_preparer.quote(view_name)
result = connection.execute(sa.text(f"SHOW CREATE VIEW {quoted_view_name}"))
row = result.fetchone()
if row is None:
return None
return row._mapping.get("CreateQuery") or row[0]

@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
table = self._describe_table(connection, table_name, schema)
as_compatible = []
for column in table.columns:
col_type, nullable = _get_column_info(column.type)
as_compatible.append(
{
"name": column.name,
"type": col_type,
"nullable": nullable,
"default": None,
}
)
as_compatible.append(_format_reflected_column(column.name, col_type, nullable))

if not as_compatible:
quoted_table_name = self.identifier_preparer.quote(table_name)
result = connection.execute(sa.text(f"SELECT * FROM {quoted_table_name} LIMIT 0"))
for column in result.cursor.description or []:
col_type, nullable = _get_column_info_from_dbapi_description(column[1])
as_compatible.append(_format_reflected_column(column[0], col_type, nullable))

return as_compatible

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema:
raise ydb_dbapi.NotSupportedError("unsupported on non empty schema")
self._ensure_schema_unsupported(schema)

raw_conn = connection.connection
return raw_conn.get_table_names()
Expand Down
3 changes: 3 additions & 0 deletions ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def check_exists(self, table_path: str):
def get_table_names(self):
return await_only(self._connection.get_table_names())

def get_view_names(self):
return await_only(self._connection.get_view_names())


# TODO(vgvoleg): Migrate to AsyncAdapt_dbapi_cursor and AsyncAdapt_dbapi_connection
class AdaptedAsyncCursor:
Expand Down
Loading