diff --git a/arrow-pyarrow-testing/tests/pyarrow.rs b/arrow-pyarrow-testing/tests/pyarrow.rs index 1e4ad43817b7..b67377a9b4f2 100644 --- a/arrow-pyarrow-testing/tests/pyarrow.rs +++ b/arrow-pyarrow-testing/tests/pyarrow.rs @@ -41,9 +41,9 @@ use arrow_array::builder::{BinaryViewBuilder, StringViewBuilder}; use arrow_array::{ Array, ArrayRef, BinaryViewArray, Int32Array, RecordBatch, StringArray, StringViewArray, }; -use arrow_schema::Schema; use arrow_pyarrow::{FromPyArrow, ToPyArrow}; -use pyo3::exceptions::PyTypeError; +use arrow_schema::Schema; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::types::{PyAnyMethods, PyModule}; use pyo3::{IntoPyObject, Python}; use std::ffi::CString; @@ -83,7 +83,7 @@ fn test_to_pyarrow_pair() { let tuple = (record_batch.clone(), record_batch).into_pyobject(py)?; Vec::::from_pyarrow_bound(&tuple) }) - .unwrap(); + .unwrap(); assert_eq!(input, res[0]); assert_eq!(input, res[1]); } @@ -156,7 +156,7 @@ class NotACapsule: value = NotACapsule() "#, ) - .unwrap(); + .unwrap(); let module = PyModule::from_code(py, code.as_c_str(), c"test.py", c"test_module").unwrap(); let value = module.getattr("value").unwrap(); @@ -170,6 +170,35 @@ value = NotACapsule() }); } +#[test] +fn test_from_pyarrow_nullable_struct_array() { + Python::initialize(); + + Python::attach(|py| { + let code = CString::new( + r#" +import pyarrow as pa + +value = pa.array( + [{"a": 1}, None], + type=pa.struct([pa.field("a", pa.int32())]), +) +"#, + ) + .unwrap(); + + let module = PyModule::from_code(py, code.as_c_str(), c"test.py", c"test_module").unwrap(); + let value = module.getattr("value").unwrap(); + + let err = RecordBatch::from_pyarrow_bound(&value).unwrap_err(); + assert!(err.is_instance_of::(py)); + assert_eq!( + err.to_string(), + "ValueError: Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" + ); + }); +} + fn binary_view_column(num_variadic_buffers: usize) -> BinaryViewArray { let long_scalar = b"but soft what light through yonder window breaks".as_slice(); let mut builder = BinaryViewBuilder::new().with_fixed_block_size(long_scalar.len() as u32); diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs index 1085716df5c5..a82256d9ed9b 100644 --- a/arrow-pyarrow/src/lib.rs +++ b/arrow-pyarrow/src/lib.rs @@ -305,11 +305,11 @@ impl FromPyArrow for RecordBatch { let schema = unsafe { Arc::new(Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?) }; let (_fields, columns, nulls) = array.into_parts(); - assert_eq!( - nulls.map(|n| n.null_count()).unwrap_or_default(), - 0, - "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" - ); + if nulls.map(|n| n.null_count()).unwrap_or_default() != 0 { + return Err(PyValueError::new_err( + "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation", + )); + } return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err); }