Skip to content
Open
11 changes: 10 additions & 1 deletion examples/book/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,13 @@ def truncate_op(session: sessionmaker, model, nsize: int) -> None:
case_sensitive=False,
),
)
def main(config: str, nsize: int, daemon: bool, tg_op: str):
@click.option(
"--weight", "-w", default=0.0, help="Weight for pgsync operations"
)
def main(
config: str, nsize: int, daemon: bool, tg_op: str, weight: float
) -> None:
"""Benchmarking script for Book model operations."""
show_settings(config)

config: str = get_config(config)
Expand All @@ -144,6 +150,9 @@ def main(config: str, nsize: int, daemon: bool, tg_op: str):
Session = sessionmaker(bind=engine, autoflush=False, autocommit=False)
session = Session()

if weight:
Copy link

Copilot AI Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The condition 'if weight:' does not execute when weight is 0 (the default), which may be unintended. Consider explicitly checking for None or always setting the weight configuration regardless of its value.

Suggested change
if weight:
if weight is not None:

Copilot uses AI. Check for mistakes.
session.execute(sa.text(f"SET pgsync.weight = {weight}"))

model = Book
func: dict = {
INSERT: insert_op,
Expand Down
14 changes: 13 additions & 1 deletion pgsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,19 @@ class Payload(object):
new (dict): The new values of the row that was affected by the event (for INSERT and UPDATE operations).
xmin (int): The transaction ID of the event.
indices (List[str]): The indices of the affected rows (for UPDATE and DELETE operations).
weight (float): The weight of the event.
"""

__slots__ = ("tg_op", "table", "schema", "old", "new", "xmin", "indices")
__slots__ = (
"tg_op",
"table",
"schema",
"old",
"new",
"xmin",
"indices",
"weight",
)

def __init__(
self,
Expand All @@ -86,6 +96,7 @@ def __init__(
new: t.Optional[t.Dict[str, t.Any]] = None,
xmin: t.Optional[int] = None,
indices: t.Optional[t.List[str]] = None,
weight: t.Optional[float] = None,
):
self.tg_op: t.Optional[str] = tg_op
self.table: t.Optional[str] = table
Expand All @@ -94,6 +105,7 @@ def __init__(
self.new: t.Dict[str, t.Any] = new or {}
self.xmin: t.Optional[int] = xmin
self.indices: t.List[str] = indices
self.weight: float = weight

@property
def data(self) -> dict:
Expand Down
69 changes: 45 additions & 24 deletions pgsync/redisqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
import time
import typing as t

from redis import Redis
Expand All @@ -14,14 +15,18 @@
)
from .urls import get_redis_url

# Pick a MULTIPLIER > max timestamp_ms (~1.7e12).
# 10**13 is safe for now.
_MULTIPLIER = 10**13


logger = logging.getLogger(__name__)


class RedisQueue(object):
"""Simple Queue with Redis Backend."""
class RedisQueue:
"""A Redis‐backed queue where items become poppable only once ready is True."""

def __init__(self, name: str, namespace: str = "queue", **kwargs):
"""Init Simple Queue with Redis Backend."""
url: str = get_redis_url(**kwargs)
self.key: str = f"{namespace}:{name}"
self._meta_key: str = f"{self.key}:meta"
Expand All @@ -38,34 +43,50 @@ def __init__(self, name: str, namespace: str = "queue", **kwargs):

@property
def qsize(self) -> int:
"""Return the approximate size of the queue."""
return self.__db.llen(self.key)

def pop(self, chunk_size: t.Optional[int] = None) -> t.List[dict]:
"""Remove and return multiple items from the queue."""
chunk_size = chunk_size or REDIS_READ_CHUNK_SIZE
if self.qsize > 0:
pipeline = self.__db.pipeline()
pipeline.lrange(self.key, 0, chunk_size - 1)
pipeline.ltrim(self.key, chunk_size, -1)
items: t.List = pipeline.execute()
logger.debug(f"pop size: {len(items[0])}")
return list(map(lambda value: json.loads(value), items[0]))

def push(self, items: t.List) -> None:
"""Push multiple items onto the queue."""
self.__db.rpush(self.key, *map(json.dumps, items))
"""Number of items currently in the ZSET (regardless of ready/not)."""
return self.__db.zcard(self.key)

def push(self, items: t.List[dict], weight: float = 0.0) -> None:
"""
Push a batch of items with the given numeric weight.

- Higher weight -> higher priority.
- Among equal weight, FIFO order.
"""
mapping: dict = {}
for item in items:
now_ms: int = int(time.time() * 1_000)
# score = -weight*M + timestamp
score = -weight * _MULTIPLIER + now_ms
mapping[json.dumps(item)] = score
# ZADD will add/update each member's score
self.__db.zadd(self.key, mapping)

def pop(self, chunk_size: int = REDIS_READ_CHUNK_SIZE) -> t.List[dict]:
"""
Pop up to chunk_size highest priority items (by weight, then FIFO).
"""
# ZPOPMIN pulls the entries with the smallest score first
popped: t.List[t.Tuple[bytes, float]] = self.__db.zpopmin(
self.key, chunk_size
)
results: t.List[dict] = [
json.loads(member) for member, score in popped
]
logger.debug(f"popped {len(results)} items (by priority)")
return results

def delete(self) -> None:
"""Delete all items from the named queue."""
logger.info(f"Deleting redis key: {self.key}")
"""Delete all items from the named queue including its metadata."""
logger.info(f"Deleting redis key: {self.key} and {self._meta_key}")
self.__db.delete(self.key)
self.__db.delete(self._meta_key)

def set_meta(self, value: t.Any) -> None:
"""Store an arbitrary JSON-serialisable value in a dedicated key."""
"""Store an arbitrary JSON‐serializable value in a dedicated key."""
self.__db.set(self._meta_key, json.dumps(value))

def get_meta(self, default: t.Any = None) -> t.Any:
"""Retrieve the stored value (or *default* if nothing is set)."""
"""Retrieve the stored metadata (or *default* if nothing is set)."""
raw = self.__db.get(self._meta_key)
return json.loads(raw) if raw is not None else default
147 changes: 101 additions & 46 deletions pgsync/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import typing as t
from collections import defaultdict
from itertools import groupby
from math import inf
from pathlib import Path

import click
Expand Down Expand Up @@ -1128,7 +1129,9 @@
self._poll_redis()

async def _async_poll_redis(self) -> None:
payloads: list = self.redis.pop()
payloads: t.List[t.Dict] = self.redis.pop(

Check warning on line 1132 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1132

Added line #L1132 was not covered by tests
settings.REDIS_AUTO_POP_READY_STATE
)
if payloads:
logger.debug(f"_async_poll_redis: {payloads}")
self.count["redis"] += len(payloads)
Expand All @@ -1144,34 +1147,44 @@
while True:
await self._async_poll_redis()

def _flush_payloads(self, payloads: list[dict]) -> None:
if not payloads:
return

Check warning on line 1152 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1152

Added line #L1152 was not covered by tests

# group by weight, default=+inf so missing weights pop first
weight_buckets: t.Dict[float, t.List[t.Dict]] = {}
for payload in payloads:
raw = payload.get("weight")
weight: float = float(raw) if raw is not None else inf
weight_buckets.setdefault(weight, []).append(payload)

# push each bucket in descending weight order (highest first)
for weight, items in sorted(
weight_buckets.items(), key=lambda kv: -kv[0]
):
logger.debug(f"Pushing {len(items)} items with weight={weight}")
self.redis.push(items, weight=weight)

@threaded
@exception
def poll_db(self) -> None:
"""
Producer which polls Postgres continuously.

Receive a notification message from the channel we are listening on
"""
conn = self.engine.connect().connection
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
cursor = conn.cursor()
cursor.execute(f'LISTEN "{self.database}"')
conn.cursor().execute(f'LISTEN "{self.database}"')

Check warning on line 1173 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1173

Added line #L1173 was not covered by tests
logger.debug(
f'Listening to notifications on channel "{self.database}"'
)
payloads: list = []

payloads: t.List[t.Dict] = []

Check warning on line 1178 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1178

Added line #L1178 was not covered by tests

while True:
# NB: consider reducing POLL_TIMEOUT to increase throughput
if select.select([conn], [], [], settings.POLL_TIMEOUT) == (
[],
[],
[],
):
# Catch any hanging items from the last poll
if payloads:
self.redis.push(payloads)
payloads = []
self._flush_payloads(payloads)
Copy link

Copilot AI Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After flushing payloads, the payloads list is not cleared. This may lead to duplicate processing of already flushed items; consider resetting the payloads list (e.g., payloads = []) after calling _flush_payloads.

Suggested change
self._flush_payloads(payloads)
self._flush_payloads(payloads)
payloads = []

Copilot uses AI. Check for mistakes.
payloads = []

Check warning on line 1187 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1186-L1187

Added lines #L1186 - L1187 were not covered by tests
continue

try:
Expand All @@ -1182,28 +1195,32 @@

while conn.notifies:
if len(payloads) >= settings.REDIS_WRITE_CHUNK_SIZE:
self.redis.push(payloads)
self._flush_payloads(payloads)

Check warning on line 1198 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1198

Added line #L1198 was not covered by tests
payloads = []

notification: t.AnyStr = conn.notifies.pop(0)
if notification.channel == self.database:
if notification.channel != self.database:
continue

Check warning on line 1203 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1202-L1203

Added lines #L1202 - L1203 were not covered by tests

try:
payload = json.loads(notification.payload)
except json.JSONDecodeError as e:
logger.exception(
f"Error decoding JSON payload: {e}\n"
f"Payload: {notification.payload}"
)
continue
if (
payload["indices"]
and self.index in payload["indices"]
and payload["schema"] in self.tree.schemas
):
payloads.append(payload)
logger.debug(f"poll_db: {payload}")
with self.lock:
self.count["db"] += 1
try:
payload = json.loads(notification.payload)
except json.JSONDecodeError:
logger.exception("Invalid JSON in notification, skipping")
continue

Check warning on line 1209 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1205-L1209

Added lines #L1205 - L1209 were not covered by tests

if (

Check warning on line 1211 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1211

Added line #L1211 was not covered by tests
payload.get("indices")
and self.index in payload["indices"]
and payload.get("schema") in self.tree.schemas
):
payloads.append(payload)
logger.debug(f"Queued payload: {payload}")
with self.lock:
self.count["db"] += 1

Check warning on line 1219 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1216-L1219

Added lines #L1216 - L1219 were not covered by tests

# flush anything left after draining notifications
self._flush_payloads(payloads)
payloads = []

Check warning on line 1223 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1222-L1223

Added lines #L1222 - L1223 were not covered by tests

@exception
def async_poll_db(self) -> None:
Expand All @@ -1220,16 +1237,29 @@

while self.conn.notifies:
notification: t.AnyStr = self.conn.notifies.pop(0)
if notification.channel == self.database:
if notification.channel != self.database:
continue

Check warning on line 1241 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1240-L1241

Added lines #L1240 - L1241 were not covered by tests

try:

Check warning on line 1243 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1243

Added line #L1243 was not covered by tests
payload = json.loads(notification.payload)
if (
payload["indices"]
and self.index in payload["indices"]
and payload["schema"] in self.tree.schemas
):
self.redis.push([payload])
logger.debug(f"async_poll: {payload}")
self.count["db"] += 1
except json.JSONDecodeError as e:
logger.exception(f"Error decoding JSON payload: {e}")
continue

Check warning on line 1247 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1245-L1247

Added lines #L1245 - L1247 were not covered by tests

if (

Check warning on line 1249 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1249

Added line #L1249 was not covered by tests
payload.get("indices")
and self.index in payload["indices"]
and payload.get("schema") in self.tree.schemas
):
# extract numeric weight (missing +inf for highest priority)
raw_w = payload.get("weight")
weight = float(raw_w) if raw_w is not None else inf

Check warning on line 1256 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1255-L1256

Added lines #L1255 - L1256 were not covered by tests

# push via priority queue
self.redis.push([payload], weight=weight)

Check warning on line 1259 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1259

Added line #L1259 was not covered by tests

logger.debug(f"async_poll: {payload} (weight={weight})")
self.count["db"] += 1

Check warning on line 1262 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1261-L1262

Added lines #L1261 - L1262 were not covered by tests

def refresh_views(self) -> None:
self._refresh_views()
Expand Down Expand Up @@ -1331,7 +1361,9 @@
if polling:
return
else:
raise
raise Exception(

Check warning on line 1364 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1364

Added line #L1364 was not covered by tests
f"Error while pulling logical slot changes: {e}"
) from e
self.checkpoint: int = txmax or self.txid_current
self._truncate = True

Expand All @@ -1350,9 +1382,32 @@
await asyncio.sleep(settings.REPLICATION_SLOT_CLEANUP_INTERVAL)

def _truncate_slots(self) -> None:
if self._truncate:
logger.debug(f"Truncating replication slot: {self.__name}")
self.logical_slot_get_changes(self.__name, upto_nchanges=None)
if not self._truncate:
return

Check warning on line 1386 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1386

Added line #L1386 was not covered by tests

"""
Handle eventual consistency of the logical replication slot.
We retry logical_slot_changes a few times in case of replication slot in use error.
"""
retries: int = 3
backoff: int = 1
txmax: int = self.txid_current
upto_lsn: str = self.current_wal_lsn

for attempt in range(1, retries + 1):
try:
logger.debug(f"Truncating replication slot: {self.__name}")
self.logical_slot_changes(txmax=txmax, upto_lsn=upto_lsn)
logger.debug("Truncation successful.")
break
except Exception as e:
logger.warning(f"Attempt {attempt} failed with {e}")
if attempt == retries:
logger.error("Max retries reached, raising exception.")
raise
sleep_time: int = backoff * (2 ** (attempt - 1))
logger.debug(f"Retrying in {sleep_time} seconds...")
time.sleep(sleep_time)

Check warning on line 1410 in pgsync/sync.py

View check run for this annotation

Codecov / codecov/patch

pgsync/sync.py#L1403-L1410

Added lines #L1403 - L1410 were not covered by tests

@threaded
@exception
Expand Down
Loading
Loading