diff --git a/.env.sample b/.env.sample index 65c599c7..37bd921a 100644 --- a/.env.sample +++ b/.env.sample @@ -21,6 +21,9 @@ # USE_ASYNC=False # JOIN_QUERIES=True # STREAM_RESULTS=True +# db polling interval +# POLL_INTERVAL=0.1 +# FILTER_CHUNK_SIZE=5000 # Elasticsearch # ELASTICSEARCH_SCHEME=http diff --git a/.github/workflows/python-build.yml b/.github/workflows/python-build.yml index 333f55c4..f4dfca02 100644 --- a/.github/workflows/python-build.yml +++ b/.github/workflows/python-build.yml @@ -8,7 +8,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.7, 3.8, 3.9, '3.10'] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] services: postgres: image: debezium/postgres:15 @@ -25,7 +25,7 @@ jobs: ports: - 6379:6379 elasticsearch: - image: docker.elastic.co/elasticsearch/elasticsearch:7.17.6 + image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7 ports: - 9200:9200 - 9300:9300 @@ -34,9 +34,9 @@ jobs: network.host: 127.0.0.1 http.host: 0.0.0.0 steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/bin/bootstrap b/bin/bootstrap index 4e29c387..c8d6e3ad 100755 --- a/bin/bootstrap +++ b/bin/bootstrap @@ -6,7 +6,7 @@ import logging import click from pgsync.sync import Sync -from pgsync.utils import get_config, load_config, show_settings +from pgsync.utils import config_loader, get_config, show_settings logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ def main(teardown, config, user, password, host, port, verbose): show_settings(config) - for document in load_config(config): + for document in config_loader(config): sync: Sync = Sync( document, verbose=verbose, repl_slots=False, **kwargs ) diff --git a/bin/es_mapping b/bin/es_mapping index 49769624..fb290bff 100755 --- a/bin/es_mapping +++ b/bin/es_mapping @@ -8,7 +8,7 @@ from elasticsearch import Elasticsearch, helpers from pgsync.settings import ELASTICSEARCH_TIMEOUT, ELASTICSEARCH_VERIFY_CERTS from pgsync.urls import get_elasticsearch_url -from pgsync.utils import get_config, load_config, timeit +from pgsync.utils import config_loader, get_config, timeit logger = logging.getLogger(__name__) @@ -114,7 +114,9 @@ def main(config): """Create custom NGram analyzer for the default mapping.""" config: str = get_config(config) - for index in set([document["index"] for document in load_config(config)]): + for index in set( + [document["index"] for document in config_loader(config)] + ): create_es_mapping(index) diff --git a/bin/parallel_sync b/bin/parallel_sync index a7fe3ecc..ddc29135 100755 --- a/bin/parallel_sync +++ b/bin/parallel_sync @@ -44,6 +44,7 @@ and row numbers. import asyncio import multiprocessing import os +import re import sys from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass @@ -54,9 +55,32 @@ from typing import Generator, Optional, Union import click import sqlalchemy as sa -from pgsync.settings import BLOCK_SIZE +from pgsync.settings import BLOCK_SIZE, CHECKPOINT_PATH from pgsync.sync import Sync -from pgsync.utils import get_config, load_config, show_settings, timeit +from pgsync.utils import ( + compiled_query, + config_loader, + get_config, + show_settings, + timeit, +) + + +def save_ctid(page: int, row: int, name: str) -> None: + checkpoint_file: str = os.path.join(CHECKPOINT_PATH, f".{name}.ctid") + with open(checkpoint_file, "w+") as fp: + fp.write(f"{page},{row}\n") + + +def read_ctid(name: str) -> None: + checkpoint_file: str = os.path.join(CHECKPOINT_PATH, f".{name}.ctid") + if os.path.exists(checkpoint_file): + with open(checkpoint_file, "r") as fp: + pairs: str = fp.read().split()[0].split(",") + page = int(pairs[0]) + row = int(pairs[1]) + return page, row + return None, None def logical_slot_changes( @@ -92,10 +116,19 @@ class Task: @timeit -def fetch_tasks(doc: dict, block_size: Optional[int] = None) -> Generator: +def fetch_tasks( + doc: dict, + block_size: Optional[int] = None, +) -> Generator: block_size = block_size or BLOCK_SIZE pages: dict = {} sync: Sync = Sync(doc) + page: Optional[int] = None + row: Optional[int] = None + name: str = re.sub( + "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}" + ) + page, row = read_ctid(name=name) statement: sa.sql.Select = sa.select( [ sa.literal_column("1").label("x"), @@ -103,6 +136,51 @@ def fetch_tasks(doc: dict, block_size: Optional[int] = None) -> Generator: sa.column("ctid"), ] ).select_from(sync.tree.root.model) + + # filter by Page + if page: + statement = statement.where( + sa.cast( + sa.func.SPLIT_PART( + sa.func.REPLACE( + sa.func.REPLACE( + sa.cast(sa.column("ctid"), sa.Text), + "(", + "", + ), + ")", + "", + ), + ",", + 1, + ), + sa.Integer, + ) + > page + ) + + # filter by Row + if row: + statement = statement.where( + sa.cast( + sa.func.SPLIT_PART( + sa.func.REPLACE( + sa.func.REPLACE( + sa.cast(sa.column("ctid"), sa.Text), + "(", + "", + ), + ")", + "", + ), + ",", + 2, + ), + sa.Integer, + ) + > row + ) + i: int = 1 for _, _, ctid in sync.fetchmany(statement): value: list = ctid[0].split(",") @@ -175,7 +253,6 @@ def multithreaded( queue.put(task) queue.join() # block until all tasks are done - logical_slot_changes(doc, verbose=verbose, validate=validate) @@ -194,7 +271,6 @@ def multiprocess( list(executor.map(task.process, tasks)) except Exception as e: sys.stdout.write(f"Exception: {e}\n") - logical_slot_changes(doc, verbose=verbose, validate=validate) @@ -209,13 +285,9 @@ def multithreaded_async( sys.stdout.write("Multi-threaded async\n") executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=nprocs) event_loop = asyncio.get_event_loop() - try: - event_loop.run_until_complete( - run_tasks(executor, tasks, doc, verbose=verbose, validate=validate) - ) - finally: - event_loop.close() - + event_loop.run_until_complete( + run_tasks(executor, tasks, doc, verbose=verbose, validate=validate) + ) logical_slot_changes(doc, verbose=verbose, validate=validate) @@ -234,9 +306,8 @@ def multiprocess_async( event_loop.run_until_complete( run_tasks(executor, tasks, doc, verbose=verbose, validate=validate) ) - finally: - event_loop.close() - + except KeyboardInterrupt: + pass logical_slot_changes(doc, verbose=verbose, validate=validate) @@ -247,24 +318,19 @@ async def run_tasks( verbose: bool = False, validate: bool = False, ) -> None: - event_loop = asyncio.get_event_loop() + sync: Optional[Sync] = None if isinstance(executor, ThreadPoolExecutor): # threads can share a common Sync object - sync: Sync = Sync(doc, verbose=verbose, validate=validate) - tasks: list = [ - event_loop.run_in_executor( - executor, run_task, task, sync, None, verbose, validate - ) - for task in tasks - ] - else: - tasks = [ + sync = Sync(doc, verbose=verbose, validate=validate) + event_loop = asyncio.get_event_loop() + completed, pending = await asyncio.wait( + [ event_loop.run_in_executor( - executor, run_task, task, None, doc, verbose, validate + executor, run_task, task, sync, doc, verbose, validate ) for task in tasks ] - completed, pending = await asyncio.wait(tasks) + ) results: list = [task.result() for task in completed] print("results: {!r}".format(results)) print("exiting") @@ -286,7 +352,14 @@ def run_task( sync.index, sync.sync(ctid=task, txmin=txmin, txmax=txmax), ) - print("run_task complete") + if len(task) > 0: + page: int = max(task.keys()) + row: int = max(task[page]) + name: str = re.sub( + "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}" + ) + save_ctid(page=page, row=row, name=name) + return 1 @@ -331,14 +404,13 @@ def main(config, nprocs, mode, verbose): """ TODO: - Track progress across cpus/threads - - Save ctid - Handle KeyboardInterrupt Exception """ show_settings() config: str = get_config(config) - for document in load_config(config): + for document in config_loader(config): tasks: Generator = fetch_tasks(document) if mode == "synchronous": synchronous(tasks, document, verbose=verbose) diff --git a/docker-compose.yml b/docker-compose.yml index 849c0774..1d518b1e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,7 @@ services: image: redis command: redis-server --requirepass PLEASE_CHANGE_ME elasticsearch: - image: docker.elastic.co/elasticsearch/elasticsearch:7.17.6 + image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7 ports: - "9201:9200" - "9301:9300" diff --git a/examples/airbnb/data.py b/examples/airbnb/data.py index f397764f..803285c3 100644 --- a/examples/airbnb/data.py +++ b/examples/airbnb/data.py @@ -7,7 +7,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -21,7 +21,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - document: dict = next(load_config(config)) + document: dict = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=True) diff --git a/examples/airbnb/schema.py b/examples/airbnb/schema.py index ab3c3e23..11adda26 100644 --- a/examples/airbnb/schema.py +++ b/examples/airbnb/schema.py @@ -7,7 +7,7 @@ from pgsync.base import create_database, pg_engine from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -91,7 +91,7 @@ class Review(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) create_database(database) with pg_engine(database) as engine: diff --git a/examples/ancestry/data.py b/examples/ancestry/data.py index c06059ed..f9d0c22e 100644 --- a/examples/ancestry/data.py +++ b/examples/ancestry/data.py @@ -4,7 +4,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -18,7 +18,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - document: dict = next(load_config(config)) + document: dict = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=True) diff --git a/examples/ancestry/schema.py b/examples/ancestry/schema.py index a3ee360f..787c6c3a 100644 --- a/examples/ancestry/schema.py +++ b/examples/ancestry/schema.py @@ -4,7 +4,7 @@ from pgsync.base import create_database, pg_engine from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -49,7 +49,7 @@ class GreatGrandChild(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) create_database(database) with pg_engine(database) as engine: diff --git a/examples/book/benchmark.py b/examples/book/benchmark.py index d5c78ac9..0572bb73 100644 --- a/examples/book/benchmark.py +++ b/examples/book/benchmark.py @@ -9,7 +9,7 @@ from pgsync.base import pg_engine from pgsync.constants import DELETE, INSERT, TG_OP, TRUNCATE, UPDATE -from pgsync.utils import get_config, load_config, show_settings, Timer +from pgsync.utils import config_loader, get_config, show_settings, Timer FIELDS = { "isbn": "isbn13", @@ -139,7 +139,7 @@ def main(config, nsize, daemon, tg_op): show_settings() config: str = get_config(config) - document: dict = next(load_config(config)) + document: dict = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=False, autocommit=False) diff --git a/examples/book/data.py b/examples/book/data.py index f7ea614b..082f03ee 100644 --- a/examples/book/data.py +++ b/examples/book/data.py @@ -25,7 +25,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.constants import DEFAULT_SCHEMA from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -41,7 +41,7 @@ def main(config, nsize): config: str = get_config(config) teardown(drop_db=False, config=config) - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) with pg_engine(database) as engine: diff --git a/examples/book/schema.py b/examples/book/schema.py index 8175ed93..22eac2bd 100644 --- a/examples/book/schema.py +++ b/examples/book/schema.py @@ -6,7 +6,7 @@ from pgsync.base import create_database, create_schema, pg_engine from pgsync.constants import DEFAULT_SCHEMA from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -199,7 +199,7 @@ class BookShelf(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) schema: str = document.get("schema", DEFAULT_SCHEMA) create_database(database) diff --git a/examples/book_view/benchmark.py b/examples/book_view/benchmark.py index ae693b42..0b791047 100644 --- a/examples/book_view/benchmark.py +++ b/examples/book_view/benchmark.py @@ -9,7 +9,7 @@ from pgsync.base import pg_engine from pgsync.constants import DELETE, INSERT, TG_OP, UPDATE -from pgsync.utils import get_config, load_config, show_settings, Timer +from pgsync.utils import config_loader, get_config, show_settings, Timer FIELDS = { "isbn": "isbn13", @@ -132,7 +132,7 @@ def main(config, nsize, daemon, tg_op): show_settings() config: str = get_config(config) - document: dict = next(load_config(config)) + document: dict = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=False, autocommit=False) diff --git a/examples/book_view/data.py b/examples/book_view/data.py index 84c0c5d5..34aa8960 100644 --- a/examples/book_view/data.py +++ b/examples/book_view/data.py @@ -6,7 +6,7 @@ from pgsync.constants import DEFAULT_SCHEMA from pgsync.helper import teardown from pgsync.sync import Sync -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -21,7 +21,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) with pg_engine(database) as engine: diff --git a/examples/book_view/schema.py b/examples/book_view/schema.py index b0d44e39..c43e6f79 100644 --- a/examples/book_view/schema.py +++ b/examples/book_view/schema.py @@ -6,7 +6,7 @@ from pgsync.base import create_database, create_schema, pg_engine from pgsync.constants import DEFAULT_SCHEMA from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config from pgsync.view import CreateView Base = declarative_base() @@ -38,7 +38,7 @@ class Book(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) schema: str = document.get("schema", DEFAULT_SCHEMA) create_database(database) diff --git a/examples/node/data.py b/examples/node/data.py index fc8a67f6..21865de5 100644 --- a/examples/node/data.py +++ b/examples/node/data.py @@ -6,7 +6,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -20,7 +20,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - document = next(load_config(config)) + document = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=True) diff --git a/examples/node/schema.py b/examples/node/schema.py index f7c1a4f3..4889e7d5 100644 --- a/examples/node/schema.py +++ b/examples/node/schema.py @@ -4,7 +4,7 @@ from pgsync.base import create_database, pg_engine from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -18,7 +18,7 @@ class Node(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) create_database(database) with pg_engine(database) as engine: diff --git a/examples/quiz/data.py b/examples/quiz/data.py index a04e36cf..79779f22 100644 --- a/examples/quiz/data.py +++ b/examples/quiz/data.py @@ -4,7 +4,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -18,7 +18,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - document = next(load_config(config)) + document = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=True) diff --git a/examples/quiz/schema.py b/examples/quiz/schema.py index 67e9fb2d..50b1846b 100644 --- a/examples/quiz/schema.py +++ b/examples/quiz/schema.py @@ -5,7 +5,7 @@ from pgsync.base import create_database, pg_engine from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -103,7 +103,7 @@ class RealAnswer(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) create_database(database) with pg_engine(database) as engine: diff --git a/examples/schemas/data.py b/examples/schemas/data.py index c11039c3..b9bd6f16 100644 --- a/examples/schemas/data.py +++ b/examples/schemas/data.py @@ -4,7 +4,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -18,7 +18,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - document = next(load_config(config)) + document = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=True) diff --git a/examples/schemas/schema.py b/examples/schemas/schema.py index 88334c32..c76e9f9e 100644 --- a/examples/schemas/schema.py +++ b/examples/schemas/schema.py @@ -4,7 +4,7 @@ from pgsync.base import create_database, create_schema, pg_engine from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -25,7 +25,7 @@ class Child(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) create_database(database) for schema in ("parent", "child"): diff --git a/examples/social/data.py b/examples/social/data.py index 1f67e261..2fa19bbd 100644 --- a/examples/social/data.py +++ b/examples/social/data.py @@ -5,7 +5,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -19,7 +19,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - document: dict = next(load_config(config)) + document: dict = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=True) diff --git a/examples/social/schema.py b/examples/social/schema.py index fd78ea66..d319044c 100644 --- a/examples/social/schema.py +++ b/examples/social/schema.py @@ -5,7 +5,7 @@ from pgsync.base import create_database, pg_engine from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -88,7 +88,7 @@ class UserTag(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) create_database(database) with pg_engine(database) as engine: diff --git a/examples/starcraft/data.py b/examples/starcraft/data.py index c876951d..f530678c 100644 --- a/examples/starcraft/data.py +++ b/examples/starcraft/data.py @@ -4,7 +4,7 @@ from pgsync.base import pg_engine, subtransactions from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config @click.command() @@ -18,7 +18,7 @@ def main(config): config: str = get_config(config) teardown(drop_db=False, config=config) - document = next(load_config(config)) + document = next(config_loader(config)) database: str = document.get("database", document["index"]) with pg_engine(database) as engine: Session = sessionmaker(bind=engine, autoflush=True) diff --git a/examples/starcraft/schema.py b/examples/starcraft/schema.py index 8452f62f..e207085a 100644 --- a/examples/starcraft/schema.py +++ b/examples/starcraft/schema.py @@ -5,7 +5,7 @@ from pgsync.base import create_database, pg_engine from pgsync.helper import teardown -from pgsync.utils import get_config, load_config +from pgsync.utils import config_loader, get_config Base = declarative_base() @@ -46,7 +46,7 @@ class Structure(Base): def setup(config: str) -> None: - for document in load_config(config): + for document in config_loader(config): database: str = document.get("database", document["index"]) create_database(database) with pg_engine(database) as engine: diff --git a/pgsync/__init__.py b/pgsync/__init__.py index 0156ea8f..5d441df0 100644 --- a/pgsync/__init__.py +++ b/pgsync/__init__.py @@ -4,4 +4,4 @@ __author__ = "Tolu Aina" __email__ = "tolu@pgsync.com" -__version__ = "2.3.3" +__version__ = "2.3.4" diff --git a/pgsync/base.py b/pgsync/base.py index fbaebfc8..efdbeaa8 100644 --- a/pgsync/base.py +++ b/pgsync/base.py @@ -179,7 +179,7 @@ def models(self, table: str, schema: str) -> sa.sql.Alias: model = metadata.tables[name] model.append_column(sa.Column("xmin", sa.BigInteger)) model.append_column(sa.Column("ctid"), TupleIdentifierType) - # support SQLQlchemy/Postgres 14 which somehow now reflects + # support SQLAlchemy/Postgres 14 which somehow now reflects # the oid column if "oid" not in [column.name for column in model.columns]: model.append_column( @@ -249,8 +249,9 @@ def _materialized_views(self, schema: str) -> list: def indices(self, table: str, schema: str) -> list: """Get the database table indexes.""" if (table, schema) not in self.__indices: + indexes = sa.inspect(self.engine).get_indexes(table, schema=schema) self.__indices[(table, schema)] = sorted( - sa.inspect(self.engine).get_indexes(table, schema=schema) + indexes, key=lambda d: d["name"] ) return self.__indices[(table, schema)] @@ -429,7 +430,7 @@ def logical_slot_get_changes( limit=limit, offset=offset, ) - self.execute(statement) + self.execute(statement, options=dict(stream_results=STREAM_RESULTS)) def logical_slot_peek_changes( self, @@ -980,6 +981,23 @@ def drop_database(database: str, echo: bool = False) -> None: logger.debug(f"Dropped database: {database}") +def database_exists(database: str, echo: bool = False) -> bool: + """Check if database is present.""" + with pg_engine("postgres", echo=echo) as engine: + conn = engine.connect() + try: + row = conn.execute( + sa.DDL( + f"SELECT 1 FROM pg_database WHERE datname = '{database}'" + ) + ).first() + conn.close() + except Exception as e: + logger.exception(f"Exception {e}") + raise + return row is not None + + def create_extension( database: str, extension: str, echo: bool = False ) -> None: diff --git a/pgsync/elastichelper.py b/pgsync/elastichelper.py index 077b6721..05e1262f 100644 --- a/pgsync/elastichelper.py +++ b/pgsync/elastichelper.py @@ -5,6 +5,7 @@ import boto3 from elasticsearch import Elasticsearch, helpers, RequestsHttpConnection +from elasticsearch.exceptions import RequestError from elasticsearch_dsl import Q, Search from elasticsearch_dsl.query import Bool from requests_aws4auth import AWS4Auth @@ -227,8 +228,13 @@ def _search(self, index: str, table: str, fields: Optional[dict] = None): ] ) ) - for hit in search.scan(): - yield hit.meta.id + try: + for hit in search.scan(): + yield hit.meta.id + except RequestError as e: + logger.warning(f"RequestError: {e}") + if "is out of range for a long" not in str(e): + raise def search(self, index: str, body: dict): """ diff --git a/pgsync/exc.py b/pgsync/exc.py index 79e13c4b..11b99f69 100644 --- a/pgsync/exc.py +++ b/pgsync/exc.py @@ -77,6 +77,14 @@ def __str__(self): return repr(self.value) +class InvalidTGOPError(Exception): + def __init__(self, value): + self.value = value + + def __str__(self): + return repr(self.value) + + class NodeAttributeError(Exception): def __init__(self, value): self.value = value diff --git a/pgsync/helper.py b/pgsync/helper.py index b05fa15b..3f314dad 100644 --- a/pgsync/helper.py +++ b/pgsync/helper.py @@ -5,9 +5,9 @@ import sqlalchemy as sa -from .base import drop_database +from .base import database_exists, drop_database from .sync import Sync -from .utils import get_config, load_config +from .utils import config_loader, get_config logger = logging.getLogger(__name__) @@ -22,9 +22,14 @@ def teardown( validate: bool = False, ) -> None: """Teardown helper.""" - config = get_config(config) + config: str = get_config(config) + + for document in config_loader(config): + + if not database_exists(document["database"]): + logger.warning(f'Database {document["database"]} does not exist') + continue - for document in load_config(config): sync: Sync = Sync(document, validate=validate) if truncate_db: try: diff --git a/pgsync/node.py b/pgsync/node.py index 69987368..5752c9a4 100644 --- a/pgsync/node.py +++ b/pgsync/node.py @@ -25,6 +25,7 @@ RelationshipForeignKeyError, RelationshipTypeError, RelationshipVariantError, + SchemaError, TableNotInNodeError, ) @@ -164,7 +165,7 @@ def setup(self): for column_name in self.column_names: - tokens = None + tokens: Optional[list] = None if any(op in column_name for op in JSONB_OPERATORS): tokens = re.split( f"({'|'.join(JSONB_OPERATORS)})", @@ -275,7 +276,10 @@ def traverse_post_order(self) -> Generator: return self.root.traverse_post_order() def build(self, data: dict) -> Node: - + if not isinstance(data, dict): + raise SchemaError( + "Incompatible schema. Please run v2 schema migration" + ) table: str = data.get("table") schema: str = data.get("schema", DEFAULT_SCHEMA) key: Tuple[str, str] = (schema, table) diff --git a/pgsync/querybuilder.py b/pgsync/querybuilder.py index 0def2ed4..f90a8f53 100644 --- a/pgsync/querybuilder.py +++ b/pgsync/querybuilder.py @@ -27,21 +27,27 @@ def _build_filters( NB: assumption dictionary is an AND and list is an OR - - filters['book'] = [ - {'id': 1, 'uid': '001'}, - {'id': 2, 'uid': '002'} - ] + filters = { + 'book': [ + {'id': 1, 'uid': '001'}, + {'id': 2, 'uid': '002'}, + ], + 'city': [ + {'id': 1}, + {'id': 2}, + ], + } """ if filters is not None: if filters.get(node.table): - _filters: list = [] - for _filter in filters.get(node.table): + clause: list = [] + for values in filters.get(node.table): where: list = [] - for key, value in _filter.items(): - where.append(node.model.c[key] == value) - _filters.append(sa.and_(*where)) - return sa.or_(*_filters) + for column, value in values.items(): + where.append(node.model.c[column] == value) + # and clause is applied for composite primary keys + clause.append(sa.and_(*where)) + return sa.or_(*clause) def _json_build_object( self, columns: list, chunk_size: int = 100 @@ -58,7 +64,7 @@ def _json_build_object( i: int = 0 expression: sa.sql.elements.BinaryExpression = None while i < len(columns): - chunk = columns[i : i + chunk_size] + chunk: list = columns[i : i + chunk_size] if i == 0: expression = sa.cast( sa.func.JSON_BUILD_OBJECT(*chunk), diff --git a/pgsync/redisqueue.py b/pgsync/redisqueue.py index 9fd96b9c..a21d68d7 100644 --- a/pgsync/redisqueue.py +++ b/pgsync/redisqueue.py @@ -37,12 +37,13 @@ def qsize(self) -> int: def bulk_pop(self, chunk_size: Optional[int] = None) -> List[dict]: """Remove and return multiple items from the queue.""" chunk_size = chunk_size or REDIS_READ_CHUNK_SIZE - pipeline = self.__db.pipeline() - pipeline.lrange(self.key, 0, chunk_size - 1) - pipeline.ltrim(self.key, chunk_size, -1) - items: List = pipeline.execute() - logger.debug(f"bulk_pop nsize: {len(items[0])}") - return list(map(lambda value: json.loads(value), items[0])) + if self.qsize > 0: + pipeline = self.__db.pipeline() + pipeline.lrange(self.key, 0, chunk_size - 1) + pipeline.ltrim(self.key, chunk_size, -1) + items: List = pipeline.execute() + logger.debug(f"bulk_pop size: {len(items[0])}") + return list(map(lambda value: json.loads(value), items[0])) def bulk_push(self, items: List) -> None: """Push multiple items onto the queue.""" diff --git a/pgsync/settings.py b/pgsync/settings.py index d78a6df5..1768c005 100644 --- a/pgsync/settings.py +++ b/pgsync/settings.py @@ -28,6 +28,7 @@ QUERY_LITERAL_BINDS = env.bool("QUERY_LITERAL_BINDS", default=False) # db query chunk size (how many records to fetch at a time) QUERY_CHUNK_SIZE = env.int("QUERY_CHUNK_SIZE", default=10000) +FILTER_CHUNK_SIZE = env.int("FILTER_CHUNK_SIZE", default=5000) # replication slot cleanup interval (in secs) REPLICATION_SLOT_CLEANUP_INTERVAL = env.float( "REPLICATION_SLOT_CLEANUP_INTERVAL", @@ -37,6 +38,8 @@ SCHEMA = env.str("SCHEMA", default=None) USE_ASYNC = env.bool("USE_ASYNC", default=False) STREAM_RESULTS = env.bool("STREAM_RESULTS", default=True) +# db polling interval +POLL_INTERVAL = env.float("POLL_INTERVAL", default=0.1) # Elasticsearch: ELASTICSEARCH_API_KEY = env.str("ELASTICSEARCH_API_KEY", default=None) diff --git a/pgsync/singleton.py b/pgsync/singleton.py new file mode 100644 index 00000000..c455ec56 --- /dev/null +++ b/pgsync/singleton.py @@ -0,0 +1,20 @@ +"""PGSync Singleton.""" + +from typing import Tuple + + +class Singleton(type): + + _instances: dict = {} + + def __call__(cls, *args, **kwargs): + if not args: + return super(Singleton, cls).__call__(*args, **kwargs) + database: str = args[0]["database"] + index: str = args[0].get("index", database) + key: Tuple[str, str] = (database, index) + if key not in cls._instances: + cls._instances[key] = super(Singleton, cls).__call__( + *args, **kwargs + ) + return cls._instances[key] diff --git a/pgsync/sync.py b/pgsync/sync.py index 791630c2..6b82d9a3 100644 --- a/pgsync/sync.py +++ b/pgsync/sync.py @@ -19,7 +19,7 @@ from psycopg2 import OperationalError from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT -from . import __version__ +from . import __version__, settings from .base import Base, Payload from .constants import ( DELETE, @@ -34,6 +34,7 @@ from .exc import ( ForeignKeyError, InvalidSchemaError, + InvalidTGOPError, PrimaryKeyNotFoundError, RDSError, SchemaError, @@ -42,24 +43,15 @@ from .plugin import Plugins from .querybuilder import QueryBuilder from .redisqueue import RedisQueue -from .settings import ( - CHECKPOINT_PATH, - JOIN_QUERIES, - LOG_INTERVAL, - LOGICAL_SLOT_CHUNK_SIZE, - NTHREADS_POLLDB, - POLL_TIMEOUT, - REDIS_POLL_INTERVAL, - REDIS_WRITE_CHUNK_SIZE, - REPLICATION_SLOT_CLEANUP_INTERVAL, - USE_ASYNC, -) +from .singleton import Singleton from .transform import Transform from .utils import ( + chunks, compiled_query, + config_loader, exception, get_config, - load_config, + MutuallyExclusiveOption, show_settings, threaded, Timer, @@ -68,7 +60,7 @@ logger = logging.getLogger(__name__) -class Sync(Base): +class Sync(Base, metaclass=Singleton): """Main application class for Sync.""" def __init__( @@ -98,13 +90,16 @@ def __init__( self._plugins: Plugins = None self._truncate: bool = False self._checkpoint_file: str = os.path.join( - CHECKPOINT_PATH, f".{self.__name}" + settings.CHECKPOINT_PATH, f".{self.__name}" ) self.redis: RedisQueue = RedisQueue(self.__name) self.tree: Tree = Tree(self.models) + self.tree.build(self.nodes) if validate: self.validate(repl_slots=repl_slots) self.create_setting() + if self.plugins: + self._plugins: Plugins = Plugins("plugins", self.plugins) self.query_builder: QueryBuilder = QueryBuilder(verbose=verbose) self.count: dict = dict(xlog=0, db=0, redis=0) @@ -119,9 +114,6 @@ def validate(self, repl_slots: bool = True) -> None: self.connect() - if self.plugins: - self._plugins: Plugins = Plugins("plugins", self.plugins) - max_replication_slots: Optional[str] = self.pg_settings( "max_replication_slots" ) @@ -162,19 +154,18 @@ def validate(self, repl_slots: bool = True) -> None: ) # ensure the checkpoint dirpath is valid - if not os.path.exists(CHECKPOINT_PATH): + if not os.path.exists(settings.CHECKPOINT_PATH): raise RuntimeError( - f'Ensure the checkpoint directory exists "{CHECKPOINT_PATH}" ' - f"and is readable." + f"Ensure the checkpoint directory exists " + f'"{settings.CHECKPOINT_PATH}" and is readable.' ) - if not os.access(CHECKPOINT_PATH, os.W_OK | os.R_OK): + if not os.access(settings.CHECKPOINT_PATH, os.W_OK | os.R_OK): raise RuntimeError( - f'Ensure the checkpoint directory "{CHECKPOINT_PATH}" is ' - f"read/writable" + f'Ensure the checkpoint directory "{settings.CHECKPOINT_PATH}"' + f" is read/writable" ) - self.tree.build(self.nodes) self.tree.display() for node in self.tree.traverse_breadth_first(): @@ -256,7 +247,7 @@ def create_setting(self) -> None: def setup(self) -> None: """Create the database triggers and replication slot.""" - join_queries: bool = JOIN_QUERIES + join_queries: bool = settings.JOIN_QUERIES self.teardown(drop_view=False) for schema in self.schemas: @@ -297,7 +288,7 @@ def setup(self) -> None: def teardown(self, drop_view: bool = True) -> None: """Drop the database triggers and replication slot.""" - join_queries: bool = JOIN_QUERIES + join_queries: bool = settings.JOIN_QUERIES try: os.unlink(self._checkpoint_file) @@ -368,7 +359,7 @@ def logical_slot_changes( # by limiting to a smaller batch size. offset: int = 0 total: int = 0 - limit: int = LOGICAL_SLOT_CHUNK_SIZE + limit: int = settings.LOGICAL_SLOT_CHUNK_SIZE count: int = self.logical_slot_count_changes( self.__name, txmin=txmin, @@ -471,7 +462,6 @@ def _insert_op( raise # set the parent as the new entity that has changed - filters[node.parent.table] = [] foreign_keys = self.query_builder._get_foreign_keys( node.parent, node, @@ -492,7 +482,6 @@ def _insert_op( # handle case where we insert into a through table # set the parent as the new entity that has changed - filters[node.parent.table] = [] foreign_keys = self.query_builder.get_foreign_keys( node.parent, node, @@ -511,7 +500,6 @@ def _update_op( node: Node, filters: dict, payloads: List[dict], - extra: dict, ) -> dict: if node.is_root: @@ -580,13 +568,6 @@ def _update_op( for key, value in primary_fields.items(): fields[key].append(value) - if None in payload.new.values(): - extra["table"] = node.table - extra["column"] = key - - if None in payload.old.values(): - for key, value in primary_fields.items(): - fields[key].append(0) for doc_id in self.es._search(self.index, node.table, fields): where = {} @@ -665,9 +646,11 @@ def _delete_op( docs.append(doc) if docs: raise_on_exception: Optional[bool] = ( - False if USE_ASYNC else None + False if settings.USE_ASYNC else None + ) + raise_on_error: Optional[bool] = ( + False if settings.USE_ASYNC else None ) - raise_on_error: Optional[bool] = False if USE_ASYNC else None self.es.bulk( self.index, docs, @@ -771,7 +754,7 @@ def _payloads(self, payloads: List[Payload]) -> None: payload: Payload = payloads[0] if payload.tg_op not in TG_OP: logger.exception(f"Unknown tg_op {payload.tg_op}") - raise + raise InvalidTGOPError(f"Unknown tg_op {payload.tg_op}") # we might receive an event triggered for a table # that is not in the tree node. @@ -798,8 +781,12 @@ def _payloads(self, payloads: List[Payload]) -> None: logger.debug(f"tg_op: {payload.tg_op} table: {node.name}") - filters: dict = {node.table: [], self.tree.root.table: []} - extra: dict = {} + filters: dict = { + node.table: [], + self.tree.root.table: [], + } + if not node.is_root: + filters[node.parent.table] = [] if payload.tg_op == INSERT: @@ -810,12 +797,10 @@ def _payloads(self, payloads: List[Payload]) -> None: ) if payload.tg_op == UPDATE: - filters = self._update_op( node, filters, payloads, - extra, ) if payload.tg_op == DELETE: @@ -834,14 +819,62 @@ def _payloads(self, payloads: List[Payload]) -> None: # otherwise we would end up performing a full query # and sync the entire db! if any(filters.values()): - yield from self.sync(filters=filters, extra=extra) + """ + Filters are applied when an insert, update or delete operation + occurs. For a large table update, this normally results + in a large sql query with multiple OR clauses + + Filters is a dict of tables where each key is a list of id's + { + 'city': [ + {'id': '1'}, + {'id': '4'}, + {'id': '5'}, + ], + 'book': [ + {'id': '1'}, + {'id': '2'}, + {'id': '7'}, + ... + ] + } + """ + for l1 in chunks( + filters.get(self.tree.root.table), settings.FILTER_CHUNK_SIZE + ): + if filters.get(node.table): + for l2 in chunks( + filters.get(node.table), settings.FILTER_CHUNK_SIZE + ): + if not node.is_root and filters.get(node.parent.table): + for l3 in chunks( + filters.get(node.parent.table), + settings.FILTER_CHUNK_SIZE, + ): + yield from self.sync( + filters={ + self.tree.root.table: l1, + node.table: l2, + node.parent.table: l3, + }, + ) + else: + yield from self.sync( + filters={ + self.tree.root.table: l1, + node.table: l2, + }, + ) + else: + yield from self.sync( + filters={self.tree.root.table: l1}, + ) def sync( self, filters: Optional[dict] = None, txmin: Optional[int] = None, txmax: Optional[int] = None, - extra: Optional[dict] = None, ctid: Optional[dict] = None, ) -> Generator: self.query_builder.isouter = True @@ -883,12 +916,6 @@ def sync( row: dict = Transform.transform(row, self.nodes) row[META] = Transform.get_primary_keys(keys) - if extra: - if extra["table"] not in row[META]: - row[META][extra["table"]] = {} - if extra["column"] not in row[META][extra["table"]]: - row[META][extra["table"]][extra["column"]] = [] - row[META][extra["table"]][extra["column"]].append(0) if self.verbose: print(f"{(i+1)})") @@ -935,7 +962,7 @@ def checkpoint(self, value: Optional[str] = None) -> None: self._checkpoint: int = value def _poll_redis(self) -> None: - payloads: dict = self.redis.bulk_pop() + payloads: list = self.redis.bulk_pop() if payloads: logger.debug(f"poll_redis: {payloads}") self.count["redis"] += len(payloads) @@ -943,7 +970,7 @@ def _poll_redis(self) -> None: self.on_publish( list(map(lambda payload: Payload(**payload), payloads)) ) - time.sleep(REDIS_POLL_INTERVAL) + time.sleep(settings.REDIS_POLL_INTERVAL) @threaded @exception @@ -953,7 +980,7 @@ def poll_redis(self) -> None: self._poll_redis() async def _async_poll_redis(self) -> None: - payloads: dict = self.redis.bulk_pop() + payloads: list = self.redis.bulk_pop() if payloads: logger.debug(f"poll_redis: {payloads}") self.count["redis"] += len(payloads) @@ -961,7 +988,7 @@ async def _async_poll_redis(self) -> None: await self.async_on_publish( list(map(lambda payload: Payload(**payload), payloads)) ) - await asyncio.sleep(REDIS_POLL_INTERVAL) + await asyncio.sleep(settings.REDIS_POLL_INTERVAL) @exception async def async_poll_redis(self) -> None: @@ -987,8 +1014,12 @@ def poll_db(self) -> None: payloads: list = [] while True: - # NB: consider reducing POLL_TIMEOUT to increase throughout - if select.select([conn], [], [], POLL_TIMEOUT) == ([], [], []): + # NB: consider reducing POLL_TIMEOUT to increase throughput + if select.select([conn], [], [], settings.POLL_TIMEOUT) == ( + [], + [], + [], + ): # Catch any hanging items from the last poll if payloads: self.redis.bulk_push(payloads) @@ -1002,7 +1033,7 @@ def poll_db(self) -> None: os._exit(-1) while conn.notifies: - if len(payloads) >= REDIS_WRITE_CHUNK_SIZE: + if len(payloads) >= settings.REDIS_WRITE_CHUNK_SIZE: self.redis.bulk_push(payloads) payloads = [] notification: AnyStr = conn.notifies.pop(0) @@ -1122,13 +1153,13 @@ def truncate_slots(self) -> None: """Truncate the logical replication slot.""" while True: self._truncate_slots() - time.sleep(REPLICATION_SLOT_CLEANUP_INTERVAL) + time.sleep(settings.REPLICATION_SLOT_CLEANUP_INTERVAL) @exception async def async_truncate_slots(self) -> None: while True: self._truncate_slots() - await asyncio.sleep(REPLICATION_SLOT_CLEANUP_INTERVAL) + await asyncio.sleep(settings.REPLICATION_SLOT_CLEANUP_INTERVAL) def _truncate_slots(self) -> None: if self._truncate: @@ -1140,13 +1171,13 @@ def _truncate_slots(self) -> None: def status(self) -> None: while True: self._status(label="Sync") - time.sleep(LOG_INTERVAL) + time.sleep(settings.LOG_INTERVAL) @exception async def async_status(self) -> None: while True: self._status(label="Async") - await asyncio.sleep(LOG_INTERVAL) + await asyncio.sleep(settings.LOG_INTERVAL) def _status(self, label: str) -> None: sys.stdout.write( @@ -1169,7 +1200,7 @@ def receive(self, nthreads_polldb: int) -> None: 2. Pull everything so far and also replay replication logs. 3. Consume all changes from Redis. """ - if USE_ASYNC: + if settings.USE_ASYNC: self._conn = self.engine.connect().connection self._conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = self.conn.cursor() @@ -1207,7 +1238,21 @@ def receive(self, nthreads_polldb: int) -> None: help="Schema config", type=click.Path(exists=True), ) -@click.option("--daemon", "-d", is_flag=True, help="Run as a daemon") +@click.option( + "--daemon", + "-d", + is_flag=True, + help="Run as a daemon (Incompatible with --polling)", + cls=MutuallyExclusiveOption, + mutually_exclusive=["polling"], +) +@click.option( + "--polling", + is_flag=True, + help="Polling mode (Incompatible with -d)", + cls=MutuallyExclusiveOption, + mutually_exclusive=["daemon"], +) @click.option("--host", "-h", help="PG_HOST override") @click.option("--password", is_flag=True, help="Prompt for database password") @click.option("--port", "-p", help="PG_PORT override", type=int) @@ -1251,13 +1296,15 @@ def receive(self, nthreads_polldb: int) -> None: is_flag=True, default=False, help="Analyse database", + cls=MutuallyExclusiveOption, + mutually_exclusive=["daemon", "polling"], ) @click.option( "--nthreads_polldb", "-n", help="Number of threads to spawn for poll db", type=int, - default=NTHREADS_POLLDB, + default=settings.NTHREADS_POLLDB, ) def main( config, @@ -1272,6 +1319,7 @@ def main( version, analyze, nthreads_polldb, + polling, ): """Main application syncer.""" if version: @@ -1301,17 +1349,27 @@ def main( with Timer(): - for document in load_config(config): - sync: Sync = Sync(document, verbose=verbose, **kwargs) + if analyze: - if analyze: + for document in config_loader(config): + sync: Sync = Sync(document, verbose=verbose, **kwargs) sync.analyze() - continue - sync.pull() + elif polling: + + while True: + for document in config_loader(config): + sync: Sync = Sync(document, verbose=verbose, **kwargs) + sync.pull() + time.sleep(settings.POLL_INTERVAL) + + else: - if daemon: - sync.receive(nthreads_polldb) + for document in config_loader(config): + sync: Sync = Sync(document, verbose=verbose, **kwargs) + sync.pull() + if daemon: + sync.receive(nthreads_polldb) if __name__ == "__main__": diff --git a/pgsync/utils.py b/pgsync/utils.py index dc609dee..d5814204 100644 --- a/pgsync/utils.py +++ b/pgsync/utils.py @@ -7,9 +7,10 @@ from datetime import timedelta from string import Template from time import time -from typing import Callable, Generator, Optional +from typing import Callable, Generator, Optional, Set from urllib.parse import ParseResult, urlparse +import click import sqlalchemy as sa import sqlparse @@ -19,10 +20,16 @@ logger = logging.getLogger(__name__) -HIGHLIGHT_START = "\033[4m" +HIGHLIGHT_BEGIN = "\033[4m" HIGHLIGHT_END = "\033[0m:" +def chunks(value: list, size: int) -> list: + """Yield successive n-sized chunks from l""" + for i in range(0, len(value), size): + yield value[i : i + size] + + def timeit(func: Callable): def timed(*args, **kwargs): since: float = time() @@ -95,20 +102,20 @@ def get_redacted_url(result: ParseResult) -> ParseResult: def show_settings(schema: Optional[str] = None) -> None: """Show settings.""" - logger.info(f"{HIGHLIGHT_START}Settings{HIGHLIGHT_END}") + logger.info(f"{HIGHLIGHT_BEGIN}Settings{HIGHLIGHT_END}") logger.info(f'{"Schema":<10s}: {schema or SCHEMA}') logger.info("-" * 65) - logger.info(f"{HIGHLIGHT_START}Checkpoint{HIGHLIGHT_END}") + logger.info(f"{HIGHLIGHT_BEGIN}Checkpoint{HIGHLIGHT_END}") logger.info(f"Path: {CHECKPOINT_PATH}") - logger.info(f"{HIGHLIGHT_START}Postgres{HIGHLIGHT_END}") + logger.info(f"{HIGHLIGHT_BEGIN}Postgres{HIGHLIGHT_END}") result: ParseResult = get_redacted_url( urlparse(get_postgres_url("postgres")) ) logger.info(f"URL: {result.geturl()}") result = get_redacted_url(urlparse(get_elasticsearch_url())) - logger.info(f"{HIGHLIGHT_START}Elasticsearch{HIGHLIGHT_END}") + logger.info(f"{HIGHLIGHT_BEGIN}Elasticsearch{HIGHLIGHT_END}") logger.info(f"URL: {result.geturl()}") - logger.info(f"{HIGHLIGHT_START}Redis{HIGHLIGHT_END}") + logger.info(f"{HIGHLIGHT_BEGIN}Redis{HIGHLIGHT_END}") result = get_redacted_url(urlparse(get_redis_url())) logger.info(f"URL: {result.geturl()}") logger.info("-" * 65) @@ -128,7 +135,7 @@ def get_config(config: Optional[str] = None) -> str: return config -def load_config(config: str) -> Generator: +def config_loader(config: str) -> Generator: with open(config, "r") as documents: for document in json.load(documents): for key, value in document.items(): @@ -160,3 +167,28 @@ def compiled_query( sys.stdout.write(f"{query}\n") sys.stdout.write("-" * 79) sys.stdout.write("\n") + + +class MutuallyExclusiveOption(click.Option): + def __init__(self, *args, **kwargs): + self.mutually_exclusive: Set = set( + kwargs.pop("mutually_exclusive", []) + ) + help: str = kwargs.get("help", "") + if self.mutually_exclusive: + kwargs["help"] = help + ( + f" NOTE: This argument is mutually exclusive with " + f" arguments: [{', '.join(self.mutually_exclusive)}]." + ) + super(MutuallyExclusiveOption, self).__init__(*args, **kwargs) + + def handle_parse_result(self, ctx, opts, args): + if self.mutually_exclusive.intersection(opts) and self.name in opts: + raise click.UsageError( + f"Illegal usage: `{self.name}` is mutually exclusive with " + f"arguments `{', '.join(self.mutually_exclusive)}`." + ) + + return super(MutuallyExclusiveOption, self).handle_parse_result( + ctx, opts, args + ) diff --git a/pyproject.toml b/pyproject.toml index 4c31d15d..65e68881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [tool.black] line-length = 79 -target-version = ['py37', 'py38', 'py39', 'py310'] \ No newline at end of file +target-version = ['py37', 'py38', 'py39', 'py310', 'py311'] \ No newline at end of file diff --git a/requirements/base.in b/requirements/base.in index 136f2603..a25ded9c 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -12,3 +12,7 @@ redis requests-aws4auth sqlalchemy sqlparse + +# pin these libs because latest flake8 does not allow newer versions of importlib-metadata https://github.com/PyCQA/flake8/issues/1522 +importlib-metadata==4.2.0 +virtualenv==20.16.2 \ No newline at end of file diff --git a/requirements/dev.txt b/requirements/dev.txt index cf716d28..2fa67b98 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # # pip-compile --output-file=requirements/dev.txt requirements/dev.in # @@ -10,13 +10,13 @@ attrs==22.1.0 # via pytest black==22.10.0 # via -r requirements/base.in -boto3==1.24.94 +boto3==1.26.22 # via -r requirements/base.in -botocore==1.27.94 +botocore==1.29.22 # via # boto3 # s3transfer -build==0.8.0 +build==0.9.0 # via pip-tools bump2version==1.0.1 # via bumpversion @@ -39,8 +39,6 @@ coverage[toml]==6.5.0 # via # -r requirements/dev.in # pytest-cov -deprecated==1.2.13 - # via redis distlib==0.3.6 # via virtualenv elasticsearch==7.13.4 @@ -51,12 +49,15 @@ elasticsearch-dsl==7.4.0 # via -r requirements/base.in environs==9.5.0 # via -r requirements/base.in -faker==15.1.1 +exceptiongroup==1.0.4 + # via pytest +faker==15.3.4 # via -r requirements/base.in filelock==3.8.0 # via virtualenv flake8==5.0.4 # via + # -r requirements/test.in # flake8-debugger # flake8-docstrings # flake8-isort @@ -65,7 +66,7 @@ flake8-debugger==4.1.2 # via -r requirements/test.in flake8-docstrings==1.6.0 # via -r requirements/test.in -flake8-isort==5.0.0 +flake8-isort==5.0.3 # via -r requirements/test.in flake8-print==5.0.0 # via -r requirements/test.in @@ -73,12 +74,14 @@ flake8-todo==0.7 # via -r requirements/test.in freezegun==1.2.2 # via -r requirements/test.in -greenlet==1.1.3.post0 +greenlet==2.0.1 # via sqlalchemy -identify==2.5.6 +identify==2.5.9 # via pre-commit idna==3.4 # via requests +importlib-metadata==4.2.0 + # via -r requirements/base.in iniconfig==1.1.1 # via pytest isort==5.10.1 @@ -87,7 +90,7 @@ jmespath==1.0.1 # via # boto3 # botocore -marshmallow==3.18.0 +marshmallow==3.19.0 # via environs mccabe==0.7.0 # via flake8 @@ -104,13 +107,13 @@ packaging==21.3 # pytest # pytest-sugar # redis -pathspec==0.10.1 +pathspec==0.10.2 # via black pep517==0.13.0 # via build -pip-tools==6.9.0 +pip-tools==6.11.0 # via -r requirements/dev.in -platformdirs==2.5.2 +platformdirs==2.5.4 # via # black # virtualenv @@ -118,10 +121,8 @@ pluggy==1.0.0 # via pytest pre-commit==2.20.0 # via -r requirements/dev.in -psycopg2-binary==2.9.4 +psycopg2-binary==2.9.5 # via -r requirements/base.in -py==1.11.0 - # via pytest pycodestyle==2.9.1 # via # flake8 @@ -134,7 +135,7 @@ pyflakes==2.5.0 # via flake8 pyparsing==3.0.9 # via packaging -pytest==7.1.3 +pytest==7.2.0 # via # -r requirements/test.in # pytest-cov @@ -146,7 +147,7 @@ pytest-mock==3.10.0 # via -r requirements/test.in pytest-runner==6.0.0 # via -r requirements/test.in -pytest-sugar==0.9.5 +pytest-sugar==0.9.6 # via -r requirements/test.in python-dateutil==2.8.2 # via @@ -158,7 +159,7 @@ python-dotenv==0.21.0 # via environs pyyaml==6.0 # via pre-commit -redis==4.3.4 +redis==4.3.5 # via -r requirements/base.in requests==2.28.1 # via requests-aws4auth @@ -173,11 +174,11 @@ six==1.16.0 # requests-aws4auth snowballstemmer==2.2.0 # via pydocstyle -sqlalchemy==1.4.42 +sqlalchemy==1.4.44 # via -r requirements/base.in sqlparse==0.4.3 # via -r requirements/base.in -termcolor==2.0.1 +termcolor==2.1.1 # via pytest-sugar toml==0.10.2 # via pre-commit @@ -189,17 +190,19 @@ tomli==2.0.1 # pytest typing-extensions==4.4.0 # via black -urllib3==1.26.12 +urllib3==1.26.13 # via # botocore # elasticsearch # requests -virtualenv==20.16.5 - # via pre-commit -wheel==0.37.1 +virtualenv==20.16.2 + # via + # -r requirements/base.in + # pre-commit +wheel==0.38.4 # via pip-tools -wrapt==1.14.1 - # via deprecated +zipp==3.11.0 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/requirements/prod.txt b/requirements/prod.txt index 5266947a..02ce09ba 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # # pip-compile --output-file=requirements/prod.txt requirements/prod.in # @@ -8,9 +8,9 @@ async-timeout==4.0.2 # via redis black==22.10.0 # via -r requirements/base.in -boto3==1.24.94 +boto3==1.26.22 # via -r requirements/base.in -botocore==1.27.94 +botocore==1.29.22 # via # boto3 # s3transfer @@ -28,8 +28,8 @@ click==8.1.3 # via # -r requirements/base.in # black -deprecated==1.2.13 - # via redis +distlib==0.3.6 + # via virtualenv elasticsearch==7.13.4 # via # -r requirements/base.in @@ -38,31 +38,37 @@ elasticsearch-dsl==7.4.0 # via -r requirements/base.in environs==9.5.0 # via -r requirements/base.in -faker==15.1.1 +faker==15.3.4 # via -r requirements/base.in -greenlet==1.1.3.post0 +filelock==3.8.0 + # via virtualenv +greenlet==2.0.1 # via sqlalchemy idna==3.4 # via requests +importlib-metadata==4.2.0 + # via -r requirements/base.in jmespath==1.0.1 # via # boto3 # botocore -marshmallow==3.18.0 +marshmallow==3.19.0 # via environs mypy-extensions==0.4.3 # via black -newrelic==8.2.1 +newrelic==8.4.0 # via -r requirements/prod.in packaging==21.3 # via # marshmallow # redis -pathspec==0.10.1 - # via black -platformdirs==2.5.2 +pathspec==0.10.2 # via black -psycopg2-binary==2.9.4 +platformdirs==2.5.4 + # via + # black + # virtualenv +psycopg2-binary==2.9.5 # via -r requirements/base.in pyparsing==3.0.9 # via packaging @@ -73,7 +79,7 @@ python-dateutil==2.8.2 # faker python-dotenv==0.21.0 # via environs -redis==4.3.4 +redis==4.3.5 # via -r requirements/base.in requests==2.28.1 # via requests-aws4auth @@ -86,7 +92,7 @@ six==1.16.0 # elasticsearch-dsl # python-dateutil # requests-aws4auth -sqlalchemy==1.4.42 +sqlalchemy==1.4.44 # via -r requirements/base.in sqlparse==0.4.3 # via -r requirements/base.in @@ -94,10 +100,12 @@ tomli==2.0.1 # via black typing-extensions==4.4.0 # via black -urllib3==1.26.12 +urllib3==1.26.13 # via # botocore # elasticsearch # requests -wrapt==1.14.1 - # via deprecated +virtualenv==20.16.2 + # via -r requirements/base.in +zipp==3.11.0 + # via importlib-metadata diff --git a/requirements/test.in b/requirements/test.in index ac2aa235..f1f707d3 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,5 +1,6 @@ -r base.in +flake8==5.0.4 flake8_docstrings flake8-debugger flake8-print diff --git a/requirements/test.txt b/requirements/test.txt index 95a72b6f..817e7dad 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # # pip-compile --output-file=requirements/test.txt requirements/test.in # @@ -10,9 +10,9 @@ attrs==22.1.0 # via pytest black==22.10.0 # via -r requirements/base.in -boto3==1.24.94 +boto3==1.26.22 # via -r requirements/base.in -botocore==1.27.94 +botocore==1.29.22 # via # boto3 # s3transfer @@ -32,8 +32,8 @@ click==8.1.3 # black coverage[toml]==6.5.0 # via pytest-cov -deprecated==1.2.13 - # via redis +distlib==0.3.6 + # via virtualenv elasticsearch==7.13.4 # via # -r requirements/base.in @@ -42,10 +42,15 @@ elasticsearch-dsl==7.4.0 # via -r requirements/base.in environs==9.5.0 # via -r requirements/base.in -faker==15.1.1 +exceptiongroup==1.0.4 + # via pytest +faker==15.3.4 # via -r requirements/base.in +filelock==3.8.0 + # via virtualenv flake8==5.0.4 # via + # -r requirements/test.in # flake8-debugger # flake8-docstrings # flake8-isort @@ -54,7 +59,7 @@ flake8-debugger==4.1.2 # via -r requirements/test.in flake8-docstrings==1.6.0 # via -r requirements/test.in -flake8-isort==5.0.0 +flake8-isort==5.0.3 # via -r requirements/test.in flake8-print==5.0.0 # via -r requirements/test.in @@ -62,10 +67,12 @@ flake8-todo==0.7 # via -r requirements/test.in freezegun==1.2.2 # via -r requirements/test.in -greenlet==1.1.3.post0 +greenlet==2.0.1 # via sqlalchemy idna==3.4 # via requests +importlib-metadata==4.2.0 + # via -r requirements/base.in iniconfig==1.1.1 # via pytest isort==5.10.1 @@ -74,7 +81,7 @@ jmespath==1.0.1 # via # boto3 # botocore -marshmallow==3.18.0 +marshmallow==3.19.0 # via environs mccabe==0.7.0 # via flake8 @@ -88,16 +95,16 @@ packaging==21.3 # pytest # pytest-sugar # redis -pathspec==0.10.1 - # via black -platformdirs==2.5.2 +pathspec==0.10.2 # via black +platformdirs==2.5.4 + # via + # black + # virtualenv pluggy==1.0.0 # via pytest -psycopg2-binary==2.9.4 +psycopg2-binary==2.9.5 # via -r requirements/base.in -py==1.11.0 - # via pytest pycodestyle==2.9.1 # via # flake8 @@ -110,7 +117,7 @@ pyflakes==2.5.0 # via flake8 pyparsing==3.0.9 # via packaging -pytest==7.1.3 +pytest==7.2.0 # via # -r requirements/test.in # pytest-cov @@ -122,7 +129,7 @@ pytest-mock==3.10.0 # via -r requirements/test.in pytest-runner==6.0.0 # via -r requirements/test.in -pytest-sugar==0.9.5 +pytest-sugar==0.9.6 # via -r requirements/test.in python-dateutil==2.8.2 # via @@ -132,7 +139,7 @@ python-dateutil==2.8.2 # freezegun python-dotenv==0.21.0 # via environs -redis==4.3.4 +redis==4.3.5 # via -r requirements/base.in requests==2.28.1 # via requests-aws4auth @@ -147,11 +154,11 @@ six==1.16.0 # requests-aws4auth snowballstemmer==2.2.0 # via pydocstyle -sqlalchemy==1.4.42 +sqlalchemy==1.4.44 # via -r requirements/base.in sqlparse==0.4.3 # via -r requirements/base.in -termcolor==2.0.1 +termcolor==2.1.1 # via pytest-sugar tomli==2.0.1 # via @@ -160,10 +167,12 @@ tomli==2.0.1 # pytest typing-extensions==4.4.0 # via black -urllib3==1.26.12 +urllib3==1.26.13 # via # botocore # elasticsearch # requests -wrapt==1.14.1 - # via deprecated +virtualenv==20.16.2 + # via -r requirements/base.in +zipp==3.11.0 + # via importlib-metadata diff --git a/scripts/del_redis.sh b/scripts/del_redis.sh new file mode 100755 index 00000000..b670f86b --- /dev/null +++ b/scripts/del_redis.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +for key in `echo 'KEYS user*' | redis-cli --scan --pattern '*' | awk '{print $1}'` + do echo DEL $key +done | redis-cli diff --git a/setup.py b/setup.py index 6a288e4e..b52d2232 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ def get_version() -> str: "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", diff --git a/tests/conftest.py b/tests/conftest.py index 2dbc6d7d..09d100b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ from pgsync.base import Base, create_database, drop_database from pgsync.constants import DEFAULT_SCHEMA +from pgsync.singleton import Singleton from pgsync.sync import Sync from pgsync.urls import get_postgres_url @@ -55,6 +56,7 @@ def sync(): _sync = Sync( { "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, } ) @@ -66,6 +68,7 @@ def sync(): _sync.engine.connect().close() _sync.engine.dispose() _sync.session.close() + Singleton._instances = {} def pytest_addoption(parser): diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_base.py b/tests/test_base.py index 6ff23c74..5da4cf36 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -110,7 +110,24 @@ def test_tables(self, connection): def test_indices(self, connection): pg_base = Base(connection.engine.url.database) - assert pg_base.indices("book", "public") == [] + assert pg_base.indices("contact_item", "public") == [ + { + "name": "contact_item_contact_id_key", + "unique": True, + "column_names": ["contact_id"], + "include_columns": [], + "duplicates_constraint": "contact_item_contact_id_key", + "dialect_options": {"postgresql_include": []}, + }, + { + "name": "contact_item_name_key", + "unique": True, + "column_names": ["name"], + "include_columns": [], + "duplicates_constraint": "contact_item_name_key", + "dialect_options": {"postgresql_include": []}, + }, + ] @patch("pgsync.base.logger") @patch("pgsync.sync.Base.execute") diff --git a/tests/test_helper.py b/tests/test_helper.py index 25059217..6a4c34b8 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -16,12 +16,13 @@ class TestHelper(object): def test_teardown_with_drop_db(self, mock_sync, mock_config, mock_logger): mock_config.return_value = "tests/fixtures/schema.json" mock_sync.truncate_schemas.return_value = None - with patch("pgsync.helper.drop_database") as mock_db: - helper.teardown(drop_db=True, config="fixtures/schema.json") - assert mock_db.call_args_list == [ - call(ANY), - call(ANY), - ] + with patch("pgsync.helper.database_exists", return_value=True): + with patch("pgsync.helper.drop_database") as mock_db: + helper.teardown(drop_db=True, config="fixtures/schema.json") + assert mock_db.call_args_list == [ + call(ANY), + call(ANY), + ] mock_logger.warning.assert_not_called() @@ -30,10 +31,15 @@ def test_teardown_with_drop_db(self, mock_sync, mock_config, mock_logger): @patch("pgsync.helper.get_config") def test_teardown_without_drop_db(self, mock_config, mock_logger, mock_es): mock_config.return_value = "tests/fixtures/schema.json" - with patch("pgsync.sync.Sync") as mock_sync: - mock_sync.truncate_schemas.side_effect = sa.exc.OperationalError - helper.teardown(drop_db=False, config="fixtures/schema.json") - assert mock_logger.warning.call_args_list == [ - call(ANY), - call(ANY), - ] + + with patch("pgsync.node.Tree.build", return_value=None): + with patch("pgsync.sync.Sync") as mock_sync: + mock_sync.tree.build.return_value = None + mock_sync.truncate_schemas.side_effect = ( + sa.exc.OperationalError + ) + helper.teardown(drop_db=False, config="fixtures/schema.json") + assert mock_logger.warning.call_args_list == [ + call(ANY), + call(ANY), + ] diff --git a/tests/test_redisqueue.py b/tests/test_redisqueue.py index 2d4ec7e3..dd89cae3 100644 --- a/tests/test_redisqueue.py +++ b/tests/test_redisqueue.py @@ -69,11 +69,11 @@ def test_bulk_pop(self, mock_logger): queue.delete() queue.bulk_push([1, 2]) items = queue.bulk_pop() - mock_logger.debug.assert_called_once_with("bulk_pop nsize: 2") + mock_logger.debug.assert_called_once_with("bulk_pop size: 2") assert items == [1, 2] queue.bulk_push([3, 4, 5]) items = queue.bulk_pop() - mock_logger.debug.assert_any_call("bulk_pop nsize: 3") + mock_logger.debug.assert_any_call("bulk_pop size: 3") assert items == [3, 4, 5] queue.delete() diff --git a/tests/test_sync.py b/tests/test_sync.py index a91cf14f..d0f5c914 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,4 +1,5 @@ """Sync tests.""" +import importlib import os from collections import namedtuple from typing import List @@ -7,10 +8,17 @@ from mock import ANY, call, patch from pgsync.base import Base, Payload -from pgsync.exc import PrimaryKeyNotFoundError, RDSError, SchemaError +from pgsync.exc import ( + InvalidTGOPError, + PrimaryKeyNotFoundError, + RDSError, + SchemaError, +) from pgsync.node import Node -from pgsync.settings import LOGICAL_SLOT_CHUNK_SIZE, REDIS_POLL_INTERVAL -from pgsync.sync import Sync +from pgsync.singleton import Singleton +from pgsync.sync import settings, Sync + +from .testing_utils import override_env_var ROW = namedtuple("Row", ["data", "xid"]) @@ -20,6 +28,7 @@ def sync(): _sync = Sync( { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title", "description"], @@ -40,6 +49,7 @@ def sync(): }, }, ) + Singleton._instances = {} yield _sync _sync.logical_slot_get_changes( f"{_sync.database}_testdb", @@ -69,7 +79,7 @@ def test_logical_slot_changes(self, mock_logger, sync): txmin=None, txmax=None, upto_nchanges=None, - limit=LOGICAL_SLOT_CHUNK_SIZE, + limit=settings.LOGICAL_SLOT_CHUNK_SIZE, offset=0, ) mock_sync.assert_not_called() @@ -86,7 +96,7 @@ def test_logical_slot_changes(self, mock_logger, sync): txmin=None, txmax=None, upto_nchanges=None, - limit=LOGICAL_SLOT_CHUNK_SIZE, + limit=settings.LOGICAL_SLOT_CHUNK_SIZE, offset=0, ) mock_sync.assert_not_called() @@ -115,7 +125,7 @@ def test_logical_slot_changes(self, mock_logger, sync): txmin=None, txmax=None, upto_nchanges=None, - limit=LOGICAL_SLOT_CHUNK_SIZE, + limit=settings.LOGICAL_SLOT_CHUNK_SIZE, offset=0, ) mock_get.assert_called_once() @@ -157,10 +167,12 @@ def test_logical_slot_changes(self, mock_logger, sync): @patch("pgsync.sync.ElasticHelper") def test_sync_validate(self, mock_es): + with pytest.raises(SchemaError) as excinfo: Sync( document={ "index": "testdb", + "database": "testdb", "nodes": ["foo"], }, verbose=False, @@ -174,6 +186,7 @@ def test_sync_validate(self, mock_es): Sync( document={ "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, "plugins": ["Hero"], }, @@ -208,6 +221,7 @@ def _side_effect(*args, **kwargs): Sync( document={ "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, "plugins": ["Hero"], }, @@ -225,6 +239,7 @@ def _side_effect(*args, **kwargs): Sync( document={ "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, "plugins": ["Hero"], }, @@ -242,6 +257,7 @@ def _side_effect(*args, **kwargs): Sync( document={ "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, "plugins": ["Hero"], }, @@ -258,6 +274,7 @@ def _side_effect(*args, **kwargs): Sync( document={ "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, "plugins": ["Hero"], }, @@ -272,6 +289,7 @@ def _side_effect(*args, **kwargs): Sync( document={ "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, "plugins": ["Hero"], }, @@ -283,6 +301,7 @@ def _side_effect(*args, **kwargs): Sync( document={ "index": "testdb", + "database": "testdb", "nodes": {"table": "book"}, "plugins": ["Hero"], }, @@ -458,9 +477,8 @@ def test__update_op(self, sync, connection): new={"isbn": "aa1"}, ) ] - extra: dict = {} assert sync.es.doc_count == 0 - _filters = sync._update_op(node, filters, payloads, extra) + _filters = sync._update_op(node, filters, payloads) sync.es.refresh("testdb") assert _filters == {"book": [{"isbn": "aa1"}]} assert sync.es.doc_count == 1 @@ -558,20 +576,6 @@ def test__truncate_op(self, mock_es, sync, connection): assert _filters == {"book": []} def test__payload(self, sync): - with patch("pgsync.sync.logger") as mock_logger: - with pytest.raises(RuntimeError): - for _ in sync._payloads( - [ - Payload( - tg_op="XXX", - table="book", - old={"id": 1}, - ), - ] - ): - pass - mock_logger.exception.assert_called_once_with("Unknown tg_op XXX") - with patch("pgsync.sync.Sync._insert_op") as mock_op: with patch("pgsync.sync.logger") as mock_logger: for _ in sync._payloads( @@ -810,12 +814,81 @@ def test_payloads(self, sync): Payload( tg_op="INSERT", table="book", + new={"isbn": "002"}, + schema="public", + ), + ] + for _ in sync._payloads(payloads): + pass + + def test_payloads_invalid_tg_op(self, mocker, sync): + payloads: List[Payload] = [ + Payload( + tg_op="FOO", + table="book", old={"isbn": "001"}, new={"isbn": "002"}, schema="public", ), ] - sync._payloads(payloads) + with patch("pgsync.sync.logger") as mock_logger: + with pytest.raises(InvalidTGOPError): + for _ in sync._payloads(payloads): + pass + mock_logger.exception.assert_called_once_with("Unknown tg_op FOO") + + def test_payloads_in_batches(self, mocker, sync): + # inserting a root node + payloads: List[Payload] = [ + Payload( + tg_op="INSERT", + table="book", + new={"isbn": "002"}, + schema="public", + ) + ] * 20 + with patch("pgsync.sync.Sync.sync") as mock_sync: + with override_env_var(FILTER_CHUNK_SIZE="4"): + importlib.reload(settings) + for _ in sync._payloads(payloads): + pass + assert mock_sync.call_count == 25 + assert mock_sync.call_args_list[-1] == call( + filters={ + "book": [ + {"isbn": "002"}, + {"isbn": "002"}, + {"isbn": "002"}, + {"isbn": "002"}, + ] + }, + ) + + # updating a child table + payloads: List[Payload] = [ + Payload( + tg_op="UPDATE", + table="publisher", + new={"id": 1, "name": "foo"}, + old={"id": 1}, + schema="public", + ) + ] + filters: dict = { + "book": [ + {"isbn": "001"}, + ], + "publisher": [ + {"id": 1}, + ], + } + with patch("pgsync.sync.Sync._update_op", return_value=filters): + with patch("pgsync.sync.Sync.sync") as mock_sync: + with override_env_var(FILTER_CHUNK_SIZE="1"): + importlib.reload(settings) + for _ in sync._payloads(payloads): + pass + mock_sync.assert_called_once_with(filters=filters) @patch("pgsync.sync.compiled_query") def test_sync(self, mock_compiled_query, sync): @@ -837,5 +910,5 @@ def test_poll_redis( mock_on_publish.assert_called_once_with([ANY, ANY]) mock_refresh_views.assert_called_once() mock_logger.debug.assert_called_once_with(f"poll_redis: {items}") - mock_time.sleep.assert_called_once_with(REDIS_POLL_INTERVAL) + mock_time.sleep.assert_called_once_with(settings.REDIS_POLL_INTERVAL) assert sync.count["redis"] == 2 diff --git a/tests/test_sync_nested_children.py b/tests/test_sync_nested_children.py index 8ea60329..e65f2783 100644 --- a/tests/test_sync_nested_children.py +++ b/tests/test_sync_nested_children.py @@ -6,9 +6,10 @@ from pgsync.base import subtransactions from pgsync.settings import NTHREADS_POLLDB +from pgsync.singleton import Singleton from pgsync.sync import Sync -from .helpers.utils import assert_resync_empty, noop, search, sort_list +from .testing_utils import assert_resync_empty, noop, search, sort_list @pytest.mark.usefixtures("table_creator") @@ -230,6 +231,8 @@ def data( upto_nchanges=None, ) + Singleton._instances = {} + yield ( books, authors, @@ -674,6 +677,7 @@ def test_insert_root( document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } @@ -798,6 +802,7 @@ def test_insert_root( def test_update_root(self, data, nodes, book_cls): document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } # 1. sync first to add the initial document @@ -918,6 +923,7 @@ def test_delete_root( ): document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } # 1. sync first to add the initial document @@ -1164,6 +1170,7 @@ def test_insert_through_child_op( document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } @@ -1393,6 +1400,7 @@ def test_update_through_child_op( # update a new through child with op document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } @@ -1620,6 +1628,7 @@ def test_delete_through_child_op(self, sync, data, nodes, book_author_cls): # delete a new through child with op document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } @@ -1816,6 +1825,7 @@ def test_insert_nonthrough_child_noop( document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } @@ -1845,6 +1855,7 @@ def test_update_nonthrough_child_noop(self, data, nodes, shelf_cls): # update a new non-through child with noop document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } @@ -1878,6 +1889,7 @@ def test_delete_nonthrough_child_noop(self, data, nodes, shelf_cls): # delete a new non-through child with noop document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } @@ -2031,6 +2043,7 @@ def test_insert_deep_nested_nonthrough_child_noop( document = { "index": "testdb", + "database": "testdb", "nodes": nodes, } # sync first to add the initial document diff --git a/tests/test_sync_root.py b/tests/test_sync_root.py index a7ff887e..a408cdc1 100644 --- a/tests/test_sync_root.py +++ b/tests/test_sync_root.py @@ -10,9 +10,10 @@ TableNotInNodeError, ) from pgsync.settings import NTHREADS_POLLDB +from pgsync.singleton import Singleton from pgsync.sync import Sync -from .helpers.utils import assert_resync_empty, noop, search, sort_list +from .testing_utils import assert_resync_empty, noop, search, sort_list @pytest.mark.usefixtures("table_creator") @@ -60,6 +61,7 @@ def data(self, sync, book_cls, publisher_cls): f"{sync.database}_testdb", upto_nchanges=None, ) + Singleton._instances = {} yield books @@ -414,6 +416,7 @@ def test_update_primary_key_non_concurrent(self, data, book_cls): """ document = { "index": "testdb", + "database": "testdb", "nodes": {"table": "book", "columns": ["isbn", "title"]}, } sync = Sync(document) @@ -456,6 +459,7 @@ def test_update_primary_key_concurrent(self, data, book_cls): """Test sync updates primary_key and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": {"table": "book", "columns": ["isbn", "title"]}, } sync = Sync(document) @@ -521,6 +525,7 @@ def test_insert_non_concurrent(self, data, book_cls): """Test sync insert and then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": {"table": "book", "columns": ["isbn", "title"]}, } sync = Sync(document) @@ -561,6 +566,7 @@ def test_update_non_concurrent(self, data, book_cls): """Test sync update and then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": {"table": "book", "columns": ["isbn", "title"]}, } sync = Sync(document) @@ -601,6 +607,7 @@ def test_update_concurrent(self, data, book_cls): """Test sync update and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": {"table": "book", "columns": ["isbn", "title"]}, } sync = Sync(document) @@ -662,6 +669,7 @@ def test_delete_concurrent(self, data, book_cls): """Test sync delete and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": {"table": "book", "columns": ["isbn", "title"]}, } sync = Sync(document) diff --git a/tests/test_sync_single_child_fk_on_child.py b/tests/test_sync_single_child_fk_on_child.py index a9b5fb39..474a72e6 100644 --- a/tests/test_sync_single_child_fk_on_child.py +++ b/tests/test_sync_single_child_fk_on_child.py @@ -14,9 +14,10 @@ ) from pgsync.node import Tree from pgsync.settings import NTHREADS_POLLDB +from pgsync.singleton import Singleton from pgsync.sync import Sync -from .helpers.utils import assert_resync_empty, noop, search, sort_list +from .testing_utils import assert_resync_empty, noop, search, sort_list @pytest.mark.usefixtures("table_creator") @@ -62,6 +63,7 @@ def data(self, sync, book_cls, rating_cls): f"{sync.database}_testdb", upto_nchanges=None, ) + Singleton._instances = {} yield ( books, @@ -589,6 +591,7 @@ def test_update_primary_key_non_concurrent( """Test sync updates primary_key then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -687,6 +690,7 @@ def test_update_primary_key_concurrent(self, data, book_cls, rating_cls): """Test sync updates primary_key and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -805,6 +809,7 @@ def test_insert_non_concurrent(self, data, book_cls, rating_cls): """Test sync insert and then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -907,6 +912,7 @@ def test_update_non_primary_key_non_concurrent( """Test sync update and then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -996,6 +1002,7 @@ def test_update_non_primary_key_concurrent( """Test sync update and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -1100,6 +1107,7 @@ def test_delete_concurrent(self, data, book_cls, rating_cls): """Test sync delete and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], diff --git a/tests/test_sync_single_child_fk_on_parent.py b/tests/test_sync_single_child_fk_on_parent.py index 89d34e07..d4ea23c5 100644 --- a/tests/test_sync_single_child_fk_on_parent.py +++ b/tests/test_sync_single_child_fk_on_parent.py @@ -14,9 +14,10 @@ ) from pgsync.node import Tree from pgsync.settings import NTHREADS_POLLDB +from pgsync.singleton import Singleton from pgsync.sync import Sync -from .helpers.utils import assert_resync_empty, noop, search, sort_list +from .testing_utils import assert_resync_empty, noop, search, sort_list @pytest.mark.usefixtures("table_creator") @@ -56,6 +57,7 @@ def data(self, sync, book_cls, publisher_cls): f"{sync.database}_testdb", upto_nchanges=None, ) + Singleton._instances = {} yield books @@ -595,6 +597,7 @@ def test_update_primary_key_non_concurrent( """Test sync updates primary_key then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -690,6 +693,7 @@ def test_update_primary_key_concurrent( """Test sync updates primary_key and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -802,6 +806,7 @@ def test_insert_non_concurrent(self, data, book_cls, publisher_cls): """Test sync insert and then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -901,6 +906,7 @@ def test_update_non_primary_key_non_concurrent( """Test sync update and then sync in non-concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -990,6 +996,7 @@ def test_update_non_primary_key_concurrent( """Test sync update and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], @@ -1094,6 +1101,7 @@ def test_delete_concurrent(self, data, book_cls, publisher_cls): """Test sync delete and then sync in concurrent mode.""" document = { "index": "testdb", + "database": "testdb", "nodes": { "table": "book", "columns": ["isbn", "title"], diff --git a/tests/test_unique_behaviour.py b/tests/test_unique_behaviour.py index 855055d8..6b10c607 100644 --- a/tests/test_unique_behaviour.py +++ b/tests/test_unique_behaviour.py @@ -5,7 +5,7 @@ from pgsync.base import subtransactions -from .helpers.utils import assert_resync_empty, sort_list +from .testing_utils import assert_resync_empty, sort_list @pytest.mark.usefixtures("table_creator") diff --git a/tests/test_utils.py b/tests/test_utils.py index ec6fbc99..7dc66e44 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,10 +12,10 @@ from pgsync.urls import get_elasticsearch_url, get_postgres_url, get_redis_url from pgsync.utils import ( compiled_query, + config_loader, exception, get_config, get_redacted_url, - load_config, show_settings, threaded, timeit, @@ -38,11 +38,11 @@ def test_get_config(self): config: str = get_config("tests/fixtures/schema.json") assert config == "tests/fixtures/schema.json" - def test_load_config(self): + def test_config_loader(self): os.environ["foo"] = "mydb" os.environ["bar"] = "myindex" config: str = get_config("tests/fixtures/schema.json") - data = load_config(config) + data = config_loader(config) assert next(data) == { "database": "fakedb", "index": "fake_index", diff --git a/tests/helpers/utils.py b/tests/testing_utils.py similarity index 61% rename from tests/helpers/utils.py rename to tests/testing_utils.py index 6346f528..578e0abb 100644 --- a/tests/helpers/utils.py +++ b/tests/testing_utils.py @@ -1,4 +1,7 @@ """Test helper methods.""" + +import os +from contextlib import contextmanager from typing import Optional from pgsync.node import Node @@ -61,3 +64,32 @@ def sort_list(data: dict) -> dict: else: result[key] = value return result + + +@contextmanager +def override_env_var(**kwargs): + """Set the given value of the given environment variable or + unset if value is None. + """ + original_values: dict = {} + envs_to_delete: list = [] + for env_name, env_value in kwargs.items(): + try: + original_values[env_name] = os.environ[env_name] + if env_value is None: + del os.environ[env_name] + except KeyError: + # Env var did not previouslt exist. + # If we are not setting it, we need to remove it. + if env_value is not None: + envs_to_delete.append(env_name) + + if env_value is not None: + os.environ[env_name] = env_value + + yield + + for env_name in envs_to_delete: + del os.environ[env_name] + for env_name, original_env_value in original_values.items(): + os.environ[env_name] = original_env_value