22// See: https://github.com/PyO3/pyo3/issues/4327
33#![ allow( clippy:: useless_conversion) ]
44
5+ use pyo3:: create_exception;
6+ use pyo3:: exceptions:: PyValueError ;
57use pyo3:: prelude:: * ;
68use pyo3:: types:: { PyBytes , PyDict , PyList } ;
79use std:: io:: Cursor ;
@@ -12,6 +14,48 @@ use ggsql::validate::{validate as rust_validate, ValidationWarning};
1214use ggsql:: writer:: { VegaLiteWriter as RustVegaLiteWriter , Writer as RustWriter } ;
1315use ggsql:: GgsqlError ;
1416
17+ // ============================================================================
18+ // Custom Exception Classes
19+ // ============================================================================
20+
21+ // All subclass ValueError for backwards compatibility
22+ create_exception ! (
23+ ggsql,
24+ ParseError ,
25+ PyValueError ,
26+ "Raised on query syntax errors."
27+ ) ;
28+ create_exception ! (
29+ ggsql,
30+ ValidationError ,
31+ PyValueError ,
32+ "Raised on semantic validation errors."
33+ ) ;
34+ create_exception ! (
35+ ggsql,
36+ ReaderError ,
37+ PyValueError ,
38+ "Raised on data source errors."
39+ ) ;
40+ create_exception ! (
41+ ggsql,
42+ WriterError ,
43+ PyValueError ,
44+ "Raised on output generation errors."
45+ ) ;
46+
47+ /// Convert a GgsqlError to the appropriate typed Python exception.
48+ fn ggsql_err_to_py ( e : GgsqlError ) -> PyErr {
49+ let msg = e. to_string ( ) ;
50+ match e {
51+ GgsqlError :: ParseError ( _) => PyErr :: new :: < ParseError , _ > ( msg) ,
52+ GgsqlError :: ValidationError ( _) => PyErr :: new :: < ValidationError , _ > ( msg) ,
53+ GgsqlError :: ReaderError ( _) => PyErr :: new :: < ReaderError , _ > ( msg) ,
54+ GgsqlError :: WriterError ( _) => PyErr :: new :: < WriterError , _ > ( msg) ,
55+ GgsqlError :: InternalError ( _) => PyErr :: new :: < PyValueError , _ > ( msg) ,
56+ }
57+ }
58+
1559use polars:: prelude:: { DataFrame , IpcReader , IpcWriter , SerReader , SerWriter } ;
1660
1761// ============================================================================
@@ -144,9 +188,13 @@ impl Reader for PyReaderBridge {
144188 Python :: attach ( |py| {
145189 let py_df =
146190 polars_to_py ( py, & df) . map_err ( |e| GgsqlError :: ReaderError ( e. to_string ( ) ) ) ?;
191+ let kwargs = PyDict :: new ( py) ;
192+ kwargs
193+ . set_item ( "replace" , replace)
194+ . map_err ( |e| GgsqlError :: ReaderError ( e. to_string ( ) ) ) ?;
147195 self . obj
148196 . bind ( py)
149- . call_method1 ( "register" , ( name, py_df, replace ) )
197+ . call_method ( "register" , ( name, py_df) , Some ( & kwargs ) )
150198 . map_err ( |e| GgsqlError :: ReaderError ( format ! ( "Reader.register() failed: {}" , e) ) ) ?;
151199 Ok ( ( ) )
152200 } )
@@ -185,7 +233,7 @@ macro_rules! try_native_readers {
185233 if let Ok ( native) = $reader. downcast:: <$native_type>( ) {
186234 return native. borrow( ) . inner. execute( $query)
187235 . map( |s| PySpec { inner: s } )
188- . map_err( |e| PyErr :: new :: <pyo3 :: exceptions :: PyValueError , _> ( e . to_string ( ) ) ) ;
236+ . map_err( ggsql_err_to_py ) ;
189237 }
190238 ) *
191239 } } ;
@@ -234,8 +282,8 @@ impl PyDuckDBReader {
234282 /// If the connection string is invalid or the database cannot be opened.
235283 #[ new]
236284 fn new ( connection : & str ) -> PyResult < Self > {
237- let inner = RustDuckDBReader :: from_connection_string ( connection )
238- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) ) ?;
285+ let inner =
286+ RustDuckDBReader :: from_connection_string ( connection ) . map_err ( ggsql_err_to_py ) ?;
239287 Ok ( Self { inner } )
240288 }
241289
@@ -265,7 +313,7 @@ impl PyDuckDBReader {
265313 let rust_df = py_to_polars ( py, df) ?;
266314 self . inner
267315 . register ( name, rust_df, replace)
268- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
316+ . map_err ( ggsql_err_to_py )
269317 }
270318
271319 /// Unregister a previously registered table.
@@ -280,9 +328,7 @@ impl PyDuckDBReader {
280328 /// ValueError
281329 /// If the table wasn't registered via this reader or unregistration fails.
282330 fn unregister ( & self , name : & str ) -> PyResult < ( ) > {
283- self . inner
284- . unregister ( name)
285- . map_err ( |e| PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( e. to_string ( ) ) )
331+ self . inner . unregister ( name) . map_err ( ggsql_err_to_py)
286332 }
287333
288334 /// Execute a SQL query and return the result as a DataFrame.
@@ -302,10 +348,7 @@ impl PyDuckDBReader {
302348 /// ValueError
303349 /// If the SQL is invalid or execution fails.
304350 fn execute_sql ( & self , py : Python < ' _ > , sql : & str ) -> PyResult < Py < PyAny > > {
305- let df = self
306- . inner
307- . execute_sql ( sql)
308- . map_err ( |e| PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?;
351+ let df = self . inner . execute_sql ( sql) . map_err ( ggsql_err_to_py) ?;
309352 polars_to_py ( py, & df)
310353 }
311354
@@ -340,7 +383,7 @@ impl PyDuckDBReader {
340383 self . inner
341384 . execute ( query)
342385 . map ( |s| PySpec { inner : s } )
343- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
386+ . map_err ( ggsql_err_to_py )
344387 }
345388}
346389
@@ -401,9 +444,7 @@ impl PyVegaLiteWriter {
401444 /// >>> writer = VegaLiteWriter()
402445 /// >>> json_output = writer.render(spec)
403446 fn render ( & self , spec : & PySpec ) -> PyResult < String > {
404- self . inner
405- . render ( & spec. inner )
406- . map_err ( |e| PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( e. to_string ( ) ) )
447+ self . inner . render ( & spec. inner ) . map_err ( ggsql_err_to_py)
407448 }
408449}
409450
@@ -667,8 +708,7 @@ impl PySpec {
667708/// If validation fails unexpectedly (not for syntax errors, which are captured).
668709#[ pyfunction]
669710fn validate ( query : & str ) -> PyResult < PyValidated > {
670- let v = rust_validate ( query)
671- . map_err ( |e| PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?;
711+ let v = rust_validate ( query) . map_err ( ggsql_err_to_py) ?;
672712
673713 Ok ( PyValidated {
674714 sql : v. sql ( ) . to_string ( ) ,
@@ -749,7 +789,7 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
749789 bridge
750790 . execute ( query)
751791 . map ( |s| PySpec { inner : s } )
752- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
792+ . map_err ( ggsql_err_to_py )
753793}
754794
755795// ============================================================================
@@ -758,6 +798,12 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
758798
759799#[ pymodule]
760800fn _ggsql ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
801+ // Exceptions
802+ m. add ( "ParseError" , m. py ( ) . get_type :: < ParseError > ( ) ) ?;
803+ m. add ( "ValidationError" , m. py ( ) . get_type :: < ValidationError > ( ) ) ?;
804+ m. add ( "ReaderError" , m. py ( ) . get_type :: < ReaderError > ( ) ) ?;
805+ m. add ( "WriterError" , m. py ( ) . get_type :: < WriterError > ( ) ) ?;
806+
761807 // Classes
762808 m. add_class :: < PyDuckDBReader > ( ) ?;
763809 m. add_class :: < PyVegaLiteWriter > ( ) ?;
0 commit comments