Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlcheck/cli/commands/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import typer

from sqlcheck.cli.common import discover_cases
from sqlcheck.cli.discovery import discover_cases
from sqlcheck.reports import build_plan_payload, write_case_plan


Expand Down
8 changes: 5 additions & 3 deletions sqlcheck/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import typer

from sqlcheck.cli.common import build_adapter, discover_cases, print_results
from sqlcheck.cli.connections import build_connector
from sqlcheck.cli.discovery import discover_cases
from sqlcheck.cli.output import print_results
from sqlcheck.function_registry import default_registry
from sqlcheck.plugins import load_plugins
from sqlcheck.reports import write_json, write_junit, write_plan
Expand Down Expand Up @@ -42,9 +44,9 @@ def run(
if plugin:
load_plugins(plugin, registry)

adapter = build_adapter(connection)
connector = build_connector(connection)

results = run_cases(cases, adapter, registry, workers=workers)
results = run_cases(cases, connector, registry, workers=workers)

print_results(results, engine=connection)

Expand Down
114 changes: 10 additions & 104 deletions sqlcheck/cli/common.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,10 @@
from __future__ import annotations

import os
import re
from pathlib import Path

import typer
from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.table import Table

from sqlcheck.db_connector import DBConnector, SQLAlchemyConnector
from sqlcheck.models import TestCase, TestResult
from sqlcheck.runner import build_test_case, discover_files


def discover_cases(target: Path, pattern: str) -> list[TestCase]:
paths = discover_files(target, pattern)
if not paths:
print("No test files found.")
raise typer.Exit(code=1)
return [build_test_case(path) for path in paths]


def _connection_env_var(name: str) -> str:
slug = re.sub(r"[^A-Za-z0-9]+", "_", name).strip("_").upper()
return f"SQLCHECK_CONN_{slug}"


def resolve_connection_uri(name: str) -> str:
env_var = _connection_env_var(name)
value = os.getenv(env_var)
if not value:
raise ValueError(f"Missing connection URI. Set {env_var}.")
return value


def build_adapter(connection: str) -> DBConnector:
connection_uri = resolve_connection_uri(connection)
return SQLAlchemyConnector(connection_uri=connection_uri)


def print_results(results: list[TestResult], engine: str | None = None) -> None:
total = len(results)
failures = [result for result in results if not result.success]
passed = total - len(failures)
console = Console()

header = "SQLCheck"
if engine:
header += f" ({engine})"
header += f" — {total} tests, {passed} passed"
if failures:
header += f", {len(failures)} failed"

if failures:
console.print("[bold]Failures:[/bold]")
for result in failures:
console.print(
f"[red]FAIL[/red] {result.case.metadata.name} [dim]{result.case.path}[/dim]"
)
for func_result in result.function_results:
if not func_result.success:
message = func_result.message or "Expectation failed"
console.print(f" {message}")
if result.output.stderr:
console.print(
Panel(
result.output.stderr.strip(),
title="STDERR",
border_style="red",
)
)
if result.output.stdout:
console.print(
Panel(
result.output.stdout.strip(),
title="STDOUT",
border_style="yellow",
)
)
console.print()

table = Table(box=box.ASCII, show_header=True, header_style="bold")
table.add_column("STATUS", style="bold")
table.add_column("TEST")
table.add_column("DURATION", justify="right")
table.add_column("PATH")

for result in results:
duration = f"{result.status.duration_s:.2f}s"
status = "PASS" if result.success else "FAIL"
status_style = "green" if result.success else "red"
table.add_row(
f"[{status_style}]{status}[/{status_style}]",
result.case.metadata.name,
duration,
str(result.case.path),
)

console.print(table)
console.print()
console.print(f"[bold]{header}[/bold]")
from sqlcheck.cli.connections import build_connector, resolve_connection_uri
from sqlcheck.cli.discovery import discover_cases
from sqlcheck.cli.output import print_results

__all__ = [
"build_connector",
"discover_cases",
"print_results",
"resolve_connection_uri",
]
26 changes: 26 additions & 0 deletions sqlcheck/cli/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import os
import re

from sqlcheck.db_connector import DBConnector, SQLAlchemyConnector


def _connection_env_var(name: str) -> str:
slug = re.sub(r"[^A-Za-z0-9]+", "_", name).strip("_").upper()
return f"SQLCHECK_CONN_{slug}"


def resolve_connection_uri(name: str) -> str:
env_var = _connection_env_var(name)
value = os.getenv(env_var)
if not value:
raise ValueError(f"Missing connection URI. Set {env_var}.")
return value


def build_connector(connection: str) -> DBConnector:
connection_uri = resolve_connection_uri(connection)
return SQLAlchemyConnector(connection_uri=connection_uri)

__all__ = ["build_connector", "resolve_connection_uri"]
19 changes: 19 additions & 0 deletions sqlcheck/cli/discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from pathlib import Path

import typer

from sqlcheck.discovery import build_test_case, discover_files
from sqlcheck.models import TestCase


def discover_cases(target: Path, pattern: str) -> list[TestCase]:
paths = discover_files(target, pattern)
if not paths:
print("No test files found.")
raise typer.Exit(code=1)
return [build_test_case(path) for path in paths]


__all__ = ["discover_cases"]
74 changes: 74 additions & 0 deletions sqlcheck/cli/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.table import Table

from sqlcheck.models import TestResult


def print_results(results: list[TestResult], engine: str | None = None) -> None:
total = len(results)
failures = [result for result in results if not result.success]
passed = total - len(failures)
console = Console()

header = "SQLCheck"
if engine:
header += f" ({engine})"
header += f" — {total} tests, {passed} passed"
if failures:
header += f", {len(failures)} failed"

if failures:
console.print("[bold]Failures:[/bold]")
for result in failures:
console.print(
f"[red]FAIL[/red] {result.case.metadata.name} [dim]{result.case.path}[/dim]"
)
for func_result in result.function_results:
if not func_result.success:
message = func_result.message or "Expectation failed"
console.print(f" {message}")
if result.output.stderr:
console.print(
Panel(
result.output.stderr.strip(),
title="STDERR",
border_style="red",
)
)
if result.output.stdout:
console.print(
Panel(
result.output.stdout.strip(),
title="STDOUT",
border_style="yellow",
)
)
console.print()

table = Table(box=box.ASCII, show_header=True, header_style="bold")
table.add_column("STATUS", style="bold")
table.add_column("TEST")
table.add_column("DURATION", justify="right")
table.add_column("PATH")

for result in results:
duration = f"{result.status.duration_s:.2f}s"
status = "PASS" if result.success else "FAIL"
status_style = "green" if result.success else "red"
table.add_row(
f"[{status_style}]{status}[/{status_style}]",
result.case.metadata.name,
duration,
str(result.case.path),
)

console.print(table)
console.print()
console.print(f"[bold]{header}[/bold]")


__all__ = ["print_results"]
5 changes: 5 additions & 0 deletions sqlcheck/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Database connector implementations."""

from sqlcheck.connectors.sqlalchemy import SQLAlchemyConnector

__all__ = ["SQLAlchemyConnector"]
98 changes: 98 additions & 0 deletions sqlcheck/connectors/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import time
from contextlib import contextmanager
from typing import Iterator
from urllib.parse import urlparse

from sqlalchemy import create_engine
from sqlalchemy.exc import NoSuchModuleError, SQLAlchemyError

from sqlcheck.db_connector import CommandDBConnector, DBSession, ExecutionResult
from sqlcheck.models import ExecutionOutput, ExecutionStatus, SQLParsed


class SQLAlchemyConnector(CommandDBConnector):
name = "sqlalchemy"

def __init__(self, connection_uri: str) -> None:
self.connection_uri = connection_uri
try:
self.engine = create_engine(connection_uri)
except NoSuchModuleError as exc:
dialect = _dialect_from_uri(connection_uri)
hint = _driver_hint(dialect)
message = (
f"Missing SQLAlchemy driver for '{dialect}'. {hint} "
f"Original error: {exc}"
)
raise ValueError(message) from exc

def execute(self, sql_parsed: SQLParsed, timeout: float | None = None) -> ExecutionResult:
with self.engine.connect() as connection:
return self._execute_with_connection(connection, sql_parsed, timeout)

@contextmanager
def open_session(self) -> Iterator[DBSession]:
with self.engine.connect() as connection:
def _execute(sql_parsed: SQLParsed, timeout: float | None = None) -> ExecutionResult:
return self._execute_with_connection(connection, sql_parsed, timeout)

yield DBSession(_execute)

def _execute_with_connection(
self,
connection: object,
sql_parsed: SQLParsed,
timeout: float | None = None,
) -> ExecutionResult:
start = time.perf_counter()
stdout = ""
stderr = ""
rows: list[list[object]] = []
returncode = 0
success = True
try:
exec_connection = connection
if timeout is not None:
exec_connection = connection.execution_options(timeout=timeout)
with exec_connection.begin():
statements = sql_parsed.statements
if not statements:
statements = []
for statement in statements or []:
result = exec_connection.exec_driver_sql(statement.text)
if result.returns_rows:
rows = [list(row) for row in result.fetchall()]
if not statements and sql_parsed.source.strip():
result = exec_connection.exec_driver_sql(sql_parsed.source)
if result.returns_rows:
rows = [list(row) for row in result.fetchall()]
except SQLAlchemyError as exc:
success = False
returncode = 1
stderr = str(exc)
duration = time.perf_counter() - start
status = ExecutionStatus(success=success, returncode=returncode, duration_s=duration)
output = ExecutionOutput(stdout=stdout, stderr=stderr, rows=rows)
return ExecutionResult(status=status, output=output)


def _dialect_from_uri(connection_uri: str) -> str:
scheme = urlparse(connection_uri).scheme
return scheme.split("+", maxsplit=1)[0] if scheme else "unknown"


def _driver_hint(dialect: str) -> str:
hints = {
"snowflake": "Install the optional dependency with: pip install sqlcheck[snowflake]",
"duckdb": "Install the optional dependency with: pip install sqlcheck[duckdb]",
"postgresql": "Install the optional dependency with: pip install sqlcheck[postgres]",
"mysql": "Install the optional dependency with: pip install sqlcheck[mysql]",
"databricks": "Install the optional dependency with: pip install sqlcheck[databricks]",
}
message = hints.get(
dialect,
"Install the database-specific SQLAlchemy dialect/driver for this URI.",
)
return f"{message} See https://docs.sqlalchemy.org/en/20/dialects/ for details."
Loading