Skip to content

Commit 19b49cc

Browse files
committed
port over changes from posit-dev/ggsql#178
1 parent 94a7780 commit 19b49cc

3 files changed

Lines changed: 189 additions & 23 deletions

File tree

python/ggsql/__init__.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

33
import json
4-
from typing import Any, Union
4+
from typing import Any, Protocol, Union, runtime_checkable
55

66
import altair
77
import narwhals as nw
88
from narwhals.typing import IntoFrame
9+
import polars as pl
910

1011
from ggsql._ggsql import (
1112
DuckDBReader,
@@ -14,6 +15,10 @@
1415
Spec,
1516
validate,
1617
execute,
18+
ParseError,
19+
ValidationError,
20+
ReaderError,
21+
WriterError,
1722
)
1823

1924
__all__ = [
@@ -22,10 +27,16 @@
2227
"VegaLiteWriter",
2328
"Validated",
2429
"Spec",
30+
"Reader",
2531
# Functions
2632
"validate",
2733
"execute",
2834
"render_altair",
35+
# Exceptions
36+
"ParseError",
37+
"ValidationError",
38+
"ReaderError",
39+
"WriterError",
2940
]
3041
__version__ = "0.2.7"
3142

@@ -41,6 +52,29 @@
4152
]
4253

4354

55+
@runtime_checkable
56+
class Reader(Protocol):
57+
"""Protocol for ggsql database readers.
58+
59+
Any object implementing these methods can be used as a reader with
60+
``ggsql.execute()``. Native readers like ``DuckDBReader`` satisfy
61+
this protocol automatically.
62+
63+
Required methods
64+
----------------
65+
execute_sql(sql: str) -> polars.DataFrame
66+
Execute a SQL query and return results as a polars DataFrame.
67+
register(name: str, df: polars.DataFrame, replace: bool = False) -> None
68+
Register a DataFrame as a named table for SQL queries.
69+
"""
70+
71+
def execute_sql(self, sql: str) -> pl.DataFrame: ...
72+
73+
def register(
74+
self, name: str, df: pl.DataFrame, replace: bool = False
75+
) -> None: ...
76+
77+
4478
def _json_to_altair_chart(vegalite_json: str, **kwargs: Any) -> AltairChart:
4579
"""Convert a Vega-Lite JSON string to the appropriate Altair chart type."""
4680
spec = json.loads(vegalite_json)

src/lib.rs

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// See: https://github.com/PyO3/pyo3/issues/4327
33
#![allow(clippy::useless_conversion)]
44

5+
use pyo3::create_exception;
6+
use pyo3::exceptions::PyValueError;
57
use pyo3::prelude::*;
68
use pyo3::types::{PyBytes, PyDict, PyList};
79
use std::io::Cursor;
@@ -12,6 +14,48 @@ use ggsql::validate::{validate as rust_validate, ValidationWarning};
1214
use ggsql::writer::{VegaLiteWriter as RustVegaLiteWriter, Writer as RustWriter};
1315
use ggsql::GgsqlError;
1416

17+
// ============================================================================
18+
// Custom Exception Classes
19+
// ============================================================================
20+
21+
// All subclass ValueError for backwards compatibility
22+
create_exception!(
23+
ggsql,
24+
ParseError,
25+
PyValueError,
26+
"Raised on query syntax errors."
27+
);
28+
create_exception!(
29+
ggsql,
30+
ValidationError,
31+
PyValueError,
32+
"Raised on semantic validation errors."
33+
);
34+
create_exception!(
35+
ggsql,
36+
ReaderError,
37+
PyValueError,
38+
"Raised on data source errors."
39+
);
40+
create_exception!(
41+
ggsql,
42+
WriterError,
43+
PyValueError,
44+
"Raised on output generation errors."
45+
);
46+
47+
/// Convert a GgsqlError to the appropriate typed Python exception.
48+
fn ggsql_err_to_py(e: GgsqlError) -> PyErr {
49+
let msg = e.to_string();
50+
match e {
51+
GgsqlError::ParseError(_) => PyErr::new::<ParseError, _>(msg),
52+
GgsqlError::ValidationError(_) => PyErr::new::<ValidationError, _>(msg),
53+
GgsqlError::ReaderError(_) => PyErr::new::<ReaderError, _>(msg),
54+
GgsqlError::WriterError(_) => PyErr::new::<WriterError, _>(msg),
55+
GgsqlError::InternalError(_) => PyErr::new::<PyValueError, _>(msg),
56+
}
57+
}
58+
1559
use polars::prelude::{DataFrame, IpcReader, IpcWriter, SerReader, SerWriter};
1660

