Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

## 0.16

### 0.16.7

- feat: Add missing support for `exclude` on triggers.
- feat: Add support for `include` on triggers and functions.

### 0.16.6

- fix: postgresql parsing of existing function defaults.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sqlalchemy-declarative-extensions"
version = "0.16.6"
version = "0.16.7"
authors = [{ name = "Dan Cardin", email = "ddcardin@gmail.com" }]
description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
license = { file = "LICENSE" }
Expand Down
11 changes: 10 additions & 1 deletion src/sqlalchemy_declarative_extensions/function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class Functions:

functions: list[Function] = field(default_factory=list)

include: list[str] | None = None
ignore: list[str] = field(default_factory=list)
ignore_unspecified: bool = False

Expand Down Expand Up @@ -130,10 +131,18 @@ def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | No
)

functions = [s for instance in instances for s in instance.functions]
# Preserve None if all instances have include=None, otherwise combine all non-None includes
include_values = [
instance.include for instance in instances if instance.include is not None
]
include = [s for inc in include_values for s in inc] if include_values else None
ignore = [s for instance in instances for s in instance.ignore]
ignore_unspecified = instances[0].ignore_unspecified
return cls(
functions=functions, ignore_unspecified=ignore_unspecified, ignore=ignore
functions=functions,
ignore_unspecified=ignore_unspecified,
ignore=ignore,
include=include,
)

def append(self, function: Function):
Expand Down
14 changes: 11 additions & 3 deletions src/sqlalchemy_declarative_extensions/function/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper
expected_function_names = set(functions_by_name)

raw_existing_functions = get_functions(connection)
existing_functions = filter_functions(raw_existing_functions, functions.ignore)
existing_functions = filter_functions(
raw_existing_functions, exclude=functions.ignore, include=functions.include
)
existing_functions_by_name = {
f.qualified_name: f.normalize() for f in existing_functions
}
Expand Down Expand Up @@ -93,12 +95,18 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper


def filter_functions(
functions: Sequence[Function], exclude: list[str]
functions: Sequence[Function], *, exclude: list[str], include: list[str] | None
) -> list[Function]:
return [
f
for f in functions
if not any(
if (
include is None
or any(
fnmatch.fnmatch(f.qualified_name, inclusion) for inclusion in include
)
)
and not any(
fnmatch.fnmatch(f.qualified_name, exclusion) for exclusion in exclude
)
]
15 changes: 14 additions & 1 deletion src/sqlalchemy_declarative_extensions/trigger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def to_sql_drop(self):
class Triggers:
triggers: list[Trigger] = field(default_factory=list)

include: list[str] | None = None
ignore: list[str] = field(default_factory=list)
ignore_unspecified: bool = False

@classmethod
Expand Down Expand Up @@ -76,8 +78,19 @@ def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | No
)

triggers = [s for instance in instances for s in instance.triggers]
# Preserve None if all instances have include=None, otherwise combine all non-None includes
include_values = [
instance.include for instance in instances if instance.include is not None
]
include = [s for inc in include_values for s in inc] if include_values else None
ignore = [s for instance in instances for s in instance.ignore]
ignore_unspecified = instances[0].ignore_unspecified
return cls(triggers=triggers, ignore_unspecified=ignore_unspecified)
return cls(
triggers=triggers,
ignore_unspecified=ignore_unspecified,
ignore=ignore,
include=include,
)

def append(self, trigger: Trigger):
self.triggers.append(trigger)
Expand Down
23 changes: 21 additions & 2 deletions src/sqlalchemy_declarative_extensions/trigger/compare.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import fnmatch
from dataclasses import dataclass
from typing import Union
from typing import Sequence, Union

from sqlalchemy.engine import Connection

Expand Down Expand Up @@ -53,7 +54,11 @@ def compare_triggers(connection: Connection, triggers: Triggers) -> list[Operati
triggers_by_name = {r.name: r for r in triggers.triggers}
expected_trigger_names = set(triggers_by_name)

existing_triggers = get_triggers(connection)
raw_existing_triggers = get_triggers(connection)
existing_triggers = filter_triggers(
raw_existing_triggers, exclude=triggers.ignore, include=triggers.include
)

existing_triggers_by_name = {r.name: r for r in existing_triggers}
existing_trigger_names = set(existing_triggers_by_name)

Expand All @@ -77,3 +82,17 @@ def compare_triggers(connection: Connection, triggers: Triggers) -> list[Operati
result.append(DropTriggerOp(trigger))

return result


def filter_triggers(
triggers: Sequence[Trigger], *, exclude: list[str], include: list[str] | None
) -> list[Trigger]:
return [
t
for t in triggers
if (
include is None
or any(fnmatch.fnmatch(t.name, inclusion) for inclusion in include)
)
and not any(fnmatch.fnmatch(t.name, exclusion) for exclusion in exclude)
]
4 changes: 1 addition & 3 deletions tests/dialect/postgresql/test_function_defaults.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from pytest_mock_resources import PostgresConfig, create_postgres_fixture
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import text

from sqlalchemy_declarative_extensions import Functions
Expand All @@ -13,8 +13,6 @@
pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True})




