From dd50b74cbb39ca8980e3f22a1991c89eef3dc808 Mon Sep 17 00:00:00 2001 From: Luis Coimbra <31017690+luisggc@users.noreply.github.com> Date: Thu, 1 Jan 2026 23:10:17 -0300 Subject: [PATCH 1/3] Refactor discovery and execution modules --- sqlcheck/cli/commands/plan.py | 2 +- sqlcheck/cli/commands/run.py | 8 ++- sqlcheck/cli/common.py | 115 +++--------------------------- sqlcheck/cli/connections.py | 29 ++++++++ sqlcheck/cli/discovery.py | 19 +++++ sqlcheck/cli/output.py | 74 +++++++++++++++++++ sqlcheck/connectors/__init__.py | 5 ++ sqlcheck/connectors/sqlalchemy.py | 100 ++++++++++++++++++++++++++ sqlcheck/db_connector.py | 98 +++---------------------- sqlcheck/discovery.py | 32 +++++++++ sqlcheck/execution.py | 68 ++++++++++++++++++ sqlcheck/runner.py | 98 +------------------------ 12 files changed, 355 insertions(+), 293 deletions(-) create mode 100644 sqlcheck/cli/connections.py create mode 100644 sqlcheck/cli/discovery.py create mode 100644 sqlcheck/cli/output.py create mode 100644 sqlcheck/connectors/__init__.py create mode 100644 sqlcheck/connectors/sqlalchemy.py create mode 100644 sqlcheck/discovery.py create mode 100644 sqlcheck/execution.py diff --git a/sqlcheck/cli/commands/plan.py b/sqlcheck/cli/commands/plan.py index 1ee0a6e..ca2e24c 100644 --- a/sqlcheck/cli/commands/plan.py +++ b/sqlcheck/cli/commands/plan.py @@ -5,7 +5,7 @@ import typer -from sqlcheck.cli.common import discover_cases +from sqlcheck.cli.discovery import discover_cases from sqlcheck.reports import build_plan_payload, write_case_plan diff --git a/sqlcheck/cli/commands/run.py b/sqlcheck/cli/commands/run.py index a162228..0c53d85 100644 --- a/sqlcheck/cli/commands/run.py +++ b/sqlcheck/cli/commands/run.py @@ -4,7 +4,9 @@ import typer -from sqlcheck.cli.common import build_adapter, discover_cases, print_results +from sqlcheck.cli.connections import build_connector +from sqlcheck.cli.discovery import discover_cases +from sqlcheck.cli.output import print_results from sqlcheck.function_registry import default_registry from sqlcheck.plugins import load_plugins from sqlcheck.reports import write_json, write_junit, write_plan @@ -42,9 +44,9 @@ def run( if plugin: load_plugins(plugin, registry) - adapter = build_adapter(connection) + connector = build_connector(connection) - results = run_cases(cases, adapter, registry, workers=workers) + results = run_cases(cases, connector, registry, workers=workers) print_results(results, engine=connection) diff --git a/sqlcheck/cli/common.py b/sqlcheck/cli/common.py index f9d588a..bde2e79 100644 --- a/sqlcheck/cli/common.py +++ b/sqlcheck/cli/common.py @@ -1,104 +1,11 @@ -from __future__ import annotations - -import os -import re -from pathlib import Path - -import typer -from rich import box -from rich.console import Console -from rich.panel import Panel -from rich.table import Table - -from sqlcheck.db_connector import DBConnector, SQLAlchemyConnector -from sqlcheck.models import TestCase, TestResult -from sqlcheck.runner import build_test_case, discover_files - - -def discover_cases(target: Path, pattern: str) -> list[TestCase]: - paths = discover_files(target, pattern) - if not paths: - print("No test files found.") - raise typer.Exit(code=1) - return [build_test_case(path) for path in paths] - - -def _connection_env_var(name: str) -> str: - slug = re.sub(r"[^A-Za-z0-9]+", "_", name).strip("_").upper() - return f"SQLCHECK_CONN_{slug}" - - -def resolve_connection_uri(name: str) -> str: - env_var = _connection_env_var(name) - value = os.getenv(env_var) - if not value: - raise ValueError(f"Missing connection URI. Set {env_var}.") - return value - - -def build_adapter(connection: str) -> DBConnector: - connection_uri = resolve_connection_uri(connection) - return SQLAlchemyConnector(connection_uri=connection_uri) - - -def print_results(results: list[TestResult], engine: str | None = None) -> None: - total = len(results) - failures = [result for result in results if not result.success] - passed = total - len(failures) - console = Console() - - header = "SQLCheck" - if engine: - header += f" ({engine})" - header += f" — {total} tests, {passed} passed" - if failures: - header += f", {len(failures)} failed" - - if failures: - console.print("[bold]Failures:[/bold]") - for result in failures: - console.print( - f"[red]FAIL[/red] {result.case.metadata.name} [dim]{result.case.path}[/dim]" - ) - for func_result in result.function_results: - if not func_result.success: - message = func_result.message or "Expectation failed" - console.print(f" {message}") - if result.output.stderr: - console.print( - Panel( - result.output.stderr.strip(), - title="STDERR", - border_style="red", - ) - ) - if result.output.stdout: - console.print( - Panel( - result.output.stdout.strip(), - title="STDOUT", - border_style="yellow", - ) - ) - console.print() - - table = Table(box=box.ASCII, show_header=True, header_style="bold") - table.add_column("STATUS", style="bold") - table.add_column("TEST") - table.add_column("DURATION", justify="right") - table.add_column("PATH") - - for result in results: - duration = f"{result.status.duration_s:.2f}s" - status = "PASS" if result.success else "FAIL" - status_style = "green" if result.success else "red" - table.add_row( - f"[{status_style}]{status}[/{status_style}]", - result.case.metadata.name, - duration, - str(result.case.path), - ) - - console.print(table) - console.print() - console.print(f"[bold]{header}[/bold]") +from sqlcheck.cli.connections import build_adapter, build_connector, resolve_connection_uri +from sqlcheck.cli.discovery import discover_cases +from sqlcheck.cli.output import print_results + +__all__ = [ + "build_adapter", + "build_connector", + "discover_cases", + "print_results", + "resolve_connection_uri", +] diff --git a/sqlcheck/cli/connections.py b/sqlcheck/cli/connections.py new file mode 100644 index 0000000..08f375f --- /dev/null +++ b/sqlcheck/cli/connections.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import os +import re + +from sqlcheck.db_connector import DBConnector, SQLAlchemyConnector + + +def _connection_env_var(name: str) -> str: + slug = re.sub(r"[^A-Za-z0-9]+", "_", name).strip("_").upper() + return f"SQLCHECK_CONN_{slug}" + + +def resolve_connection_uri(name: str) -> str: + env_var = _connection_env_var(name) + value = os.getenv(env_var) + if not value: + raise ValueError(f"Missing connection URI. Set {env_var}.") + return value + + +def build_connector(connection: str) -> DBConnector: + connection_uri = resolve_connection_uri(connection) + return SQLAlchemyConnector(connection_uri=connection_uri) + + +build_adapter = build_connector + +__all__ = ["build_adapter", "build_connector", "resolve_connection_uri"] diff --git a/sqlcheck/cli/discovery.py b/sqlcheck/cli/discovery.py new file mode 100644 index 0000000..4c098be --- /dev/null +++ b/sqlcheck/cli/discovery.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from pathlib import Path + +import typer + +from sqlcheck.discovery import build_test_case, discover_files +from sqlcheck.models import TestCase + + +def discover_cases(target: Path, pattern: str) -> list[TestCase]: + paths = discover_files(target, pattern) + if not paths: + print("No test files found.") + raise typer.Exit(code=1) + return [build_test_case(path) for path in paths] + + +__all__ = ["discover_cases"] diff --git a/sqlcheck/cli/output.py b/sqlcheck/cli/output.py new file mode 100644 index 0000000..ad8492c --- /dev/null +++ b/sqlcheck/cli/output.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from rich import box +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +from sqlcheck.models import TestResult + + +def print_results(results: list[TestResult], engine: str | None = None) -> None: + total = len(results) + failures = [result for result in results if not result.success] + passed = total - len(failures) + console = Console() + + header = "SQLCheck" + if engine: + header += f" ({engine})" + header += f" — {total} tests, {passed} passed" + if failures: + header += f", {len(failures)} failed" + + if failures: + console.print("[bold]Failures:[/bold]") + for result in failures: + console.print( + f"[red]FAIL[/red] {result.case.metadata.name} [dim]{result.case.path}[/dim]" + ) + for func_result in result.function_results: + if not func_result.success: + message = func_result.message or "Expectation failed" + console.print(f" {message}") + if result.output.stderr: + console.print( + Panel( + result.output.stderr.strip(), + title="STDERR", + border_style="red", + ) + ) + if result.output.stdout: + console.print( + Panel( + result.output.stdout.strip(), + title="STDOUT", + border_style="yellow", + ) + ) + console.print() + + table = Table(box=box.ASCII, show_header=True, header_style="bold") + table.add_column("STATUS", style="bold") + table.add_column("TEST") + table.add_column("DURATION", justify="right") + table.add_column("PATH") + + for result in results: + duration = f"{result.status.duration_s:.2f}s" + status = "PASS" if result.success else "FAIL" + status_style = "green" if result.success else "red" + table.add_row( + f"[{status_style}]{status}[/{status_style}]", + result.case.metadata.name, + duration, + str(result.case.path), + ) + + console.print(table) + console.print() + console.print(f"[bold]{header}[/bold]") + + +__all__ = ["print_results"] diff --git a/sqlcheck/connectors/__init__.py b/sqlcheck/connectors/__init__.py new file mode 100644 index 0000000..ca726c2 --- /dev/null +++ b/sqlcheck/connectors/__init__.py @@ -0,0 +1,5 @@ +"""Database connector implementations.""" + +from sqlcheck.connectors.sqlalchemy import SQLAlchemyConnector + +__all__ = ["SQLAlchemyConnector"] diff --git a/sqlcheck/connectors/sqlalchemy.py b/sqlcheck/connectors/sqlalchemy.py new file mode 100644 index 0000000..33b72c8 --- /dev/null +++ b/sqlcheck/connectors/sqlalchemy.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import time +from contextlib import contextmanager +from typing import Iterator +from urllib.parse import urlparse + +from sqlalchemy import create_engine +from sqlalchemy.exc import NoSuchModuleError, SQLAlchemyError + +from sqlcheck.db_connector import CommandDBConnector, DBSession, ExecutionResult +from sqlcheck.models import ExecutionOutput, ExecutionStatus, SQLParsed + + +class SQLAlchemyConnector(CommandDBConnector): + name = "sqlalchemy" + + def __init__(self, connection_uri: str) -> None: + self.connection_uri = connection_uri + try: + self.engine = create_engine(connection_uri) + except NoSuchModuleError as exc: + dialect = _dialect_from_uri(connection_uri) + hint = _driver_hint(dialect) + message = ( + f"Missing SQLAlchemy driver for '{dialect}'. {hint} " + f"Original error: {exc}" + ) + raise ValueError(message) from exc + + def execute(self, sql_parsed: SQLParsed, timeout: float | None = None) -> ExecutionResult: + with self.engine.connect() as connection: + return self._execute_with_connection(connection, sql_parsed, timeout) + + @contextmanager + def open_session(self) -> Iterator[DBSession]: + with self.engine.connect() as connection: + def _execute(sql_parsed: SQLParsed, timeout: float | None = None) -> ExecutionResult: + return self._execute_with_connection(connection, sql_parsed, timeout) + + yield DBSession(_execute) + + def _execute_with_connection( + self, + connection: object, + sql_parsed: SQLParsed, + timeout: float | None = None, + ) -> ExecutionResult: + start = time.perf_counter() + stdout = "" + stderr = "" + rows: list[list[object]] = [] + returncode = 0 + success = True + try: + exec_connection = connection + if timeout is not None: + exec_connection = connection.execution_options(timeout=timeout) + with exec_connection.begin(): + statements = sql_parsed.statements + if not statements: + statements = [] + for statement in statements or []: + result = exec_connection.exec_driver_sql(statement.text) + if result.returns_rows: + rows = [list(row) for row in result.fetchall()] + if not statements and sql_parsed.source.strip(): + result = exec_connection.exec_driver_sql(sql_parsed.source) + if result.returns_rows: + rows = [list(row) for row in result.fetchall()] + except SQLAlchemyError as exc: + success = False + returncode = 1 + stderr = str(exc) + duration = time.perf_counter() - start + status = ExecutionStatus(success=success, returncode=returncode, duration_s=duration) + output = ExecutionOutput(stdout=stdout, stderr=stderr, rows=rows) + return ExecutionResult(status=status, output=output) + + +def _dialect_from_uri(connection_uri: str) -> str: + scheme = urlparse(connection_uri).scheme + return scheme.split("+", maxsplit=1)[0] if scheme else "unknown" + + +def _driver_hint(dialect: str) -> str: + hints = { + "snowflake": "Install it with: pip install snowflake-sqlalchemy", + "duckdb": "Install it with: pip install duckdb duckdb-engine", + "postgresql": "Install it with: pip install psycopg[binary]", + "mysql": "Install it with: pip install pymysql", + "databricks": "Install it with: pip install databricks-sql-connector", + "mssql": "Install it with: pip install pyodbc", + "oracle": "Install it with: pip install oracledb", + } + message = hints.get( + dialect, + "Install the database-specific SQLAlchemy dialect/driver for this URI.", + ) + return f"{message} See https://docs.sqlalchemy.org/en/20/dialects/ for details." diff --git a/sqlcheck/db_connector.py b/sqlcheck/db_connector.py index ec9b286..6edf8ac 100644 --- a/sqlcheck/db_connector.py +++ b/sqlcheck/db_connector.py @@ -1,13 +1,8 @@ from __future__ import annotations -import time from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, Iterator -from urllib.parse import urlparse - -from sqlalchemy import create_engine -from sqlalchemy.exc import NoSuchModuleError, SQLAlchemyError from sqlcheck.models import ExecutionOutput, ExecutionStatus, SQLParsed @@ -38,89 +33,12 @@ class CommandDBConnector(DBConnector): pass -class SQLAlchemyConnector(CommandDBConnector): - name = "sqlalchemy" - - def __init__(self, connection_uri: str) -> None: - self.connection_uri = connection_uri - try: - self.engine = create_engine(connection_uri) - except NoSuchModuleError as exc: - dialect = _dialect_from_uri(connection_uri) - hint = _driver_hint(dialect) - message = ( - f"Missing SQLAlchemy driver for '{dialect}'. {hint} " - f"Original error: {exc}" - ) - raise ValueError(message) from exc - - def execute(self, sql_parsed: SQLParsed, timeout: float | None = None) -> ExecutionResult: - with self.engine.connect() as connection: - return self._execute_with_connection(connection, sql_parsed, timeout) - - @contextmanager - def open_session(self) -> Iterator[DBSession]: - with self.engine.connect() as connection: - def _execute(sql_parsed: SQLParsed, timeout: float | None = None) -> ExecutionResult: - return self._execute_with_connection(connection, sql_parsed, timeout) - - yield DBSession(_execute) - - def _execute_with_connection( - self, - connection: object, - sql_parsed: SQLParsed, - timeout: float | None = None, - ) -> ExecutionResult: - start = time.perf_counter() - stdout = "" - stderr = "" - rows: list[list[object]] = [] - returncode = 0 - success = True - try: - exec_connection = connection - if timeout is not None: - exec_connection = connection.execution_options(timeout=timeout) - with exec_connection.begin(): - statements = sql_parsed.statements - if not statements: - statements = [] - for statement in statements or []: - result = exec_connection.exec_driver_sql(statement.text) - if result.returns_rows: - rows = [list(row) for row in result.fetchall()] - if not statements and sql_parsed.source.strip(): - result = exec_connection.exec_driver_sql(sql_parsed.source) - if result.returns_rows: - rows = [list(row) for row in result.fetchall()] - except SQLAlchemyError as exc: - success = False - returncode = 1 - stderr = str(exc) - duration = time.perf_counter() - start - status = ExecutionStatus(success=success, returncode=returncode, duration_s=duration) - output = ExecutionOutput(stdout=stdout, stderr=stderr, rows=rows) - return ExecutionResult(status=status, output=output) - - -def _dialect_from_uri(connection_uri: str) -> str: - scheme = urlparse(connection_uri).scheme - return scheme.split("+", maxsplit=1)[0] if scheme else "unknown" - +from sqlcheck.connectors.sqlalchemy import SQLAlchemyConnector -def _driver_hint(dialect: str) -> str: - hints = { - "snowflake": "Install it with: pip install snowflake-sqlalchemy", - "duckdb": "Install it with: pip install duckdb duckdb-engine", - "postgresql": "Install it with: pip install psycopg[binary]", - "mysql": "Install it with: pip install pymysql", - "databricks": "Install it with: pip install databricks-sql-connector", - "mssql": "Install it with: pip install pyodbc", - "oracle": "Install it with: pip install oracledb", - } - message = hints.get( - dialect, - "Install the database-specific SQLAlchemy dialect/driver for this URI.", - ) - return f"{message} See https://docs.sqlalchemy.org/en/20/dialects/ for details." +__all__ = [ + "CommandDBConnector", + "DBConnector", + "DBSession", + "ExecutionResult", + "SQLAlchemyConnector", +] diff --git a/sqlcheck/discovery.py b/sqlcheck/discovery.py new file mode 100644 index 0000000..9c5b942 --- /dev/null +++ b/sqlcheck/discovery.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from pathlib import Path + +from sqlcheck.models import DirectiveCall, TestCase, TestMetadata +from sqlcheck.parser import ParsedFile, parse_file, summarize_directives + + +def discover_files(target: Path, pattern: str) -> list[Path]: + if target.is_file(): + return [target] + return sorted(target.rglob(pattern)) + + +def build_test_case(path: Path) -> TestCase: + parsed: ParsedFile = parse_file(path) + directives = parsed.directives or [DirectiveCall(name="success", args=(), kwargs={}, raw="")] + summary = summarize_directives(directives) + metadata = TestMetadata( + name=summary["name"] or path.stem, + tags=summary["tags"], + serial=summary["serial"], + timeout=summary["timeout"], + retries=summary["retries"], + ) + return TestCase( + path=path, + sql_parsed=parsed.sql_parsed, + directives=directives, + segments=parsed.segments, + metadata=metadata, + ) diff --git a/sqlcheck/execution.py b/sqlcheck/execution.py new file mode 100644 index 0000000..01d833d --- /dev/null +++ b/sqlcheck/execution.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import concurrent.futures +from typing import Iterable + +from sqlcheck.db_connector import DBConnector, ExecutionResult +from sqlcheck.function_context import execution_context +from sqlcheck.function_registry import FunctionRegistry +from sqlcheck.models import FunctionResult, TestCase, TestResult + + +def run_test_case(case: TestCase, adapter: DBConnector, registry: FunctionRegistry) -> TestResult: + execution: ExecutionResult | None = None + function_results: list[FunctionResult] = [] + with adapter.open_session() as session: + for segment in case.segments: + for attempt in range(case.metadata.retries + 1): + execution = session.execute(segment.sql_parsed, timeout=case.metadata.timeout) + if execution.status.success or attempt >= case.metadata.retries: + break + if execution is None: + raise RuntimeError("Execution never started") + status = execution.status + output = execution.output + exit_on_failure = segment.directive.kwargs.get("exit_on_failure", True) + func = registry.resolve(segment.directive.name) + kwargs = { + key: value + for key, value in segment.directive.kwargs.items() + if key != "exit_on_failure" + } + with execution_context(segment.sql_parsed, status, output): + result = func(*segment.directive.args, **kwargs) + function_results.append(result) + if exit_on_failure and not result.success: + break + if execution is None: + raise RuntimeError("Execution never started") + return TestResult( + case=case, + status=execution.status, + output=execution.output, + function_results=function_results, + ) + + +def run_cases( + cases: Iterable[TestCase], + adapter: DBConnector, + registry: FunctionRegistry, + workers: int, +) -> list[TestResult]: + parallel_cases = [case for case in cases if not case.metadata.serial] + serial_cases = [case for case in cases if case.metadata.serial] + results: list[TestResult] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + future_map = { + executor.submit(run_test_case, case, adapter, registry): case + for case in parallel_cases + } + for future in concurrent.futures.as_completed(future_map): + results.append(future.result()) + + for case in serial_cases: + results.append(run_test_case(case, adapter, registry)) + + return results diff --git a/sqlcheck/runner.py b/sqlcheck/runner.py index c8bef36..c4a6a05 100644 --- a/sqlcheck/runner.py +++ b/sqlcheck/runner.py @@ -1,96 +1,4 @@ -from __future__ import annotations +from sqlcheck.discovery import build_test_case, discover_files +from sqlcheck.execution import run_cases, run_test_case -import concurrent.futures -from pathlib import Path -from typing import Iterable - -from sqlcheck.db_connector import DBConnector, ExecutionResult -from sqlcheck.function_context import execution_context -from sqlcheck.function_registry import FunctionRegistry -from sqlcheck.models import ( - DirectiveCall, - FunctionResult, - TestCase, - TestMetadata, - TestResult, -) -from sqlcheck.parser import ParsedFile, parse_file, summarize_directives - - -def discover_files(target: Path, pattern: str) -> list[Path]: - if target.is_file(): - return [target] - return sorted(target.rglob(pattern)) - - -def build_test_case(path: Path) -> TestCase: - parsed: ParsedFile = parse_file(path) - directives = parsed.directives or [DirectiveCall(name="success", args=(), kwargs={}, raw="")] - summary = summarize_directives(directives) - metadata = TestMetadata( - name=summary["name"] or path.stem, - tags=summary["tags"], - serial=summary["serial"], - timeout=summary["timeout"], - retries=summary["retries"], - ) - return TestCase( - path=path, - sql_parsed=parsed.sql_parsed, - directives=directives, - segments=parsed.segments, - metadata=metadata, - ) - - -def run_test_case(case: TestCase, adapter: DBConnector, registry: FunctionRegistry) -> TestResult: - execution: ExecutionResult | None = None - function_results: list[FunctionResult] = [] - with adapter.open_session() as session: - for segment in case.segments: - for attempt in range(case.metadata.retries + 1): - execution = session.execute(segment.sql_parsed, timeout=case.metadata.timeout) - if execution.status.success or attempt >= case.metadata.retries: - break - if execution is None: - raise RuntimeError("Execution never started") - status = execution.status - output = execution.output - exit_on_failure = segment.directive.kwargs.get("exit_on_failure", True) - func = registry.resolve(segment.directive.name) - kwargs = { - key: value - for key, value in segment.directive.kwargs.items() - if key != "exit_on_failure" - } - with execution_context(segment.sql_parsed, status, output): - result = func(*segment.directive.args, **kwargs) - function_results.append(result) - if exit_on_failure and not result.success: - break - if execution is None: - raise RuntimeError("Execution never started") - return TestResult(case=case, status=execution.status, output=execution.output, function_results=function_results) - - -def run_cases( - cases: Iterable[TestCase], - adapter: DBConnector, - registry: FunctionRegistry, - workers: int, -) -> list[TestResult]: - parallel_cases = [case for case in cases if not case.metadata.serial] - serial_cases = [case for case in cases if case.metadata.serial] - results: list[TestResult] = [] - - with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: - future_map = { - executor.submit(run_test_case, case, adapter, registry): case for case in parallel_cases - } - for future in concurrent.futures.as_completed(future_map): - results.append(future.result()) - - for case in serial_cases: - results.append(run_test_case(case, adapter, registry)) - - return results +__all__ = ["build_test_case", "discover_files", "run_cases", "run_test_case"] From 8764fc844deca6fa214ff9c077a60bb3cd233976 Mon Sep 17 00:00:00 2001 From: Luis Coimbra <31017690+luisggc@users.noreply.github.com> Date: Fri, 2 Jan 2026 14:37:02 -0300 Subject: [PATCH 2/3] Simplify connector naming in helpers --- sqlcheck/cli/common.py | 3 +-- sqlcheck/cli/connections.py | 5 +---- sqlcheck/execution.py | 14 +++++++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sqlcheck/cli/common.py b/sqlcheck/cli/common.py index bde2e79..6a672ed 100644 --- a/sqlcheck/cli/common.py +++ b/sqlcheck/cli/common.py @@ -1,9 +1,8 @@ -from sqlcheck.cli.connections import build_adapter, build_connector, resolve_connection_uri +from sqlcheck.cli.connections import build_connector, resolve_connection_uri from sqlcheck.cli.discovery import discover_cases from sqlcheck.cli.output import print_results __all__ = [ - "build_adapter", "build_connector", "discover_cases", "print_results", diff --git a/sqlcheck/cli/connections.py b/sqlcheck/cli/connections.py index 08f375f..d7c6db9 100644 --- a/sqlcheck/cli/connections.py +++ b/sqlcheck/cli/connections.py @@ -23,7 +23,4 @@ def build_connector(connection: str) -> DBConnector: connection_uri = resolve_connection_uri(connection) return SQLAlchemyConnector(connection_uri=connection_uri) - -build_adapter = build_connector - -__all__ = ["build_adapter", "build_connector", "resolve_connection_uri"] +__all__ = ["build_connector", "resolve_connection_uri"] diff --git a/sqlcheck/execution.py b/sqlcheck/execution.py index 01d833d..7b25f5a 100644 --- a/sqlcheck/execution.py +++ b/sqlcheck/execution.py @@ -9,10 +9,14 @@ from sqlcheck.models import FunctionResult, TestCase, TestResult -def run_test_case(case: TestCase, adapter: DBConnector, registry: FunctionRegistry) -> TestResult: +def run_test_case( + case: TestCase, + connector: DBConnector, + registry: FunctionRegistry, +) -> TestResult: execution: ExecutionResult | None = None function_results: list[FunctionResult] = [] - with adapter.open_session() as session: + with connector.open_session() as session: for segment in case.segments: for attempt in range(case.metadata.retries + 1): execution = session.execute(segment.sql_parsed, timeout=case.metadata.timeout) @@ -46,7 +50,7 @@ def run_test_case(case: TestCase, adapter: DBConnector, registry: FunctionRegist def run_cases( cases: Iterable[TestCase], - adapter: DBConnector, + connector: DBConnector, registry: FunctionRegistry, workers: int, ) -> list[TestResult]: @@ -56,13 +60,13 @@ def run_cases( with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: future_map = { - executor.submit(run_test_case, case, adapter, registry): case + executor.submit(run_test_case, case, connector, registry): case for case in parallel_cases } for future in concurrent.futures.as_completed(future_map): results.append(future.result()) for case in serial_cases: - results.append(run_test_case(case, adapter, registry)) + results.append(run_test_case(case, connector, registry)) return results From 5f8b06d99fe00f399acb92cea2952907372c6cc0 Mon Sep 17 00:00:00 2001 From: Luis Coimbra <31017690+luisggc@users.noreply.github.com> Date: Fri, 2 Jan 2026 14:47:11 -0300 Subject: [PATCH 3/3] Update connector driver hints to use extras --- sqlcheck/connectors/sqlalchemy.py | 12 +++++------- tests/test_sqlalchemy_adapter.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sqlcheck/connectors/sqlalchemy.py b/sqlcheck/connectors/sqlalchemy.py index 33b72c8..803eb7e 100644 --- a/sqlcheck/connectors/sqlalchemy.py +++ b/sqlcheck/connectors/sqlalchemy.py @@ -85,13 +85,11 @@ def _dialect_from_uri(connection_uri: str) -> str: def _driver_hint(dialect: str) -> str: hints = { - "snowflake": "Install it with: pip install snowflake-sqlalchemy", - "duckdb": "Install it with: pip install duckdb duckdb-engine", - "postgresql": "Install it with: pip install psycopg[binary]", - "mysql": "Install it with: pip install pymysql", - "databricks": "Install it with: pip install databricks-sql-connector", - "mssql": "Install it with: pip install pyodbc", - "oracle": "Install it with: pip install oracledb", + "snowflake": "Install the optional dependency with: pip install sqlcheck[snowflake]", + "duckdb": "Install the optional dependency with: pip install sqlcheck[duckdb]", + "postgresql": "Install the optional dependency with: pip install sqlcheck[postgres]", + "mysql": "Install the optional dependency with: pip install sqlcheck[mysql]", + "databricks": "Install the optional dependency with: pip install sqlcheck[databricks]", } message = hints.get( dialect, diff --git a/tests/test_sqlalchemy_adapter.py b/tests/test_sqlalchemy_adapter.py index d43bc46..9a09d21 100644 --- a/tests/test_sqlalchemy_adapter.py +++ b/tests/test_sqlalchemy_adapter.py @@ -10,7 +10,7 @@ def test_missing_driver_reports_install_hint(self) -> None: message = str(context.exception) self.assertIn("snowflake", message) - self.assertIn("snowflake-sqlalchemy", message) + self.assertIn("sqlcheck[snowflake]", message) self.assertIn("Can't load plugin", message)