1761
// ============================================================================
@@ -144,9 +188,13 @@ impl Reader for PyReaderBridge {
144188
Python::attach(|py| {
145189
let py_df =
146190
polars_to_py(py, &df).map_err(|e| GgsqlError::ReaderError(e.to_string()))?;
191+
let kwargs = PyDict::new(py);
192+
kwargs
193+
.set_item("replace", replace)
194+
.map_err(|e| GgsqlError::ReaderError(e.to_string()))?;
147195
self.obj
148196
.bind(py)
149-
.call_method1("register", (name, py_df, replace))
197+
.call_method("register", (name, py_df), Some(&kwargs))
150198
.map_err(|e| GgsqlError::ReaderError(format!("Reader.register() failed: {}", e)))?;
151199
Ok(())
152200
})
@@ -185,7 +233,7 @@ macro_rules! try_native_readers {
185233
if let Ok(native) = $reader.downcast::<$native_type>() {
186234
return native.borrow().inner.execute($query)
187235
.map(|s| PySpec { inner: s })
188-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()));
236+
.map_err(ggsql_err_to_py);
189237
}
190238
)*
191239
}};
@@ -234,8 +282,8 @@ impl PyDuckDBReader {
234282
/// If the connection string is invalid or the database cannot be opened.
235283
#[new]
236284
fn new(connection: &str) -> PyResult<Self> {
237-
let inner = RustDuckDBReader::from_connection_string(connection)
238-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
285+
let inner =
286+
RustDuckDBReader::from_connection_string(connection).map_err(ggsql_err_to_py)?;
239287
Ok(Self { inner })
240288
}
241289

@@ -265,7 +313,7 @@ impl PyDuckDBReader {
265313
let rust_df = py_to_polars(py, df)?;
266314
self.inner
267315
.register(name, rust_df, replace)
268-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
316+
.map_err(ggsql_err_to_py)
269317
}
270318

271319
/// Unregister a previously registered table.
@@ -280,9 +328,7 @@ impl PyDuckDBReader {
280328
/// ValueError
281329
/// If the table wasn't registered via this reader or unregistration fails.
282330
fn unregister(&self, name: &str) -> PyResult<()> {
283-
self.inner
284-
.unregister(name)
285-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
331+
self.inner.unregister(name).map_err(ggsql_err_to_py)
286332
}
287333

288334
/// Execute a SQL query and return the result as a DataFrame.
@@ -302,10 +348,7 @@ impl PyDuckDBReader {
302348
/// ValueError
303349
/// If the SQL is invalid or execution fails.
304350
fn execute_sql(&self, py: Python<'_>, sql: &str) -> PyResult<Py<PyAny>> {
305-
let df = self
306-
.inner
307-
.execute_sql(sql)
308-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
351+
let df = self.inner.execute_sql(sql).map_err(ggsql_err_to_py)?;
309352
polars_to_py(py, &df)
310353
}
311354

@@ -340,7 +383,7 @@ impl PyDuckDBReader {
340383
self.inner
341384
.execute(query)
342385
.map(|s| PySpec { inner: s })
343-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
386+
.map_err(ggsql_err_to_py)
344387
}
345388
}
346389

@@ -401,9 +444,7 @@ impl PyVegaLiteWriter {
401444
/// >>> writer = VegaLiteWriter()
402445
/// >>> json_output = writer.render(spec)
403446
fn render(&self, spec: &PySpec) -> PyResult<String> {
404-
self.inner
405-
.render(&spec.inner)
406-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
447+
self.inner.render(&spec.inner).map_err(ggsql_err_to_py)
407448
}
408449
}
409450

@@ -667,8 +708,7 @@ impl PySpec {
667708
/// If validation fails unexpectedly (not for syntax errors, which are captured).
668709
#[pyfunction]
669710
fn validate(query: &str) -> PyResult<PyValidated> {
670-
let v = rust_validate(query)
671-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
711+
let v = rust_validate(query).map_err(ggsql_err_to_py)?;
672712

673713
Ok(PyValidated {
674714
sql: v.sql().to_string(),
@@ -749,7 +789,7 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
749789
bridge
750790
.execute(query)
751791
.map(|s| PySpec { inner: s })
752-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
792+
.map_err(ggsql_err_to_py)
753793
}
754794

755795
// ============================================================================
@@ -758,6 +798,12 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
758798

759799
#[pymodule]
760800
fn _ggsql(m: &Bound<'_, PyModule>) -> PyResult<()> {
801+
// Exceptions
802+
m.add("ParseError", m.py().get_type::<ParseError>())?;
803+
m.add("ValidationError", m.py().get_type::<ValidationError>())?;
804+
m.add("ReaderError", m.py().get_type::<ReaderError>())?;
805+
m.add("WriterError", m.py().get_type::<WriterError>())?;
806+
761807
// Classes
762808
m.add_class::<PyDuckDBReader>()?;
763809
m.add_class::<PyVegaLiteWriter>()?;

tests/test_ggsql.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def __init__(self):
402402
def execute_sql(self, sql: str) -> pl.DataFrame:
403403
return self.conn.execute(sql).pl()
404404

405-
def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
405+
def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None:
406406
self.conn.register(name, df)
407407

408408
reader = RegisterReader()
@@ -453,7 +453,7 @@ def __init__(self):
453453
def execute_sql(self, sql: str) -> pl.DataFrame:
454454
return self.conn.execute(sql).pl()
455455

456-
def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
456+
def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None:
457457
self.conn.register(name, df)
458458

459459
reader = DuckDBBackedReader()
@@ -484,7 +484,7 @@ def execute_sql(self, sql: str) -> pl.DataFrame:
484484
self.execute_calls.append(sql)
485485
return self.conn.execute(sql).pl()
486486

487-
def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
487+
def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None:
488488
self.conn.register(name, df)
489489

490490
reader = RecordingReader()
@@ -532,6 +532,92 @@ def unregister(self, name: str) -> None:
532532
assert "point" in json_output
533533

534534

535+
class TestExceptions:
536+
"""Tests for typed exception classes."""
537+
538+
def test_parse_error_on_invalid_syntax(self):
539+
"""Invalid syntax raises ParseError when executing."""
540+
with pytest.raises(ggsql.ParseError):
541+
reader = ggsql.DuckDBReader("duckdb://memory")
542+
reader.execute("SELECT 1 AS x VISUALISE DRAW not_a_geom")
543+
544+
def test_parse_error_is_value_error(self):
545+
"""ParseError is a subclass of ValueError for backwards compat."""
546+
assert issubclass(ggsql.ParseError, ValueError)
547+
548+
def test_validation_error_on_missing_aesthetic(self):
549+
"""Missing required aesthetic raises ValidationError."""
550+
with pytest.raises(ggsql.ValidationError):
551+
reader = ggsql.DuckDBReader("duckdb://memory")
552+
reader.execute("SELECT 1 AS x VISUALISE DRAW point MAPPING x AS x")
553+
554+
def test_validation_error_is_value_error(self):
555+
"""ValidationError is a subclass of ValueError for backwards compat."""
556+
assert issubclass(ggsql.ValidationError, ValueError)
557+
558+
def test_reader_error_on_bad_sql(self):
559+
"""Bad SQL raises ReaderError."""
560+
with pytest.raises(ggsql.ReaderError):
561+
reader = ggsql.DuckDBReader("duckdb://memory")
562+
reader.execute(
563+
"SELECT * FROM nonexistent_table VISUALISE DRAW point MAPPING x AS x, y AS y"
564+
)
565+
566+
def test_reader_error_is_value_error(self):
567+
"""ReaderError is a subclass of ValueError for backwards compat."""
568+
assert issubclass(ggsql.ReaderError, ValueError)
569+
570+
def test_writer_error_is_value_error(self):
571+
"""WriterError is a subclass of ValueError for backwards compat."""
572+
assert issubclass(ggsql.WriterError, ValueError)
573+
574+
def test_all_exceptions_exported(self):
575+
"""All exception classes are accessible from ggsql module."""
576+
assert hasattr(ggsql, "ParseError")
577+
assert hasattr(ggsql, "ValidationError")
578+
assert hasattr(ggsql, "ReaderError")
579+
assert hasattr(ggsql, "WriterError")
580+
581+
582+
class TestReaderProtocol:
583+
"""Tests for Reader protocol."""
584+
585+
def test_duckdb_reader_is_reader(self):
586+
"""Native DuckDBReader satisfies the Reader protocol."""
587+
reader = ggsql.DuckDBReader("duckdb://memory")
588+
assert isinstance(reader, ggsql.Reader)
589+
590+
def test_custom_reader_is_reader(self):
591+
"""Custom reader with correct methods satisfies the Reader protocol."""
592+
593+
class MyReader:
594+
def execute_sql(self, sql: str) -> pl.DataFrame:
595+
return pl.DataFrame({"x": [1]})
596+
597+
def register(
598+
self, name: str, df: pl.DataFrame, replace: bool = False
599+
) -> None:
600+
pass
601+
602+
reader = MyReader()
603+
assert isinstance(reader, ggsql.Reader)
604+
605+
def test_incomplete_reader_is_not_reader(self):
606+
"""Object missing required methods is not a Reader."""
607+
608+
class NotAReader:
609+
def execute_sql(self, sql: str) -> pl.DataFrame:
610+
return pl.DataFrame({"x": [1]})
611+
# Missing register()
612+
613+
obj = NotAReader()
614+
assert not isinstance(obj, ggsql.Reader)
615+
616+
def test_reader_is_exported(self):
617+
"""Reader is accessible from ggsql module."""
618+
assert hasattr(ggsql, "Reader")
619+
620+
535621
class TestVegaLiteWriterRenderChart:
536622
"""Tests for VegaLiteWriter.render_chart() method."""
537623

0 commit comments

Comments
 (0)