From f036d867b6c0b1f81bc035616e7ce5e71f0af290 Mon Sep 17 00:00:00 2001 From: Daniel Knott Date: Thu, 4 Dec 2025 08:30:38 +1000 Subject: [PATCH 1/2] Fix postgresql parsing of existing function defaults The column pg_proc.proargdefaults when converted with pg_get_expr returns a comma separated string of default values. This list of values then needs to be aligned to the right most input argument. To avoid this complex rearangement, and avoid any potential edge cases where ',' might be in the default value, the function pg_get_function_arg_default is used to check for a default for each argument. --- .../dialects/postgresql/schema.py | 26 +++++++--- .../postgresql/test_function_defaults.py | 48 +++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 tests/dialect/postgresql/test_function_defaults.py diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py index 71f9fdb..6bf9e18 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py @@ -117,6 +117,7 @@ def char_literals(*literals: str) -> Collection[BindParameter]: column("prosecdef"), column("prokind"), column("provolatile"), + column("pronargs"), column("proargnames"), column("proargmodes"), column("proargtypes"), @@ -318,6 +319,23 @@ def get_types(arg_type_oids): ) +def get_defaults(): + arg_num = ( + func.generate_series(1, pg_proc.c.pronargs) + .table_valued("arg_num", with_ordinality="ordinality") + .alias("arg_num") + ) + return ( + select( + func.array_agg( + func.pg_get_function_arg_default(pg_proc.c.oid, arg_num.c.arg_num) + ) + ) + .select_from(arg_num) + .scalar_subquery() + ) + + def get_procedures_query(version_info): source = get_source_column(version_info) return ( @@ -334,9 +352,7 @@ def get_procedures_query(version_info): func.coalesce( get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes) ).label("arg_types"), - func.pg_get_expr( - pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS) - ).label("arg_defaults"), + get_defaults().label("arg_defaults"), ) .select_from( pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid) @@ -369,9 +385,7 @@ def get_functions_query(version_info): func.coalesce( get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes) ).label("arg_types"), - func.pg_get_expr( - pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS) - ).label("arg_defaults"), + get_defaults().label("arg_defaults"), ) .select_from( pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid) diff --git a/tests/dialect/postgresql/test_function_defaults.py b/tests/dialect/postgresql/test_function_defaults.py new file mode 100644 index 0000000..bfccad4 --- /dev/null +++ b/tests/dialect/postgresql/test_function_defaults.py @@ -0,0 +1,48 @@ +import pytest +from pytest_mock_resources import PostgresConfig, create_postgres_fixture +from sqlalchemy import text + +from sqlalchemy_declarative_extensions import Functions +from sqlalchemy_declarative_extensions.dialects.postgresql import ( + Function, + FunctionParam, + FunctionVolatility, +) +from sqlalchemy_declarative_extensions.function.compare import compare_functions + +pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True}) + + +@pytest.fixture +def pmr_postgres_config(): + return PostgresConfig(image="postgres:11", port=None, ci_port=None) + + +@pytest.mark.parametrize( + ("default_a", "default_b", "default_c"), + [(None, None, "''::text"), (1, 0, "'m'::text"), (None, 2, "'ft'::text")], +) +def test_function_argument_defaults(pg, default_a, default_b, default_c): + add_label_function = Function( + name="add_label", + definition=""" + BEGIN + RETURN (((a + b))::text || c); + END; + """, + parameters=[ + FunctionParam("a", "integer", default=default_a), + FunctionParam("b", "integer", default=default_b), + FunctionParam("c", "text", default=default_c), + ], + returns="TEXT", + volatility=FunctionVolatility.STABLE, + language="plpgsql", + ).normalize() + create_function = add_label_function.to_sql_create() + functions = Functions([add_label_function]) + with pg.connect() as connection: + connection.execute(text("\n".join(create_function))) + diff = compare_functions(connection, functions) + for op in diff: + assert op.from_function == op.function From a389397a2eb12d47efc6fcef682e5b287ba898a6 Mon Sep 17 00:00:00 2001 From: Dan Cardin Date: Thu, 4 Dec 2025 10:38:37 -0500 Subject: [PATCH 2/2] Update tests/dialect/postgresql/test_function_defaults.py --- tests/dialect/postgresql/test_function_defaults.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/dialect/postgresql/test_function_defaults.py b/tests/dialect/postgresql/test_function_defaults.py index bfccad4..75bef84 100644 --- a/tests/dialect/postgresql/test_function_defaults.py +++ b/tests/dialect/postgresql/test_function_defaults.py @@ -13,9 +13,6 @@ pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True}) -@pytest.fixture -def pmr_postgres_config(): - return PostgresConfig(image="postgres:11", port=None, ci_port=None) @pytest.mark.parametrize(