diff --git a/.env.sample b/.env.sample index 3b25fe51..7863e48b 100644 --- a/.env.sample +++ b/.env.sample @@ -27,6 +27,7 @@ # store checkpoint in redis instead of on filesystem # REDIS_CHECKPOINT=False # FORMAT_WITH_COMMAS=True +# TRIGGER_ALL_COLUMNS=True # Elasticsearch/Opensearch # ELASTICSEARCH_SCHEME=http diff --git a/pgsync/__init__.py b/pgsync/__init__.py index 17f4bc64..68e15c6a 100644 --- a/pgsync/__init__.py +++ b/pgsync/__init__.py @@ -2,4 +2,4 @@ __author__ = "Tolu Aina" __email__ = "tolu@pgsync.com" -__version__ = "4.2.0" +__version__ = "4.3.0-beta.1" diff --git a/pgsync/base.py b/pgsync/base.py index 94bdf089..6f5d3d56 100644 --- a/pgsync/base.py +++ b/pgsync/base.py @@ -805,22 +805,42 @@ def create_triggers( self, schema: str, tables: t.Optional[t.List[str]] = None, + table_columns: t.Optional[t.Dict[str, t.List[str]]] = None, join_queries: bool = False, if_not_exists: bool = False, ) -> None: - """Create a database triggers.""" + """Create database triggers with optional column filtering for updates.""" queries: t.List[str] = [] + table_columns = table_columns or {} + for table in self.tables(schema): - if (tables and table not in tables) or ( - table in self.views(schema) - ): + if (tables and table not in tables) or (table in self.views(schema)): continue + logger.debug(f"Creating trigger on table: {schema}.{table}") + + # Get specific columns for this table (empty list or None means no filtering) + specific_cols = table_columns.get(table, []) + for name, level, tg_op in [ ("notify", "ROW", ["INSERT", "UPDATE", "DELETE"]), ("truncate", "STATEMENT", ["TRUNCATE"]), ]: + # If we have specific columns and this is the UPDATE operation, + # modify the UPDATE event to UPDATE OF col1, col2, ... + if tg_op == ["INSERT", "UPDATE", "DELETE"] and specific_cols: + col_names = list( + dict.fromkeys( # preserves order while removing duplicates + col.name if hasattr(col, "name") else str(col) + for col in specific_cols + ) + ) + update_event = "UPDATE OF " + ", ".join(col_names) + events = ["INSERT", update_event, "DELETE"] + else: + events = tg_op + if if_not_exists or not self.view_exists( MATERIALIZED_VIEW, schema ): @@ -828,7 +848,7 @@ def create_triggers( self.drop_triggers(schema, [table]) queries.append( f'CREATE TRIGGER "{schema}_{table}_{name}" ' - f'AFTER {" OR ".join(tg_op)} ON "{schema}"."{table}" ' + f'AFTER {" OR ".join(events)} ON "{schema}"."{table}" ' f"FOR EACH {level} EXECUTE PROCEDURE " f"{schema}.{TRIGGER_FUNC}()", ) diff --git a/pgsync/settings.py b/pgsync/settings.py index f05e24e5..68c1205e 100644 --- a/pgsync/settings.py +++ b/pgsync/settings.py @@ -36,6 +36,8 @@ # db query chunk size (how many records to fetch at a time) QUERY_CHUNK_SIZE = env.int("QUERY_CHUNK_SIZE", default=10000) FILTER_CHUNK_SIZE = env.int("FILTER_CHUNK_SIZE", default=5000) +# Whether to create triggers for all columns in a table (True) or only for columns defined in the schema (False). +TRIGGER_ALL_COLUMNS = env.bool("TRIGGER_ALL_COLUMNS", default=True) # replication slot cleanup interval (in secs) REPLICATION_SLOT_CLEANUP_INTERVAL = env.float( "REPLICATION_SLOT_CLEANUP_INTERVAL", diff --git a/pgsync/sync.py b/pgsync/sync.py index 947fe5ff..fa6023ac 100644 --- a/pgsync/sync.py +++ b/pgsync/sync.py @@ -14,6 +14,7 @@ from collections import defaultdict from itertools import groupby from pathlib import Path +from typing import Dict, List, Optional import click import sqlalchemy as sa @@ -284,6 +285,7 @@ def setup(self, no_create: bool = False) -> None: if_not_exists: bool = not no_create join_queries: bool = settings.JOIN_QUERIES + trigger_all_columns: bool = settings.TRIGGER_ALL_COLUMNS with self.advisory_lock( self.database, max_retries=None, retry_interval=0.1 @@ -302,6 +304,8 @@ def setup(self, no_create: bool = False) -> None: # tables with user defined foreign keys user_defined_fkey_tables: dict = {} + table_columns: Optional[Dict[str, List[str]]] = {} + for node in self.tree.traverse_breadth_first(): if node.schema != schema: continue @@ -314,6 +318,8 @@ def setup(self, no_create: bool = False) -> None: tables |= set([node.table]) # we also need to bootstrap the base tables tables |= set(node.base_tables) + cols = list(node.columns) if node.columns else [] + table_columns[node.table] = cols # we want to get both the parent and the child keys here # even though only one of them is the foreign_key. @@ -342,6 +348,7 @@ def setup(self, no_create: bool = False) -> None: self.create_triggers( schema, tables=tables, + table_columns=None if trigger_all_columns else table_columns, join_queries=join_queries, if_not_exists=if_not_exists, ) @@ -505,6 +512,44 @@ def logical_slot_changes( upto_lsn=upto_lsn, ) + def _root_primary_key_resolver_bulk( + self, node: Node, fields: dict, filters: list + ) -> list: + for doc_id in self.search_client._search( + self.index, node.table, fields + ): + where: dict = {} + params: dict = doc_id.split(PRIMARY_KEY_DELIMITER) + for i, key in enumerate(self.tree.root.model.primary_keys): + where[key] = params[i] + filters.append(where) + + return filters + + def _root_foreign_key_resolver_bulk( + self, node: Node, fields: dict, filters: list + ) -> list: + """ + Bulk Foreign key resolver logic: + + Lookup this value in the meta section of Elasticsearch/OpenSearch + Then get the root node returned and re-sync that root record. + Essentially, we want to lookup the root node affected by + our insert/update operation and sync the tree branch for that root. + """ + for doc_id in self.search_client._search( + self.index, + node.parent.table, + fields, + ): + where: dict = {} + params: dict = doc_id.split(PRIMARY_KEY_DELIMITER) + for i, key in enumerate(self.tree.root.model.primary_keys): + where[key] = params[i] + filters.append(where) + + return filters + def _root_primary_key_resolver( self, node: Node, payload: Payload, filters: list ) -> list: @@ -663,7 +708,7 @@ def _update_op( ) -> dict: if node.is_root: # Here, we are performing two operations: - # 1) Build a filter to sync the updated record(s) + # 1) Build a filter by using Bulk operation to sync the updated record(s) # 2) Delete the old record(s) in Elasticsearch/OpenSearch if the # primary key has changed # 2.1) This is crucial otherwise we can have the old and new @@ -715,12 +760,18 @@ def _update_op( else: # update the child tables + fields: dict = defaultdict(list) + foreign_fields: dict = defaultdict(list) for payload in payloads: - _filters: list = [] - _filters = self._root_primary_key_resolver( - node, payload, _filters + primary_values: list = [ + payload.data[key] for key in node.model.primary_keys + ] + primary_fields: dict = dict( + zip(node.model.primary_keys, primary_values) ) - # also handle foreign_keys + for key, value in primary_fields.items(): + fields[key].append(value) + # _filters: list = self._root_primary_key_resolver(node, payload, []) if node.parent: try: foreign_keys = self.query_builder.get_foreign_keys( @@ -732,13 +783,28 @@ def _update_op( node.parent, node, ) + + foreign_values: list = [ + payload.new.get(key) for key in foreign_keys[node.name] + ] + for key in [key.name for key in node.primary_keys]: + for value in foreign_values: + if value: + foreign_fields[key].append(value) + # _filters = self._root_foreign_key_resolver( + # node, payload, foreign_keys, _filters + # ) - _filters = self._root_foreign_key_resolver( - node, payload, foreign_keys, _filters - ) - if _filters: - filters[self.tree.root.table].extend(_filters) + # if _filters: + # filters[self.tree.root.table].extend(_filters) + + _filters: list = self._root_primary_key_resolver_bulk(node, fields, []) + if _filters: + filters[self.tree.root.table].extend(_filters) + _forign_filters: list = self._root_foreign_key_resolver_bulk(node, foreign_fields, _filters) + if _forign_filters: + filters[self.tree.root.table].extend(_forign_filters) return filters