diff --git a/python/ggsql/__init__.py b/python/ggsql/__init__.py index 9b20a56..508df52 100644 --- a/python/ggsql/__init__.py +++ b/python/ggsql/__init__.py @@ -1,11 +1,12 @@ from __future__ import annotations import json -from typing import Any, Union +from typing import Any, Protocol, Union, runtime_checkable import altair import narwhals as nw from narwhals.typing import IntoFrame +import polars as pl from ggsql._ggsql import ( DuckDBReader, @@ -14,6 +15,10 @@ Spec, validate, execute, + ParseError, + ValidationError, + ReaderError, + WriterError, ) __all__ = [ @@ -22,10 +27,16 @@ "VegaLiteWriter", "Validated", "Spec", + "Reader", # Functions "validate", "execute", "render_altair", + # Exceptions + "ParseError", + "ValidationError", + "ReaderError", + "WriterError", ] __version__ = "0.2.7" @@ -41,6 +52,29 @@ ] +@runtime_checkable +class Reader(Protocol): + """Protocol for ggsql database readers. + + Any object implementing these methods can be used as a reader with + ``ggsql.execute()``. Native readers like ``DuckDBReader`` satisfy + this protocol automatically. + + Required methods + ---------------- + execute_sql(sql: str) -> polars.DataFrame + Execute a SQL query and return results as a polars DataFrame. + register(name: str, df: polars.DataFrame, replace: bool = False) -> None + Register a DataFrame as a named table for SQL queries. + """ + + def execute_sql(self, sql: str) -> pl.DataFrame: ... + + def register( + self, name: str, df: pl.DataFrame, replace: bool = False + ) -> None: ... + + def _json_to_altair_chart(vegalite_json: str, **kwargs: Any) -> AltairChart: """Convert a Vega-Lite JSON string to the appropriate Altair chart type.""" spec = json.loads(vegalite_json) diff --git a/python/ggsql/_ggsql.pyi b/python/ggsql/_ggsql.pyi new file mode 100644 index 0000000..756a287 --- /dev/null +++ b/python/ggsql/_ggsql.pyi @@ -0,0 +1,438 @@ +"""Type stubs for the ggsql native module (_ggsql).""" + +from __future__ import annotations + +import polars as pl + +# --------------------------------------------------------------------------- +# Exceptions (all subclass ValueError for backwards compatibility) +# --------------------------------------------------------------------------- + +class ParseError(ValueError): ... +class ValidationError(ValueError): ... +class ReaderError(ValueError): ... +class WriterError(ValueError): ... + +# --------------------------------------------------------------------------- +# DuckDBReader +# --------------------------------------------------------------------------- + +class DuckDBReader: + """DuckDB database reader for executing SQL queries. + + Creates an in-memory or file-based DuckDB connection that can execute + SQL queries and register DataFrames as queryable tables. + + Parameters + ---------- + connection + DuckDB connection string. Use ``"duckdb://memory"`` for in-memory + database or ``"duckdb://path/to/file.db"`` for file-based database. + + Raises + ------ + ReaderError + If the connection string is invalid or the database cannot be opened. + """ + + def __init__(self, connection: str) -> None: ... + def execute_sql(self, sql: str) -> pl.DataFrame: + """Execute a SQL query and return results as a polars DataFrame. + + Parameters + ---------- + sql + The SQL query to execute. + + Returns + ------- + polars.DataFrame + The query result as a polars DataFrame. + + Raises + ------ + ReaderError + If the SQL is invalid or execution fails. + """ + ... + + def register( + self, name: str, df: pl.DataFrame, replace: bool = False + ) -> None: + """Register a polars DataFrame as a named table. + + After registration the DataFrame can be queried by name in SQL. + + Parameters + ---------- + name + The table name to register under. + df + The DataFrame to register. Must be a polars DataFrame. + replace + Whether to replace an existing table with the same name. + + Raises + ------ + ReaderError + If registration fails or the table name is invalid. + """ + ... + + def unregister(self, name: str) -> None: + """Unregister a previously registered table. + + Parameters + ---------- + name + The table name to unregister. + + Raises + ------ + ReaderError + If the table was not registered or unregistration fails. + """ + ... + + def execute( + self, + query: str, + *, + data: dict[str, pl.DataFrame] | None = None, + ) -> Spec: + """Execute a ggsql query and return the visualization specification. + + This is the main entry point for creating visualizations. It parses + the query, executes the SQL portion, and returns a ``Spec`` ready + for rendering. + + Parameters + ---------- + query + The ggsql query (SQL + VISUALISE clause). + data + Optional dictionary mapping table names to DataFrames. Tables are + registered before execution and unregistered afterward (even on + error). + + Returns + ------- + Spec + The resolved visualization specification ready for rendering. + + Raises + ------ + ParseError + If the query syntax is invalid. + ValidationError + If the query has no VISUALISE clause or fails semantic checks. + ReaderError + If SQL execution fails. + """ + ... + +# --------------------------------------------------------------------------- +# VegaLiteWriter +# --------------------------------------------------------------------------- + +class VegaLiteWriter: + """Vega-Lite v6 JSON output writer. + + Converts visualization specifications to Vega-Lite v6 JSON. + """ + + def __init__(self) -> None: ... + def render(self, spec: Spec) -> str: + """Render a Spec to a Vega-Lite JSON string. + + Parameters + ---------- + spec + The visualization specification from ``reader.execute()``. + + Returns + ------- + str + The Vega-Lite JSON string. + + Raises + ------ + WriterError + If rendering fails. + """ + ... + +# --------------------------------------------------------------------------- +# Validated +# --------------------------------------------------------------------------- + +class Validated: + """Result of ``validate()`` — query inspection without SQL execution. + + Contains information about query structure and any validation + errors/warnings. + """ + + def has_visual(self) -> bool: + """Whether the query contains a VISUALISE clause. + + Returns + ------- + bool + ``True`` if the query has a VISUALISE clause. + """ + ... + + def sql(self) -> str: + """The SQL portion (before VISUALISE). + + Returns + ------- + str + The SQL part of the query. + """ + ... + + def visual(self) -> str: + """The VISUALISE portion (raw text). + + Returns + ------- + str + The VISUALISE part of the query. + """ + ... + + def valid(self) -> bool: + """Whether the query is valid (no errors). + + Returns + ------- + bool + ``True`` if the query is syntactically and semantically valid. + """ + ... + + def errors(self) -> list[dict[str, object]]: + """Validation errors (fatal issues). + + Returns + ------- + list[dict] + List of error dictionaries with ``"message"`` (str) and + ``"location"`` (``{"line": int, "column": int}`` or ``None``) + keys. + """ + ... + + def warnings(self) -> list[dict[str, object]]: + """Validation warnings (non-fatal issues). + + Returns + ------- + list[dict] + List of warning dictionaries with ``"message"`` (str) and + ``"location"`` (``{"line": int, "column": int}`` or ``None``) + keys. + """ + ... + +# --------------------------------------------------------------------------- +# Spec +# --------------------------------------------------------------------------- + +class Spec: + """Result of ``reader.execute()`` — resolved visualization spec. + + Contains the resolved plot specification, data, and metadata. + Use ``writer.render(spec)`` to generate output. + """ + + def metadata(self) -> dict[str, object]: + """Get visualization metadata. + + Returns + ------- + dict + Dictionary with ``"rows"`` (int), ``"columns"`` (list[str]), + and ``"layer_count"`` (int) keys. + """ + ... + + def sql(self) -> str: + """The main SQL query that was executed. + + Returns + ------- + str + The SQL query string. + """ + ... + + def visual(self) -> str: + """The VISUALISE portion (raw text). + + Returns + ------- + str + The VISUALISE clause text. + """ + ... + + def layer_count(self) -> int: + """Number of DRAW layers. + + Returns + ------- + int + The number of DRAW clauses in the visualization. + """ + ... + + def data(self) -> pl.DataFrame | None: + """Main query result DataFrame. + + Returns + ------- + polars.DataFrame or None + The main query result DataFrame, or ``None`` if not available. + """ + ... + + def layer_data(self, index: int) -> pl.DataFrame | None: + """Layer-specific DataFrame (from FILTER or FROM clause). + + Parameters + ---------- + index + The layer index (0-based). + + Returns + ------- + polars.DataFrame or None + The layer-specific DataFrame, or ``None`` if the layer uses + global data. + """ + ... + + def stat_data(self, index: int) -> pl.DataFrame | None: + """Statistical transform DataFrame. + + Parameters + ---------- + index + The layer index (0-based). + + Returns + ------- + polars.DataFrame or None + The stat transform DataFrame, or ``None`` if no stat transform. + """ + ... + + def layer_sql(self, index: int) -> str | None: + """Layer filter/source query. + + Parameters + ---------- + index + The layer index (0-based). + + Returns + ------- + str or None + The filter SQL query, or ``None`` if the layer uses global data. + """ + ... + + def stat_sql(self, index: int) -> str | None: + """Stat transform query. + + Parameters + ---------- + index + The layer index (0-based). + + Returns + ------- + str or None + The stat transform SQL query, or ``None`` if no stat transform. + """ + ... + + def warnings(self) -> list[dict[str, object]]: + """Validation warnings from preparation. + + Returns + ------- + list[dict] + List of warning dictionaries with ``"message"`` (str) and + ``"location"`` (``{"line": int, "column": int}`` or ``None``) + keys. + """ + ... + +# --------------------------------------------------------------------------- +# Module-level functions +# --------------------------------------------------------------------------- + +def validate(query: str) -> Validated: + """Validate query syntax and semantics without executing SQL. + + Parameters + ---------- + query + The ggsql query to validate. + + Returns + ------- + Validated + Validation result with query inspection methods. + + Raises + ------ + ParseError + If validation fails unexpectedly (syntax errors are captured in + the returned ``Validated`` object, not raised). + """ + ... + +def execute( + query: str, + reader: object, + *, + data: dict[str, pl.DataFrame] | None = None, +) -> Spec: + """Execute a ggsql query with a reader (native or custom Python object). + + This is a convenience function for custom readers. For native readers, + prefer using ``reader.execute()`` directly. + + Parameters + ---------- + query + The ggsql query to execute. + reader + The database reader to execute SQL against. Can be a native + ``DuckDBReader`` for optimal performance, or any Python object with + an ``execute_sql(sql: str) -> polars.DataFrame`` method. + data + Optional dictionary mapping table names to DataFrames. Tables are + registered before execution and unregistered afterward (even on + error). + + Returns + ------- + Spec + The resolved visualization specification ready for rendering. + + Raises + ------ + ParseError + If the query syntax is invalid. + ValidationError + If semantic validation fails. + ReaderError + If SQL execution fails. + """ + ... diff --git a/src/lib.rs b/src/lib.rs index 6e5c543..2041b14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,8 @@ // See: https://github.com/PyO3/pyo3/issues/4327 #![allow(clippy::useless_conversion)] +use pyo3::create_exception; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyList}; use std::io::Cursor; @@ -12,6 +14,48 @@ use ggsql::validate::{validate as rust_validate, ValidationWarning}; use ggsql::writer::{VegaLiteWriter as RustVegaLiteWriter, Writer as RustWriter}; use ggsql::GgsqlError; +// ============================================================================ +// Custom Exception Classes +// ============================================================================ + +// All subclass ValueError for backwards compatibility +create_exception!( + ggsql, + ParseError, + PyValueError, + "Raised on query syntax errors." +); +create_exception!( + ggsql, + ValidationError, + PyValueError, + "Raised on semantic validation errors." +); +create_exception!( + ggsql, + ReaderError, + PyValueError, + "Raised on data source errors." +); +create_exception!( + ggsql, + WriterError, + PyValueError, + "Raised on output generation errors." +); + +/// Convert a GgsqlError to the appropriate typed Python exception. +fn ggsql_err_to_py(e: GgsqlError) -> PyErr { + let msg = e.to_string(); + match e { + GgsqlError::ParseError(_) => PyErr::new::(msg), + GgsqlError::ValidationError(_) => PyErr::new::(msg), + GgsqlError::ReaderError(_) => PyErr::new::(msg), + GgsqlError::WriterError(_) => PyErr::new::(msg), + GgsqlError::InternalError(_) => PyErr::new::(msg), + } +} + use polars::prelude::{DataFrame, IpcReader, IpcWriter, SerReader, SerWriter}; // ============================================================================ @@ -144,9 +188,13 @@ impl Reader for PyReaderBridge { Python::attach(|py| { let py_df = polars_to_py(py, &df).map_err(|e| GgsqlError::ReaderError(e.to_string()))?; + let kwargs = PyDict::new(py); + kwargs + .set_item("replace", replace) + .map_err(|e| GgsqlError::ReaderError(e.to_string()))?; self.obj .bind(py) - .call_method1("register", (name, py_df, replace)) + .call_method("register", (name, py_df), Some(&kwargs)) .map_err(|e| GgsqlError::ReaderError(format!("Reader.register() failed: {}", e)))?; Ok(()) }) @@ -173,24 +221,6 @@ impl Reader for PyReaderBridge { } } -// ============================================================================ -// Native Reader Detection Macro -// ============================================================================ - -/// Macro to try native readers and fall back to bridge. -/// Adding new native readers = add to the macro invocation list. -macro_rules! try_native_readers { - ($query:expr, $reader:expr, $($native_type:ty),*) => {{ - $( - if let Ok(native) = $reader.downcast::<$native_type>() { - return native.borrow().inner.execute($query) - .map(|s| PySpec { inner: s }) - .map_err(|e| PyErr::new::(e.to_string())); - } - )* - }}; -} - // ============================================================================ // PyDuckDBReader // ============================================================================ @@ -234,8 +264,8 @@ impl PyDuckDBReader { /// If the connection string is invalid or the database cannot be opened. #[new] fn new(connection: &str) -> PyResult { - let inner = RustDuckDBReader::from_connection_string(connection) - .map_err(|e| PyErr::new::(e.to_string()))?; + let inner = + RustDuckDBReader::from_connection_string(connection).map_err(ggsql_err_to_py)?; Ok(Self { inner }) } @@ -265,7 +295,7 @@ impl PyDuckDBReader { let rust_df = py_to_polars(py, df)?; self.inner .register(name, rust_df, replace) - .map_err(|e| PyErr::new::(e.to_string())) + .map_err(ggsql_err_to_py) } /// Unregister a previously registered table. @@ -280,9 +310,7 @@ impl PyDuckDBReader { /// ValueError /// If the table wasn't registered via this reader or unregistration fails. fn unregister(&self, name: &str) -> PyResult<()> { - self.inner - .unregister(name) - .map_err(|e| PyErr::new::(e.to_string())) + self.inner.unregister(name).map_err(ggsql_err_to_py) } /// Execute a SQL query and return the result as a DataFrame. @@ -302,10 +330,7 @@ impl PyDuckDBReader { /// ValueError /// If the SQL is invalid or execution fails. fn execute_sql(&self, py: Python<'_>, sql: &str) -> PyResult> { - let df = self - .inner - .execute_sql(sql) - .map_err(|e| PyErr::new::(e.to_string()))?; + let df = self.inner.execute_sql(sql).map_err(ggsql_err_to_py)?; polars_to_py(py, &df) } @@ -319,6 +344,9 @@ impl PyDuckDBReader { /// ---------- /// query : str /// The ggsql query (SQL + VISUALISE clause). + /// data : dict[str, polars.DataFrame] | None + /// Optional dictionary mapping table names to DataFrames. Tables are + /// registered before execution and unregistered afterward (even on error). /// /// Returns /// ------- @@ -336,11 +364,65 @@ impl PyDuckDBReader { /// >>> spec = reader.execute("SELECT 1 AS x, 2 AS y VISUALISE x, y DRAW point") /// >>> writer = VegaLiteWriter() /// >>> json_output = writer.render(spec) - fn execute(&self, query: &str) -> PyResult { - self.inner + #[pyo3(signature = (query, *, data=None))] + fn execute( + &self, + py: Python<'_>, + query: &str, + data: Option<&Bound<'_, PyDict>>, + ) -> PyResult { + // Register DataFrames from data dict + let registered_names = if let Some(data_dict) = data { + self.register_data_dict(py, data_dict)? + } else { + vec![] + }; + + // Execute query (capture result, don't return early) + let result = self + .inner .execute(query) .map(|s| PySpec { inner: s }) - .map_err(|e| PyErr::new::(e.to_string())) + .map_err(ggsql_err_to_py); + + // Cleanup: unregister temporary tables (even on error) + for name in ®istered_names { + let _ = self.inner.unregister(name); + } + + result + } +} + +impl PyDuckDBReader { + /// Check whether a table already exists in the reader. + fn table_exists(&self, name: &str) -> bool { + // A lightweight probe: try to select zero rows from the table. + self.inner + .execute_sql(&format!("SELECT 1 FROM {name} LIMIT 0")) + .is_ok() + } + + /// Register DataFrames from a Python dict. Returns list of *newly created* + /// table names for cleanup (pre-existing tables are not tracked). + fn register_data_dict( + &self, + py: Python<'_>, + data: &Bound<'_, PyDict>, + ) -> PyResult> { + let mut names = Vec::new(); + for (key, value) in data.iter() { + let name: String = key.extract()?; + let existed = self.table_exists(&name); + let df = py_to_polars(py, &value)?; + self.inner + .register(&name, df, true) + .map_err(ggsql_err_to_py)?; + if !existed { + names.push(name); + } + } + Ok(names) } } @@ -401,9 +483,7 @@ impl PyVegaLiteWriter { /// >>> writer = VegaLiteWriter() /// >>> json_output = writer.render(spec) fn render(&self, spec: &PySpec) -> PyResult { - self.inner - .render(&spec.inner) - .map_err(|e| PyErr::new::(e.to_string())) + self.inner.render(&spec.inner).map_err(ggsql_err_to_py) } } @@ -667,8 +747,7 @@ impl PySpec { /// If validation fails unexpectedly (not for syntax errors, which are captured). #[pyfunction] fn validate(query: &str) -> PyResult { - let v = rust_validate(query) - .map_err(|e| PyErr::new::(e.to_string()))?; + let v = rust_validate(query).map_err(ggsql_err_to_py)?; Ok(PyValidated { sql: v.sql().to_string(), @@ -711,6 +790,9 @@ fn validate(query: &str) -> PyResult { /// The database reader to execute SQL against. Can be a native Reader /// for optimal performance, or any Python object with an /// `execute_sql(sql: str) -> polars.DataFrame` method. +/// data : dict[str, polars.DataFrame] | None +/// Optional dictionary mapping table names to DataFrames. Tables are +/// registered before execution and unregistered afterward (even on error). /// /// Returns /// ------- @@ -737,19 +819,102 @@ fn validate(query: &str) -> PyResult { /// >>> reader = MyReader() /// >>> spec = execute("SELECT * FROM data VISUALISE x, y DRAW point", reader) #[pyfunction] -fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult { - // Fast path: try all known native reader types - // Add new native readers to this list as they're implemented - try_native_readers!(query, reader, PyDuckDBReader); +#[pyo3(signature = (query, reader, *, data=None))] +fn execute( + py: Python<'_>, + query: &str, + reader: &Bound<'_, PyAny>, + data: Option<&Bound<'_, PyDict>>, +) -> PyResult { + // Native reader fast path: DuckDBReader + // Note: we can't use the try_native_readers! macro here because it uses `return` + // which would skip cleanup of registered tables. + if let Ok(native) = reader.downcast::() { + // Register DataFrames if provided + let registered_names = if let Some(data_dict) = data { + native.borrow().register_data_dict(py, data_dict)? + } else { + vec![] + }; + + // Execute (capture result for cleanup) + let result = native + .borrow() + .inner + .execute(query) + .map(|s| PySpec { inner: s }) + .map_err(ggsql_err_to_py); + + // Cleanup: unregister temporary tables (even on error) + for name in ®istered_names { + let _ = native.borrow().inner.unregister(name); + } + + return result; + } // Bridge path: wrap Python object as Reader + // Register DataFrames if provided + let registered_names = if let Some(data_dict) = data { + register_data_on_reader(py, reader, data_dict)? + } else { + vec![] + }; + let bridge = PyReaderBridge { obj: reader.clone().unbind(), }; - bridge + let result = bridge .execute(query) .map(|s| PySpec { inner: s }) - .map_err(|e| PyErr::new::(e.to_string())) + .map_err(ggsql_err_to_py); + + // Cleanup for bridge path + for name in ®istered_names { + let _ = call_unregister(py, reader, name); + } + + result +} + +/// Register DataFrames from a Python dict onto a Python reader object. +/// Returns list of registered names for cleanup. +/// Check whether a table exists via a Python reader's execute_sql method. +fn reader_table_exists(reader: &Bound<'_, PyAny>, name: &str) -> bool { + reader + .call_method1("execute_sql", (format!("SELECT 1 FROM {name} LIMIT 0"),)) + .is_ok() +} + +/// Register DataFrames from a Python dict onto a Python reader object. +/// Returns list of *newly created* table names for cleanup. +fn register_data_on_reader( + py: Python<'_>, + reader: &Bound<'_, PyAny>, + data: &Bound<'_, PyDict>, +) -> PyResult> { + let mut names = Vec::new(); + for (key, value) in data.iter() { + let name: String = key.extract()?; + let existed = reader_table_exists(reader, &name); + let df = py_to_polars(py, &value)?; + let py_df = polars_to_py(py, &df)?; + let kwargs = PyDict::new(py); + kwargs.set_item("replace", true)?; + reader.call_method("register", (&name, py_df), Some(&kwargs))?; + if !existed { + names.push(name); + } + } + Ok(names) +} + +/// Call unregister on a reader if the method exists. +fn call_unregister(_py: Python<'_>, reader: &Bound<'_, PyAny>, name: &str) -> PyResult<()> { + if reader.hasattr("unregister")? { + reader.call_method1("unregister", (name,))?; + } + Ok(()) } // ============================================================================ @@ -758,6 +923,12 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult { #[pymodule] fn _ggsql(m: &Bound<'_, PyModule>) -> PyResult<()> { + // Exceptions + m.add("ParseError", m.py().get_type::())?; + m.add("ValidationError", m.py().get_type::())?; + m.add("ReaderError", m.py().get_type::())?; + m.add("WriterError", m.py().get_type::())?; + // Classes m.add_class::()?; m.add_class::()?; diff --git a/tests/test_ggsql.py b/tests/test_ggsql.py index fbe4b13..01ac3ac 100644 --- a/tests/test_ggsql.py +++ b/tests/test_ggsql.py @@ -402,7 +402,7 @@ def __init__(self): def execute_sql(self, sql: str) -> pl.DataFrame: return self.conn.execute(sql).pl() - def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: self.conn.register(name, df) reader = RegisterReader() @@ -453,7 +453,7 @@ def __init__(self): def execute_sql(self, sql: str) -> pl.DataFrame: return self.conn.execute(sql).pl() - def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: self.conn.register(name, df) reader = DuckDBBackedReader() @@ -484,7 +484,7 @@ def execute_sql(self, sql: str) -> pl.DataFrame: self.execute_calls.append(sql) return self.conn.execute(sql).pl() - def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: self.conn.register(name, df) reader = RecordingReader() @@ -532,6 +532,92 @@ def unregister(self, name: str) -> None: assert "point" in json_output +class TestExceptions: + """Tests for typed exception classes.""" + + def test_parse_error_on_invalid_syntax(self): + """Invalid syntax raises ParseError when executing.""" + with pytest.raises(ggsql.ParseError): + reader = ggsql.DuckDBReader("duckdb://memory") + reader.execute("SELECT 1 AS x VISUALISE DRAW not_a_geom") + + def test_parse_error_is_value_error(self): + """ParseError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.ParseError, ValueError) + + def test_validation_error_on_missing_aesthetic(self): + """Missing required aesthetic raises ValidationError.""" + with pytest.raises(ggsql.ValidationError): + reader = ggsql.DuckDBReader("duckdb://memory") + reader.execute("SELECT 1 AS x VISUALISE DRAW point MAPPING x AS x") + + def test_validation_error_is_value_error(self): + """ValidationError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.ValidationError, ValueError) + + def test_reader_error_on_bad_sql(self): + """Bad SQL raises ReaderError.""" + with pytest.raises(ggsql.ReaderError): + reader = ggsql.DuckDBReader("duckdb://memory") + reader.execute( + "SELECT * FROM nonexistent_table VISUALISE DRAW point MAPPING x AS x, y AS y" + ) + + def test_reader_error_is_value_error(self): + """ReaderError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.ReaderError, ValueError) + + def test_writer_error_is_value_error(self): + """WriterError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.WriterError, ValueError) + + def test_all_exceptions_exported(self): + """All exception classes are accessible from ggsql module.""" + assert hasattr(ggsql, "ParseError") + assert hasattr(ggsql, "ValidationError") + assert hasattr(ggsql, "ReaderError") + assert hasattr(ggsql, "WriterError") + + +class TestReaderProtocol: + """Tests for Reader protocol.""" + + def test_duckdb_reader_is_reader(self): + """Native DuckDBReader satisfies the Reader protocol.""" + reader = ggsql.DuckDBReader("duckdb://memory") + assert isinstance(reader, ggsql.Reader) + + def test_custom_reader_is_reader(self): + """Custom reader with correct methods satisfies the Reader protocol.""" + + class MyReader: + def execute_sql(self, sql: str) -> pl.DataFrame: + return pl.DataFrame({"x": [1]}) + + def register( + self, name: str, df: pl.DataFrame, replace: bool = False + ) -> None: + pass + + reader = MyReader() + assert isinstance(reader, ggsql.Reader) + + def test_incomplete_reader_is_not_reader(self): + """Object missing required methods is not a Reader.""" + + class NotAReader: + def execute_sql(self, sql: str) -> pl.DataFrame: + return pl.DataFrame({"x": [1]}) + # Missing register() + + obj = NotAReader() + assert not isinstance(obj, ggsql.Reader) + + def test_reader_is_exported(self): + """Reader is accessible from ggsql module.""" + assert hasattr(ggsql, "Reader") + + class TestVegaLiteWriterRenderChart: """Tests for VegaLiteWriter.render_chart() method.""" @@ -569,3 +655,135 @@ def test_render_chart_facet(self): chart = writer.render_chart(spec, validate=False) assert isinstance(chart, altair.FacetChart) + +class TestExecuteWithData: + """Tests for reader.execute() with data= parameter.""" + + def test_execute_with_single_dataframe(self): + """Can pass a single DataFrame via data dict.""" + reader = ggsql.DuckDBReader("duckdb://memory") + df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + spec = reader.execute( + "SELECT * FROM mydata VISUALISE x, y DRAW point", + data={"mydata": df}, + ) + assert spec.metadata()["rows"] == 3 + + def test_execute_with_multiple_dataframes(self): + """Can pass multiple DataFrames via data dict.""" + reader = ggsql.DuckDBReader("duckdb://memory") + df1 = pl.DataFrame({"id": [1, 2, 3], "y": [10, 20, 30]}) + df2 = pl.DataFrame({"id": [2, 3], "category": ["A", "B"]}) + spec = reader.execute( + "SELECT t1.id AS x, t1.y FROM t1 JOIN t2 ON t1.id = t2.id " + "VISUALISE x, y DRAW point", + data={"t1": df1, "t2": df2}, + ) + assert spec.metadata()["rows"] == 2 + + def test_execute_with_data_cleans_up(self): + """DataFrames passed via data= are unregistered after execution.""" + reader = ggsql.DuckDBReader("duckdb://memory") + df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + reader.execute( + "SELECT * FROM temp VISUALISE x, y DRAW point", + data={"temp": df}, + ) + # Table should be cleaned up — querying it should fail + with pytest.raises(ggsql.ReaderError): + reader.execute_sql("SELECT * FROM temp") + + def test_execute_with_data_cleans_up_on_error(self): + """DataFrames are unregistered even if execution fails.""" + reader = ggsql.DuckDBReader("duckdb://memory") + df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + with pytest.raises(ggsql.ParseError): + reader.execute( + "SELECT * FROM temp VISUALISE DRAW not_a_geom", + data={"temp": df}, + ) + # Table should still be cleaned up + with pytest.raises(ggsql.ReaderError): + reader.execute_sql("SELECT * FROM temp") + + def test_execute_without_data_still_works(self): + """Calling execute() without data= still works as before.""" + reader = ggsql.DuckDBReader("duckdb://memory") + spec = reader.execute("SELECT 1 AS x, 2 AS y VISUALISE x, y DRAW point") + assert spec.metadata()["rows"] == 1 + + def test_execute_with_empty_data(self): + """Passing empty data= dict works fine.""" + reader = ggsql.DuckDBReader("duckdb://memory") + spec = reader.execute( + "SELECT 1 AS x, 2 AS y VISUALISE x, y DRAW point", + data={}, + ) + assert spec.metadata()["rows"] == 1 + + def test_execute_with_data_preserves_preexisting_table(self): + """data= does not unregister a table that existed before the call.""" + reader = ggsql.DuckDBReader("duckdb://memory") + existing = pl.DataFrame({"x": [1, 2], "y": [10, 20]}) + reader.register("mytable", existing) + + # Pass same name via data= — should replace temporarily, then NOT unregister + override = pl.DataFrame({"x": [3, 4, 5], "y": [30, 40, 50]}) + spec = reader.execute( + "SELECT * FROM mytable VISUALISE x, y DRAW point", + data={"mytable": override}, + ) + assert spec.metadata()["rows"] == 3 + + # The pre-existing table should still be queryable (not unregistered) + result = reader.execute_sql("SELECT * FROM mytable") + assert result.shape[0] > 0 + + def test_module_execute_with_data(self): + """Module-level execute() also supports data= parameter.""" + reader = ggsql.DuckDBReader("duckdb://memory") + df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + spec = ggsql.execute( + "SELECT * FROM mydata VISUALISE x, y DRAW point", + reader, + data={"mydata": df}, + ) + assert spec.metadata()["rows"] == 3 + + +class TestTypeStubs: + """Tests for type stub presence and correctness.""" + + def test_stub_file_exists(self): + """Type stub file exists for the native module.""" + import pathlib + + assert ggsql.__file__ is not None + ggsql_dir = pathlib.Path(ggsql.__file__).parent + stub_path = ggsql_dir / "_ggsql.pyi" + assert stub_path.exists(), f"Type stub not found at {stub_path}" + + def test_stub_exports_match_module(self): + """All public names from _ggsql are in the stub.""" + import pathlib + + assert ggsql.__file__ is not None + ggsql_dir = pathlib.Path(ggsql.__file__).parent + stub_path = ggsql_dir / "_ggsql.pyi" + stub_text = stub_path.read_text() + + # Core classes and functions should be in the stub + for name in [ + "DuckDBReader", + "VegaLiteWriter", + "Validated", + "Spec", + "validate", + "execute", + "ParseError", + "ValidationError", + "ReaderError", + "WriterError", + ]: + assert name in stub_text, f"{name} not found in type stub" +