diff --git a/arrow-json/Cargo.toml b/arrow-json/Cargo.toml index 5fcde480eb6d..82a4b8b7a32d 100644 --- a/arrow-json/Cargo.toml +++ b/arrow-json/Cargo.toml @@ -54,6 +54,7 @@ ryu = "1.0" itoa = "1.0" [dev-dependencies] +arrow-select = { workspace = true } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } serde = { version = "1.0", default-features = false, features = ["derive"] } futures = "0.3" diff --git a/arrow-json/src/lib.rs b/arrow-json/src/lib.rs index 1b18e0094708..a53ca74fd618 100644 --- a/arrow-json/src/lib.rs +++ b/arrow-json/src/lib.rs @@ -86,7 +86,7 @@ pub mod reader; pub mod writer; -pub use self::reader::{Reader, ReaderBuilder}; +pub use self::reader::{ArrayDecoder, DecoderFactory, Reader, ReaderBuilder, Tape, TapeElement}; pub use self::writer::{ ArrayWriter, Encoder, EncoderFactory, EncoderOptions, LineDelimitedWriter, Writer, WriterBuilder, diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index e74fef79178a..3a39c40da1e6 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, make_decoder}; +use crate::reader::{ArrayDecoder, DecoderContext}; use arrow_array::OffsetSizeTrait; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::buffer::NullBuffer; @@ -34,27 +33,19 @@ pub struct ListArrayDecoder { impl ListArrayDecoder { pub fn new( - data_type: DataType, - coerce_primitive: bool, - strict_mode: bool, + ctx: &DecoderContext, + data_type: &DataType, is_nullable: bool, - struct_mode: StructMode, ) -> Result { - let field = match &data_type { + let field = match data_type { DataType::List(f) if !O::IS_LARGE => f, DataType::LargeList(f) if O::IS_LARGE => f, _ => unreachable!(), }; - let decoder = make_decoder( - field.data_type().clone(), - coerce_primitive, - strict_mode, - field.is_nullable(), - struct_mode, - )?; + let decoder = ctx.make_decoder(field.data_type(), field.is_nullable(), field.metadata())?; Ok(Self { - data_type, + data_type: data_type.clone(), decoder, phantom: Default::default(), is_nullable, diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index c2068577a094..8abc3916d2b3 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, make_decoder}; +use crate::reader::{ArrayDecoder, DecoderContext}; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::ArrowNativeType; use arrow_buffer::buffer::NullBuffer; @@ -33,46 +32,41 @@ pub struct MapArrayDecoder { impl MapArrayDecoder { pub fn new( - data_type: DataType, - coerce_primitive: bool, - strict_mode: bool, + ctx: &DecoderContext, + data_type: &DataType, is_nullable: bool, - struct_mode: StructMode, ) -> Result { - let fields = match &data_type { - DataType::Map(_, true) => { - return Err(ArrowError::NotYetImplemented( - "Decoding MapArray with sorted fields".to_string(), - )); + let DataType::Map(f, false) = data_type else { + return Err(ArrowError::NotYetImplemented( + "Decoding MapArray with sorted fields".to_string(), + )); + }; + + // TODO: Once MSRV bumps to 1.88+, use an if-let chain to directly unpack the slice + let fields = match f.data_type() { + DataType::Struct(fields) if fields.len() == 2 => fields, + d => { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray must contain struct with two fields, got {d}" + ))); } - DataType::Map(f, _) => match f.data_type() { - DataType::Struct(fields) if fields.len() == 2 => fields, - d => { - return Err(ArrowError::InvalidArgumentError(format!( - "MapArray must contain struct with two fields, got {d}" - ))); - } - }, - _ => unreachable!(), }; - let keys = make_decoder( - fields[0].data_type().clone(), - coerce_primitive, - strict_mode, - fields[0].is_nullable(), - struct_mode, + let (key_field, value_field) = (&fields[0], &fields[1]); + + let keys = ctx.make_decoder( + key_field.data_type(), + key_field.is_nullable(), + key_field.metadata(), )?; - let values = make_decoder( - fields[1].data_type().clone(), - coerce_primitive, - strict_mode, - fields[1].is_nullable(), - struct_mode, + let values = ctx.make_decoder( + value_field.data_type(), + value_field.is_nullable(), + value_field.metadata(), )?; Ok(Self { - data_type, + data_type: data_type.clone(), keys, values, is_nullable, @@ -82,12 +76,11 @@ impl MapArrayDecoder { impl ArrayDecoder for MapArrayDecoder { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { - let s = match &self.data_type { - DataType::Map(f, _) => match f.data_type() { - s @ DataType::Struct(_) => s, - _ => unreachable!(), - }, - _ => unreachable!(), + let DataType::Map(f, _) = &self.data_type else { + unreachable!() + }; + let s @ DataType::Struct(_) = f.data_type() else { + unreachable!() }; let mut offsets = BufferBuilder::::new(pos.len() + 1); diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index f5fd1a8e7c38..443b1943e1ab 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -137,6 +137,8 @@ use crate::StructMode; use crate::reader::binary_array::{ BinaryArrayDecoder, BinaryViewDecoder, FixedSizeBinaryArrayDecoder, }; +use std::borrow::Cow; +use std::collections::HashMap; use std::io::BufRead; use std::sync::Arc; @@ -149,6 +151,7 @@ use arrow_array::{RecordBatch, RecordBatchReader, StructArray, downcast_integer, use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; +pub use tape::*; use crate::reader::boolean_array::BooleanArrayDecoder; use crate::reader::decimal_array::DecimalArrayDecoder; @@ -159,7 +162,6 @@ use crate::reader::primitive_array::PrimitiveArrayDecoder; use crate::reader::string_array::StringArrayDecoder; use crate::reader::string_view_array::StringViewArrayDecoder; use crate::reader::struct_array::StructArrayDecoder; -use crate::reader::tape::{Tape, TapeDecoder}; use crate::reader::timestamp_array::TimestampArrayDecoder; mod binary_array; @@ -184,6 +186,7 @@ pub struct ReaderBuilder { strict_mode: bool, is_field: bool, struct_mode: StructMode, + decoder_factory: Option>, schema: SchemaRef, } @@ -205,6 +208,7 @@ impl ReaderBuilder { is_field: false, struct_mode: Default::default(), schema, + decoder_factory: None, } } @@ -218,7 +222,7 @@ impl ReaderBuilder { /// # use arrow_array::cast::AsArray; /// # use arrow_array::types::Int32Type; /// # use arrow_json::ReaderBuilder; - /// # use arrow_schema::{DataType, Field}; + /// # use arrow_schema::{DataType, Field, FieldRef}; /// // Root of JSON schema is a numeric type /// let data = "1\n2\n3\n"; /// let field = Arc::new(Field::new("int", DataType::Int32, true)); @@ -246,6 +250,7 @@ impl ReaderBuilder { is_field: true, struct_mode: Default::default(), schema: Arc::new(Schema::new([field.into()])), + decoder_factory: None, } } @@ -285,6 +290,14 @@ impl ReaderBuilder { } } + /// Set an optional hook for customizing decoding behavior. + pub fn with_decoder_factory(self, decoder_factory: Arc) -> Self { + Self { + decoder_factory: Some(decoder_factory), + ..self + } + } + /// Create a [`Reader`] with the provided [`BufRead`] pub fn build(self, reader: R) -> Result, ArrowError> { Ok(Reader { @@ -295,21 +308,24 @@ impl ReaderBuilder { /// Create a [`Decoder`] pub fn build_decoder(self) -> Result { - let (data_type, nullable) = match self.is_field { - false => (DataType::Struct(self.schema.fields.clone()), false), - true => { - let field = &self.schema.fields[0]; - (field.data_type().clone(), field.is_nullable()) - } + let empty_metadata = HashMap::default(); + let (data_type, nullable, metadata) = if self.is_field { + let field = &self.schema.fields[0]; + let data_type = Cow::Borrowed(field.data_type()); + (data_type, field.is_nullable(), field.metadata()) + } else { + let data_type = Cow::Owned(DataType::Struct(self.schema.fields.clone())); + (data_type, false, &empty_metadata) + }; + + let ctx = DecoderContext { + coerce_primitive: self.coerce_primitive, + strict_mode: self.strict_mode, + struct_mode: self.struct_mode, + decoder_factory: self.decoder_factory.as_deref(), }; - let decoder = make_decoder( - data_type, - self.coerce_primitive, - self.strict_mode, - nullable, - self.struct_mode, - )?; + let decoder = ctx.make_decoder(data_type.as_ref(), nullable, metadata)?; let num_fields = self.schema.flattened_fields().len(); @@ -373,6 +389,88 @@ impl RecordBatchReader for Reader { } } +/// A trait to create custom decoders for specific data types. +/// +/// This allows overriding the default decoders for specific data types, +/// or adding new decoders for custom data types. +/// +/// # Example +/// +/// ``` +/// # use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; +/// # use arrow_json::{ReaderBuilder, TapeElement, Tape}; +/// # use arrow_schema::ArrowError; +/// # use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::Array; +/// # use arrow_array::builder::StringBuilder; +/// # use arrow_data::ArrayData; +/// # use std::collections::HashMap; +/// # use std::sync::Arc; +/// # +/// struct IncorrectStringAsNullDecoder; +/// +/// impl ArrayDecoder for IncorrectStringAsNullDecoder { +/// fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { +/// let mut builder = StringBuilder::new(); +/// for p in pos { +/// match tape.get(*p) { +/// TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), +/// _ => builder.append_null(), +/// } +/// } +/// Ok(builder.finish().into_data()) +/// } +/// } +/// +/// #[derive(Debug)] +/// struct IncorrectStringAsNullDecoderFactory; +/// +/// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { +/// fn make_custom_decoder( +/// &self, +/// _ctx: &DecoderContext, +/// data_type: &DataType, +/// _is_nullable: bool, +/// _field_metadata: &HashMap, +/// ) -> Result>, ArrowError> { +/// match data_type { +/// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder))), +/// _ => Ok(None), +/// } +/// } +/// } +/// +/// let json = r#" +/// {"a": "a"} +/// {"a": 12} +/// "#; +/// let fields = vec![Field::new("a", DataType::Utf8, true)]; +/// let batch = ReaderBuilder::new(Arc::new(Schema::new(fields))) +/// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory)) +/// .build(json.as_bytes()) +/// .unwrap() +/// .next() +/// .unwrap() +/// .unwrap(); +/// +/// let values = batch.column(0).as_string::(); +/// assert_eq!(values.len(), 2); +/// assert_eq!(values.value(0), "a"); +/// assert!(values.is_null(1)); +/// ``` +pub trait DecoderFactory: std::fmt::Debug + Send + Sync { + /// Make a decoder that overrides the default decoder for a specific data type. + /// This can be used to override how e.g. error in decoding are handled. + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError>; +} + /// A low-level interface for reading JSON data from a byte stream /// /// See [`Reader`] for a higher-level interface for interface with [`BufRead`] @@ -674,26 +772,134 @@ impl Decoder { } } -trait ArrayDecoder: Send { +/// A trait to decode JSON values into arrow arrays +pub trait ArrayDecoder: Send { /// Decode elements from `tape` starting at the indexes contained in `pos` fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result; } +/// Context for decoder creation, containing configuration and factory reference. +/// +/// This context is passed through the decoder creation process and contains +/// all the configuration needed to create decoders recursively. +pub struct DecoderContext<'a> { + /// Whether to coerce primitives to strings + coerce_primitive: bool, + /// Whether to validate struct fields strictly + strict_mode: bool, + /// How to decode struct fields + struct_mode: StructMode, + /// Optional custom decoder factory + decoder_factory: Option<&'a dyn DecoderFactory>, +} + +impl DecoderContext<'_> { + /// Returns whether to coerce primitive types (e.g., number to string) + pub fn coerce_primitive(&self) -> bool { + self.coerce_primitive + } + + /// Returns whether to validate struct fields strictly + pub fn strict_mode(&self) -> bool { + self.strict_mode + } + + /// Returns how to decode struct fields + pub fn struct_mode(&self) -> StructMode { + self.struct_mode + } + + /// Create a decoder for a type, allowing the factory to intercept it. + /// + /// This is the standard way to create child decoders from within a decoder + /// implementation. The factory (if present) will be given the opportunity + /// to intercept and customize the decoder. + pub fn make_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result, ArrowError> { + if let Some(factory) = self.decoder_factory { + if let Some(decoder) = + factory.make_custom_decoder(self, data_type, is_nullable, field_metadata)? + { + return Ok(decoder); + } + } + + make_decoder(self, data_type, is_nullable) + } + + /// Create a decoder for a type without allowing the factory to intercept it directly, + /// but still allowing the factory to intercept children of complex types. + /// + /// This is used by custom factories when they want to delegate to the standard + /// decoder implementation while still customizing child decoders. + /// + /// # Example + /// + /// When a factory intercepts a type and wants to wrap or modify the standard decoder, + /// calling `ctx.make_decoder()` would cause infinite recursion (the factory would + /// intercept its own call). This method bypasses the factory check at the current + /// level but still passes the context through so child fields can be customized. + /// + /// ``` + /// # use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; + /// # use arrow_json::TapeElement; + /// # use arrow_schema::{DataType, ArrowError}; + /// # use std::collections::HashMap; + /// # + /// #[derive(Debug)] + /// struct StructWrapperFactory; + /// + /// impl DecoderFactory for StructWrapperFactory { + /// fn make_custom_decoder( + /// &self, + /// ctx: &DecoderContext, + /// data_type: &DataType, + /// is_nullable: bool, + /// _field_metadata: &HashMap, + /// ) -> Result>, ArrowError> { + /// if matches!(data_type, DataType::Struct(_)) { + /// // Get standard struct decoder, bypassing self-interception + /// let delegate = ctx.make_delegate_decoder(data_type, is_nullable)?; + /// + /// // In real usage: wrap delegate with custom behavior + /// Ok(Some(delegate)) + /// } else { + /// Ok(None) + /// } + /// } + /// } + /// ``` + pub fn make_delegate_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + ) -> Result, ArrowError> { + make_decoder(self, data_type, is_nullable) + } +} + macro_rules! primitive_decoder { ($t:ty, $data_type:expr) => { Ok(Box::new(PrimitiveArrayDecoder::<$t>::new($data_type))) }; } +/// Private workhorse function for decoder creation. +/// +/// Constructs the appropriate decoder for the given data type without +/// checking the factory. All decoder construction logic lives here. fn make_decoder( - data_type: DataType, - coerce_primitive: bool, - strict_mode: bool, + ctx: &DecoderContext, + data_type: &DataType, is_nullable: bool, - struct_mode: StructMode, ) -> Result, ArrowError> { + let coerce_primitive = ctx.coerce_primitive(); downcast_integer! { - data_type => (primitive_decoder, data_type), + *data_type => (primitive_decoder, data_type), DataType::Null => Ok(Box::::default()), DataType::Float16 => primitive_decoder!(Float16Type, data_type), DataType::Float32 => primitive_decoder!(Float32Type, data_type), @@ -744,15 +950,15 @@ fn make_decoder( DataType::Utf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), DataType::Utf8View => Ok(Box::new(StringViewArrayDecoder::new(coerce_primitive))), DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), - DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), + DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(ctx, data_type, is_nullable)?)), + DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(ctx, data_type, is_nullable)?)), + DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(ctx, data_type, is_nullable)?)), DataType::Binary => Ok(Box::new(BinaryArrayDecoder::::default())), DataType::LargeBinary => Ok(Box::new(BinaryArrayDecoder::::default())), DataType::FixedSizeBinary(len) => Ok(Box::new(FixedSizeBinaryArrayDecoder::new(len))), DataType::BinaryView => Ok(Box::new(BinaryViewDecoder::default())), - DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in JSON reader"))) + DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(ctx, data_type, is_nullable)?)), + _ => Err(ArrowError::NotYetImplemented(format!("Support for {data_type} in JSON reader"))) } } @@ -2815,4 +3021,67 @@ mod tests { "Json error: whilst decoding field 'a': failed to parse \"a\" as Int32".to_owned() ); } + + #[test] + fn test_decoder_factory() { + use arrow_array::builder; + + struct AlwaysNullStringArrayDecoder; + + impl ArrayDecoder for AlwaysNullStringArrayDecoder { + fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut builder = builder::StringBuilder::new(); + for _ in pos { + builder.append_null(); + } + Ok(builder.finish().into_data()) + } + } + + #[derive(Debug)] + struct AlwaysNullStringArrayDecoderFactory; + + impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { + fn make_custom_decoder( + &self, + _ctx: &crate::reader::DecoderContext, + data_type: &DataType, + _is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + match data_type { + DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder))), + _ => Ok(None), + } + } + } + + let buf = r#" + {"a": "1", "b": 2} + {"a": "hello", "b": 23} + "#; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + ])); + + let batches = ReaderBuilder::new(schema.clone()) + .with_batch_size(2) + .with_decoder_factory(Arc::new(AlwaysNullStringArrayDecoderFactory)) + .build(Cursor::new(buf.as_bytes())) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(batches.len(), 1); + + let col1 = batches[0].column(0).as_string::(); + assert_eq!(col1.null_count(), 2); + assert!(col1.is_null(0)); + assert!(col1.is_null(1)); + + let col2 = batches[0].column(1).as_primitive::(); + assert_eq!(col2.value(0), 2); + assert_eq!(col2.value(1), 23); + } } diff --git a/arrow-json/src/reader/primitive_array.rs b/arrow-json/src/reader/primitive_array.rs index fa8464aa3251..b2bffe45e43a 100644 --- a/arrow-json/src/reader/primitive_array.rs +++ b/arrow-json/src/reader/primitive_array.rs @@ -80,9 +80,9 @@ pub struct PrimitiveArrayDecoder { } impl PrimitiveArrayDecoder

{ - pub fn new(data_type: DataType) -> Self { + pub fn new(data_type: &DataType) -> Self { Self { - data_type, + data_type: data_type.clone(), phantom: Default::default(), } } diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index df0d5b8a5b83..f8f64f8fe166 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -16,7 +16,7 @@ // under the License. use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, StructMode, make_decoder}; +use crate::reader::{ArrayDecoder, DecoderContext, StructMode}; use arrow_array::builder::BooleanBufferBuilder; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; @@ -80,42 +80,36 @@ pub struct StructArrayDecoder { impl StructArrayDecoder { pub fn new( - data_type: DataType, - coerce_primitive: bool, - strict_mode: bool, + ctx: &DecoderContext, + data_type: &DataType, is_nullable: bool, - struct_mode: StructMode, ) -> Result { - let (decoders, field_name_to_index) = { - let fields = struct_fields(&data_type); - let decoders = fields - .iter() - .map(|f| { - // If this struct nullable, need to permit nullability in child array - // StructArrayDecoder::decode verifies that if the child is not nullable - // it doesn't contain any nulls not masked by its parent - let nullable = f.is_nullable() || is_nullable; - make_decoder( - f.data_type().clone(), - coerce_primitive, - strict_mode, - nullable, - struct_mode, - ) - }) - .collect::, ArrowError>>()?; - let field_name_to_index = if struct_mode == StructMode::ObjectOnly { - build_field_index(fields) - } else { - None - }; - (decoders, field_name_to_index) + let struct_mode = ctx.struct_mode(); + let fields = struct_fields(data_type); + + let mut decoders = Vec::with_capacity(fields.len()); + for field in fields { + // If this struct nullable, need to permit nullability in child array + // StructArrayDecoder::decode verifies that if the child is not nullable + // it doesn't contain any nulls not masked by its parent + let decoder = ctx.make_decoder( + field.data_type(), + field.is_nullable() || is_nullable, + field.metadata(), + )?; + decoders.push(decoder); + } + + let field_name_to_index = if struct_mode == StructMode::ObjectOnly { + build_field_index(fields) + } else { + None }; Ok(Self { - data_type, + data_type: data_type.clone(), decoders, - strict_mode, + strict_mode: ctx.strict_mode(), is_nullable, struct_mode, field_name_to_index, diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 89ee3f778765..fcab173ef110 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -338,6 +338,7 @@ impl TapeDecoder { } } + /// Decodes JSON data from the provided buffer, returning the number of bytes consumed pub fn decode(&mut self, buf: &[u8]) -> Result { let mut iter = BufIter::new(buf); diff --git a/arrow-json/src/reader/timestamp_array.rs b/arrow-json/src/reader/timestamp_array.rs index 79f2b04eeba8..ff24d5391f9d 100644 --- a/arrow-json/src/reader/timestamp_array.rs +++ b/arrow-json/src/reader/timestamp_array.rs @@ -37,9 +37,9 @@ pub struct TimestampArrayDecoder { } impl TimestampArrayDecoder { - pub fn new(data_type: DataType, timezone: Tz) -> Self { + pub fn new(data_type: &DataType, timezone: Tz) -> Self { Self { - data_type, + data_type: data_type.clone(), timezone, phantom: Default::default(), } diff --git a/arrow-json/tests/custom_decoder_tests.rs b/arrow-json/tests/custom_decoder_tests.rs new file mode 100644 index 000000000000..df67b91ff4bc --- /dev/null +++ b/arrow-json/tests/custom_decoder_tests.rs @@ -0,0 +1,836 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for custom decoder functionality +//! +//! This test suite demonstrates various patterns for customizing JSON decoding: +//! 1. Type-based routing - customize all fields of a given type +//! 2. Annotation-based routing - customize specific fields marked with metadata +//! 3. Type-specific behavior - custom parsing logic for specific types +//! 4. Composition - combining multiple custom factories +//! 5. Delegation with interleaving - wrapping standard decoders when direct access not possible +//! 6. Path-based routing - customize specific paths in the schema tree + +use arrow_array::Array as _; +use arrow_array::builder::StringBuilder; +use arrow_array::cast::AsArray; +use arrow_data::ArrayData; +use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; +use arrow_json::{ReaderBuilder, Tape, TapeElement}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, Schema}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::io::Cursor; +use std::sync::Arc; + +// ============================================================================ +// Test 1: Type-based lenient string decoder +// ============================================================================ + +/// A string decoder that converts type mismatches to NULL instead of erroring +struct LenientStringDecoder; + +impl ArrayDecoder for LenientStringDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut builder = StringBuilder::new(); + for p in pos { + match tape.get(*p) { + TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), + _ => builder.append_null(), + } + } + Ok(builder.finish().into_data()) + } +} + +/// A factory that applies the LenientStringDecoder to ALL Utf8 fields (type-based routing) +#[derive(Debug)] +struct TypeBasedLenientStringFactory; + +impl DecoderFactory for TypeBasedLenientStringFactory { + fn make_custom_decoder( + &self, + _ctx: &DecoderContext, + data_type: &DataType, + _is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + match data_type { + DataType::Utf8 => Ok(Some(Box::new(LenientStringDecoder))), + _ => Ok(None), + } + } +} + +#[test] +fn test_type_based_lenient_strings() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("email", DataType::Utf8, true), + ])); + + // JSON with type mismatches in BOTH string fields + let json = r#" + {"id": 1, "name": "Alice", "email": "alice@example.com"} + {"id": 2, "name": 42, "email": "bob@example.com"} + {"id": 3, "name": "Charlie", "email": true} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedLenientStringFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL + assert_eq!(names.value(2), "Charlie"); + + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + assert!(emails.is_null(2)); // true -> NULL +} + +// ============================================================================ +// Test 2: Annotation-based lenient string decoder +// ============================================================================ + +/// A factory that applies LenientStringDecoder only to annotated Utf8 fields +#[derive(Debug)] +struct AnnotatedLenientStringFactory; + +impl DecoderFactory for AnnotatedLenientStringFactory { + fn make_custom_decoder( + &self, + _ctx: &DecoderContext, + data_type: &DataType, + _is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + let config = field_metadata + .get("test:decoder:config") + .map(|s| s.as_str()); + match data_type { + DataType::Utf8 if config == Some("lenient") => Ok(Some(Box::new(LenientStringDecoder))), + _ => Ok(None), + } + } +} + +#[test] +fn test_annotation_based_lenient_strings() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + // ONLY this field is marked lenient + Field::new("name", DataType::Utf8, true).with_metadata(HashMap::from([( + "test:decoder:config".to_string(), + "lenient".to_string(), + )])), + // This field is NOT marked, should use standard decoder + Field::new("email", DataType::Utf8, false), + ])); + + // JSON with type mismatches in BOTH string fields + let json_valid = r#" + {"id": 1, "name": "Alice", "email": "alice@example.com"} + {"id": 2, "name": 42, "email": "bob@example.com"} + "#; + + let reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(Arc::new(AnnotatedLenientStringFactory)) + .build(Cursor::new(json_valid.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL (lenient behavior) + + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + + // Negative test: email field without annotation should error on type mismatch + let json_invalid = r#" + {"id": 1, "name": "Alice", "email": true} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(AnnotatedLenientStringFactory)) + .build(Cursor::new(json_invalid.as_bytes())) + .unwrap(); + + let result = reader.into_iter().next().unwrap(); + assert!(result.is_err()); // email field errors on type mismatch +} + +// ============================================================================ +// Shared helpers for interleaved decoding pattern +// ============================================================================ + +/// A general-purpose interleaved decoder that routes positions to different decoders +/// based on a filter predicate, then interleaves the results back to original order. +/// +/// This pattern is useful when you want to customize behavior but need to delegate +/// to standard decoders (which you can't directly access the builders of). +/// +/// Note: This example uses `Fn(TapeElement) -> bool` for simplicity. A production +/// implementation might use `Fn(&Tape, u32) -> bool` to support filters that need +/// to examine list/object contents before routing (e.g., routing based on list length +/// or presence of discriminator fields). Downside is more complexity in simple cases. +struct InterleavedDecoder { + primary: Box, + fallback: Box, + filter: F, +} + +impl bool + Send> ArrayDecoder for InterleavedDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + use arrow_select::interleave::interleave; + + // Partition positions based on filter + let mut primary_pos = Vec::new(); + let mut fallback_pos = Vec::new(); + let mut indices = Vec::with_capacity(pos.len()); + + for &p in pos { + if (self.filter)(tape.get(p)) { + indices.push((0, primary_pos.len())); + primary_pos.push(p); + } else { + indices.push((1, fallback_pos.len())); + fallback_pos.push(p); + } + } + + // Decode both parts + let primary = self.primary.decode(tape, &primary_pos)?; + let fallback = self.fallback.decode(tape, &fallback_pos)?; + + // Convert to arrays for interleaving + let primary = arrow_array::make_array(primary); + let fallback = arrow_array::make_array(fallback); + + // Interleave back to original order + let result = interleave(&[primary.as_ref(), fallback.as_ref()], &indices)?; + Ok(result.into_data()) + } +} + +/// A decoder that always produces NULL values (useful as fallback in InterleavedDecoder) +struct NullDecoder { + data_type: DataType, +} + +impl ArrayDecoder for NullDecoder { + fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result { + Ok(arrow_array::new_null_array(&self.data_type, pos.len()).into_data()) + } +} + +// ============================================================================ +// Test 3: Lenient struct decoder (introduces InterleavedDecoder pattern) +// ============================================================================ + +/// A factory that makes ALL Struct fields lenient (type-based routing) +/// +/// This demonstrates the InterleavedDecoder pattern: we want to customize struct +/// handling but can't directly access the StructArrayDecoder's internal builder +/// and -- unlike string decoding -- the logic is too complex to replicate easily. +/// +/// Solution: partition positions and delegate to standard decoder + null decoder. +#[derive(Debug)] +struct TypeBasedLenientStructFactory; + +impl DecoderFactory for TypeBasedLenientStructFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + if !matches!(data_type, DataType::Struct(_)) { + return Ok(None); + } + + // Delegate to a standard struct decoder for objects, with a null decoder as fallback. + Ok(Some(Box::new(InterleavedDecoder { + primary: ctx.make_delegate_decoder(data_type, is_nullable)?, + fallback: Box::new(NullDecoder { + data_type: data_type.clone(), + }), + filter: |elem| matches!(elem, TapeElement::StartObject(_)), + }))) + } +} + +#[test] +fn test_type_based_lenient_structs() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "person", + DataType::Struct(vec![Field::new("name", DataType::Utf8, false)].into()), + true, + ), + ])); + + let json = r#" + {"id": 1, "person": {"name": "Alice"}} + {"id": 2, "person": "not a struct"} + {"id": 3, "person": 42} + {"id": 4, "person": null} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedLenientStructFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + let person = batch.column(1).as_struct(); + assert!(!person.is_null(0)); // Valid struct + assert!(person.is_null(1)); // "not a struct" -> NULL + assert!(person.is_null(2)); // 42 -> NULL + assert!(person.is_null(3)); // null -> NULL + + // Verify first row decoded correctly + let names = person.column(0).as_string::(); + assert_eq!(names.value(0), "Alice"); +} + +// ============================================================================ +// Test 4: Quirky string list decoder (reinforces InterleavedDecoder pattern) +// ============================================================================ + +/// Parse a string like "[a, b, c]" into an iterator of strings, e.g. "a", "b", "c" +fn parse_list_string(s: &str) -> Result + '_, ArrowError> { + let trimmed = s.trim(); + if !trimmed.starts_with('[') || !trimmed.ends_with(']') { + return Err(ArrowError::JsonError(format!( + "Failed to parse list string: {}", + s + ))); + } + let inner = &trimmed[1..trimmed.len() - 1]; + Ok(inner.split(',').map(|s| s.trim())) +} + +/// A decoder that ONLY handles string representations like "[a, b, c]" and parses them as lists. +/// Designed to be used with InterleavedDecoder (standard list decoder handles normal lists). +struct StringToListDecoder { + field: FieldRef, +} + +impl ArrayDecoder for StringToListDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + // Use with_field() to ensure the builder respects the schema's field (including nullability) + let mut builder = arrow_array::builder::ListBuilder::new(StringBuilder::new()) + .with_field(self.field.clone()); + + for p in pos { + let TapeElement::String(s) = tape.get(*p) else { + unreachable!("InterleavedDecoder filter should only route String elements here"); + }; + + // Parse string representation like "[x, y, z]" + for item in parse_list_string(tape.get_string(s))? { + builder.values().append_value(item); + } + builder.append(true); + } + + Ok(builder.finish().into_data()) + } +} + +/// A factory that makes ALL List fields quirky (type-based routing) +/// +/// Uses InterleavedDecoder to combine string parsing with standard list decoding. +#[derive(Debug)] +struct TypeBasedQuirkyListFactory; + +impl DecoderFactory for TypeBasedQuirkyListFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + let field = match data_type { + DataType::List(f) if *f.data_type() == DataType::Utf8 => f.clone(), + _ => return Ok(None), + }; + + // Intercept and attempt to parse strings; all others fall back to standard list decoder + Ok(Some(Box::new(InterleavedDecoder { + primary: Box::new(StringToListDecoder { field }), + fallback: ctx.make_delegate_decoder(data_type, is_nullable)?, + filter: |elem| matches!(elem, TapeElement::String(_)), + }))) + } +} + +#[test] +fn test_type_based_quirky_lists() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + ])); + + let json = r#" +{"id": 1, "tags": ["a", "b", "c"]} +{"id": 2, "tags": "[x, y, z]"} +{"id": 3, "tags": null} +"#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedQuirkyListFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let tags = batch.column(1).as_list::(); + + // First row: normal JSON list + let row0 = tags.value(0); + let row0 = row0.as_string::(); + assert_eq!(row0.len(), 3); + assert_eq!(row0.value(0), "a"); + assert_eq!(row0.value(1), "b"); + assert_eq!(row0.value(2), "c"); + + // Second row: parsed from string representation + let row1 = tags.value(1); + let row1 = row1.as_string::(); + assert_eq!(row1.len(), 3); + assert_eq!(row1.value(0), "x"); + assert_eq!(row1.value(1), "y"); + assert_eq!(row1.value(2), "z"); + + // Third row: null + assert!(tags.is_null(2)); +} + +// ============================================================================ +// Test 5: Composition - combining multiple type-based factories +// ============================================================================ + +/// A factory that tries multiple child factories in sequence +#[derive(Debug)] +struct ComposedDecoderFactory { + factories: Vec>, +} + +impl DecoderFactory for ComposedDecoderFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + // Try each child factory in order until one returns Some or Err + for factory in &self.factories { + if let Some(decoder) = + factory.make_custom_decoder(ctx, data_type, is_nullable, field_metadata)? + { + return Ok(Some(decoder)); + } + } + Ok(None) + } +} + +#[test] +fn test_composed_type_based_factories() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + ])); + + let json = r#" +{"id": 1, "name": "Alice", "tags": ["a", "b"]} +{"id": 2, "name": 42, "tags": "[x, y]"} +{"id": 3, "name": "Bob", "tags": null} +"#; + + // Compose two type-based factories + let factories = vec![ + Arc::new(TypeBasedLenientStringFactory) as _, + Arc::new(TypeBasedQuirkyListFactory) as _, + ]; + let factory = Arc::new(ComposedDecoderFactory { factories }); + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(factory) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + // Verify lenient string behavior + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL + assert_eq!(names.value(2), "Bob"); + + // Verify quirky list behavior + let tags = batch.column(2).as_list::(); + + // Row 0: normal JSON list + let row0 = tags.value(0); + let row0 = row0.as_string::(); + assert_eq!(row0.len(), 2); + assert_eq!(row0.value(0), "a"); + assert_eq!(row0.value(1), "b"); + + // Row 1: parsed from string representation + let row1 = tags.value(1); + let row1 = row1.as_string::(); + assert_eq!(row1.len(), 2); + assert_eq!(row1.value(0), "x"); + assert_eq!(row1.value(1), "y"); + + // Row 2: null + assert!(tags.is_null(2)); +} + +// ============================================================================ +// Test 6: Path-based routing (pointer-identity-based) +// ============================================================================ + +/// Identity wrapper for DataType that uses pointer equality for HashMap lookup. +/// +/// Supports two modes: +/// - `FieldRef` variant: Stores ownership of a Field (for HashMap storage) +/// - `DataType` variant: Borrows a DataType temporarily (for HashMap lookup) +/// +/// Both variants compare by pointer identity of the DataType they reference. +/// +/// Safety: Pointer-identity comparison relies on DataType stability guarantees: +/// - We store the owning FieldRef (Arc) which keeps the Field alive +/// - We never call any potentially-mutating methods such as `Arc::get_mut` or `Arc::make_mut` +/// - We never share a reference to the FieldRef that could allow others to mutate it +/// - Our FieldRef ensures that anyone else who might attempt potentially-mutating operations +/// of the same Field through their own FieldRef will fail because `!Arc::is_unique()` +/// - The &DataType returned by `Field::data_type` is stable -- no interior mutability +/// such as `Mutex>` that could move it to a new memory location +/// +/// NOTE: We never dereference the raw pointer values used for comparison. A violation +/// of the above would only produce incorrect HashMap lookups (false positives/negatives). +#[derive(Debug)] +enum DataTypeIdentity<'a> { + FieldRef(FieldRef), + DataType(&'a DataType), +} + +impl<'a> DataTypeIdentity<'a> { + /// Extract the raw DataType pointer for identity comparison. + fn as_ptr(&self) -> *const DataType { + match self { + DataTypeIdentity::FieldRef(f) => f.data_type(), + DataTypeIdentity::DataType(dt) => *dt, + } + } +} + +impl<'a> Hash for DataTypeIdentity<'a> { + fn hash(&self, state: &mut H) { + self.as_ptr().hash(state); + } +} + +impl<'a> PartialEq for DataTypeIdentity<'a> { + fn eq(&self, other: &Self) -> bool { + self.as_ptr() == other.as_ptr() + } +} + +impl<'a> Eq for DataTypeIdentity<'a> {} + +/// A factory that routes to custom decoders based on specific field paths. +/// +/// This allows fine-grained control: customize specific fields by name without +/// polluting the schema with metadata or affecting all fields of a given type. +#[derive(Debug)] +struct PathBasedDecoderFactory { + // Maps DataTypeIdentity::FieldRef to factory + routes: HashMap, Arc>, +} + +impl PathBasedDecoderFactory { + /// Create a new path-based factory by mapping field paths to factories. + /// + /// Takes a reference to the schema fields and a map of paths to factories. + /// Paths can be nested using dot notation: "field", "struct.nested", "struct.deep.nested" + fn new(fields: &Fields, path_routes: HashMap<&str, Arc>) -> Self { + // Walk the fields and associate DataTypeIdentity::FieldRef with factory for O(1) lookup + let mut routes = HashMap::new(); + for (path, factory) in path_routes { + let parts: Vec<&str> = path.split('.').collect(); + if let Some(field) = Self::find_field_by_path(fields, &parts) { + routes.insert(DataTypeIdentity::FieldRef(field), factory); + } + } + + Self { routes } + } + + /// Recursively find a Field by following a path of field names. + fn find_field_by_path(fields: &Fields, path: &[&str]) -> Option { + let (first, rest) = path.split_first()?; + let field = fields.iter().find(|f| f.name() == *first)?; + + if rest.is_empty() { + // End of path - return this field + return Some(field.clone()); + } + + // Path continues - attempt to recurse into a nested struct + let DataType::Struct(children) = field.data_type() else { + return None; + }; + + Self::find_field_by_path(children, rest) + } +} + +impl DecoderFactory for PathBasedDecoderFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + // O(1) lookup using temporary DataType variant for identity comparison + let key = DataTypeIdentity::DataType(data_type); + let Some(factory) = self.routes.get(&key) else { + return Ok(None); + }; + + // Delegate to the route-specific factory + factory.make_custom_decoder(ctx, data_type, is_nullable, field_metadata) + } +} + +#[test] +fn test_path_based_routing() { + // Create schema with both flat and nested String fields + let metadata_fields = Fields::from(vec![ + Field::new("source", DataType::Utf8, true), + Field::new("comment", DataType::Utf8, true), // Nested path: metadata.comment + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("email", DataType::Utf8, true), + Field::new("metadata", DataType::Struct(metadata_fields.clone()), true), + ])); + + // Create a lenient string factory (will be applied to specific paths only) + let lenient_factory = Arc::new(TypeBasedLenientStringFactory); + + // Build routes: ONLY the nested "metadata.comment" field gets lenient handling + // Demonstrates nested path routing with dot notation + let mut path_routes = HashMap::new(); + path_routes.insert( + "metadata.comment", + lenient_factory.clone() as Arc, + ); + + let factory = Arc::new(PathBasedDecoderFactory::new(schema.fields(), path_routes)); + + // JSON with type mismatches in multiple string fields + let json = r#" +{"id": 1, "name": "Alice", "email": "alice@example.com", "metadata": {"source": "web", "comment": "Good"}} +{"id": 2, "name": 42, "email": "bob@example.com", "metadata": {"source": "api", "comment": 100}} +{"id": 3, "name": "Charlie", "email": 999, "metadata": {"source": "mobile", "comment": "Excellent"}} +"#; + + let mut reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(factory) + .build(Cursor::new(json)) + .unwrap(); + + // The decode should FAIL because "name" and "email" don't have lenient handling + // Only "metadata.comment" is lenient, but "name" has a type mismatch in row 2 + let result = reader.next().unwrap(); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("name") || err.contains("expected string")); + + // Now let's test with data that only has issues in the lenient nested field + let json_lenient_only = r#" +{"id": 1, "name": "Alice", "email": "alice@example.com", "metadata": {"source": "web", "comment": "Good"}} +{"id": 2, "name": "Bob", "email": "bob@example.com", "metadata": {"source": "api", "comment": 100}} +{"id": 3, "name": "Charlie", "email": "charlie@example.com", "metadata": {"source": "mobile", "comment": true}} +"#; + + // Rebuild routes for the second test + let mut path_routes2 = HashMap::new(); + path_routes2.insert( + "metadata.comment", + lenient_factory as Arc, + ); + + let mut reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(Arc::new(PathBasedDecoderFactory::new( + schema.fields(), + path_routes2, + ))) + .build(Cursor::new(json_lenient_only)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + // Verify all fields + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 4); + + // ID column: all valid + let ids = batch + .column(0) + .as_primitive::(); + assert_eq!(ids.value(0), 1); + assert_eq!(ids.value(1), 2); + assert_eq!(ids.value(2), 3); + + // Name column: all valid (no type mismatches) + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(names.value(1), "Bob"); + assert_eq!(names.value(2), "Charlie"); + + // Email column: all valid (no type mismatches) + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + assert_eq!(emails.value(2), "charlie@example.com"); + + // Metadata struct column + let metadata = batch.column(3).as_struct(); + + // metadata.source: all valid (no lenient handling) + let sources = metadata.column(0).as_string::(); + assert_eq!(sources.value(0), "web"); + assert_eq!(sources.value(1), "api"); + assert_eq!(sources.value(2), "mobile"); + + // metadata.comment: LENIENT - non-strings become NULL + let comments = metadata.column(1).as_string::(); + assert_eq!(comments.value(0), "Good"); + assert!(comments.is_null(1)); // 100 -> NULL + assert!(comments.is_null(2)); // true -> NULL +} + +// ============================================================================ +// Test 7: Recursive factory propagation +// ============================================================================ + +#[test] +fn test_recursive_factory_propagation() { + // Schema with nested struct containing string fields + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "person", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), // Make nullable + Field::new("email", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ])); + + let json = r#" +{"id": 1, "person": {"name": "Alice", "email": "alice@example.com"}} +{"id": 2, "person": 42} +{"id": 3, "person": {"name": 123, "email": "charlie@example.com"}} +{"id": 4, "person": {"name": "Dave", "email": true}} +"#; + + // Compose lenient struct + lenient string factories + // This tests that the factory propagates through struct decoder creation + let factories = vec![ + Arc::new(TypeBasedLenientStructFactory) as Arc, + Arc::new(TypeBasedLenientStringFactory) as Arc, + ]; + let factory = Arc::new(ComposedDecoderFactory { factories }); + + let mut reader = ReaderBuilder::new(schema) + .with_decoder_factory(factory) + .build(Cursor::new(json)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + // ID column: all valid + let ids = batch + .column(0) + .as_primitive::(); + assert_eq!(ids.value(0), 1); + assert_eq!(ids.value(1), 2); + assert_eq!(ids.value(2), 3); + assert_eq!(ids.value(3), 4); + + // Person struct column + let person = batch.column(1).as_struct(); + + // Row 0: Valid struct with valid strings + assert!(!person.is_null(0)); + let names = person.column(0).as_string::(); + let emails = person.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(emails.value(0), "alice@example.com"); + + // Row 1: Not a struct (42) -> NULL from struct factory + assert!(person.is_null(1)); + + // Row 2: Valid struct but name has type mismatch (123) + // This tests that the string factory was invoked for the nested field! + assert!(!person.is_null(2)); + assert!(names.is_null(2)); // 123 -> NULL from string factory + assert_eq!(emails.value(2), "charlie@example.com"); + + // Row 3: Valid struct but email has type mismatch (true) + assert!(!person.is_null(3)); + assert_eq!(names.value(3), "Dave"); + assert!(emails.is_null(3)); // true -> NULL from string factory +} diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index 362a2cc018c7..c94c8217b8ff 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -96,7 +96,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::Int8, false).with_metadata( [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned())] diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index b6bd1c1223f4..5157eefe9ebb 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -471,7 +471,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index 297a2d99aa04..d2a54b9189b7 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -222,7 +222,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::Int8, false).with_metadata( [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())] diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index fceae8d3711d..acfc1331a670 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -285,7 +285,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::Null, false).with_metadata( [( diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs index 09533564ed44..3e897f47318d 100644 --- a/arrow-schema/src/extension/canonical/uuid.rs +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -100,7 +100,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::FixedSizeBinary(16), false); field.extension_type::(); diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index b5403dcf684f..fbc641f54366 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -529,7 +529,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new_struct( "", diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index cd17272e15ab..4b9ddf1a4548 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -23,6 +23,7 @@ mod canonical; pub use canonical::*; use crate::{ArrowError, DataType}; +use std::collections::HashMap; /// The metadata key for the string name identifying an [`ExtensionType`]. pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; @@ -255,4 +256,41 @@ pub trait ExtensionType: Sized { /// This should return an error if the given data type is not supported by /// this extension type. fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result; + + /// Construct this extension type from field metadata and data type. + /// + /// This is a provided method that extracts extension type information from + /// metadata (using [`EXTENSION_TYPE_NAME_KEY`] and + /// [`EXTENSION_TYPE_METADATA_KEY`]) and delegates to [`Self::try_new`]. + /// + /// Returns an error if: + /// - The extension type name is missing or doesn't match [`Self::NAME`] + /// - Metadata deserialization fails + /// - The data type is not supported + /// + /// This method enables extension type checking without requiring a full + /// [`Field`] instance, useful when only metadata and data type are available. + /// + /// [`Field`]: crate::Field + fn try_from_parts( + metadata: &HashMap, + data_type: &DataType, + ) -> Result { + match metadata.get(EXTENSION_TYPE_NAME_KEY).map(|s| s.as_str()) { + Some(name) if name == Self::NAME => { + let ext_metadata = metadata + .get(EXTENSION_TYPE_METADATA_KEY) + .map(|s| s.as_str()); + let parsed = Self::deserialize_metadata(ext_metadata)?; + Self::try_new(data_type, parsed) + } + Some(name) => Err(ArrowError::InvalidArgumentError(format!( + "Extension type name mismatch: expected {}, got {name}", + Self::NAME + ))), + None => Err(ArrowError::InvalidArgumentError( + "Extension type name missing".to_string(), + )), + } + } } diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index c4566e41bfa8..27d0b0c46e51 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -575,25 +575,7 @@ impl Field { /// } /// ``` pub fn try_extension_type(&self) -> Result { - // Check the extension name in the metadata - match self.extension_type_name() { - // It should match the name of the given extension type - Some(name) if name == E::NAME => { - // Deserialize the metadata and try to construct the extension - // type - E::deserialize_metadata(self.extension_type_metadata()) - .and_then(|metadata| E::try_new(self.data_type(), metadata)) - } - // Name mismatch - Some(name) => Err(ArrowError::InvalidArgumentError(format!( - "Field extension type name mismatch, expected {}, found {name}", - E::NAME - ))), - // Name missing - None => Err(ArrowError::InvalidArgumentError( - "Field extension type name missing".to_owned(), - )), - } + E::try_from_parts(self.metadata(), self.data_type()) } /// Returns an instance of the given [`ExtensionType`] of this [`Field`], diff --git a/parquet-variant-compute/Cargo.toml b/parquet-variant-compute/Cargo.toml index 85d66a9cf706..8f15f99aa58d 100644 --- a/parquet-variant-compute/Cargo.toml +++ b/parquet-variant-compute/Cargo.toml @@ -30,9 +30,13 @@ rust-version = { workspace = true } [dependencies] arrow = { workspace = true , features = ["canonical_extension_types"]} +arrow-array = { workspace = true } +arrow-data = { workspace = true } +arrow-json = { workspace = true } arrow-schema = { workspace = true } half = { version = "2.1", default-features = false } indexmap = "2.10.0" +lexical-core = { version = "1.0", default-features = false} parquet-variant = { workspace = true } parquet-variant-json = { workspace = true } chrono = { workspace = true } diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs new file mode 100644 index 000000000000..0e51cd370f69 --- /dev/null +++ b/parquet-variant-compute/src/decoder.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{VariantArrayBuilder, VariantType}; +use arrow_array::{Array, StructArray}; +use arrow_data::ArrayData; +use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory, Tape, TapeElement}; +use arrow_schema::extension::ExtensionType; +use arrow_schema::{ArrowError, DataType}; +use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt}; +use std::collections::HashMap; + +/// An [`ArrayDecoder`] implementation that decodes JSON values into a Variant array. +/// +/// This decoder converts JSON tape elements (parsed JSON tokens) into Parquet Variant +/// format, preserving the full structure of arbitrary JSON including nested objects, +/// arrays, and primitive types. +/// +/// This decoder is typically used indirectly via [`VariantArrayDecoderFactory`] when +/// reading JSON data into Variant columns. +#[derive(Default)] +pub struct VariantArrayDecoder; + +impl ArrayDecoder for VariantArrayDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut array_builder = VariantArrayBuilder::new(pos.len()); + for p in pos { + variant_from_tape_element(&mut array_builder, *p, tape)?; + } + let variant_struct_array = StructArray::from(array_builder.build()); + Ok(variant_struct_array.into_data()) + } +} + +/// A [`DecoderFactory`] that integrates with the Arrow JSON reader to automatically +/// decode JSON values into Variant arrays when the target field is registered as a +/// [`VariantType`] extension type. +/// +/// # Example +/// +/// ```ignore +/// use arrow_json::reader::ReaderBuilder; +/// use arrow_json::StructMode; +/// use std::sync::Arc; +/// +/// let builder = ReaderBuilder::new(Arc::new(schema)); +/// let reader = builder +/// .with_struct_mode(StructMode::ObjectOnly) +/// .with_decoder_factory(Arc::new(VariantArrayDecoderFactory)) +/// .build(json_input)?; +/// ``` +#[derive(Default, Debug)] +#[allow(unused)] +pub struct VariantArrayDecoderFactory; + +impl DecoderFactory for VariantArrayDecoderFactory { + fn make_custom_decoder( + &self, + _ctx: &DecoderContext, + data_type: &DataType, + _is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + // Check if this is a Variant extension type using metadata + let result = VariantType::try_from_parts(field_metadata, data_type); + Ok(result.ok().map(|_| Box::new(VariantArrayDecoder) as _)) + } +} + +fn variant_from_tape_element( + builder: &mut impl VariantBuilderExt, + mut p: u32, + tape: &Tape, +) -> Result { + match tape.get(p) { + TapeElement::StartObject(end_idx) => { + let mut object_builder = builder.try_new_object()?; + p += 1; + while p < end_idx { + // Read field name + let field_name = match tape.get(p) { + TapeElement::String(s) => tape.get_string(s), + _ => return Err(tape.error(p, "field name")), + }; + + let mut field_builder = ObjectFieldBuilder::new(field_name, &mut object_builder); + p = tape.next(p, "field value")?; + p = variant_from_tape_element(&mut field_builder, p, tape)?; + } + object_builder.finish(); + } + TapeElement::EndObject(_u32) => { + return Err(ArrowError::JsonError( + "unexpected end of object".to_string(), + )); + } + TapeElement::StartList(end_idx) => { + let mut list_builder = builder.try_new_list()?; + p += 1; + while p < end_idx { + p = variant_from_tape_element(&mut list_builder, p, tape)?; + } + list_builder.finish(); + } + TapeElement::EndList(_u32) => { + return Err(ArrowError::JsonError("unexpected end of list".to_string())); + } + TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), + TapeElement::Number(idx) => { + let s = tape.get_string(idx); + builder.append_value(parse_number(s)?) + } + TapeElement::I64(i) => { + return Err(ArrowError::JsonError(format!( + "I64 tape element not supported: {i}" + ))); + } + TapeElement::I32(i) => { + return Err(ArrowError::JsonError(format!( + "I32 tape element not supported: {i}" + ))); + } + TapeElement::F64(f) => { + return Err(ArrowError::JsonError(format!( + "F64 tape element not supported: {f}" + ))); + } + TapeElement::F32(f) => { + return Err(ArrowError::JsonError(format!( + "F32 tape element not supported: {f}" + ))); + } + TapeElement::True => builder.append_value(true), + TapeElement::False => builder.append_value(false), + TapeElement::Null => builder.append_value(Variant::Null), + } + p += 1; + Ok(p) +} + +fn parse_number<'a, 'b>(s: &'a str) -> Result, ArrowError> { + if let Ok(v) = lexical_core::parse(s.as_bytes()) { + return Ok(Variant::Int64(v)); + } + + match lexical_core::parse(s.as_bytes()) { + Ok(v) => Ok(Variant::Double(v)), + Err(_) => Err(ArrowError::JsonError(format!( + "failed to parse {s} as number" + ))), + } +} + +#[cfg(test)] +mod tests { + use crate::VariantArray; + + use super::*; + use arrow_array::Int32Array; + use arrow_json::StructMode; + use arrow_json::reader::ReaderBuilder; + use arrow_schema::{DataType, Field, Schema}; + use parquet_variant::VariantBuilder; + use std::io::Cursor; + use std::sync::Arc; + + #[test] + fn test_variant() { + let do_test = |json_input: &str, ids: Vec, variants: Vec>| { + let variant_array = VariantArrayBuilder::new(0).build(); + + let struct_field = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + // call VariantArray::field to get the correct Field + variant_array.field("var"), + ]); + + let builder = ReaderBuilder::new(Arc::new(struct_field.clone())); + let result = builder + .with_struct_mode(StructMode::ObjectOnly) + .with_decoder_factory(Arc::new(VariantArrayDecoderFactory)) + .build(Cursor::new(json_input.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + + assert_eq!(result.num_columns(), 2); + let int_array = arrow_array::array::Int32Array::from(ids); + assert_eq!( + result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(), + &int_array + ); + + let result_variant_array: VariantArray = + VariantArray::try_new(result.column(1)).unwrap(); + let values = result_variant_array.iter().collect::>(); + + assert_eq!(values, variants); + }; + + do_test( + "{\"id\": 1, \"var\": \"a\"}\n{\"id\": 2, \"var\": \"b\"}", + vec![1, 2], + vec![Some(Variant::from("a")), Some(Variant::from("b"))], + ); + + let mut builder = VariantBuilder::new(); + let mut object_builder = builder.new_object(); + object_builder.insert("int64", Variant::Int64(1)); + object_builder.insert("double", Variant::Double(1.0)); + object_builder.insert("null", Variant::Null); + object_builder.insert("true", Variant::BooleanTrue); + object_builder.insert("false", Variant::BooleanFalse); + object_builder.insert("string", Variant::from("a")); + object_builder.finish(); + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": {\"int64\": 1, \"double\": 1.0, \"null\": null, \"true\": true, \"false\": false, \"string\": \"a\"}}", + vec![1], + vec![Some(variant)], + ); + + // nested structs + let mut builder = VariantBuilder::new(); + let mut object_builder = builder.new_object(); + { + let mut list_builder = object_builder.new_list("somelist"); + { + let mut nested_object_builder = list_builder.new_object(); + nested_object_builder.insert("num", Variant::Int64(2)); + nested_object_builder.finish(); + } + { + let mut nested_object_builder = list_builder.new_object(); + nested_object_builder.insert("num", Variant::Int64(3)); + nested_object_builder.finish(); + } + list_builder.finish(); + object_builder.insert("scalar", Variant::from("a")); + } + object_builder.finish(); + + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": {\"somelist\": [{\"num\": 2}, {\"num\": 3}], \"scalar\": \"a\"}}", + vec![1], + vec![Some(variant)], + ); + + let mut builder = VariantBuilder::new(); + let mut list_builder = builder.new_list(); + list_builder.append_value(Variant::Int64(1000000000000)); + list_builder.append_value(Variant::Double(std::f64::consts::E)); + list_builder.finish(); + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": [1000000000000, 2.718281828459045]}", + vec![1], + vec![Some(variant)], + ); + } +} diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index b05d0e023653..d33eece28d9b 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -41,6 +41,7 @@ mod arrow_to_variant; mod cast_to_variant; +mod decoder; mod from_json; mod shred_variant; mod to_json; @@ -55,6 +56,7 @@ pub use variant_array::{BorrowedShreddingState, ShreddingState, VariantArray, Va pub use variant_array_builder::{VariantArrayBuilder, VariantValueArrayBuilder}; pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options}; +pub use decoder::VariantArrayDecoderFactory; pub use from_json::json_to_variant; pub use shred_variant::{IntoShreddingField, ShreddedSchemaBuilder, shred_variant}; pub use to_json::variant_to_json;