@pytest.mark.parametrize(
("default_a", "default_b", "default_c"),
[(None, None, "''::text"), (1, 0, "'m'::text"), (None, 2, "'ft'::text")],
Expand Down
177 changes: 177 additions & 0 deletions tests/function/test_include_exclude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import pytest
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError

from sqlalchemy_declarative_extensions import (
Functions,
declarative_database,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

_Base = declarative_base()


@declarative_database
class BaseIncludeOnly(_Base): # type: ignore
__abstract__ = True

functions = Functions(include=["test_*"])


@declarative_database
class BaseExcludeOnly(_Base): # type: ignore
__abstract__ = True

functions = Functions(ignore=["ignore_*"])


@declarative_database
class BaseIncludeAndExclude(_Base): # type: ignore
__abstract__ = True

functions = Functions(include=["test_*", "keep_*"], ignore=["*_ignore"])


register_sqlalchemy_events(BaseIncludeOnly.metadata, functions=True)
register_sqlalchemy_events(BaseExcludeOnly.metadata, functions=True)
register_sqlalchemy_events(BaseIncludeAndExclude.metadata, functions=True)

pg_include = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)
pg_exclude = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)
pg_both = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)


def test_include_only(pg_include):
# Matches the include pattern, thus dropped because it's not declared.
pg_include.execute(
text(
"CREATE FUNCTION test_func() RETURNS INTEGER language sql as $$ select 1 $$;"
)
)
# Doesn't match the include pattern, thus kept because it's unmanaged.
pg_include.execute(
text(
"CREATE FUNCTION other_func() RETURNS INTEGER language sql as $$ select 2 $$;"
)
)
pg_include.commit()

BaseIncludeOnly.metadata.create_all(bind=pg_include.connection())
pg_include.commit()

with pytest.raises(ProgrammingError):
pg_include.execute(text("SELECT test_func()")).scalar()
pg_include.rollback()

result = pg_include.execute(text("SELECT other_func()")).scalar()
assert result == 2


def test_exclude_only(pg_exclude):
# Matches the exclude pattern, thus kept because it's being ignored.
pg_exclude.execute(
text(
"CREATE FUNCTION ignore_this() RETURNS INTEGER language sql as $$ select 1 $$;"
)
)
# Doesn't match the exclude pattern, thus dropped because it's not being ignored.
pg_exclude.execute(
text(
"CREATE FUNCTION manage_this() RETURNS INTEGER language sql as $$ select 2 $$;"
)
)
pg_exclude.commit()

BaseExcludeOnly.metadata.create_all(bind=pg_exclude.connection())
pg_exclude.commit()

result = pg_exclude.execute(text("SELECT ignore_this()")).scalar()
assert result == 1

with pytest.raises(ProgrammingError):
pg_exclude.execute(text("SELECT manage_this()")).scalar()


def test_include_and_exclude_interaction(pg_both):
"""Test the interaction between include and exclude.

A function that matches include becomes managed, but can become unmanaged if also matching the
exclude.
"""
pg_both.execute(
text(
"CREATE FUNCTION test_func() RETURNS INTEGER language sql as $$ select 1 $$;"
)
)
pg_both.execute(
text(
"CREATE FUNCTION test_ignore() RETURNS INTEGER language sql as $$ select 2 $$;"
)
)
pg_both.execute(
text(
"CREATE FUNCTION keep_this() RETURNS INTEGER language sql as $$ select 3 $$;"
)
)
pg_both.execute(
text(
"CREATE FUNCTION other_func() RETURNS INTEGER language sql as $$ select 4 $$;"
)
)

pg_both.commit()

BaseIncludeAndExclude.metadata.create_all(bind=pg_both.connection())
pg_both.commit()

with pytest.raises(ProgrammingError):
pg_both.execute(text("SELECT test_func()")).scalar()
pg_both.rollback()

result = pg_both.execute(text("SELECT test_ignore()")).scalar()
assert result == 2

with pytest.raises(ProgrammingError):
pg_both.execute(text("SELECT keep_this()")).scalar()
pg_both.rollback()

result = pg_both.execute(text("SELECT other_func()")).scalar()
assert result == 4


def test_include_with_schema_patterns(pg_include):
pg_include.execute(text("CREATE SCHEMA foo"))
pg_include.execute(text("CREATE SCHEMA bar"))

pg_include.execute(
text(
"CREATE FUNCTION test_one() RETURNS INTEGER language sql as $$ select 1 $$;"
)
)
pg_include.execute(
text(
"CREATE FUNCTION foo.test_two() RETURNS INTEGER language sql as $$ select 2 $$;"
)
)
pg_include.execute(
text(
"CREATE FUNCTION bar.other() RETURNS INTEGER language sql as $$ select 3 $$;"
)
)

pg_include.commit()

BaseIncludeOnly.metadata.create_all(bind=pg_include.connection())
pg_include.commit()

with pytest.raises(ProgrammingError):
pg_include.execute(text("SELECT test_one()")).scalar()
pg_include.rollback()

result = pg_include.execute(text("SELECT foo.test_two()")).scalar()
assert result == 2

result = pg_include.execute(text("SELECT bar.other()")).scalar()
assert result == 3
Loading
Loading