Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# store checkpoint in redis instead of on filesystem
# REDIS_CHECKPOINT=False
# FORMAT_WITH_COMMAS=True
# TRIGGER_ALL_COLUMNS=True

# Elasticsearch/Opensearch
# ELASTICSEARCH_SCHEME=http
Expand Down
2 changes: 1 addition & 1 deletion pgsync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__author__ = "Tolu Aina"
__email__ = "tolu@pgsync.com"
__version__ = "4.2.0"
__version__ = "4.3.0-beta.1"
30 changes: 25 additions & 5 deletions pgsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,30 +805,50 @@ def create_triggers(
self,
schema: str,
tables: t.Optional[t.List[str]] = None,
table_columns: t.Optional[t.Dict[str, t.List[str]]] = None,
join_queries: bool = False,
if_not_exists: bool = False,
) -> None:
"""Create a database triggers."""
"""Create database triggers with optional column filtering for updates."""
queries: t.List[str] = []
table_columns = table_columns or {}

for table in self.tables(schema):
if (tables and table not in tables) or (
table in self.views(schema)
):
if (tables and table not in tables) or (table in self.views(schema)):
continue

logger.debug(f"Creating trigger on table: {schema}.{table}")

# Get specific columns for this table (empty list or None means no filtering)
specific_cols = table_columns.get(table, [])

for name, level, tg_op in [
("notify", "ROW", ["INSERT", "UPDATE", "DELETE"]),
("truncate", "STATEMENT", ["TRUNCATE"]),
]:

# If we have specific columns and this is the UPDATE operation,
# modify the UPDATE event to UPDATE OF col1, col2, ...
if tg_op == ["INSERT", "UPDATE", "DELETE"] and specific_cols:
col_names = list(
dict.fromkeys( # preserves order while removing duplicates
col.name if hasattr(col, "name") else str(col)
for col in specific_cols
)
)
update_event = "UPDATE OF " + ", ".join(col_names)
events = ["INSERT", update_event, "DELETE"]
else:
events = tg_op

if if_not_exists or not self.view_exists(
MATERIALIZED_VIEW, schema
):

