From 430134e95a546becab17ac4147fafc852aa01542 Mon Sep 17 00:00:00 2001 From: Jonathan Raiman Date: Sat, 16 Sep 2023 16:52:45 +0200 Subject: [PATCH] fix issue with outerjoin when there are repeated tables --- pgsync/node.py | 12 ++++++++++-- pgsync/querybuilder.py | 14 +++++++++----- tests/test_node.py | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/pgsync/node.py b/pgsync/node.py index ff945506..17b9bab4 100644 --- a/pgsync/node.py +++ b/pgsync/node.py @@ -103,6 +103,7 @@ class Node(object): models: Callable table: str schema: str + table_count: Optional[int] = None primary_key: Optional[list] = None label: Optional[str] = None transform: Optional[dict] = None @@ -259,6 +260,7 @@ class Tree: def __post_init__(self): self.tables: Set[str] = set() + self.table_counts: Dict[str, int] = {} self.__nodes: Dict[Node] = {} self.root: Optional[Node] = None @@ -271,7 +273,7 @@ def traverse_breadth_first(self) -> Generator: def traverse_post_order(self) -> Generator: return self.root.traverse_post_order() - def build(self, data: dict) -> Node: + def build(self, data: dict, is_root: bool = True) -> Node: if not isinstance(data, dict): raise SchemaError( "Incompatible schema. Please run v2 schema migration" @@ -302,13 +304,19 @@ def build(self, data: dict) -> Node: self.root = node self.tables.add(node.table) + if node.table not in self.table_counts: + self.table_counts[node.table] = 0 + self.table_counts[node.table] += 1 for through in node.relationship.throughs: self.tables.add(through.table) for child in data.get("children", []): - node.add_child(self.build(child)) + node.add_child(self.build(child, is_root=False)) self.__nodes[key] = node + if is_root: + for child_node in self.traverse_post_order(): + child_node.table_count = self.table_counts[child_node.table] return node def get_node(self, table: str, schema: str) -> Node: diff --git a/pgsync/querybuilder.py b/pgsync/querybuilder.py index 07d920d3..481ed2bf 100644 --- a/pgsync/querybuilder.py +++ b/pgsync/querybuilder.py @@ -396,9 +396,11 @@ def _children(self, node: Node) -> None: self.from_obj = child.parent.model if child._filters: - self.isouter = False + if child.table_count <= 1: + self.isouter = False for _filter in child._filters: if isinstance(_filter, sa.sql.elements.BinaryExpression): + for column in _filter._orig: if hasattr(column, "value"): _column = child._subquery.c @@ -546,8 +548,8 @@ def _through(self, node: Node) -> None: # noqa: C901 from_obj = node.model if child._filters: - self.isouter = False - + if child.table_count <= 1: + self.isouter = False for _filter in child._filters: if isinstance(_filter, sa.sql.elements.BinaryExpression): for column in _filter._orig: @@ -667,7 +669,8 @@ def _through(self, node: Node) -> None: # noqa: C901 ) if node._filters: - self.isouter = False + if node.table_count <= 1: + self.isouter = False op = sa.and_ if node.table == node.parent.table: @@ -723,7 +726,8 @@ def _non_through(self, node: Node) -> None: # noqa: C901 from_obj = node.model if child._filters: - self.isouter = False + if child.table_count <= 1: + self.isouter = False for _filter in child._filters: if isinstance(_filter, sa.sql.elements.BinaryExpression): diff --git a/tests/test_node.py b/tests/test_node.py index 2a193cdb..e295c0b5 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -340,3 +340,44 @@ def test_init(self): } ) assert "Multiple through tables" in str(excinfo.value) + + +@pytest.mark.usefixtures("table_creator") +class TestNodeTableCount(object): + """Node tests.""" + + @pytest.fixture(scope="function") + def nodes(self): + return { + "table": "book", + "columns": ["isbn", "title", "description"], + "children": [ + { + "table": "author", + "columns": ["id", "name"], + "label": "authors", + "relationship": { + "type": "one_to_many", + "variant": "object", + "through_tables": ["book_author"], + }, + "children": [ + { + "table": "book", + "columns": ["isbn", "title", "description"], + "label": "favorite_book", + "relationship": { + "type": "one_to_many", + "variant": "object", + "through_tables": ["author_favorite_book"], + }, + } + ], + }, + ], + } + + def test_tree_build_table_count(self, sync, nodes): + tree = Tree(sync.models).build(nodes) + assert tree.table_count == {"author": 1, "book": 2} + sync.search_client.close()