diff --git a/examples/book/benchmark.py b/examples/book/benchmark.py index f804d780..0dd5f1aa 100644 --- a/examples/book/benchmark.py +++ b/examples/book/benchmark.py @@ -134,7 +134,13 @@ def truncate_op(session: sessionmaker, model, nsize: int) -> None: case_sensitive=False, ), ) -def main(config: str, nsize: int, daemon: bool, tg_op: str): +@click.option( + "--weight", "-w", default=0.0, help="Weight for pgsync operations" +) +def main( + config: str, nsize: int, daemon: bool, tg_op: str, weight: float +) -> None: + """Benchmarking script for Book model operations.""" show_settings(config) config: str = get_config(config) @@ -144,6 +150,9 @@ def main(config: str, nsize: int, daemon: bool, tg_op: str): Session = sessionmaker(bind=engine, autoflush=False, autocommit=False) session = Session() + if weight: + session.execute(sa.text(f"SET pgsync.weight = {weight}")) + model = Book func: dict = { INSERT: insert_op, diff --git a/pgsync/base.py b/pgsync/base.py index eadb5a7a..2bf8d223 100644 --- a/pgsync/base.py +++ b/pgsync/base.py @@ -73,9 +73,19 @@ class Payload(object): new (dict): The new values of the row that was affected by the event (for INSERT and UPDATE operations). xmin (int): The transaction ID of the event. indices (List[str]): The indices of the affected rows (for UPDATE and DELETE operations). + weight (float): The weight of the event. """ - __slots__ = ("tg_op", "table", "schema", "old", "new", "xmin", "indices") + __slots__ = ( + "tg_op", + "table", + "schema", + "old", + "new", + "xmin", + "indices", + "weight", + ) def __init__( self, @@ -86,6 +96,7 @@ def __init__( new: t.Optional[t.Dict[str, t.Any]] = None, xmin: t.Optional[int] = None, indices: t.Optional[t.List[str]] = None, + weight: t.Optional[float] = None, ): self.tg_op: t.Optional[str] = tg_op self.table: t.Optional[str] = table @@ -94,6 +105,7 @@ def __init__( self.new: t.Dict[str, t.Any] = new or {} self.xmin: t.Optional[int] = xmin self.indices: t.List[str] = indices + self.weight: float = weight @property def data(self) -> dict: diff --git a/pgsync/redisqueue.py b/pgsync/redisqueue.py index 92f9c173..8135074e 100644 --- a/pgsync/redisqueue.py +++ b/pgsync/redisqueue.py @@ -2,6 +2,7 @@ import json import logging +import time import typing as t from redis import Redis @@ -14,14 +15,18 @@ ) from .urls import get_redis_url +# Pick a MULTIPLIER > max timestamp_ms (~1.7e12). +# 10**13 is safe for now. +_MULTIPLIER = 10**13 + + logger = logging.getLogger(__name__) -class RedisQueue(object): - """Simple Queue with Redis Backend.""" +class RedisQueue: + """A Redis‐backed queue where items become poppable only once ready is True.""" def __init__(self, name: str, namespace: str = "queue", **kwargs): - """Init Simple Queue with Redis Backend.""" url: str = get_redis_url(**kwargs) self.key: str = f"{namespace}:{name}" self._meta_key: str = f"{self.key}:meta" @@ -38,34 +43,50 @@ def __init__(self, name: str, namespace: str = "queue", **kwargs): @property def qsize(self) -> int: - """Return the approximate size of the queue.""" - return self.__db.llen(self.key) - - def pop(self, chunk_size: t.Optional[int] = None) -> t.List[dict]: - """Remove and return multiple items from the queue.""" - chunk_size = chunk_size or REDIS_READ_CHUNK_SIZE - if self.qsize > 0: - pipeline = self.__db.pipeline() - pipeline.lrange(self.key, 0, chunk_size - 1) - pipeline.ltrim(self.key, chunk_size, -1) - items: t.List = pipeline.execute() - logger.debug(f"pop size: {len(items[0])}") - return list(map(lambda value: json.loads(value), items[0])) - - def push(self, items: t.List) -> None: - """Push multiple items onto the queue.""" - self.__db.rpush(self.key, *map(json.dumps, items)) + """Number of items currently in the ZSET (regardless of ready/not).""" + return self.__db.zcard(self.key) + + def push(self, items: t.List[dict], weight: float = 0.0) -> None: + """ + Push a batch of items with the given numeric weight. + + - Higher weight -> higher priority. + - Among equal weight, FIFO order. + """ + mapping: dict = {} + for item in items: + now_ms: int = int(time.time() * 1_000) + # score = -weight*M + timestamp + score = -weight * _MULTIPLIER + now_ms + mapping[json.dumps(item)] = score + # ZADD will add/update each member's score + self.__db.zadd(self.key, mapping) + + def pop(self, chunk_size: int = REDIS_READ_CHUNK_SIZE) -> t.List[dict]: + """ + Pop up to chunk_size highest priority items (by weight, then FIFO). + """ + # ZPOPMIN pulls the entries with the smallest score first + popped: t.List[t.Tuple[bytes, float]] = self.__db.zpopmin( + self.key, chunk_size + ) + results: t.List[dict] = [ + json.loads(member) for member, score in popped + ] + logger.debug(f"popped {len(results)} items (by priority)") + return results def delete(self) -> None: - """Delete all items from the named queue.""" - logger.info(f"Deleting redis key: {self.key}") + """Delete all items from the named queue including its metadata.""" + logger.info(f"Deleting redis key: {self.key} and {self._meta_key}") self.__db.delete(self.key) + self.__db.delete(self._meta_key) def set_meta(self, value: t.Any) -> None: - """Store an arbitrary JSON-serialisable value in a dedicated key.""" + """Store an arbitrary JSON‐serializable value in a dedicated key.""" self.__db.set(self._meta_key, json.dumps(value)) def get_meta(self, default: t.Any = None) -> t.Any: - """Retrieve the stored value (or *default* if nothing is set).""" + """Retrieve the stored metadata (or *default* if nothing is set).""" raw = self.__db.get(self._meta_key) return json.loads(raw) if raw is not None else default diff --git a/pgsync/sync.py b/pgsync/sync.py index 47c420cf..c731cd4f 100644 --- a/pgsync/sync.py +++ b/pgsync/sync.py @@ -13,6 +13,7 @@ import typing as t from collections import defaultdict from itertools import groupby +from math import inf from pathlib import Path import click @@ -1128,7 +1129,9 @@ def poll_redis(self) -> None: self._poll_redis() async def _async_poll_redis(self) -> None: - payloads: list = self.redis.pop() + payloads: t.List[t.Dict] = self.redis.pop( + settings.REDIS_AUTO_POP_READY_STATE + ) if payloads: logger.debug(f"_async_poll_redis: {payloads}") self.count["redis"] += len(payloads) @@ -1144,34 +1147,44 @@ async def async_poll_redis(self) -> None: while True: await self._async_poll_redis() + def _flush_payloads(self, payloads: list[dict]) -> None: + if not payloads: + return + + # group by weight, default=+inf so missing weights pop first + weight_buckets: t.Dict[float, t.List[t.Dict]] = {} + for payload in payloads: + raw = payload.get("weight") + weight: float = float(raw) if raw is not None else inf + weight_buckets.setdefault(weight, []).append(payload) + + # push each bucket in descending weight order (highest first) + for weight, items in sorted( + weight_buckets.items(), key=lambda kv: -kv[0] + ): + logger.debug(f"Pushing {len(items)} items with weight={weight}") + self.redis.push(items, weight=weight) + @threaded @exception def poll_db(self) -> None: - """ - Producer which polls Postgres continuously. - - Receive a notification message from the channel we are listening on - """ conn = self.engine.connect().connection conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - cursor = conn.cursor() - cursor.execute(f'LISTEN "{self.database}"') + conn.cursor().execute(f'LISTEN "{self.database}"') logger.debug( f'Listening to notifications on channel "{self.database}"' ) - payloads: list = [] + + payloads: t.List[t.Dict] = [] while True: - # NB: consider reducing POLL_TIMEOUT to increase throughput if select.select([conn], [], [], settings.POLL_TIMEOUT) == ( [], [], [], ): - # Catch any hanging items from the last poll - if payloads: - self.redis.push(payloads) - payloads = [] + self._flush_payloads(payloads) + payloads = [] continue try: @@ -1182,28 +1195,32 @@ def poll_db(self) -> None: while conn.notifies: if len(payloads) >= settings.REDIS_WRITE_CHUNK_SIZE: - self.redis.push(payloads) + self._flush_payloads(payloads) payloads = [] + notification: t.AnyStr = conn.notifies.pop(0) - if notification.channel == self.database: + if notification.channel != self.database: + continue - try: - payload = json.loads(notification.payload) - except json.JSONDecodeError as e: - logger.exception( - f"Error decoding JSON payload: {e}\n" - f"Payload: {notification.payload}" - ) - continue - if ( - payload["indices"] - and self.index in payload["indices"] - and payload["schema"] in self.tree.schemas - ): - payloads.append(payload) - logger.debug(f"poll_db: {payload}") - with self.lock: - self.count["db"] += 1 + try: + payload = json.loads(notification.payload) + except json.JSONDecodeError: + logger.exception("Invalid JSON in notification, skipping") + continue + + if ( + payload.get("indices") + and self.index in payload["indices"] + and payload.get("schema") in self.tree.schemas + ): + payloads.append(payload) + logger.debug(f"Queued payload: {payload}") + with self.lock: + self.count["db"] += 1 + + # flush anything left after draining notifications + self._flush_payloads(payloads) + payloads = [] @exception def async_poll_db(self) -> None: @@ -1220,16 +1237,29 @@ def async_poll_db(self) -> None: while self.conn.notifies: notification: t.AnyStr = self.conn.notifies.pop(0) - if notification.channel == self.database: + if notification.channel != self.database: + continue + + try: payload = json.loads(notification.payload) - if ( - payload["indices"] - and self.index in payload["indices"] - and payload["schema"] in self.tree.schemas - ): - self.redis.push([payload]) - logger.debug(f"async_poll: {payload}") - self.count["db"] += 1 + except json.JSONDecodeError as e: + logger.exception(f"Error decoding JSON payload: {e}") + continue + + if ( + payload.get("indices") + and self.index in payload["indices"] + and payload.get("schema") in self.tree.schemas + ): + # extract numeric weight (missing +inf for highest priority) + raw_w = payload.get("weight") + weight = float(raw_w) if raw_w is not None else inf + + # push via priority queue + self.redis.push([payload], weight=weight) + + logger.debug(f"async_poll: {payload} (weight={weight})") + self.count["db"] += 1 def refresh_views(self) -> None: self._refresh_views() @@ -1331,7 +1361,9 @@ def pull(self, polling: bool = False) -> None: if polling: return else: - raise + raise Exception( + f"Error while pulling logical slot changes: {e}" + ) from e self.checkpoint: int = txmax or self.txid_current self._truncate = True @@ -1350,9 +1382,32 @@ async def async_truncate_slots(self) -> None: await asyncio.sleep(settings.REPLICATION_SLOT_CLEANUP_INTERVAL) def _truncate_slots(self) -> None: - if self._truncate: - logger.debug(f"Truncating replication slot: {self.__name}") - self.logical_slot_get_changes(self.__name, upto_nchanges=None) + if not self._truncate: + return + + """ + Handle eventual consistency of the logical replication slot. + We retry logical_slot_changes a few times in case of replication slot in use error. + """ + retries: int = 3 + backoff: int = 1 + txmax: int = self.txid_current + upto_lsn: str = self.current_wal_lsn + + for attempt in range(1, retries + 1): + try: + logger.debug(f"Truncating replication slot: {self.__name}") + self.logical_slot_changes(txmax=txmax, upto_lsn=upto_lsn) + logger.debug("Truncation successful.") + break + except Exception as e: + logger.warning(f"Attempt {attempt} failed with {e}") + if attempt == retries: + logger.error("Max retries reached, raising exception.") + raise + sleep_time: int = backoff * (2 ** (attempt - 1)) + logger.debug(f"Retrying in {sleep_time} seconds...") + time.sleep(sleep_time) @threaded @exception diff --git a/pgsync/trigger.py b/pgsync/trigger.py index 4e690ae2..b3a33efc 100644 --- a/pgsync/trigger.py +++ b/pgsync/trigger.py @@ -18,11 +18,19 @@ _indices TEXT []; _primary_keys TEXT []; _foreign_keys TEXT []; - + weight NUMERIC := 0; BEGIN -- database is also the channel name. channel := CURRENT_DATABASE(); + -- load your numeric weight (default 0 if unset) + BEGIN + weight := CURRENT_SETTING('pgsync.weight', true)::NUMERIC; + EXCEPTION WHEN undefined_object THEN + -- setting not defined leave weight = 0 + NULL; + END; + IF TG_OP = 'DELETE' THEN SELECT primary_keys, indices @@ -71,7 +79,8 @@ 'indices', _indices, 'tg_op', TG_OP, 'table', TG_TABLE_NAME, - 'schema', TG_TABLE_SCHEMA + 'schema', TG_TABLE_SCHEMA, + 'weight', weight ); -- Notify/Listen updates occur asynchronously, diff --git a/tests/test_redisqueue.py b/tests/test_redisqueue.py index 39cf1c5e..a4f97eeb 100644 --- a/tests/test_redisqueue.py +++ b/tests/test_redisqueue.py @@ -1,10 +1,15 @@ """RedisQueues tests.""" +import json +import time +import typing as t + import pytest +from freezegun import freeze_time from mock import patch from redis.exceptions import ConnectionError -from pgsync.redisqueue import RedisQueue +from pgsync.redisqueue import _MULTIPLIER, RedisQueue class TestRedisQueue(object): @@ -69,11 +74,13 @@ def test_pop(self, mock_logger): queue.delete() queue.push([1, 2]) items = queue.pop() - mock_logger.debug.assert_called_once_with("pop size: 2") + mock_logger.debug.assert_called_once_with( + "popped 2 items (by priority)" + ) assert items == [1, 2] queue.push([3, 4, 5]) items = queue.pop() - mock_logger.debug.assert_any_call("pop size: 3") + mock_logger.debug.assert_any_call("popped 3 items (by priority)") assert items == [3, 4, 5] queue.delete() @@ -86,6 +93,47 @@ def test_delete(self, mock_logger): assert queue.qsize == 6 queue.delete() mock_logger.info.assert_called_once_with( - "Deleting redis key: queue:something" + "Deleting redis key: queue:something and queue:something:meta" ) assert queue.qsize == 0 + + @freeze_time("2025-06-25T12:00:00Z") + def test_push_and_pop_respects_weight_and_fifo(self): + queue: RedisQueue = RedisQueue("test") + a: dict = {"id": "A"} + b: dict = {"id": "B"} + c: dict = {"id": "C"} + # A has no explicit weight → default 0.0 + queue.push([a]) + # wait a millisecond for a different timestamp + time.sleep(0.001) + # B and C both weight=5 + queue.push([b], weight=5) + time.sleep(0.001) + queue.push([c], weight=5) + # popping 3 items + out = queue.pop(3) + # B then C (both weight=5, FIFO), then A (weight=0) + assert [x["id"] for x in out] == ["B", "C", "A"] + + @freeze_time("2024-06-25T12:00:00Z") + def test_push_adds_correct_scores(self): + queue: RedisQueue = RedisQueue("test") + items: t.List[t.Dict] = [{"id": 1}, {"id": 2}] + weight: float = 5.0 + with ( + patch.object(queue, "_RedisQueue__db") as mock_db, + patch( + "time.time", side_effect=[1_717_267_200.100, 1_717_267_200.200] + ), + ): + queue.push(items, weight=weight) + expected_mapping: dict = { + json.dumps({"id": 1}, sort_keys=True): -weight * _MULTIPLIER + + int(1_717_267_200.100 * 1_000), + json.dumps({"id": 2}, sort_keys=True): -weight * _MULTIPLIER + + int(1_717_267_200.200 * 1_000), + } + mock_db.zadd.assert_called_once_with( + "queue:test", expected_mapping + ) diff --git a/tests/test_sync.py b/tests/test_sync.py index 72d143ac..59cafb18 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -4,6 +4,7 @@ import os import typing as t from collections import namedtuple +from math import inf import pytest from mock import ANY, call, patch @@ -453,16 +454,19 @@ def test_status(self, sync): @patch("pgsync.sync.logger") def test_truncate_slots(self, mock_logger, sync): with patch( - "pgsync.sync.Sync.logical_slot_get_changes" + "pgsync.sync.Sync.logical_slot_changes" ) as mock_logical_slot_changes: sync._truncate = True sync._truncate_slots() mock_logical_slot_changes.assert_called_once_with( - "testdb_testdb", upto_nchanges=None - ) - mock_logger.debug.assert_called_once_with( - "Truncating replication slot: testdb_testdb" + txmax=ANY, upto_lsn=ANY ) + assert mock_logger.debug.call_args_list == [ + call( + "Truncating replication slot: testdb_testdb", + ), + call("Truncation successful."), + ] @patch("pgsync.sync.SearchClient.bulk") @patch("pgsync.sync.logger") @@ -1069,3 +1073,36 @@ def test_poll_redis( mock_logger.debug.assert_called_once_with(f"_poll_redis: {items}") mock_time.sleep.assert_called_once_with(settings.REDIS_POLL_INTERVAL) assert sync.count["redis"] == 2 + + def test_flush_groups_and_orders(self, sync): + + class DummyRedis: + def __init__(self): + self.calls = [] + + def push(self, items, weight): + self.calls.append((tuple(items), weight)) + + sync.redis = DummyRedis() + + # four payloads; one has no weight + payloads = [ + {"id": 1, "indices": ["x"], "schema": "public"}, + {"id": 2, "weight": 1, "indices": ["x"], "schema": "public"}, + {"id": 3, "weight": 2, "indices": ["x"], "schema": "public"}, + {"id": 4, "weight": 1, "indices": ["x"], "schema": "public"}, + ] + + sync._flush_payloads(payloads) + + # expect three pushes, in this order: + # 1) id=1 (no weight → inf) + # 2) id=3 (weight=2) + # 3) ids=2 and 4 (weight=1) + calls = sync.redis.calls + assert calls[0][1] == inf + assert calls[0][0] == (payloads[0],) + assert calls[1][1] == 2.0 + assert calls[1][0] == (payloads[2],) + assert calls[2][1] == 1.0 + assert set(item["id"] for item in calls[2][0]) == {2, 4} diff --git a/tests/test_trigger.py b/tests/test_trigger.py index e3f8c90c..e1984121 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -23,11 +23,19 @@ def test_trigger_template(self): _indices TEXT []; _primary_keys TEXT []; _foreign_keys TEXT []; - + weight NUMERIC := 0; BEGIN -- database is also the channel name. channel := CURRENT_DATABASE(); + -- load your numeric weight (default 0 if unset) + BEGIN + weight := CURRENT_SETTING('pgsync.weight', true)::NUMERIC; + EXCEPTION WHEN undefined_object THEN + -- setting not defined leave weight = 0 + NULL; + END; + IF TG_OP = 'DELETE' THEN SELECT primary_keys, indices @@ -76,7 +84,8 @@ def test_trigger_template(self): 'indices', _indices, 'tg_op', TG_OP, 'table', TG_TABLE_NAME, - 'schema', TG_TABLE_SCHEMA + 'schema', TG_TABLE_SCHEMA, + 'weight', weight ); -- Notify/Listen updates occur asynchronously,