self.drop_triggers(schema, [table])
queries.append(
f'CREATE TRIGGER "{schema}_{table}_{name}" '
f'AFTER {" OR ".join(tg_op)} ON "{schema}"."{table}" '
f'AFTER {" OR ".join(events)} ON "{schema}"."{table}" '
f"FOR EACH {level} EXECUTE PROCEDURE "
f"{schema}.{TRIGGER_FUNC}()",
)
Expand Down
2 changes: 2 additions & 0 deletions pgsync/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
# db query chunk size (how many records to fetch at a time)
QUERY_CHUNK_SIZE = env.int("QUERY_CHUNK_SIZE", default=10000)
FILTER_CHUNK_SIZE = env.int("FILTER_CHUNK_SIZE", default=5000)
# Whether to create triggers for all columns in a table (True) or only for columns defined in the schema (False).
TRIGGER_ALL_COLUMNS = env.bool("TRIGGER_ALL_COLUMNS", default=True)
# replication slot cleanup interval (in secs)
REPLICATION_SLOT_CLEANUP_INTERVAL = env.float(
"REPLICATION_SLOT_CLEANUP_INTERVAL",
Expand Down
86 changes: 76 additions & 10 deletions pgsync/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import defaultdict
from itertools import groupby
from pathlib import Path
from typing import Dict, List, Optional

import click
import sqlalchemy as sa
Expand Down Expand Up @@ -284,6 +285,7 @@ def setup(self, no_create: bool = False) -> None:
if_not_exists: bool = not no_create

join_queries: bool = settings.JOIN_QUERIES
trigger_all_columns: bool = settings.TRIGGER_ALL_COLUMNS

with self.advisory_lock(
self.database, max_retries=None, retry_interval=0.1
Expand All @@ -302,6 +304,8 @@ def setup(self, no_create: bool = False) -> None:
# tables with user defined foreign keys
user_defined_fkey_tables: dict = {}

table_columns: Optional[Dict[str, List[str]]] = {}

for node in self.tree.traverse_breadth_first():
if node.schema != schema:
continue
Expand All @@ -314,6 +318,8 @@ def setup(self, no_create: bool = False) -> None:
tables |= set([node.table])
# we also need to bootstrap the base tables
tables |= set(node.base_tables)
cols = list(node.columns) if node.columns else []
table_columns[node.table] = cols

# we want to get both the parent and the child keys here
# even though only one of them is the foreign_key.
Expand Down Expand Up @@ -342,6 +348,7 @@ def setup(self, no_create: bool = False) -> None:
self.create_triggers(
schema,
tables=tables,
table_columns=None if trigger_all_columns else table_columns,
join_queries=join_queries,
if_not_exists=if_not_exists,
)
Expand Down Expand Up @@ -505,6 +512,44 @@ def logical_slot_changes(
upto_lsn=upto_lsn,
)

def _root_primary_key_resolver_bulk(
self, node: Node, fields: dict, filters: list
) -> list:
for doc_id in self.search_client._search(
self.index, node.table, fields
):
where: dict = {}
params: dict = doc_id.split(PRIMARY_KEY_DELIMITER)
for i, key in enumerate(self.tree.root.model.primary_keys):
where[key] = params[i]
filters.append(where)

return filters

def _root_foreign_key_resolver_bulk(
self, node: Node, fields: dict, filters: list
) -> list:
"""
Bulk Foreign key resolver logic:

Lookup this value in the meta section of Elasticsearch/OpenSearch
Then get the root node returned and re-sync that root record.
Essentially, we want to lookup the root node affected by
our insert/update operation and sync the tree branch for that root.
"""
for doc_id in self.search_client._search(
self.index,
node.parent.table,
fields,
):
where: dict = {}
params: dict = doc_id.split(PRIMARY_KEY_DELIMITER)
for i, key in enumerate(self.tree.root.model.primary_keys):
where[key] = params[i]
filters.append(where)

return filters

def _root_primary_key_resolver(
self, node: Node, payload: Payload, filters: list
) -> list:
Expand Down Expand Up @@ -663,7 +708,7 @@ def _update_op(
) -> dict:
if node.is_root:
# Here, we are performing two operations:
# 1) Build a filter to sync the updated record(s)
# 1) Build a filter by using Bulk operation to sync the updated record(s)
# 2) Delete the old record(s) in Elasticsearch/OpenSearch if the
# primary key has changed
# 2.1) This is crucial otherwise we can have the old and new
Expand Down Expand Up @@ -715,12 +760,18 @@ def _update_op(

else:
# update the child tables
fields: dict = defaultdict(list)
foreign_fields: dict = defaultdict(list)
for payload in payloads:
_filters: list = []
_filters = self._root_primary_key_resolver(
node, payload, _filters
primary_values: list = [
payload.data[key] for key in node.model.primary_keys
]
primary_fields: dict = dict(
zip(node.model.primary_keys, primary_values)
)
# also handle foreign_keys
for key, value in primary_fields.items():
fields[key].append(value)
# _filters: list = self._root_primary_key_resolver(node, payload, [])
if node.parent:
try:
foreign_keys = self.query_builder.get_foreign_keys(
Expand All @@ -732,13 +783,28 @@ def _update_op(
node.parent,
node,
)

foreign_values: list = [
payload.new.get(key) for key in foreign_keys[node.name]
]
for key in [key.name for key in node.primary_keys]:
for value in foreign_values:
if value:
foreign_fields[key].append(value)
# _filters = self._root_foreign_key_resolver(
# node, payload, foreign_keys, _filters
# )

_filters = self._root_foreign_key_resolver(
node, payload, foreign_keys, _filters
)

if _filters:
filters[self.tree.root.table].extend(_filters)
# if _filters:
# filters[self.tree.root.table].extend(_filters)

_filters: list = self._root_primary_key_resolver_bulk(node, fields, [])
if _filters:
filters[self.tree.root.table].extend(_filters)
_forign_filters: list = self._root_foreign_key_resolver_bulk(node, foreign_fields, _filters)
if _forign_filters:
filters[self.tree.root.table].extend(_forign_filters)

return filters

Expand Down