From 21581ea199471bfffee12e8de003557c04819b0f Mon Sep 17 00:00:00 2001 From: Cyprien Huet Date: Thu, 24 Apr 2025 13:05:33 +0400 Subject: [PATCH 01/14] Add custom decoder in arrow-json --- arrow-json/src/lib.rs | 2 +- arrow-json/src/reader/list_array.rs | 5 + arrow-json/src/reader/map_array.rs | 7 + arrow-json/src/reader/mod.rs | 191 +++++++++++++++++++++++++- arrow-json/src/reader/struct_array.rs | 6 + arrow-json/src/reader/tape.rs | 1 + 6 files changed, 205 insertions(+), 7 deletions(-) diff --git a/arrow-json/src/lib.rs b/arrow-json/src/lib.rs index 1b18e0094708..a53ca74fd618 100644 --- a/arrow-json/src/lib.rs +++ b/arrow-json/src/lib.rs @@ -86,7 +86,7 @@ pub mod reader; pub mod writer; -pub use self::reader::{Reader, ReaderBuilder}; +pub use self::reader::{ArrayDecoder, DecoderFactory, Reader, ReaderBuilder, Tape, TapeElement}; pub use self::writer::{ ArrayWriter, Encoder, EncoderFactory, EncoderOptions, LineDelimitedWriter, Writer, WriterBuilder, diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index e74fef79178a..72a2d0e1bb9b 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -24,6 +24,9 @@ use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; use std::marker::PhantomData; +use std::sync::Arc; + +use super::DecoderFactory; pub struct ListArrayDecoder { data_type: DataType, @@ -39,6 +42,7 @@ impl ListArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result { let field = match &data_type { DataType::List(f) if !O::IS_LARGE => f, @@ -51,6 +55,7 @@ impl ListArrayDecoder { strict_mode, field.is_nullable(), struct_mode, + decoder_factory, )?; Ok(Self { diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index c2068577a094..572f4f38522c 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; use crate::reader::{ArrayDecoder, make_decoder}; @@ -24,6 +26,8 @@ use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; +use super::DecoderFactory; + pub struct MapArrayDecoder { data_type: DataType, keys: Box, @@ -38,6 +42,7 @@ impl MapArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result { let fields = match &data_type { DataType::Map(_, true) => { @@ -62,6 +67,7 @@ impl MapArrayDecoder { strict_mode, fields[0].is_nullable(), struct_mode, + decoder_factory.clone(), )?; let values = make_decoder( fields[1].data_type().clone(), @@ -69,6 +75,7 @@ impl MapArrayDecoder { strict_mode, fields[1].is_nullable(), struct_mode, + decoder_factory, )?; Ok(Self { diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index f5fd1a8e7c38..f1bfd979f391 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -149,6 +149,7 @@ use arrow_array::{RecordBatch, RecordBatchReader, StructArray, downcast_integer, use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; +pub use tape::*; use crate::reader::boolean_array::BooleanArrayDecoder; use crate::reader::decimal_array::DecimalArrayDecoder; @@ -159,7 +160,6 @@ use crate::reader::primitive_array::PrimitiveArrayDecoder; use crate::reader::string_array::StringArrayDecoder; use crate::reader::string_view_array::StringViewArrayDecoder; use crate::reader::struct_array::StructArrayDecoder; -use crate::reader::tape::{Tape, TapeDecoder}; use crate::reader::timestamp_array::TimestampArrayDecoder; mod binary_array; @@ -184,6 +184,7 @@ pub struct ReaderBuilder { strict_mode: bool, is_field: bool, struct_mode: StructMode, + decoder_factory: Option>, schema: SchemaRef, } @@ -205,6 +206,7 @@ impl ReaderBuilder { is_field: false, struct_mode: Default::default(), schema, + decoder_factory: None, } } @@ -246,6 +248,7 @@ impl ReaderBuilder { is_field: true, struct_mode: Default::default(), schema: Arc::new(Schema::new([field.into()])), + decoder_factory: None, } } @@ -285,6 +288,14 @@ impl ReaderBuilder { } } + /// Set an optional hook for customizing decoding behavior. + pub fn with_decoder_factory(self, decoder_factory: Arc) -> Self { + Self { + decoder_factory: Some(decoder_factory), + ..self + } + } + /// Create a [`Reader`] with the provided [`BufRead`] pub fn build(self, reader: R) -> Result, ArrowError> { Ok(Reader { @@ -309,6 +320,7 @@ impl ReaderBuilder { self.strict_mode, nullable, self.struct_mode, + self.decoder_factory, )?; let num_fields = self.schema.flattened_fields().len(); @@ -373,6 +385,95 @@ impl RecordBatchReader for Reader { } } +/// A trait to create custom decoders for specific data types. +/// +/// This allows overriding the default decoders for specific data types, +/// or adding new decoders for custom data types. +/// +/// # Examples +/// +/// ``` +/// use arrow_json::{ArrayDecoder, DecoderFactory, TapeElement, Tape, ReaderBuilder, StructMode}; +/// use arrow_schema::ArrowError; +/// use arrow_schema::{DataType, Field, Fields, Schema}; +/// use arrow_array::cast::AsArray; +/// use arrow_array::Array; +/// use arrow_array::builder::StringBuilder; +/// use arrow_data::ArrayData; +/// use std::sync::Arc; +/// +/// struct IncorrectStringAsNullDecoder {} +/// +/// impl ArrayDecoder for IncorrectStringAsNullDecoder { +/// fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { +/// let mut builder = StringBuilder::new(); +/// for p in pos { +/// match tape.get(*p) { +/// TapeElement::String(idx) => { +/// builder.append_value(tape.get_string(idx)); +/// } +/// _ => builder.append_null(), +/// } +/// } +/// Ok(builder.finish().into_data()) +/// } +/// } +/// +/// #[derive(Debug)] +/// struct IncorrectStringAsNullDecoderFactory; +/// +/// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { +/// fn make_default_decoder<'a>( +/// &self, +/// data_type: DataType, +/// _coerce_primitive: bool, +/// _strict_mode: bool, +/// _is_nullable: bool, +/// _struct_mode: StructMode, +/// ) -> Result>, ArrowError> { +/// match data_type { +/// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder {}))), +/// _ => Ok(None), +/// } +/// } +/// } +/// +/// let json = r#" +/// {"a": "a"} +/// {"a": 12} +/// "#; +/// let batch = ReaderBuilder::new(Arc::new(Schema::new(Fields::from(vec![Field::new( +/// "a", +/// DataType::Utf8, +/// true, +/// )])))) +/// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory)) +/// .build(json.as_bytes()) +/// .unwrap() +/// .next() +/// .unwrap() +/// .unwrap(); +/// +/// let values = batch.column(0).as_string::(); +/// assert_eq!(values.len(), 2); +/// assert_eq!(values.value(0), "a"); +/// assert!(values.is_null(1)); +/// ``` +pub trait DecoderFactory: std::fmt::Debug + Send + Sync { + /// Make a decoder that overrides the default decoder for a specific data type. + /// This can be used to override how e.g. error in decoding are handled. + fn make_default_decoder( + &self, + _data_type: DataType, + _coerce_primitive: bool, + _strict_mode: bool, + _is_nullable: bool, + _struct_mode: StructMode, + ) -> Result>, ArrowError> { + Ok(None) + } +} + /// A low-level interface for reading JSON data from a byte stream /// /// See [`Reader`] for a higher-level interface for interface with [`BufRead`] @@ -674,7 +775,8 @@ impl Decoder { } } -trait ArrayDecoder: Send { +/// A trait to decode JSON values into arrow arrays +pub trait ArrayDecoder: Send { /// Decode elements from `tape` starting at the indexes contained in `pos` fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result; } @@ -691,7 +793,20 @@ fn make_decoder( strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result, ArrowError> { + if let Some(ref factory) = decoder_factory { + if let Some(decoder) = factory.make_default_decoder( + data_type.clone(), + coerce_primitive, + strict_mode, + is_nullable, + struct_mode, + )? { + return Ok(decoder); + } + } + downcast_integer! { data_type => (primitive_decoder, data_type), DataType::Null => Ok(Box::::default()), @@ -744,14 +859,14 @@ fn make_decoder( DataType::Utf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), DataType::Utf8View => Ok(Box::new(StringViewArrayDecoder::new(coerce_primitive))), DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), - DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), + DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), + DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), + DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), DataType::Binary => Ok(Box::new(BinaryArrayDecoder::::default())), DataType::LargeBinary => Ok(Box::new(BinaryArrayDecoder::::default())), DataType::FixedSizeBinary(len) => Ok(Box::new(FixedSizeBinaryArrayDecoder::new(len))), DataType::BinaryView => Ok(Box::new(BinaryViewDecoder::default())), - DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), + DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in JSON reader"))) } } @@ -2815,4 +2930,68 @@ mod tests { "Json error: whilst decoding field 'a': failed to parse \"a\" as Int32".to_owned() ); } + + #[test] + fn test_decoder_factory() { + use arrow_array::builder; + + struct AlwaysNullStringArrayDecoder; + + impl ArrayDecoder for AlwaysNullStringArrayDecoder { + fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut builder = builder::StringBuilder::new(); + for _ in pos { + builder.append_null(); + } + Ok(builder.finish().into_data()) + } + } + + #[derive(Debug)] + struct AlwaysNullStringArrayDecoderFactory; + + impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { + fn make_default_decoder<'a>( + &self, + data_type: DataType, + _coerce_primitive: bool, + _strict_mode: bool, + _is_nullable: bool, + _struct_mode: StructMode, + ) -> Result>, ArrowError> { + match data_type { + DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder {}))), + _ => Ok(None), + } + } + } + + let buf = r#" + {"a": "1", "b": 2} + {"a": "hello", "b": 23} + "#; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + ])); + + let batches = ReaderBuilder::new(schema.clone()) + .with_batch_size(2) + .with_decoder_factory(Arc::new(AlwaysNullStringArrayDecoderFactory)) + .build(Cursor::new(buf.as_bytes())) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(batches.len(), 1); + + let col1 = batches[0].column(0).as_string::(); + assert_eq!(col1.null_count(), 2); + assert!(col1.is_null(0)); + assert!(col1.is_null(1)); + + let col2 = batches[0].column(1).as_primitive::(); + assert_eq!(col2.value(0), 2); + assert_eq!(col2.value(1), 23); + } } diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 262097ace396..0b1832c530ed 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::reader::tape::{Tape, TapeElement}; use crate::reader::{ArrayDecoder, StructMode, make_decoder}; use arrow_array::builder::BooleanBufferBuilder; @@ -22,6 +24,8 @@ use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Fields}; +use super::DecoderFactory; + pub struct StructArrayDecoder { data_type: DataType, decoders: Vec>, @@ -37,6 +41,7 @@ impl StructArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result { let decoders = struct_fields(&data_type) .iter() @@ -51,6 +56,7 @@ impl StructArrayDecoder { strict_mode, nullable, struct_mode, + decoder_factory.clone(), ) }) .collect::, ArrowError>>()?; diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 89ee3f778765..fcab173ef110 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -338,6 +338,7 @@ impl TapeDecoder { } } + /// Decodes JSON data from the provided buffer, returning the number of bytes consumed pub fn decode(&mut self, buf: &[u8]) -> Result { let mut iter = BufIter::new(buf); From 5fe364287881ad91ace7b6f4bfca70c1ec384cf1 Mon Sep 17 00:00:00 2001 From: Doug Miller Date: Fri, 19 Dec 2025 08:23:33 -0500 Subject: [PATCH 02/14] include field --- arrow-json/src/reader/list_array.rs | 1 + arrow-json/src/reader/map_array.rs | 2 + arrow-json/src/reader/mod.rs | 10 +- arrow-json/src/reader/struct_array.rs | 1 + parquet-variant-compute/Cargo.toml | 4 + parquet-variant-compute/src/decoder.rs | 252 +++++++++++++++++++++++++ parquet-variant-compute/src/lib.rs | 2 + 7 files changed, 269 insertions(+), 3 deletions(-) create mode 100644 parquet-variant-compute/src/decoder.rs diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index 72a2d0e1bb9b..7a697faf9bc2 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -50,6 +50,7 @@ impl ListArrayDecoder { _ => unreachable!(), }; let decoder = make_decoder( + Some(field.clone()), field.data_type().clone(), coerce_primitive, strict_mode, diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index 572f4f38522c..4a115d49bea3 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -62,6 +62,7 @@ impl MapArrayDecoder { }; let keys = make_decoder( + Some(fields[0].clone()), fields[0].data_type().clone(), coerce_primitive, strict_mode, @@ -70,6 +71,7 @@ impl MapArrayDecoder { decoder_factory.clone(), )?; let values = make_decoder( + Some(fields[1].clone()), fields[1].data_type().clone(), coerce_primitive, strict_mode, diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index f1bfd979f391..cdd035d83347 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -315,6 +315,7 @@ impl ReaderBuilder { }; let decoder = make_decoder( + None, data_type, self.coerce_primitive, self.strict_mode, @@ -425,6 +426,7 @@ impl RecordBatchReader for Reader { /// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { /// fn make_default_decoder<'a>( /// &self, +/// _field: Option, /// data_type: DataType, /// _coerce_primitive: bool, /// _strict_mode: bool, @@ -464,14 +466,13 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// This can be used to override how e.g. error in decoding are handled. fn make_default_decoder( &self, + _field: Option, _data_type: DataType, _coerce_primitive: bool, _strict_mode: bool, _is_nullable: bool, _struct_mode: StructMode, - ) -> Result>, ArrowError> { - Ok(None) - } + ) -> Result>, ArrowError>; } /// A low-level interface for reading JSON data from a byte stream @@ -788,6 +789,7 @@ macro_rules! primitive_decoder { } fn make_decoder( + field: Option, data_type: DataType, coerce_primitive: bool, strict_mode: bool, @@ -797,6 +799,7 @@ fn make_decoder( ) -> Result, ArrowError> { if let Some(ref factory) = decoder_factory { if let Some(decoder) = factory.make_default_decoder( + field.clone(), data_type.clone(), coerce_primitive, strict_mode, @@ -2953,6 +2956,7 @@ mod tests { impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { fn make_default_decoder<'a>( &self, + _field: Option, data_type: DataType, _coerce_primitive: bool, _strict_mode: bool, diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 0b1832c530ed..71a17141a4d4 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -51,6 +51,7 @@ impl StructArrayDecoder { // it doesn't contain any nulls not masked by its parent let nullable = f.is_nullable() || is_nullable; make_decoder( + Some(f.clone()), f.data_type().clone(), coerce_primitive, strict_mode, diff --git a/parquet-variant-compute/Cargo.toml b/parquet-variant-compute/Cargo.toml index 74c3dd3fb72f..c88d479393ad 100644 --- a/parquet-variant-compute/Cargo.toml +++ b/parquet-variant-compute/Cargo.toml @@ -30,9 +30,13 @@ rust-version = { workspace = true } [dependencies] arrow = { workspace = true , features = ["canonical_extension_types"]} +arrow-array = { workspace = true } +arrow-data = { workspace = true } +arrow-json = { workspace = true } arrow-schema = { workspace = true } half = { version = "2.1", default-features = false } indexmap = "2.10.0" +lexical-core = { version = "1.0", default-features = false} parquet-variant = { workspace = true } parquet-variant-json = { workspace = true } chrono = { workspace = true } diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs new file mode 100644 index 000000000000..d20ff20e0a00 --- /dev/null +++ b/parquet-variant-compute/src/decoder.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{Array, StructArray}; +use arrow_json::{DecoderFactory, StructMode}; +use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilder, VariantBuilderExt}; +use crate::{VariantArrayBuilder, VariantType}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType, FieldRef}; + +use arrow_json::reader::ArrayDecoder; +use arrow_json::reader::{Tape, TapeElement}; + +/// An [`ArrayDecoder`] implementation that decodes JSON values into a Variant array. +/// +/// This decoder converts JSON tape elements (parsed JSON tokens) into Parquet Variant +/// format, preserving the full structure of arbitrary JSON including nested objects, +/// arrays, and primitive types. +/// +/// This decoder is typically used indirectly via [`VariantArrayDecoderFactory`] when +/// reading JSON data into Variant columns. +#[derive(Default)] +pub struct VariantArrayDecoder; + +impl ArrayDecoder for VariantArrayDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut array_builder = VariantArrayBuilder::new(pos.len()); + for p in pos { + let mut builder = VariantBuilder::new(); + variant_from_tape_element(&mut builder, *p, tape)?; + let (metadata, value) = builder.finish(); + array_builder.append_value(Variant::new(&metadata, &value)); + } + let variant_struct_array: StructArray = array_builder.build().into(); + Ok(variant_struct_array.into_data()) + } +} + +/// A [`DecoderFactory`] that creates [`VariantArrayDecoder`] instances for Variant-typed fields. +/// +/// This factory integrates with the Arrow JSON reader to automatically decode JSON values +/// into Variant arrays when the target field is registered as a [`VariantType`] extension type. +/// +/// # Example +/// +/// ```ignore +/// use arrow_json::reader::ReaderBuilder; +/// use arrow_json::StructMode; +/// use std::sync::Arc; +/// +/// let builder = ReaderBuilder::new(Arc::new(schema)); +/// let reader = builder +/// .with_struct_mode(StructMode::ObjectOnly) +/// .with_decoder_factory(Arc::new(VariantArrayDecoderFactory)) +/// .build(json_input)?; +/// ``` +#[derive(Default, Debug)] +#[allow(unused)] +pub struct VariantArrayDecoderFactory; + +impl DecoderFactory for VariantArrayDecoderFactory { + fn make_default_decoder<'a>(&self, field: Option, + _data_type: DataType, + _coerce_primitive: bool, + _strict_mode: bool, + _is_nullable: bool, + _struct_mode: StructMode, + ) -> Result>, ArrowError> { + if let Some(field) = field && field.try_extension_type::().is_ok() { + return Ok(Some(Box::new(VariantArrayDecoder))); + } + Ok(None) + } +} + +fn variant_from_tape_element(builder: &mut impl VariantBuilderExt, mut p: u32, tape: &Tape) -> Result { + match tape.get(p) { + TapeElement::StartObject(end_idx) => { + let mut object_builder = builder.try_new_object()?; + p += 1; + while p < end_idx { + // Read field name + let field_name = match tape.get(p) { + TapeElement::String(s) => tape.get_string(s), + _ => return Err(tape.error(p, "field name")), + }; + + let mut field_builder = ObjectFieldBuilder::new(field_name, &mut object_builder); + p = tape.next(p, "field value")?; + p = variant_from_tape_element(&mut field_builder, p, tape)?; + } + object_builder.finish(); + } + TapeElement::EndObject(_u32) => { return Err(ArrowError::JsonError("unexpected end of object".to_string())) }, + TapeElement::StartList(end_idx) => { + let mut list_builder = builder.try_new_list()?; + p+= 1; + while p < end_idx { + p = variant_from_tape_element(&mut list_builder, p, tape)?; + } + list_builder.finish(); + } + TapeElement::EndList(_u32) => { return Err(ArrowError::JsonError("unexpected end of list".to_string())) }, + TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), + TapeElement::Number(idx) => { + let s = tape.get_string(idx); + builder.append_value(parse_number(s)?) + }, + TapeElement::I64(i) => builder.append_value(i), + TapeElement::I32(i) => builder.append_value(i), + TapeElement::F64(f) => builder.append_value(f), + TapeElement::F32(f) => builder.append_value(f), + TapeElement::True => builder.append_value(true), + TapeElement::False => builder.append_value(false), + TapeElement::Null => builder.append_value(Variant::Null), + } + p += 1; + Ok(p) +} + +fn parse_number<'a, 'b>(s: &'a str) -> Result, ArrowError> { + if let Ok(v) = lexical_core::parse(s.as_bytes()) { + return Ok(Variant::Int64(v)); + } + + match lexical_core::parse(s.as_bytes()) { + Ok(v) => Ok(Variant::Double(v)), + Err(_) => Err(ArrowError::JsonError(format!("failed to parse {s} as number"))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{Schema, Field, DataType}; + use arrow_json::reader::ReaderBuilder; + use arrow_json::StructMode; + use std::sync::Arc; + use std::io::Cursor; + use arrow_array::RecordBatch; + + #[test] + fn test_variant() { + let do_test = |json_input: &str, ids: Vec, variants: Vec| { + let variant_array = VariantArrayBuilder::new(0).build(); + + let struct_field = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + // call VariantArray::field to get the correct Field + variant_array.field("var"), + ]); + + let builder = ReaderBuilder::new(Arc::new(struct_field.clone())); + let result = builder + .with_struct_mode(StructMode::ObjectOnly) + .with_decoder_factory(Arc::new(VariantArrayDecoderFactory)) + .build(Cursor::new(json_input.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + + let int_array = arrow_array::array::Int32Array::from(ids); + + let variant_array = { + let mut variant_builder = VariantArrayBuilder::new(variants.len()); + for v in variants { + variant_builder.append_variant(v); + } + variant_builder.build() + }; + + let variant_struct_array: StructArray = variant_array.into(); + + let expected = RecordBatch::try_new( + struct_field.into(), + vec![Arc::new(int_array), Arc::new(variant_struct_array)], + ) + .unwrap(); + + assert_eq!(result, expected); + }; + + do_test( + "{\"id\": 1, \"var\": \"a\"}\n{\"id\": 2, \"var\": \"b\"}", + vec![1, 2], + vec![Variant::from("a"), Variant::from("b")], + ); + + let mut builder = VariantBuilder::new(); + let mut object_builder = builder.new_object(); + object_builder.insert("int64", Variant::Int64(1)); + object_builder.insert("double", Variant::Double(1.0)); + object_builder.insert("null", Variant::Null); + object_builder.insert("true", Variant::BooleanTrue); + object_builder.insert("false", Variant::BooleanFalse); + object_builder.insert("string", Variant::from("a")); + object_builder.finish(); + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": {\"int64\": 1, \"double\": 1.0, \"null\": null, \"true\": true, \"false\": false, \"string\": \"a\"}}", + vec![1], + vec![variant], + ); + + // nested structs + let mut builder = VariantBuilder::new(); + let mut object_builder = builder.new_object(); + { + let mut list_builder = object_builder.new_list("somelist"); + { + let mut nested_object_builder = list_builder.new_object(); + nested_object_builder.insert("num", Variant::Int64(2)); + nested_object_builder.finish(); + } + { + let mut nested_object_builder = list_builder.new_object(); + nested_object_builder.insert("num", Variant::Int64(3)); + nested_object_builder.finish(); + } + list_builder.finish(); + object_builder.insert("scalar", Variant::from("a")); + } + object_builder.finish(); + + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": {\"somelist\": [{\"num\": 2}, {\"num\": 3}], \"scalar\": \"a\"}}", + vec![1], + vec![variant], + ); + } + +} diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index 9b8008f58422..cd8025e011aa 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -41,6 +41,7 @@ mod arrow_to_variant; mod cast_to_variant; +mod decoder; mod from_json; mod shred_variant; mod to_json; @@ -61,3 +62,4 @@ pub use to_json::variant_to_json; pub use type_conversion::CastOptions; pub use unshred_variant::unshred_variant; pub use variant_get::{GetOptions, variant_get}; +pub use decoder::VariantArrayDecoder; From 53bffa29a028eb28a31f16283a55d5b69409b8d7 Mon Sep 17 00:00:00 2001 From: Doug Miller Date: Fri, 19 Dec 2025 17:44:02 -0500 Subject: [PATCH 03/14] change export --- parquet-variant-compute/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index cd8025e011aa..e6eab80283fd 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -62,4 +62,4 @@ pub use to_json::variant_to_json; pub use type_conversion::CastOptions; pub use unshred_variant::unshred_variant; pub use variant_get::{GetOptions, variant_get}; -pub use decoder::VariantArrayDecoder; +pub use decoder::VariantArrayDecoderFactory; From 66b75126b2426c6a6f4decc0082099150c3f7db7 Mon Sep 17 00:00:00 2001 From: Doug Miller Date: Mon, 22 Dec 2025 07:30:23 -0500 Subject: [PATCH 04/14] cleanup --- arrow-json/src/reader/mod.rs | 8 +- parquet-variant-compute/src/decoder.rs | 117 +++++++++++++++---------- parquet-variant-compute/src/lib.rs | 2 +- 3 files changed, 77 insertions(+), 50 deletions(-) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index cdd035d83347..dadb437611dc 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -424,7 +424,7 @@ impl RecordBatchReader for Reader { /// struct IncorrectStringAsNullDecoderFactory; /// /// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { -/// fn make_default_decoder<'a>( +/// fn make_custom_decoder<'a>( /// &self, /// _field: Option, /// data_type: DataType, @@ -464,7 +464,7 @@ impl RecordBatchReader for Reader { pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// Make a decoder that overrides the default decoder for a specific data type. /// This can be used to override how e.g. error in decoding are handled. - fn make_default_decoder( + fn make_custom_decoder( &self, _field: Option, _data_type: DataType, @@ -798,7 +798,7 @@ fn make_decoder( decoder_factory: Option>, ) -> Result, ArrowError> { if let Some(ref factory) = decoder_factory { - if let Some(decoder) = factory.make_default_decoder( + if let Some(decoder) = factory.make_custom_decoder( field.clone(), data_type.clone(), coerce_primitive, @@ -2954,7 +2954,7 @@ mod tests { struct AlwaysNullStringArrayDecoderFactory; impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { - fn make_default_decoder<'a>( + fn make_custom_decoder<'a>( &self, _field: Option, data_type: DataType, diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index d20ff20e0a00..f05668bfa65e 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::{Array, StructArray}; -use arrow_json::{DecoderFactory, StructMode}; -use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilder, VariantBuilderExt}; use crate::{VariantArrayBuilder, VariantType}; +use arrow_array::{Array, StructArray}; use arrow_data::ArrayData; +use arrow_json::{DecoderFactory, StructMode}; +use arrow_schema::extension::ExtensionType; use arrow_schema::{ArrowError, DataType, FieldRef}; +use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt}; use arrow_json::reader::ArrayDecoder; use arrow_json::reader::{Tape, TapeElement}; @@ -40,12 +41,9 @@ impl ArrayDecoder for VariantArrayDecoder { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut array_builder = VariantArrayBuilder::new(pos.len()); for p in pos { - let mut builder = VariantBuilder::new(); - variant_from_tape_element(&mut builder, *p, tape)?; - let (metadata, value) = builder.finish(); - array_builder.append_value(Variant::new(&metadata, &value)); + variant_from_tape_element(&mut array_builder, *p, tape)?; } - let variant_struct_array: StructArray = array_builder.build().into(); + let variant_struct_array = StructArray::from(array_builder.build()); Ok(variant_struct_array.into_data()) } } @@ -73,21 +71,29 @@ impl ArrayDecoder for VariantArrayDecoder { pub struct VariantArrayDecoderFactory; impl DecoderFactory for VariantArrayDecoderFactory { - fn make_default_decoder<'a>(&self, field: Option, + fn make_custom_decoder<'a>( + &self, + field: Option, _data_type: DataType, _coerce_primitive: bool, _strict_mode: bool, _is_nullable: bool, _struct_mode: StructMode, ) -> Result>, ArrowError> { - if let Some(field) = field && field.try_extension_type::().is_ok() { + if let Some(field) = field + && field.extension_type_name() == Some(VariantType::NAME) + { return Ok(Some(Box::new(VariantArrayDecoder))); } - Ok(None) + Ok(None) } } -fn variant_from_tape_element(builder: &mut impl VariantBuilderExt, mut p: u32, tape: &Tape) -> Result { +fn variant_from_tape_element( + builder: &mut impl VariantBuilderExt, + mut p: u32, + tape: &Tape, +) -> Result { match tape.get(p) { TapeElement::StartObject(end_idx) => { let mut object_builder = builder.try_new_object()?; @@ -98,28 +104,34 @@ fn variant_from_tape_element(builder: &mut impl VariantBuilderExt, mut p: u32, t TapeElement::String(s) => tape.get_string(s), _ => return Err(tape.error(p, "field name")), }; - + let mut field_builder = ObjectFieldBuilder::new(field_name, &mut object_builder); p = tape.next(p, "field value")?; p = variant_from_tape_element(&mut field_builder, p, tape)?; } object_builder.finish(); } - TapeElement::EndObject(_u32) => { return Err(ArrowError::JsonError("unexpected end of object".to_string())) }, + TapeElement::EndObject(_u32) => { + return Err(ArrowError::JsonError( + "unexpected end of object".to_string(), + )); + } TapeElement::StartList(end_idx) => { let mut list_builder = builder.try_new_list()?; - p+= 1; + p += 1; while p < end_idx { p = variant_from_tape_element(&mut list_builder, p, tape)?; } list_builder.finish(); } - TapeElement::EndList(_u32) => { return Err(ArrowError::JsonError("unexpected end of list".to_string())) }, + TapeElement::EndList(_u32) => { + return Err(ArrowError::JsonError("unexpected end of list".to_string())); + } TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), TapeElement::Number(idx) => { let s = tape.get_string(idx); builder.append_value(parse_number(s)?) - }, + } TapeElement::I64(i) => builder.append_value(i), TapeElement::I32(i) => builder.append_value(i), TapeElement::F64(f) => builder.append_value(f), @@ -139,23 +151,28 @@ fn parse_number<'a, 'b>(s: &'a str) -> Result, ArrowError> { match lexical_core::parse(s.as_bytes()) { Ok(v) => Ok(Variant::Double(v)), - Err(_) => Err(ArrowError::JsonError(format!("failed to parse {s} as number"))), + Err(_) => Err(ArrowError::JsonError(format!( + "failed to parse {s} as number" + ))), } } #[cfg(test)] mod tests { + use crate::VariantArray; + use super::*; - use arrow_schema::{Schema, Field, DataType}; - use arrow_json::reader::ReaderBuilder; + use arrow_array::Int32Array; use arrow_json::StructMode; - use std::sync::Arc; + use arrow_json::reader::ReaderBuilder; + use arrow_schema::{DataType, Field, Schema}; + use parquet_variant::VariantBuilder; use std::io::Cursor; - use arrow_array::RecordBatch; + use std::sync::Arc; #[test] fn test_variant() { - let do_test = |json_input: &str, ids: Vec, variants: Vec| { + let do_test = |json_input: &str, ids: Vec, variants: Vec>| { let variant_array = VariantArrayBuilder::new(0).build(); let struct_field = Schema::new(vec![ @@ -174,31 +191,28 @@ mod tests { .unwrap() .unwrap(); + assert_eq!(result.num_columns(), 2); let int_array = arrow_array::array::Int32Array::from(ids); + assert_eq!( + result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(), + &int_array + ); - let variant_array = { - let mut variant_builder = VariantArrayBuilder::new(variants.len()); - for v in variants { - variant_builder.append_variant(v); - } - variant_builder.build() - }; - - let variant_struct_array: StructArray = variant_array.into(); + let result_variant_array: VariantArray = + VariantArray::try_new(result.column(1)).unwrap(); + let values = result_variant_array.iter().collect::>(); - let expected = RecordBatch::try_new( - struct_field.into(), - vec![Arc::new(int_array), Arc::new(variant_struct_array)], - ) - .unwrap(); - - assert_eq!(result, expected); + assert_eq!(values, variants); }; do_test( "{\"id\": 1, \"var\": \"a\"}\n{\"id\": 2, \"var\": \"b\"}", vec![1, 2], - vec![Variant::from("a"), Variant::from("b")], + vec![Some(Variant::from("a")), Some(Variant::from("b"))], ); let mut builder = VariantBuilder::new(); @@ -216,10 +230,10 @@ mod tests { do_test( "{\"id\": 1, \"var\": {\"int64\": 1, \"double\": 1.0, \"null\": null, \"true\": true, \"false\": false, \"string\": \"a\"}}", vec![1], - vec![variant], + vec![Some(variant)], ); - // nested structs + // nested structs let mut builder = VariantBuilder::new(); let mut object_builder = builder.new_object(); { @@ -233,7 +247,7 @@ mod tests { let mut nested_object_builder = list_builder.new_object(); nested_object_builder.insert("num", Variant::Int64(3)); nested_object_builder.finish(); - } + } list_builder.finish(); object_builder.insert("scalar", Variant::from("a")); } @@ -245,8 +259,21 @@ mod tests { do_test( "{\"id\": 1, \"var\": {\"somelist\": [{\"num\": 2}, {\"num\": 3}], \"scalar\": \"a\"}}", vec![1], - vec![variant], + vec![Some(variant)], ); - } + let mut builder = VariantBuilder::new(); + let mut list_builder = builder.new_list(); + list_builder.append_value(Variant::Int64(1000000000000)); + list_builder.append_value(Variant::Double(2.718281828459045)); + list_builder.finish(); + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": [1000000000000, 2.718281828459045]}", + vec![1], + vec![Some(variant)], + ); + } } diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index e6eab80283fd..afc064b3c106 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -56,10 +56,10 @@ pub use variant_array::{BorrowedShreddingState, ShreddingState, VariantArray, Va pub use variant_array_builder::{VariantArrayBuilder, VariantValueArrayBuilder}; pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options}; +pub use decoder::VariantArrayDecoderFactory; pub use from_json::json_to_variant; pub use shred_variant::{IntoShreddingField, ShreddedSchemaBuilder, shred_variant}; pub use to_json::variant_to_json; pub use type_conversion::CastOptions; pub use unshred_variant::unshred_variant; pub use variant_get::{GetOptions, variant_get}; -pub use decoder::VariantArrayDecoderFactory; From 6aeb892706742afc69e421cf9b3766ff26a4e459 Mon Sep 17 00:00:00 2001 From: Doug Miller Date: Mon, 22 Dec 2025 07:37:34 -0500 Subject: [PATCH 05/14] more obvious --- parquet-variant-compute/src/decoder.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index f05668bfa65e..a7150338b1f1 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -132,10 +132,26 @@ fn variant_from_tape_element( let s = tape.get_string(idx); builder.append_value(parse_number(s)?) } - TapeElement::I64(i) => builder.append_value(i), - TapeElement::I32(i) => builder.append_value(i), - TapeElement::F64(f) => builder.append_value(f), - TapeElement::F32(f) => builder.append_value(f), + TapeElement::I64(i) => { + return Err(ArrowError::JsonError(format!( + "I64 tape element not supported: {i}" + ))); + } + TapeElement::I32(i) => { + return Err(ArrowError::JsonError(format!( + "I32 tape element not supported: {i}" + ))); + } + TapeElement::F64(f) => { + return Err(ArrowError::JsonError(format!( + "F64 tape element not supported: {f}" + ))); + } + TapeElement::F32(f) => { + return Err(ArrowError::JsonError(format!( + "F32 tape element not supported: {f}" + ))); + } TapeElement::True => builder.append_value(true), TapeElement::False => builder.append_value(false), TapeElement::Null => builder.append_value(Variant::Null), From 93ad08c3a046b3daf38ba5d189e7cbe69e47ac54 Mon Sep 17 00:00:00 2001 From: Doug Miller Date: Thu, 8 Jan 2026 14:23:06 -0500 Subject: [PATCH 06/14] fix doctest and more exhaustive check --- arrow-json/src/reader/mod.rs | 4 ++-- parquet-variant-compute/src/decoder.rs | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index dadb437611dc..5602d5e92f64 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -220,7 +220,7 @@ impl ReaderBuilder { /// # use arrow_array::cast::AsArray; /// # use arrow_array::types::Int32Type; /// # use arrow_json::ReaderBuilder; - /// # use arrow_schema::{DataType, Field}; + /// # use arrow_schema::{DataType, Field, FieldRef}; /// // Root of JSON schema is a numeric type /// let data = "1\n2\n3\n"; /// let field = Arc::new(Field::new("int", DataType::Int32, true)); @@ -396,7 +396,7 @@ impl RecordBatchReader for Reader { /// ``` /// use arrow_json::{ArrayDecoder, DecoderFactory, TapeElement, Tape, ReaderBuilder, StructMode}; /// use arrow_schema::ArrowError; -/// use arrow_schema::{DataType, Field, Fields, Schema}; +/// use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; /// use arrow_array::cast::AsArray; /// use arrow_array::Array; /// use arrow_array::builder::StringBuilder; diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index a7150338b1f1..68cc6c9c6299 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -82,6 +82,7 @@ impl DecoderFactory for VariantArrayDecoderFactory { ) -> Result>, ArrowError> { if let Some(field) = field && field.extension_type_name() == Some(VariantType::NAME) + && field.try_extension_type::().is_ok() { return Ok(Some(Box::new(VariantArrayDecoder))); } From 1fb7c47f3cd105b35014e9050776c06f14f9fbb8 Mon Sep 17 00:00:00 2001 From: Doug Miller Date: Sun, 11 Jan 2026 12:10:13 -0500 Subject: [PATCH 07/14] more lint fixes --- parquet-variant-compute/src/decoder.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index 68cc6c9c6299..7f9eaa5267c4 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -48,10 +48,9 @@ impl ArrayDecoder for VariantArrayDecoder { } } -/// A [`DecoderFactory`] that creates [`VariantArrayDecoder`] instances for Variant-typed fields. -/// -/// This factory integrates with the Arrow JSON reader to automatically decode JSON values -/// into Variant arrays when the target field is registered as a [`VariantType`] extension type. +/// A [`DecoderFactory`] that integrates with the Arrow JSON reader to automatically +/// decode JSON values into Variant arrays when the target field is registered as a +/// [`VariantType`] extension type. /// /// # Example /// @@ -80,11 +79,14 @@ impl DecoderFactory for VariantArrayDecoderFactory { _is_nullable: bool, _struct_mode: StructMode, ) -> Result>, ArrowError> { - if let Some(field) = field - && field.extension_type_name() == Some(VariantType::NAME) + let field = match field { + Some(inner_field) => inner_field, + None => return Ok(None), + }; + if field.extension_type_name() == Some(VariantType::NAME) && field.try_extension_type::().is_ok() { - return Ok(Some(Box::new(VariantArrayDecoder))); + return Ok(Some(Box::new(VariantArrayDecoder))) } Ok(None) } @@ -282,7 +284,7 @@ mod tests { let mut builder = VariantBuilder::new(); let mut list_builder = builder.new_list(); list_builder.append_value(Variant::Int64(1000000000000)); - list_builder.append_value(Variant::Double(2.718281828459045)); + list_builder.append_value(Variant::Double(std::f64::consts::E)); list_builder.finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); From 1f50f34623948d85333e0f944fe7449a3b4973e1 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Wed, 21 Jan 2026 13:31:54 -0800 Subject: [PATCH 08/14] move extension type construction logic out of Field --- arrow-schema/src/extension/canonical/bool8.rs | 2 +- .../extension/canonical/fixed_shape_tensor.rs | 2 +- arrow-schema/src/extension/canonical/json.rs | 2 +- .../src/extension/canonical/opaque.rs | 2 +- arrow-schema/src/extension/canonical/uuid.rs | 2 +- .../canonical/variable_shape_tensor.rs | 2 +- arrow-schema/src/extension/mod.rs | 38 +++++++++++++++++++ arrow-schema/src/field.rs | 20 +--------- 8 files changed, 45 insertions(+), 25 deletions(-) diff --git a/arrow-schema/src/extension/canonical/bool8.rs b/arrow-schema/src/extension/canonical/bool8.rs index 362a2cc018c7..c94c8217b8ff 100644 --- a/arrow-schema/src/extension/canonical/bool8.rs +++ b/arrow-schema/src/extension/canonical/bool8.rs @@ -96,7 +96,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::Int8, false).with_metadata( [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "".to_owned())] diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index b6bd1c1223f4..5157eefe9ebb 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -471,7 +471,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index 297a2d99aa04..d2a54b9189b7 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -222,7 +222,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::Int8, false).with_metadata( [(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())] diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index fceae8d3711d..acfc1331a670 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -285,7 +285,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::Null, false).with_metadata( [( diff --git a/arrow-schema/src/extension/canonical/uuid.rs b/arrow-schema/src/extension/canonical/uuid.rs index 09533564ed44..3e897f47318d 100644 --- a/arrow-schema/src/extension/canonical/uuid.rs +++ b/arrow-schema/src/extension/canonical/uuid.rs @@ -100,7 +100,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new("", DataType::FixedSizeBinary(16), false); field.extension_type::(); diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index b5403dcf684f..fbc641f54366 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -529,7 +529,7 @@ mod tests { } #[test] - #[should_panic(expected = "Field extension type name missing")] + #[should_panic(expected = "Extension type name missing")] fn missing_name() { let field = Field::new_struct( "", diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index cd17272e15ab..b356d0b61422 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -23,6 +23,7 @@ mod canonical; pub use canonical::*; use crate::{ArrowError, DataType}; +use std::collections::HashMap; /// The metadata key for the string name identifying an [`ExtensionType`]. pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; @@ -255,4 +256,41 @@ pub trait ExtensionType: Sized { /// This should return an error if the given data type is not supported by /// this extension type. fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result; + + /// Construct this extension type from field metadata and data type. + /// + /// This is a provided method that extracts extension type information from + /// metadata (using [`EXTENSION_TYPE_NAME_KEY`] and + /// [`EXTENSION_TYPE_METADATA_KEY`]) and delegates to [`Self::try_new`]. + /// + /// Returns an error if: + /// - The extension type name is missing or doesn't match [`Self::NAME`] + /// - Metadata deserialization fails + /// - The data type is not supported + /// + /// This method enables extension type checking without requiring a full + /// [`Field`] instance, useful when only metadata and data type are available. + /// + /// [`Field`]: crate::Field + fn try_from_parts( + metadata: &HashMap, + data_type: &DataType, + ) -> Result { + match metadata.get(EXTENSION_TYPE_NAME_KEY).map(|s| s.as_str()) { + Some(Self::NAME) => { + let ext_metadata = metadata + .get(EXTENSION_TYPE_METADATA_KEY) + .map(|s| s.as_str()); + let parsed = Self::deserialize_metadata(ext_metadata)?; + Self::try_new(data_type, parsed) + } + Some(name) => Err(ArrowError::InvalidArgumentError(format!( + "Extension type name mismatch: expected {}, got {name}", + Self::NAME + ))), + None => Err(ArrowError::InvalidArgumentError( + "Extension type name missing".to_string() + )), + } + } } diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index c4566e41bfa8..27d0b0c46e51 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -575,25 +575,7 @@ impl Field { /// } /// ``` pub fn try_extension_type(&self) -> Result { - // Check the extension name in the metadata - match self.extension_type_name() { - // It should match the name of the given extension type - Some(name) if name == E::NAME => { - // Deserialize the metadata and try to construct the extension - // type - E::deserialize_metadata(self.extension_type_metadata()) - .and_then(|metadata| E::try_new(self.data_type(), metadata)) - } - // Name mismatch - Some(name) => Err(ArrowError::InvalidArgumentError(format!( - "Field extension type name mismatch, expected {}, found {name}", - E::NAME - ))), - // Name missing - None => Err(ArrowError::InvalidArgumentError( - "Field extension type name missing".to_owned(), - )), - } + E::try_from_parts(self.metadata(), self.data_type()) } /// Returns an instance of the given [`ExtensionType`] of this [`Field`], From 6290db90fac8a427ebc63a1c24789e6b09ef195c Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Wed, 21 Jan 2026 15:05:15 -0800 Subject: [PATCH 09/14] checkpoint - reworked signatures --- arrow-json/src/reader/list_array.rs | 9 +-- arrow-json/src/reader/map_array.rs | 18 ++--- arrow-json/src/reader/mod.rs | 100 +++++++++++++++++-------- arrow-json/src/reader/struct_array.rs | 12 ++- arrow-schema/src/extension/mod.rs | 4 +- parquet-variant-compute/src/decoder.rs | 23 +++--- 6 files changed, 98 insertions(+), 68 deletions(-) diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index 7a697faf9bc2..cc9eb356a97d 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -24,7 +24,6 @@ use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; use std::marker::PhantomData; -use std::sync::Arc; use super::DecoderFactory; @@ -42,7 +41,7 @@ impl ListArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, - decoder_factory: Option>, + decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { let field = match &data_type { DataType::List(f) if !O::IS_LARGE => f, @@ -50,11 +49,11 @@ impl ListArrayDecoder { _ => unreachable!(), }; let decoder = make_decoder( - Some(field.clone()), - field.data_type().clone(), + field.data_type(), + field.is_nullable(), + field.metadata(), coerce_primitive, strict_mode, - field.is_nullable(), struct_mode, decoder_factory, )?; diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index 4a115d49bea3..2575aacb55a6 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; use crate::reader::{ArrayDecoder, make_decoder}; @@ -42,7 +40,7 @@ impl MapArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, - decoder_factory: Option>, + decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { let fields = match &data_type { DataType::Map(_, true) => { @@ -62,20 +60,20 @@ impl MapArrayDecoder { }; let keys = make_decoder( - Some(fields[0].clone()), - fields[0].data_type().clone(), + fields[0].data_type(), + fields[0].is_nullable(), + fields[0].metadata(), coerce_primitive, strict_mode, - fields[0].is_nullable(), struct_mode, - decoder_factory.clone(), + decoder_factory, )?; let values = make_decoder( - Some(fields[1].clone()), - fields[1].data_type().clone(), + fields[1].data_type(), + fields[1].is_nullable(), + fields[1].metadata(), coerce_primitive, strict_mode, - fields[1].is_nullable(), struct_mode, decoder_factory, )?; diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 5602d5e92f64..a5bfe6ca67e2 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -137,6 +137,8 @@ use crate::StructMode; use crate::reader::binary_array::{ BinaryArrayDecoder, BinaryViewDecoder, FixedSizeBinaryArrayDecoder, }; +use std::borrow::Cow; +use std::collections::HashMap; use std::io::BufRead; use std::sync::Arc; @@ -306,22 +308,24 @@ impl ReaderBuilder { /// Create a [`Decoder`] pub fn build_decoder(self) -> Result { - let (data_type, nullable) = match self.is_field { - false => (DataType::Struct(self.schema.fields.clone()), false), - true => { - let field = &self.schema.fields[0]; - (field.data_type().clone(), field.is_nullable()) - } + let empty_metadata = HashMap::default(); + let (data_type, nullable, metadata) = if self.is_field { + let field = &self.schema.fields[0]; + let data_type = Cow::Borrowed(field.data_type()); + (data_type, field.is_nullable(), field.metadata()) + } else { + let data_type = Cow::Owned(DataType::Struct(self.schema.fields.clone())); + (data_type, false, &empty_metadata) }; let decoder = make_decoder( - None, - data_type, + data_type.as_ref(), + nullable, + metadata, self.coerce_primitive, self.strict_mode, - nullable, self.struct_mode, - self.decoder_factory, + self.decoder_factory.as_deref(), )?; let num_fields = self.schema.flattened_fields().len(); @@ -424,14 +428,15 @@ impl RecordBatchReader for Reader { /// struct IncorrectStringAsNullDecoderFactory; /// /// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { -/// fn make_custom_decoder<'a>( +/// fn make_custom_decoder( /// &self, -/// _field: Option, -/// data_type: DataType, +/// data_type: &DataType, +/// _is_nullable: bool, +/// _field_metadata: &HashMap, /// _coerce_primitive: bool, /// _strict_mode: bool, -/// _is_nullable: bool, /// _struct_mode: StructMode, +/// _decoder_factory: Option<&dyn DecoderFactory>, /// ) -> Result>, ArrowError> { /// match data_type { /// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder {}))), @@ -466,15 +471,26 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// This can be used to override how e.g. error in decoding are handled. fn make_custom_decoder( &self, - _field: Option, - _data_type: DataType, + _data_type: &DataType, + _is_nullable: bool, + _field_metadata: &HashMap, _coerce_primitive: bool, _strict_mode: bool, - _is_nullable: bool, _struct_mode: StructMode, + _decoder_factory: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError>; } +/// +/// Metadata key for JSON decoder configuration. +/// +/// This well-known metadata key can be used to annotate schema fields with +/// custom decoding instructions. Custom [`DecoderFactory`] implementations +/// can inspect this metadata to determine how to decode specific fields. +/// +/// This metadata key is automatically stripped from the output array schema. +pub const JSON_DECODER_CONFIG_KEY: &str = "arrow-rs:json:decoder"; + /// A low-level interface for reading JSON data from a byte stream /// /// See [`Reader`] for a higher-level interface for interface with [`BufRead`] @@ -788,28 +804,49 @@ macro_rules! primitive_decoder { }; } -fn make_decoder( - field: Option, - data_type: DataType, +/// Creates an [`ArrayDecoder`] for decoding JSON values into Arrow arrays. +/// +/// This function is the primary entry point for constructing decoders. It first +/// attempts to use a custom [`DecoderFactory`] if provided, then falls back to +/// default decoder implementations based on the data type. +/// +/// Custom decoders can use this function to recursively create child decoders +/// for complex types by passing along the `decoder_factory` parameter. +/// +/// # Arguments +/// +/// * `data_type` - The Arrow data type to decode into +/// * `is_nullable` - Whether the field is nullable +/// * `field_metadata` - Schema metadata for the field (can contain [`JSON_DECODER_CONFIG_KEY`]) +/// * `coerce_primitive` - Whether to coerce primitive types (e.g., string to number) +/// * `strict_mode` - Whether to validate struct fields strictly +/// * `struct_mode` - How to decode struct fields +/// * `decoder_factory` - Optional custom decoder factory for recursion +pub fn make_decoder( + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, coerce_primitive: bool, strict_mode: bool, - is_nullable: bool, struct_mode: StructMode, - decoder_factory: Option>, + decoder_factory: Option<&dyn DecoderFactory>, ) -> Result, ArrowError> { - if let Some(ref factory) = decoder_factory { + if let Some(factory) = decoder_factory { if let Some(decoder) = factory.make_custom_decoder( - field.clone(), - data_type.clone(), + data_type, + is_nullable, + field_metadata, coerce_primitive, strict_mode, - is_nullable, struct_mode, + decoder_factory, )? { return Ok(decoder); } } + let data_type = data_type.clone(); + downcast_integer! { data_type => (primitive_decoder, data_type), DataType::Null => Ok(Box::::default()), @@ -2954,17 +2991,18 @@ mod tests { struct AlwaysNullStringArrayDecoderFactory; impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { - fn make_custom_decoder<'a>( + fn make_custom_decoder( &self, - _field: Option, - data_type: DataType, + data_type: &DataType, + _is_nullable: bool, + _field_metadata: &HashMap, _coerce_primitive: bool, _strict_mode: bool, - _is_nullable: bool, _struct_mode: StructMode, + _decoder_factory: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { match data_type { - DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder {}))), + DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder))), _ => Ok(None), } } diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 7025fe73d76c..12db53b103ba 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::reader::tape::{Tape, TapeElement}; use crate::reader::{ArrayDecoder, StructMode, make_decoder}; use arrow_array::builder::BooleanBufferBuilder; @@ -89,7 +87,7 @@ impl StructArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, - decoder_factory: Option>, + decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { let (decoders, field_name_to_index) = { let fields = struct_fields(&data_type); @@ -101,13 +99,13 @@ impl StructArrayDecoder { // it doesn't contain any nulls not masked by its parent let nullable = f.is_nullable() || is_nullable; make_decoder( - Some(f.clone()), - f.data_type().clone(), + f.data_type(), + nullable, + f.metadata(), coerce_primitive, strict_mode, - nullable, struct_mode, - decoder_factory.clone(), + decoder_factory, ) }) .collect::, ArrowError>>()?; diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index b356d0b61422..4b9ddf1a4548 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -277,7 +277,7 @@ pub trait ExtensionType: Sized { data_type: &DataType, ) -> Result { match metadata.get(EXTENSION_TYPE_NAME_KEY).map(|s| s.as_str()) { - Some(Self::NAME) => { + Some(name) if name == Self::NAME => { let ext_metadata = metadata .get(EXTENSION_TYPE_METADATA_KEY) .map(|s| s.as_str()); @@ -289,7 +289,7 @@ pub trait ExtensionType: Sized { Self::NAME ))), None => Err(ArrowError::InvalidArgumentError( - "Extension type name missing".to_string() + "Extension type name missing".to_string(), )), } } diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index 7f9eaa5267c4..eb323cda67ad 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -20,7 +20,8 @@ use arrow_array::{Array, StructArray}; use arrow_data::ArrayData; use arrow_json::{DecoderFactory, StructMode}; use arrow_schema::extension::ExtensionType; -use arrow_schema::{ArrowError, DataType, FieldRef}; +use arrow_schema::{ArrowError, DataType}; +use std::collections::HashMap; use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt}; use arrow_json::reader::ArrayDecoder; @@ -70,23 +71,19 @@ impl ArrayDecoder for VariantArrayDecoder { pub struct VariantArrayDecoderFactory; impl DecoderFactory for VariantArrayDecoderFactory { - fn make_custom_decoder<'a>( + fn make_custom_decoder( &self, - field: Option, - _data_type: DataType, + data_type: &DataType, + _is_nullable: bool, + field_metadata: &HashMap, _coerce_primitive: bool, _strict_mode: bool, - _is_nullable: bool, _struct_mode: StructMode, + _decoder_factory: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { - let field = match field { - Some(inner_field) => inner_field, - None => return Ok(None), - }; - if field.extension_type_name() == Some(VariantType::NAME) - && field.try_extension_type::().is_ok() - { - return Ok(Some(Box::new(VariantArrayDecoder))) + // Check if this is a Variant extension type using metadata + if VariantType::try_from_parts(field_metadata, data_type).is_ok() { + return Ok(Some(Box::new(VariantArrayDecoder))); } Ok(None) } From 06fb5f5c95a98c60c18826fc63e08491f73a604b Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Wed, 21 Jan 2026 15:59:54 -0800 Subject: [PATCH 10/14] checkpoint - add metadata stripping --- arrow-flight/tests/flight_sql_client_cli.rs | 2 +- arrow-json/src/reader/list_array.rs | 37 +++++++- arrow-json/src/reader/map_array.rs | 88 ++++++++++++++---- arrow-json/src/reader/mod.rs | 22 ++++- arrow-json/src/reader/primitive_array.rs | 4 +- arrow-json/src/reader/struct_array.rs | 99 ++++++++++++++------- arrow-json/src/reader/timestamp_array.rs | 4 +- parquet-variant-compute/src/decoder.rs | 2 +- 8 files changed, 196 insertions(+), 62 deletions(-) diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index c161caae8ca4..716c626d6adb 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -48,7 +48,7 @@ const QUERY: &str = "SELECT * FROM table;"; /// Return a Command instance for running the `flight_sql_client` CLI fn flight_sql_client_cmd() -> Command { - Command::new(assert_cmd::cargo::cargo_bin!("flight_sql_client")) + Command::new(assert_cmd::cargo::cargo_bin("flight_sql_client")) } #[tokio::test] diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index cc9eb356a97d..201e725dc155 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -17,18 +17,20 @@ use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, make_decoder}; +use crate::reader::{ArrayDecoder, JSON_DECODER_CONFIG_KEY, make_decoder}; use arrow_array::OffsetSizeTrait; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::{ArrowError, DataType, Field}; use std::marker::PhantomData; +use std::sync::Arc; use super::DecoderFactory; pub struct ListArrayDecoder { data_type: DataType, + type_changed: bool, decoder: Box, phantom: PhantomData, is_nullable: bool, @@ -36,14 +38,14 @@ pub struct ListArrayDecoder { impl ListArrayDecoder { pub fn new( - data_type: DataType, + data_type: &DataType, coerce_primitive: bool, strict_mode: bool, is_nullable: bool, struct_mode: StructMode, decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { - let field = match &data_type { + let field = match data_type { DataType::List(f) if !O::IS_LARGE => f, DataType::LargeList(f) if O::IS_LARGE => f, _ => unreachable!(), @@ -58,8 +60,31 @@ impl ListArrayDecoder { decoder_factory, )?; + // Check if field needs modification + let type_changed = field.metadata().contains_key(JSON_DECODER_CONFIG_KEY) + || decoder.output_data_type().is_some(); + + let data_type = if type_changed { + // Strip decoder metadata and update data type if needed + let data_type = decoder.output_data_type().unwrap_or(field.data_type()); + let mut metadata = field.metadata().clone(); + metadata.remove(JSON_DECODER_CONFIG_KEY); + + let field = Field::new(field.name(), data_type.clone(), field.is_nullable()); + let field = Arc::new(field.with_metadata(metadata)); + + if O::IS_LARGE { + DataType::LargeList(field) + } else { + DataType::List(field) + } + } else { + data_type.clone() + }; + Ok(Self { data_type, + type_changed, decoder, phantom: Default::default(), is_nullable, @@ -68,6 +93,10 @@ impl ListArrayDecoder { } impl ArrayDecoder for ListArrayDecoder { + fn output_data_type(&self) -> Option<&DataType> { + self.type_changed.then_some(&self.data_type) + } + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut child_pos = Vec::with_capacity(pos.len()); let mut offsets = BufferBuilder::::new(pos.len() + 1); diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index 2575aacb55a6..f39c9ae1f332 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -17,17 +17,20 @@ use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, make_decoder}; +use crate::reader::{ArrayDecoder, JSON_DECODER_CONFIG_KEY, make_decoder}; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::ArrowNativeType; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::{ArrowError, DataType, Field}; +use std::borrow::Cow; +use std::sync::Arc; use super::DecoderFactory; pub struct MapArrayDecoder { data_type: DataType, + type_changed: bool, keys: Box, values: Box, is_nullable: bool, @@ -35,21 +38,21 @@ pub struct MapArrayDecoder { impl MapArrayDecoder { pub fn new( - data_type: DataType, + data_type: &DataType, coerce_primitive: bool, strict_mode: bool, is_nullable: bool, struct_mode: StructMode, decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { - let fields = match &data_type { + let (map_field, fields) = match data_type { DataType::Map(_, true) => { return Err(ArrowError::NotYetImplemented( "Decoding MapArray with sorted fields".to_string(), )); } DataType::Map(f, _) => match f.data_type() { - DataType::Struct(fields) if fields.len() == 2 => fields, + DataType::Struct(fields) if fields.len() == 2 => (f, fields), d => { return Err(ArrowError::InvalidArgumentError(format!( "MapArray must contain struct with two fields, got {d}" @@ -59,27 +62,75 @@ impl MapArrayDecoder { _ => unreachable!(), }; + let (key_field, value_field) = (&fields[0], &fields[1]); + let keys = make_decoder( - fields[0].data_type(), - fields[0].is_nullable(), - fields[0].metadata(), + key_field.data_type(), + key_field.is_nullable(), + key_field.metadata(), coerce_primitive, strict_mode, struct_mode, decoder_factory, )?; let values = make_decoder( - fields[1].data_type(), - fields[1].is_nullable(), - fields[1].metadata(), + value_field.data_type(), + value_field.is_nullable(), + value_field.metadata(), coerce_primitive, strict_mode, struct_mode, decoder_factory, )?; + // Check if fields need modification + let key_has_metadata = key_field.metadata().contains_key(JSON_DECODER_CONFIG_KEY); + let value_has_metadata = value_field.metadata().contains_key(JSON_DECODER_CONFIG_KEY); + + let key_field = if key_has_metadata || keys.output_data_type().is_some() { + let key_type = keys.output_data_type().unwrap_or(key_field.data_type()); + let mut metadata = key_field.metadata().clone(); + metadata.remove(JSON_DECODER_CONFIG_KEY); + + let field = Field::new(key_field.name(), key_type.clone(), key_field.is_nullable()); + Cow::Owned(Arc::new(field.with_metadata(metadata))) + } else { + Cow::Borrowed(key_field) + }; + + let value_field = if value_has_metadata || values.output_data_type().is_some() { + let value_type = values.output_data_type().unwrap_or(value_field.data_type()); + let mut metadata = value_field.metadata().clone(); + metadata.remove(JSON_DECODER_CONFIG_KEY); + + let field = Field::new( + value_field.name(), + value_type.clone(), + value_field.is_nullable(), + ); + Cow::Owned(Arc::new(field.with_metadata(metadata))) + } else { + Cow::Borrowed(value_field) + }; + + let type_changed = + matches!(key_field, Cow::Owned(_)) || matches!(value_field, Cow::Owned(_)); + + let data_type = if type_changed { + let struct_fields = vec![key_field.into_owned(), value_field.into_owned()]; + let struct_field = Arc::new(Field::new( + map_field.name(), + DataType::Struct(struct_fields.into()), + map_field.is_nullable(), + )); + DataType::Map(struct_field, false) + } else { + data_type.clone() + }; + Ok(Self { data_type, + type_changed, keys, values, is_nullable, @@ -88,13 +139,16 @@ impl MapArrayDecoder { } impl ArrayDecoder for MapArrayDecoder { + fn output_data_type(&self) -> Option<&DataType> { + self.type_changed.then_some(&self.data_type) + } + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { - let s = match &self.data_type { - DataType::Map(f, _) => match f.data_type() { - s @ DataType::Struct(_) => s, - _ => unreachable!(), - }, - _ => unreachable!(), + let DataType::Map(f, _) = &self.data_type else { + unreachable!() + }; + let s @ DataType::Struct(_) = f.data_type() else { + unreachable!() }; let mut offsets = BufferBuilder::::new(pos.len() + 1); diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index a5bfe6ca67e2..7adf5f76669d 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -794,6 +794,22 @@ impl Decoder { /// A trait to decode JSON values into arrow arrays pub trait ArrayDecoder: Send { + /// Returns the output data type if it differs from the input data type. + /// + /// Custom decoders may produce output with a different data type than was + /// specified in the input schema. For example, a custom decoder might strip + /// extension metadata or change the type entirely. + /// + /// Returns: + /// - `Some(&DataType)` if the decoder's output type differs from its input + /// - `None` if the decoder produces the same type as its input + /// + /// This is used by parent decoders (Struct, List, Map) to determine whether + /// they need to update their output type to reflect child type changes. + fn output_data_type(&self) -> Option<&DataType> { + None + } + /// Decode elements from `tape` starting at the indexes contained in `pos` fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result; } @@ -845,10 +861,8 @@ pub fn make_decoder( } } - let data_type = data_type.clone(); - downcast_integer! { - data_type => (primitive_decoder, data_type), + *data_type => (primitive_decoder, data_type), DataType::Null => Ok(Box::::default()), DataType::Float16 => primitive_decoder!(Float16Type, data_type), DataType::Float32 => primitive_decoder!(Float32Type, data_type), @@ -907,7 +921,7 @@ pub fn make_decoder( DataType::FixedSizeBinary(len) => Ok(Box::new(FixedSizeBinaryArrayDecoder::new(len))), DataType::BinaryView => Ok(Box::new(BinaryViewDecoder::default())), DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), - d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in JSON reader"))) + _ => Err(ArrowError::NotYetImplemented(format!("Support for {data_type} in JSON reader"))) } } diff --git a/arrow-json/src/reader/primitive_array.rs b/arrow-json/src/reader/primitive_array.rs index fa8464aa3251..b2bffe45e43a 100644 --- a/arrow-json/src/reader/primitive_array.rs +++ b/arrow-json/src/reader/primitive_array.rs @@ -80,9 +80,9 @@ pub struct PrimitiveArrayDecoder { } impl PrimitiveArrayDecoder

{ - pub fn new(data_type: DataType) -> Self { + pub fn new(data_type: &DataType) -> Self { Self { - data_type, + data_type: data_type.clone(), phantom: Default::default(), } } diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 12db53b103ba..39e663b60554 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -16,12 +16,14 @@ // under the License. use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, StructMode, make_decoder}; +use crate::reader::{ArrayDecoder, JSON_DECODER_CONFIG_KEY, StructMode, make_decoder}; use arrow_array::builder::BooleanBufferBuilder; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType, Fields}; +use arrow_schema::{ArrowError, DataType, Field, Fields}; +use std::borrow::Cow; use std::collections::HashMap; +use std::sync::Arc; /// Reusable buffer for tape positions, indexed by (field_idx, row_idx). /// A value of 0 indicates the field is absent for that row. @@ -72,6 +74,7 @@ use super::DecoderFactory; pub struct StructArrayDecoder { data_type: DataType, + type_changed: bool, decoders: Vec>, strict_mode: bool, is_nullable: bool, @@ -82,43 +85,73 @@ pub struct StructArrayDecoder { impl StructArrayDecoder { pub fn new( - data_type: DataType, + data_type: &DataType, coerce_primitive: bool, strict_mode: bool, is_nullable: bool, struct_mode: StructMode, decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { - let (decoders, field_name_to_index) = { - let fields = struct_fields(&data_type); - let decoders = fields - .iter() - .map(|f| { - // If this struct nullable, need to permit nullability in child array - // StructArrayDecoder::decode verifies that if the child is not nullable - // it doesn't contain any nulls not masked by its parent - let nullable = f.is_nullable() || is_nullable; - make_decoder( - f.data_type(), - nullable, - f.metadata(), - coerce_primitive, - strict_mode, - struct_mode, - decoder_factory, - ) - }) - .collect::, ArrowError>>()?; - let field_name_to_index = if struct_mode == StructMode::ObjectOnly { - build_field_index(fields) + let fields = struct_fields(data_type); + + // Create decoders and track whether any fields need modification + let mut decoders = Vec::with_capacity(fields.len()); + let mut output_fields: Vec>> = Vec::with_capacity(fields.len()); + let mut type_changed = false; + + for field in fields { + // If this struct nullable, need to permit nullability in child array + // StructArrayDecoder::decode verifies that if the child is not nullable + // it doesn't contain any nulls not masked by its parent + let nullable = field.is_nullable() || is_nullable; + let decoder = make_decoder( + field.data_type(), + nullable, + field.metadata(), + coerce_primitive, + strict_mode, + struct_mode, + decoder_factory, + )?; + + // Check if field needs modification + let has_decoder_metadata = field.metadata().contains_key(JSON_DECODER_CONFIG_KEY); + let child_type_changed = decoder.output_data_type().is_some(); + + let field = if child_type_changed || has_decoder_metadata { + type_changed = true; + // Strip decoder metadata and update data type if needed + let data_type = decoder.output_data_type().unwrap_or(field.data_type()); + let mut metadata = field.metadata().clone(); + metadata.remove(JSON_DECODER_CONFIG_KEY); + + let field = Field::new(field.name(), data_type.clone(), field.is_nullable()); + Cow::Owned(Arc::new(field.with_metadata(metadata))) } else { - None + Cow::Borrowed(field) }; - (decoders, field_name_to_index) + + decoders.push(decoder); + output_fields.push(field); + } + + // Only create new DataType if something actually changed + let data_type = if type_changed { + let owned: Vec> = output_fields.into_iter().map(Cow::into_owned).collect(); + DataType::Struct(owned.into()) + } else { + data_type.clone() + }; + + let field_name_to_index = if struct_mode == StructMode::ObjectOnly { + build_field_index(fields) + } else { + None }; Ok(Self { data_type, + type_changed, decoders, strict_mode, is_nullable, @@ -130,6 +163,10 @@ impl StructArrayDecoder { } impl ArrayDecoder for StructArrayDecoder { + fn output_data_type(&self) -> Option<&DataType> { + self.type_changed.then_some(&self.data_type) + } + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let fields = struct_fields(&self.data_type); let row_count = pos.len(); @@ -275,10 +312,10 @@ impl ArrayDecoder for StructArrayDecoder { } fn struct_fields(data_type: &DataType) -> &Fields { - match &data_type { - DataType::Struct(f) => f, - _ => unreachable!(), - } + let DataType::Struct(f) = data_type else { + unreachable!() + }; + f } fn build_field_index(fields: &Fields) -> Option> { diff --git a/arrow-json/src/reader/timestamp_array.rs b/arrow-json/src/reader/timestamp_array.rs index 79f2b04eeba8..ff24d5391f9d 100644 --- a/arrow-json/src/reader/timestamp_array.rs +++ b/arrow-json/src/reader/timestamp_array.rs @@ -37,9 +37,9 @@ pub struct TimestampArrayDecoder { } impl TimestampArrayDecoder { - pub fn new(data_type: DataType, timezone: Tz) -> Self { + pub fn new(data_type: &DataType, timezone: Tz) -> Self { Self { - data_type, + data_type: data_type.clone(), timezone, phantom: Default::default(), } diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index eb323cda67ad..2ebc677f31e8 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -21,8 +21,8 @@ use arrow_data::ArrayData; use arrow_json::{DecoderFactory, StructMode}; use arrow_schema::extension::ExtensionType; use arrow_schema::{ArrowError, DataType}; -use std::collections::HashMap; use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt}; +use std::collections::HashMap; use arrow_json::reader::ArrayDecoder; use arrow_json::reader::{Tape, TapeElement}; From f7b36d72e27b92f315795ef35ecc55a678ade807 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Fri, 23 Jan 2026 16:19:10 -0800 Subject: [PATCH 11/14] Ditch annotations. Got it working --- arrow-json/Cargo.toml | 1 + arrow-json/src/reader/list_array.rs | 35 +- arrow-json/src/reader/map_array.rs | 86 +-- arrow-json/src/reader/mod.rs | 181 +++-- arrow-json/src/reader/struct_array.rs | 48 +- arrow-json/tests/custom_decoder_tests.rs | 800 +++++++++++++++++++++++ parquet-variant-compute/src/decoder.rs | 7 +- 7 files changed, 944 insertions(+), 214 deletions(-) create mode 100644 arrow-json/tests/custom_decoder_tests.rs diff --git a/arrow-json/Cargo.toml b/arrow-json/Cargo.toml index 5fcde480eb6d..82a4b8b7a32d 100644 --- a/arrow-json/Cargo.toml +++ b/arrow-json/Cargo.toml @@ -54,6 +54,7 @@ ryu = "1.0" itoa = "1.0" [dev-dependencies] +arrow-select = { workspace = true } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } serde = { version = "1.0", default-features = false, features = ["derive"] } futures = "0.3" diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index 201e725dc155..ecb7c7bbaf27 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -17,20 +17,18 @@ use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, JSON_DECODER_CONFIG_KEY, make_decoder}; +use crate::reader::{ArrayDecoder, make_decoder}; use arrow_array::OffsetSizeTrait; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType, Field}; +use arrow_schema::{ArrowError, DataType}; use std::marker::PhantomData; -use std::sync::Arc; use super::DecoderFactory; pub struct ListArrayDecoder { data_type: DataType, - type_changed: bool, decoder: Box, phantom: PhantomData, is_nullable: bool, @@ -60,31 +58,8 @@ impl ListArrayDecoder { decoder_factory, )?; - // Check if field needs modification - let type_changed = field.metadata().contains_key(JSON_DECODER_CONFIG_KEY) - || decoder.output_data_type().is_some(); - - let data_type = if type_changed { - // Strip decoder metadata and update data type if needed - let data_type = decoder.output_data_type().unwrap_or(field.data_type()); - let mut metadata = field.metadata().clone(); - metadata.remove(JSON_DECODER_CONFIG_KEY); - - let field = Field::new(field.name(), data_type.clone(), field.is_nullable()); - let field = Arc::new(field.with_metadata(metadata)); - - if O::IS_LARGE { - DataType::LargeList(field) - } else { - DataType::List(field) - } - } else { - data_type.clone() - }; - Ok(Self { - data_type, - type_changed, + data_type: data_type.clone(), decoder, phantom: Default::default(), is_nullable, @@ -93,10 +68,6 @@ impl ListArrayDecoder { } impl ArrayDecoder for ListArrayDecoder { - fn output_data_type(&self) -> Option<&DataType> { - self.type_changed.then_some(&self.data_type) - } - fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut child_pos = Vec::with_capacity(pos.len()); let mut offsets = BufferBuilder::::new(pos.len() + 1); diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index f39c9ae1f332..e3fd6d6572be 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -17,20 +17,17 @@ use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, JSON_DECODER_CONFIG_KEY, make_decoder}; +use crate::reader::{ArrayDecoder, make_decoder}; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::ArrowNativeType; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType, Field}; -use std::borrow::Cow; -use std::sync::Arc; +use arrow_schema::{ArrowError, DataType}; use super::DecoderFactory; pub struct MapArrayDecoder { data_type: DataType, - type_changed: bool, keys: Box, values: Box, is_nullable: bool, @@ -45,21 +42,20 @@ impl MapArrayDecoder { struct_mode: StructMode, decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { - let (map_field, fields) = match data_type { - DataType::Map(_, true) => { - return Err(ArrowError::NotYetImplemented( - "Decoding MapArray with sorted fields".to_string(), - )); + let DataType::Map(f, false) = data_type else { + return Err(ArrowError::NotYetImplemented( + "Decoding MapArray with sorted fields".to_string(), + )); + }; + + // TODO: Once MSRV bumps to 1.88+, use an if-let chain to directly unpack the slice + let fields = match f.data_type() { + DataType::Struct(fields) if fields.len() == 2 => fields, + d => { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray must contain struct with two fields, got {d}" + ))); } - DataType::Map(f, _) => match f.data_type() { - DataType::Struct(fields) if fields.len() == 2 => (f, fields), - d => { - return Err(ArrowError::InvalidArgumentError(format!( - "MapArray must contain struct with two fields, got {d}" - ))); - } - }, - _ => unreachable!(), }; let (key_field, value_field) = (&fields[0], &fields[1]); @@ -83,54 +79,8 @@ impl MapArrayDecoder { decoder_factory, )?; - // Check if fields need modification - let key_has_metadata = key_field.metadata().contains_key(JSON_DECODER_CONFIG_KEY); - let value_has_metadata = value_field.metadata().contains_key(JSON_DECODER_CONFIG_KEY); - - let key_field = if key_has_metadata || keys.output_data_type().is_some() { - let key_type = keys.output_data_type().unwrap_or(key_field.data_type()); - let mut metadata = key_field.metadata().clone(); - metadata.remove(JSON_DECODER_CONFIG_KEY); - - let field = Field::new(key_field.name(), key_type.clone(), key_field.is_nullable()); - Cow::Owned(Arc::new(field.with_metadata(metadata))) - } else { - Cow::Borrowed(key_field) - }; - - let value_field = if value_has_metadata || values.output_data_type().is_some() { - let value_type = values.output_data_type().unwrap_or(value_field.data_type()); - let mut metadata = value_field.metadata().clone(); - metadata.remove(JSON_DECODER_CONFIG_KEY); - - let field = Field::new( - value_field.name(), - value_type.clone(), - value_field.is_nullable(), - ); - Cow::Owned(Arc::new(field.with_metadata(metadata))) - } else { - Cow::Borrowed(value_field) - }; - - let type_changed = - matches!(key_field, Cow::Owned(_)) || matches!(value_field, Cow::Owned(_)); - - let data_type = if type_changed { - let struct_fields = vec![key_field.into_owned(), value_field.into_owned()]; - let struct_field = Arc::new(Field::new( - map_field.name(), - DataType::Struct(struct_fields.into()), - map_field.is_nullable(), - )); - DataType::Map(struct_field, false) - } else { - data_type.clone() - }; - Ok(Self { - data_type, - type_changed, + data_type: data_type.clone(), keys, values, is_nullable, @@ -139,10 +89,6 @@ impl MapArrayDecoder { } impl ArrayDecoder for MapArrayDecoder { - fn output_data_type(&self) -> Option<&DataType> { - self.type_changed.then_some(&self.data_type) - } - fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let DataType::Map(f, _) = &self.data_type else { unreachable!() diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 7adf5f76669d..7dcefbbe3a92 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -395,28 +395,27 @@ impl RecordBatchReader for Reader { /// This allows overriding the default decoders for specific data types, /// or adding new decoders for custom data types. /// -/// # Examples +/// # Example /// /// ``` -/// use arrow_json::{ArrayDecoder, DecoderFactory, TapeElement, Tape, ReaderBuilder, StructMode}; -/// use arrow_schema::ArrowError; -/// use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; -/// use arrow_array::cast::AsArray; -/// use arrow_array::Array; -/// use arrow_array::builder::StringBuilder; -/// use arrow_data::ArrayData; -/// use std::sync::Arc; -/// -/// struct IncorrectStringAsNullDecoder {} +/// # use arrow_json::{ArrayDecoder, DecoderFactory, TapeElement, Tape, ReaderBuilder, StructMode}; +/// # use arrow_schema::ArrowError; +/// # use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::Array; +/// # use arrow_array::builder::StringBuilder; +/// # use arrow_data::ArrayData; +/// # use std::collections::HashMap; +/// # use std::sync::Arc; +/// # +/// struct IncorrectStringAsNullDecoder; /// /// impl ArrayDecoder for IncorrectStringAsNullDecoder { /// fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { /// let mut builder = StringBuilder::new(); /// for p in pos { /// match tape.get(*p) { -/// TapeElement::String(idx) => { -/// builder.append_value(tape.get_string(idx)); -/// } +/// TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), /// _ => builder.append_null(), /// } /// } @@ -436,10 +435,9 @@ impl RecordBatchReader for Reader { /// _coerce_primitive: bool, /// _strict_mode: bool, /// _struct_mode: StructMode, -/// _decoder_factory: Option<&dyn DecoderFactory>, /// ) -> Result>, ArrowError> { /// match data_type { -/// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder {}))), +/// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder))), /// _ => Ok(None), /// } /// } @@ -449,17 +447,14 @@ impl RecordBatchReader for Reader { /// {"a": "a"} /// {"a": 12} /// "#; -/// let batch = ReaderBuilder::new(Arc::new(Schema::new(Fields::from(vec![Field::new( -/// "a", -/// DataType::Utf8, -/// true, -/// )])))) -/// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory)) -/// .build(json.as_bytes()) -/// .unwrap() -/// .next() -/// .unwrap() -/// .unwrap(); +/// let fields = vec![Field::new("a", DataType::Utf8, true)]; +/// let batch = ReaderBuilder::new(Arc::new(Schema::new(fields.into()))) +/// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory)) +/// .build(json.as_bytes()) +/// .unwrap() +/// .next() +/// .unwrap() +/// .unwrap(); /// /// let values = batch.column(0).as_string::(); /// assert_eq!(values.len(), 2); @@ -471,25 +466,83 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// This can be used to override how e.g. error in decoding are handled. fn make_custom_decoder( &self, - _data_type: &DataType, - _is_nullable: bool, - _field_metadata: &HashMap, - _coerce_primitive: bool, - _strict_mode: bool, - _struct_mode: StructMode, - _decoder_factory: Option<&dyn DecoderFactory>, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + coerce_primitive: bool, + strict_mode: bool, + struct_mode: StructMode, ) -> Result>, ArrowError>; -} -/// -/// Metadata key for JSON decoder configuration. -/// -/// This well-known metadata key can be used to annotate schema fields with -/// custom decoding instructions. Custom [`DecoderFactory`] implementations -/// can inspect this metadata to determine how to decode specific fields. -/// -/// This metadata key is automatically stripped from the output array schema. -pub const JSON_DECODER_CONFIG_KEY: &str = "arrow-rs:json:decoder"; + /// Create a decoder for a type, without allowing this factory to directly intercept it, + /// but still allowing the factory to intercept children of complex types. + /// + /// Solves the delegation pattern: When a factory intercepts a type and wants to + /// delegate to the default implementation, calling `make_decoder` would cause + /// infinite recursion (the factory intercepts its own `make_decoder` call). This method + /// skips the factory check at the current level but still passes the factory through so + /// child fields are customized. + /// + /// # Example + /// + /// ``` + /// # use arrow_json::reader::{DecoderFactory, ArrayDecoder}; + /// # use arrow_json::StructMode; + /// # use arrow_schema::{DataType, ArrowError}; + /// # use std::collections::HashMap; + /// # + /// #[derive(Debug)] + /// struct NoOpFactory; + /// + /// impl DecoderFactory for NoOpFactory { + /// fn make_custom_decoder( + /// &self, + /// data_type: &DataType, + /// is_nullable: bool, + /// _field_metadata: &HashMap, + /// coerce_primitive: bool, + /// strict_mode: bool, + /// struct_mode: StructMode, + /// ) -> Result>, ArrowError> { + /// if matches!(data_type, DataType::Struct(_)) { + /// // Bypass self-interception, children still use this factory + /// let delegate = self.make_delegate_decoder( + /// data_type, + /// is_nullable, + /// coerce_primitive, + /// strict_mode, + /// struct_mode, + /// )?; + /// + /// // In real usage: wrap the delegate with some custom behavior + /// Ok(Some(delegate)) + /// } else { + /// Ok(None) + /// } + /// } + /// } + /// ``` + fn make_delegate_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + coerce_primitive: bool, + strict_mode: bool, + struct_mode: StructMode, + ) -> Result, ArrowError> + where + Self: Sized, + { + make_decoder_impl( + data_type, + is_nullable, + coerce_primitive, + strict_mode, + struct_mode, + Some(self), + ) + } +} /// A low-level interface for reading JSON data from a byte stream /// @@ -794,22 +847,6 @@ impl Decoder { /// A trait to decode JSON values into arrow arrays pub trait ArrayDecoder: Send { - /// Returns the output data type if it differs from the input data type. - /// - /// Custom decoders may produce output with a different data type than was - /// specified in the input schema. For example, a custom decoder might strip - /// extension metadata or change the type entirely. - /// - /// Returns: - /// - `Some(&DataType)` if the decoder's output type differs from its input - /// - `None` if the decoder produces the same type as its input - /// - /// This is used by parent decoders (Struct, List, Map) to determine whether - /// they need to update their output type to reflect child type changes. - fn output_data_type(&self) -> Option<&DataType> { - None - } - /// Decode elements from `tape` starting at the indexes contained in `pos` fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result; } @@ -833,7 +870,7 @@ macro_rules! primitive_decoder { /// /// * `data_type` - The Arrow data type to decode into /// * `is_nullable` - Whether the field is nullable -/// * `field_metadata` - Schema metadata for the field (can contain [`JSON_DECODER_CONFIG_KEY`]) +/// * `field_metadata` - Schema metadata for the field /// * `coerce_primitive` - Whether to coerce primitive types (e.g., string to number) /// * `strict_mode` - Whether to validate struct fields strictly /// * `struct_mode` - How to decode struct fields @@ -855,12 +892,31 @@ pub fn make_decoder( coerce_primitive, strict_mode, struct_mode, - decoder_factory, )? { return Ok(decoder); } } + make_decoder_impl( + data_type, + is_nullable, + coerce_primitive, + strict_mode, + struct_mode, + decoder_factory, + ) +} + +/// Private implementation of decoder creation that skips factory checks. +/// Used by both `make_decoder` and `DecoderFactory::make_delegate_decoder`. +fn make_decoder_impl( + data_type: &DataType, + is_nullable: bool, + coerce_primitive: bool, + strict_mode: bool, + struct_mode: StructMode, + decoder_factory: Option<&dyn DecoderFactory>, +) -> Result, ArrowError> { downcast_integer! { *data_type => (primitive_decoder, data_type), DataType::Null => Ok(Box::::default()), @@ -3013,7 +3069,6 @@ mod tests { _coerce_primitive: bool, _strict_mode: bool, _struct_mode: StructMode, - _decoder_factory: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { match data_type { DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder))), diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 39e663b60554..68e07b8bdf4e 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -16,14 +16,12 @@ // under the License. use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, JSON_DECODER_CONFIG_KEY, StructMode, make_decoder}; +use crate::reader::{ArrayDecoder, StructMode, make_decoder}; use arrow_array::builder::BooleanBufferBuilder; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType, Field, Fields}; -use std::borrow::Cow; +use arrow_schema::{ArrowError, DataType, Fields}; use std::collections::HashMap; -use std::sync::Arc; /// Reusable buffer for tape positions, indexed by (field_idx, row_idx). /// A value of 0 indicates the field is absent for that row. @@ -74,7 +72,6 @@ use super::DecoderFactory; pub struct StructArrayDecoder { data_type: DataType, - type_changed: bool, decoders: Vec>, strict_mode: bool, is_nullable: bool, @@ -94,55 +91,23 @@ impl StructArrayDecoder { ) -> Result { let fields = struct_fields(data_type); - // Create decoders and track whether any fields need modification let mut decoders = Vec::with_capacity(fields.len()); - let mut output_fields: Vec>> = Vec::with_capacity(fields.len()); - let mut type_changed = false; - for field in fields { // If this struct nullable, need to permit nullability in child array // StructArrayDecoder::decode verifies that if the child is not nullable // it doesn't contain any nulls not masked by its parent - let nullable = field.is_nullable() || is_nullable; let decoder = make_decoder( field.data_type(), - nullable, + field.is_nullable() || is_nullable, field.metadata(), coerce_primitive, strict_mode, struct_mode, decoder_factory, )?; - - // Check if field needs modification - let has_decoder_metadata = field.metadata().contains_key(JSON_DECODER_CONFIG_KEY); - let child_type_changed = decoder.output_data_type().is_some(); - - let field = if child_type_changed || has_decoder_metadata { - type_changed = true; - // Strip decoder metadata and update data type if needed - let data_type = decoder.output_data_type().unwrap_or(field.data_type()); - let mut metadata = field.metadata().clone(); - metadata.remove(JSON_DECODER_CONFIG_KEY); - - let field = Field::new(field.name(), data_type.clone(), field.is_nullable()); - Cow::Owned(Arc::new(field.with_metadata(metadata))) - } else { - Cow::Borrowed(field) - }; - decoders.push(decoder); - output_fields.push(field); } - // Only create new DataType if something actually changed - let data_type = if type_changed { - let owned: Vec> = output_fields.into_iter().map(Cow::into_owned).collect(); - DataType::Struct(owned.into()) - } else { - data_type.clone() - }; - let field_name_to_index = if struct_mode == StructMode::ObjectOnly { build_field_index(fields) } else { @@ -150,8 +115,7 @@ impl StructArrayDecoder { }; Ok(Self { - data_type, - type_changed, + data_type: data_type.clone(), decoders, strict_mode, is_nullable, @@ -163,10 +127,6 @@ impl StructArrayDecoder { } impl ArrayDecoder for StructArrayDecoder { - fn output_data_type(&self) -> Option<&DataType> { - self.type_changed.then_some(&self.data_type) - } - fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let fields = struct_fields(&self.data_type); let row_count = pos.len(); diff --git a/arrow-json/tests/custom_decoder_tests.rs b/arrow-json/tests/custom_decoder_tests.rs new file mode 100644 index 000000000000..2756fcb6db78 --- /dev/null +++ b/arrow-json/tests/custom_decoder_tests.rs @@ -0,0 +1,800 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for custom decoder functionality +//! +//! This test suite demonstrates various patterns for customizing JSON decoding: +//! 1. Type-based routing - customize all fields of a given type +//! 2. Annotation-based routing - customize specific fields marked with metadata +//! 3. Type-specific behavior - custom parsing logic for specific types +//! 4. Composition - combining multiple custom factories +//! 5. Delegation with interleaving - wrapping standard decoders when direct access not possible +//! 6. Path-based routing - customize specific paths in the schema tree + +use arrow_array::Array as _; +use arrow_array::builder::StringBuilder; +use arrow_array::cast::AsArray; +use arrow_data::ArrayData; +use arrow_json::reader::{ArrayDecoder, DecoderFactory}; +use arrow_json::{ReaderBuilder, StructMode, Tape, TapeElement}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, Schema}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::io::Cursor; +use std::sync::Arc; + +// ============================================================================ +// Test 1: Type-based lenient string decoder +// ============================================================================ + +/// A string decoder that converts type mismatches to NULL instead of erroring +struct LenientStringDecoder; + +impl ArrayDecoder for LenientStringDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut builder = StringBuilder::new(); + for p in pos { + match tape.get(*p) { + TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), + _ => builder.append_null(), + } + } + Ok(builder.finish().into_data()) + } +} + +/// A factory that applies the LenientStringDecoder to ALL Utf8 fields (type-based routing) +#[derive(Debug)] +struct TypeBasedLenientStringFactory; + +impl DecoderFactory for TypeBasedLenientStringFactory { + fn make_custom_decoder( + &self, + data_type: &DataType, + _is_nullable: bool, + _field_metadata: &HashMap, + _coerce_primitive: bool, + _strict_mode: bool, + _struct_mode: StructMode, + ) -> Result>, ArrowError> { + match data_type { + DataType::Utf8 => Ok(Some(Box::new(LenientStringDecoder))), + _ => Ok(None), + } + } +} + +#[test] +fn test_type_based_lenient_strings() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("email", DataType::Utf8, true), + ])); + + // JSON with type mismatches in BOTH string fields + let json = r#" + {"id": 1, "name": "Alice", "email": "alice@example.com"} + {"id": 2, "name": 42, "email": "bob@example.com"} + {"id": 3, "name": "Charlie", "email": true} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedLenientStringFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL + assert_eq!(names.value(2), "Charlie"); + + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + assert!(emails.is_null(2)); // true -> NULL +} + +// ============================================================================ +// Test 2: Annotation-based lenient string decoder +// ============================================================================ + +/// A factory that applies LenientStringDecoder only to annotated Utf8 fields +#[derive(Debug)] +struct AnnotatedLenientStringFactory; + +impl DecoderFactory for AnnotatedLenientStringFactory { + fn make_custom_decoder( + &self, + data_type: &DataType, + _is_nullable: bool, + field_metadata: &HashMap, + _coerce_primitive: bool, + _strict_mode: bool, + _struct_mode: StructMode, + ) -> Result>, ArrowError> { + let config = field_metadata + .get("test:decoder:config") + .map(|s| s.as_str()); + match data_type { + DataType::Utf8 if config == Some("lenient") => Ok(Some(Box::new(LenientStringDecoder))), + _ => Ok(None), + } + } +} + +#[test] +fn test_annotation_based_lenient_strings() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + // ONLY this field is marked lenient + Field::new("name", DataType::Utf8, true).with_metadata(HashMap::from([( + "test:decoder:config".to_string(), + "lenient".to_string(), + )])), + // This field is NOT marked, should use standard decoder + Field::new("email", DataType::Utf8, false), + ])); + + // JSON with type mismatches in BOTH string fields + let json_valid = r#" + {"id": 1, "name": "Alice", "email": "alice@example.com"} + {"id": 2, "name": 42, "email": "bob@example.com"} + "#; + + let reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(Arc::new(AnnotatedLenientStringFactory)) + .build(Cursor::new(json_valid.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL (lenient behavior) + + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + + // Negative test: email field without annotation should error on type mismatch + let json_invalid = r#" + {"id": 1, "name": "Alice", "email": true} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(AnnotatedLenientStringFactory)) + .build(Cursor::new(json_invalid.as_bytes())) + .unwrap(); + + let result = reader.into_iter().next().unwrap(); + assert!(result.is_err()); // email field errors on type mismatch +} + +// ============================================================================ +// Shared helpers for interleaved decoding pattern +// ============================================================================ + +/// A general-purpose interleaved decoder that routes positions to different decoders +/// based on a filter predicate, then interleaves the results back to original order. +/// +/// This pattern is useful when you want to customize behavior but need to delegate +/// to standard decoders (which you can't directly access the builders of). +/// +/// Note: This example uses `Fn(TapeElement) -> bool` for simplicity. A production +/// implementation might use `Fn(&Tape, u32) -> bool` to support filters that need +/// to examine list/object contents before routing (e.g., routing based on list length +/// or presence of discriminator fields). Downside is more complexity in simple cases. +struct InterleavedDecoder { + primary: Box, + fallback: Box, + filter: F, +} + +impl bool + Send> ArrayDecoder for InterleavedDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + use arrow_select::interleave::interleave; + + // Partition positions based on filter + let mut primary_pos = Vec::new(); + let mut fallback_pos = Vec::new(); + let mut indices = Vec::with_capacity(pos.len()); + + for &p in pos { + if (self.filter)(tape.get(p)) { + indices.push((0, primary_pos.len())); + primary_pos.push(p); + } else { + indices.push((1, fallback_pos.len())); + fallback_pos.push(p); + } + } + + // Decode both parts + let primary = self.primary.decode(tape, &primary_pos)?; + let fallback = self.fallback.decode(tape, &fallback_pos)?; + + // Convert to arrays for interleaving + let primary = arrow_array::make_array(primary); + let fallback = arrow_array::make_array(fallback); + + // Interleave back to original order + let result = interleave(&[primary.as_ref(), fallback.as_ref()], &indices)?; + Ok(result.into_data()) + } +} + +/// A decoder that always produces NULL values (useful as fallback in InterleavedDecoder) +struct NullDecoder { + data_type: DataType, +} + +impl ArrayDecoder for NullDecoder { + fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result { + Ok(arrow_array::new_null_array(&self.data_type, pos.len()).into_data()) + } +} + +// ============================================================================ +// Test 3: Lenient struct decoder (introduces InterleavedDecoder pattern) +// ============================================================================ + +/// A factory that makes ALL Struct fields lenient (type-based routing) +/// +/// This demonstrates the InterleavedDecoder pattern: we want to customize struct +/// handling but can't directly access the StructArrayDecoder's internal builder +/// and -- unlike string decoding -- the logic is too complex to replicate easily. +/// +/// Solution: partition positions and delegate to standard decoder + null decoder. +#[derive(Debug)] +struct TypeBasedLenientStructFactory; + +impl DecoderFactory for TypeBasedLenientStructFactory { + fn make_custom_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + _field_metadata: &HashMap, + coerce_primitive: bool, + strict_mode: bool, + struct_mode: StructMode, + ) -> Result>, ArrowError> { + if !matches!(data_type, DataType::Struct(_)) { + return Ok(None); + } + + // Create standard struct decoder (using make_delegate_decoder to avoid infinite recursion) + let primary = self.make_delegate_decoder( + data_type, + is_nullable, + coerce_primitive, + strict_mode, + struct_mode, + )?; + + // Create null decoder as fallback + let fallback = Box::new(NullDecoder { + data_type: data_type.clone(), + }); + + Ok(Some(Box::new(InterleavedDecoder { + primary, + fallback, + filter: |elem| matches!(elem, TapeElement::StartObject(_)), + }))) + } +} + +#[test] +fn test_type_based_lenient_structs() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "person", + DataType::Struct(vec![Field::new("name", DataType::Utf8, false)].into()), + true, + ), + ])); + + let json = r#" + {"id": 1, "person": {"name": "Alice"}} + {"id": 2, "person": "not a struct"} + {"id": 3, "person": 42} + {"id": 4, "person": null} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedLenientStructFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + let person = batch.column(1).as_struct(); + assert!(!person.is_null(0)); // Valid struct + assert!(person.is_null(1)); // "not a struct" -> NULL + assert!(person.is_null(2)); // 42 -> NULL + assert!(person.is_null(3)); // null -> NULL + + // Verify first row decoded correctly + let names = person.column(0).as_string::(); + assert_eq!(names.value(0), "Alice"); +} + +// ============================================================================ +// Test 4: Quirky string list decoder (reinforces InterleavedDecoder pattern) +// ============================================================================ + +/// Parse a string like "[a, b, c]" into an iterator of strings, e.g. "a", "b", "c" +fn parse_list_string(s: &str) -> Result + '_, ArrowError> { + let trimmed = s.trim(); + if !trimmed.starts_with('[') || !trimmed.ends_with(']') { + return Err(ArrowError::JsonError(format!( + "Failed to parse list string: {}", + s + ))); + } + let inner = &trimmed[1..trimmed.len() - 1]; + Ok(inner.split(',').map(|s| s.trim())) +} + +/// A decoder that ONLY handles string representations like "[a, b, c]" and parses them as lists. +/// Designed to be used with InterleavedDecoder (standard list decoder handles normal lists). +struct StringToListDecoder { + field: FieldRef, +} + +impl ArrayDecoder for StringToListDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + // Use with_field() to ensure the builder respects the schema's field (including nullability) + let mut builder = arrow_array::builder::ListBuilder::new(StringBuilder::new()) + .with_field(self.field.clone()); + + for p in pos { + let TapeElement::String(s) = tape.get(*p) else { + unreachable!("InterleavedDecoder filter should only route String elements here"); + }; + + // Parse string representation like "[x, y, z]" + for item in parse_list_string(tape.get_string(s))? { + builder.values().append_value(item); + } + builder.append(true); + } + + Ok(builder.finish().into_data()) + } +} + +/// A factory that makes ALL List fields quirky (type-based routing) +/// +/// Uses InterleavedDecoder to combine string parsing with standard list decoding. +#[derive(Debug)] +struct TypeBasedQuirkyListFactory; + +impl DecoderFactory for TypeBasedQuirkyListFactory { + fn make_custom_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + _field_metadata: &HashMap, + coerce_primitive: bool, + strict_mode: bool, + struct_mode: StructMode, + ) -> Result>, ArrowError> { + let field = match data_type { + DataType::List(f) if *f.data_type() == DataType::Utf8 => f.clone(), + _ => return Ok(None), + }; + + // Primary: parse string representations + // Fallback: standard list decoder for normal JSON lists + Ok(Some(Box::new(InterleavedDecoder { + primary: Box::new(StringToListDecoder { field }), + fallback: self.make_delegate_decoder( + data_type, + is_nullable, + coerce_primitive, + strict_mode, + struct_mode, + )?, + filter: |elem| matches!(elem, TapeElement::String(_)), + }))) + } +} + +#[test] +fn test_type_based_quirky_lists() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + ])); + + let json = r#" +{"id": 1, "tags": ["a", "b", "c"]} +{"id": 2, "tags": "[x, y, z]"} +{"id": 3, "tags": null} +"#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedQuirkyListFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let tags = batch.column(1).as_list::(); + + // First row: normal JSON list + let row0 = tags.value(0); + let row0 = row0.as_string::(); + assert_eq!(row0.len(), 3); + assert_eq!(row0.value(0), "a"); + assert_eq!(row0.value(1), "b"); + assert_eq!(row0.value(2), "c"); + + // Second row: parsed from string representation + let row1 = tags.value(1); + let row1 = row1.as_string::(); + assert_eq!(row1.len(), 3); + assert_eq!(row1.value(0), "x"); + assert_eq!(row1.value(1), "y"); + assert_eq!(row1.value(2), "z"); + + // Third row: null + assert!(tags.is_null(2)); +} + +// ============================================================================ +// Test 5: Composition - combining multiple type-based factories +// ============================================================================ + +/// A factory that tries multiple child factories in sequence +#[derive(Debug)] +struct ComposedDecoderFactory { + factories: Vec>, +} + +impl DecoderFactory for ComposedDecoderFactory { + fn make_custom_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + coerce_primitive: bool, + strict_mode: bool, + struct_mode: StructMode, + ) -> Result>, ArrowError> { + // Try each child factory in order until one returns Some or Err + for factory in &self.factories { + if let Some(decoder) = factory.make_custom_decoder( + data_type, + is_nullable, + field_metadata, + coerce_primitive, + strict_mode, + struct_mode, + )? { + return Ok(Some(decoder)); + } + } + Ok(None) + } +} + +#[test] +fn test_composed_type_based_factories() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + ])); + + let json = r#" +{"id": 1, "name": "Alice", "tags": ["a", "b"]} +{"id": 2, "name": 42, "tags": "[x, y]"} +{"id": 3, "name": "Bob", "tags": null} +"#; + + // Compose two type-based factories + let factories = vec![ + Arc::new(TypeBasedLenientStringFactory) as _, + Arc::new(TypeBasedQuirkyListFactory) as _, + ]; + let factory = Arc::new(ComposedDecoderFactory { factories }); + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(factory) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + // Verify lenient string behavior + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL + assert_eq!(names.value(2), "Bob"); + + // Verify quirky list behavior + let tags = batch.column(2).as_list::(); + + // Row 0: normal JSON list + let row0 = tags.value(0); + let row0 = row0.as_string::(); + assert_eq!(row0.len(), 2); + assert_eq!(row0.value(0), "a"); + assert_eq!(row0.value(1), "b"); + + // Row 1: parsed from string representation + let row1 = tags.value(1); + let row1 = row1.as_string::(); + assert_eq!(row1.len(), 2); + assert_eq!(row1.value(0), "x"); + assert_eq!(row1.value(1), "y"); + + // Row 2: null + assert!(tags.is_null(2)); +} + +// ============================================================================ +// Test 6: Path-based routing (pointer-identity-based) +// ============================================================================ + +/// Identity wrapper for DataType that uses pointer equality for HashMap lookup. +/// +/// Supports two modes: +/// - `FieldRef` variant: Stores ownership of a Field (for HashMap storage) +/// - `DataType` variant: Borrows a DataType temporarily (for HashMap lookup) +/// +/// Both variants compare by pointer identity of the DataType they reference. +/// +/// Safety: Pointer-identity comparison relies on DataType stability guarantees: +/// - We store the owning FieldRef (Arc) which keeps the Field alive +/// - We never call any potentially-mutating methods such as `Arc::get_mut` or `Arc::make_mut` +/// - We never share a reference to the FieldRef that could allow others to mutate it +/// - Our FieldRef ensures that anyone else who might attempt potentially-mutating operations +/// of the same Field through their own FieldRef will fail because `!Arc::is_unique()` +/// - The &DataType returned by `Field::data_type` is stable -- no interior mutability +/// such as `Mutex>` that could move it to a new memory location +/// +/// NOTE: We never dereference the raw pointer values used for comparison. A violation +/// of the above would only produce incorrect HashMap lookups (false positives/negatives). +#[derive(Debug)] +enum DataTypeIdentity<'a> { + FieldRef(FieldRef), + DataType(&'a DataType), +} + +impl<'a> DataTypeIdentity<'a> { + /// Extract the raw DataType pointer for identity comparison. + fn as_ptr(&self) -> *const DataType { + match self { + DataTypeIdentity::FieldRef(f) => f.data_type(), + DataTypeIdentity::DataType(dt) => *dt, + } + } +} + +impl<'a> Hash for DataTypeIdentity<'a> { + fn hash(&self, state: &mut H) { + self.as_ptr().hash(state); + } +} + +impl<'a> PartialEq for DataTypeIdentity<'a> { + fn eq(&self, other: &Self) -> bool { + self.as_ptr() == other.as_ptr() + } +} + +impl<'a> Eq for DataTypeIdentity<'a> {} + +/// A factory that routes to custom decoders based on specific field paths. +/// +/// This allows fine-grained control: customize specific fields by name without +/// polluting the schema with metadata or affecting all fields of a given type. +#[derive(Debug)] +struct PathBasedDecoderFactory { + // Maps DataTypeIdentity::FieldRef to factory + routes: HashMap, Arc>, +} + +impl PathBasedDecoderFactory { + /// Create a new path-based factory by mapping field paths to factories. + /// + /// Takes a reference to the schema fields and a map of paths to factories. + /// Paths can be nested using dot notation: "field", "struct.nested", "struct.deep.nested" + fn new(fields: &Fields, path_routes: HashMap<&str, Arc>) -> Self { + // Walk the fields and associate DataTypeIdentity::FieldRef with factory for O(1) lookup + let mut routes = HashMap::new(); + for (path, factory) in path_routes { + let parts: Vec<&str> = path.split('.').collect(); + if let Some(field) = Self::find_field_by_path(fields, &parts) { + routes.insert(DataTypeIdentity::FieldRef(field), factory); + } + } + + Self { routes } + } + + /// Recursively find a Field by following a path of field names. + fn find_field_by_path(fields: &Fields, path: &[&str]) -> Option { + let (first, rest) = path.split_first()?; + let field = fields.iter().find(|f| f.name() == *first)?; + + if rest.is_empty() { + // End of path - return this field + return Some(field.clone()); + } + + // Path continues - attempt to recurse into a nested struct + let DataType::Struct(children) = field.data_type() else { + return None; + }; + + Self::find_field_by_path(children, rest) + } +} + +impl DecoderFactory for PathBasedDecoderFactory { + fn make_custom_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + coerce_primitive: bool, + strict_mode: bool, + struct_mode: StructMode, + ) -> Result>, ArrowError> { + // O(1) lookup using temporary DataType variant for identity comparison + let key = DataTypeIdentity::DataType(data_type); + let Some(factory) = self.routes.get(&key) else { + return Ok(None); + }; + + // Delegate to the route-specific factory + factory.make_custom_decoder( + data_type, + is_nullable, + field_metadata, + coerce_primitive, + strict_mode, + struct_mode, + ) + } +} + +#[test] +fn test_path_based_routing() { + // Create schema with both flat and nested String fields + let metadata_fields = Fields::from(vec![ + Field::new("source", DataType::Utf8, true), + Field::new("comment", DataType::Utf8, true), // Nested path: metadata.comment + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("email", DataType::Utf8, true), + Field::new("metadata", DataType::Struct(metadata_fields.clone()), true), + ])); + + // Create a lenient string factory (will be applied to specific paths only) + let lenient_factory = Arc::new(TypeBasedLenientStringFactory); + + // Build routes: ONLY the nested "metadata.comment" field gets lenient handling + // Demonstrates nested path routing with dot notation + let mut path_routes = HashMap::new(); + path_routes.insert( + "metadata.comment", + lenient_factory.clone() as Arc, + ); + + let factory = Arc::new(PathBasedDecoderFactory::new(schema.fields(), path_routes)); + + // JSON with type mismatches in multiple string fields + let json = r#" +{"id": 1, "name": "Alice", "email": "alice@example.com", "metadata": {"source": "web", "comment": "Good"}} +{"id": 2, "name": 42, "email": "bob@example.com", "metadata": {"source": "api", "comment": 100}} +{"id": 3, "name": "Charlie", "email": 999, "metadata": {"source": "mobile", "comment": "Excellent"}} +"#; + + let mut reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(factory) + .build(Cursor::new(json)) + .unwrap(); + + // The decode should FAIL because "name" and "email" don't have lenient handling + // Only "metadata.comment" is lenient, but "name" has a type mismatch in row 2 + let result = reader.next().unwrap(); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("name") || err.contains("expected string")); + + // Now let's test with data that only has issues in the lenient nested field + let json_lenient_only = r#" +{"id": 1, "name": "Alice", "email": "alice@example.com", "metadata": {"source": "web", "comment": "Good"}} +{"id": 2, "name": "Bob", "email": "bob@example.com", "metadata": {"source": "api", "comment": 100}} +{"id": 3, "name": "Charlie", "email": "charlie@example.com", "metadata": {"source": "mobile", "comment": true}} +"#; + + // Rebuild routes for the second test + let mut path_routes2 = HashMap::new(); + path_routes2.insert( + "metadata.comment", + lenient_factory as Arc, + ); + + let mut reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(Arc::new(PathBasedDecoderFactory::new( + schema.fields(), + path_routes2, + ))) + .build(Cursor::new(json_lenient_only)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + // Verify all fields + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 4); + + // ID column: all valid + let ids = batch + .column(0) + .as_primitive::(); + assert_eq!(ids.value(0), 1); + assert_eq!(ids.value(1), 2); + assert_eq!(ids.value(2), 3); + + // Name column: all valid (no type mismatches) + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(names.value(1), "Bob"); + assert_eq!(names.value(2), "Charlie"); + + // Email column: all valid (no type mismatches) + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + assert_eq!(emails.value(2), "charlie@example.com"); + + // Metadata struct column + let metadata = batch.column(3).as_struct(); + + // metadata.source: all valid (no lenient handling) + let sources = metadata.column(0).as_string::(); + assert_eq!(sources.value(0), "web"); + assert_eq!(sources.value(1), "api"); + assert_eq!(sources.value(2), "mobile"); + + // metadata.comment: LENIENT - non-strings become NULL + let comments = metadata.column(1).as_string::(); + assert_eq!(comments.value(0), "Good"); + assert!(comments.is_null(1)); // 100 -> NULL + assert!(comments.is_null(2)); // true -> NULL +} diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index 2ebc677f31e8..cdb9a556d389 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -79,13 +79,10 @@ impl DecoderFactory for VariantArrayDecoderFactory { _coerce_primitive: bool, _strict_mode: bool, _struct_mode: StructMode, - _decoder_factory: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { // Check if this is a Variant extension type using metadata - if VariantType::try_from_parts(field_metadata, data_type).is_ok() { - return Ok(Some(Box::new(VariantArrayDecoder))); - } - Ok(None) + let result = VariantType::try_from_parts(field_metadata, data_type); + Ok(result.ok().map(|_| Box::new(VariantArrayDecoder) as _)) } } From a5cad3f89fe63771cdbdf7b9c33a954f9b81f856 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Sat, 24 Jan 2026 05:26:09 -0800 Subject: [PATCH 12/14] Fix plumbing for nested factories --- arrow-json/src/reader/mod.rs | 22 ++++-- arrow-json/tests/custom_decoder_tests.rs | 92 +++++++++++++++++++++++- 2 files changed, 107 insertions(+), 7 deletions(-) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 7dcefbbe3a92..cf57c3e37a24 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -472,10 +472,11 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { coerce_primitive: bool, strict_mode: bool, struct_mode: StructMode, + factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError>; /// Create a decoder for a type, without allowing this factory to directly intercept it, - /// but still allowing the factory to intercept children of complex types. + /// but still allowing the specified factory to intercept children of complex types. /// /// Solves the delegation pattern: When a factory intercepts a type and wants to /// delegate to the default implementation, calling `make_decoder` would cause @@ -483,6 +484,10 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// skips the factory check at the current level but still passes the factory through so /// child fields are customized. /// + /// # Parameters + /// + /// * `decoder_factory` - The factory to use for child decoders of complex types + /// /// # Example /// /// ``` @@ -503,10 +508,15 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// coerce_primitive: bool, /// strict_mode: bool, /// struct_mode: StructMode, + /// decoder_factory: Option<&dyn DecoderFactory>, /// ) -> Result>, ArrowError> { /// if matches!(data_type, DataType::Struct(_)) { - /// // Bypass self-interception, children still use this factory - /// let delegate = self.make_delegate_decoder( + /// // Decide which factory children should use + /// let factory = decoder_factory.unwrap_or(self); + /// + /// // Bypass self-interception, children use the specified factory + /// let delegate = Self::make_delegate_decoder( + /// factory, /// data_type, /// is_nullable, /// coerce_primitive, @@ -523,7 +533,7 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// } /// ``` fn make_delegate_decoder( - &self, + decoder_factory: &dyn DecoderFactory, data_type: &DataType, is_nullable: bool, coerce_primitive: bool, @@ -539,7 +549,7 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { coerce_primitive, strict_mode, struct_mode, - Some(self), + Some(decoder_factory), ) } } @@ -892,6 +902,7 @@ pub fn make_decoder( coerce_primitive, strict_mode, struct_mode, + decoder_factory, )? { return Ok(decoder); } @@ -3069,6 +3080,7 @@ mod tests { _coerce_primitive: bool, _strict_mode: bool, _struct_mode: StructMode, + _factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { match data_type { DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder))), diff --git a/arrow-json/tests/custom_decoder_tests.rs b/arrow-json/tests/custom_decoder_tests.rs index 2756fcb6db78..b0f12e876080 100644 --- a/arrow-json/tests/custom_decoder_tests.rs +++ b/arrow-json/tests/custom_decoder_tests.rs @@ -70,6 +70,7 @@ impl DecoderFactory for TypeBasedLenientStringFactory { _coerce_primitive: bool, _strict_mode: bool, _struct_mode: StructMode, + _factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { match data_type { DataType::Utf8 => Ok(Some(Box::new(LenientStringDecoder))), @@ -128,6 +129,7 @@ impl DecoderFactory for AnnotatedLenientStringFactory { _coerce_primitive: bool, _strict_mode: bool, _struct_mode: StructMode, + _factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { let config = field_metadata .get("test:decoder:config") @@ -274,13 +276,15 @@ impl DecoderFactory for TypeBasedLenientStructFactory { coerce_primitive: bool, strict_mode: bool, struct_mode: StructMode, + factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { if !matches!(data_type, DataType::Struct(_)) { return Ok(None); } // Create standard struct decoder (using make_delegate_decoder to avoid infinite recursion) - let primary = self.make_delegate_decoder( + let primary = Self::make_delegate_decoder( + factory_override.unwrap_or(self), data_type, is_nullable, coerce_primitive, @@ -396,6 +400,7 @@ impl DecoderFactory for TypeBasedQuirkyListFactory { coerce_primitive: bool, strict_mode: bool, struct_mode: StructMode, + factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { let field = match data_type { DataType::List(f) if *f.data_type() == DataType::Utf8 => f.clone(), @@ -406,7 +411,8 @@ impl DecoderFactory for TypeBasedQuirkyListFactory { // Fallback: standard list decoder for normal JSON lists Ok(Some(Box::new(InterleavedDecoder { primary: Box::new(StringToListDecoder { field }), - fallback: self.make_delegate_decoder( + fallback: Self::make_delegate_decoder( + factory_override.unwrap_or(self), data_type, is_nullable, coerce_primitive, @@ -483,6 +489,7 @@ impl DecoderFactory for ComposedDecoderFactory { coerce_primitive: bool, strict_mode: bool, struct_mode: StructMode, + factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { // Try each child factory in order until one returns Some or Err for factory in &self.factories { @@ -493,6 +500,7 @@ impl DecoderFactory for ComposedDecoderFactory { coerce_primitive, strict_mode, struct_mode, + Some(factory_override.unwrap_or(self)), )? { return Ok(Some(decoder)); } @@ -669,6 +677,7 @@ impl DecoderFactory for PathBasedDecoderFactory { coerce_primitive: bool, strict_mode: bool, struct_mode: StructMode, + factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { // O(1) lookup using temporary DataType variant for identity comparison let key = DataTypeIdentity::DataType(data_type); @@ -684,6 +693,7 @@ impl DecoderFactory for PathBasedDecoderFactory { coerce_primitive, strict_mode, struct_mode, + Some(factory_override.unwrap_or(self)), ) } } @@ -798,3 +808,81 @@ fn test_path_based_routing() { assert!(comments.is_null(1)); // 100 -> NULL assert!(comments.is_null(2)); // true -> NULL } + +// ============================================================================ +// Test 7: Recursive factory propagation +// ============================================================================ + +#[test] +fn test_recursive_factory_propagation() { + // Schema with nested struct containing string fields + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "person", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), // Make nullable + Field::new("email", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ])); + + let json = r#" +{"id": 1, "person": {"name": "Alice", "email": "alice@example.com"}} +{"id": 2, "person": 42} +{"id": 3, "person": {"name": 123, "email": "charlie@example.com"}} +{"id": 4, "person": {"name": "Dave", "email": true}} +"#; + + // Compose lenient struct + lenient string factories + // This tests that the factory propagates through struct decoder creation + let factories = vec![ + Arc::new(TypeBasedLenientStructFactory) as Arc, + Arc::new(TypeBasedLenientStringFactory) as Arc, + ]; + let factory = Arc::new(ComposedDecoderFactory { factories }); + + let mut reader = ReaderBuilder::new(schema) + .with_decoder_factory(factory) + .build(Cursor::new(json)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + // ID column: all valid + let ids = batch + .column(0) + .as_primitive::(); + assert_eq!(ids.value(0), 1); + assert_eq!(ids.value(1), 2); + assert_eq!(ids.value(2), 3); + assert_eq!(ids.value(3), 4); + + // Person struct column + let person = batch.column(1).as_struct(); + + // Row 0: Valid struct with valid strings + assert!(!person.is_null(0)); + let names = person.column(0).as_string::(); + let emails = person.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(emails.value(0), "alice@example.com"); + + // Row 1: Not a struct (42) -> NULL from struct factory + assert!(person.is_null(1)); + + // Row 2: Valid struct but name has type mismatch (123) + // This tests that the string factory was invoked for the nested field! + assert!(!person.is_null(2)); + assert!(names.is_null(2)); // 123 -> NULL from string factory + assert_eq!(emails.value(2), "charlie@example.com"); + + // Row 3: Valid struct but email has type mismatch (true) + assert!(!person.is_null(3)); + assert_eq!(names.value(3), "Dave"); + assert!(emails.is_null(3)); // true -> NULL from string factory +} From 41583ef907a1366364ee57cb2d5b9b4bf72ba473 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Mon, 26 Jan 2026 09:00:27 -0800 Subject: [PATCH 13/14] got it all working nice! --- arrow-json/src/reader/list_array.rs | 20 +- arrow-json/src/reader/map_array.rs | 22 +- arrow-json/src/reader/mod.rs | 285 ++++++++++------------- arrow-json/src/reader/struct_array.rs | 18 +- arrow-json/tests/custom_decoder_tests.rs | 90 ++----- parquet-variant-compute/src/decoder.rs | 9 +- 6 files changed, 159 insertions(+), 285 deletions(-) diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index ecb7c7bbaf27..3a39c40da1e6 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, make_decoder}; +use crate::reader::{ArrayDecoder, DecoderContext}; use arrow_array::OffsetSizeTrait; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::buffer::NullBuffer; @@ -25,8 +24,6 @@ use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; use std::marker::PhantomData; -use super::DecoderFactory; - pub struct ListArrayDecoder { data_type: DataType, decoder: Box, @@ -36,27 +33,16 @@ pub struct ListArrayDecoder { impl ListArrayDecoder { pub fn new( + ctx: &DecoderContext, data_type: &DataType, - coerce_primitive: bool, - strict_mode: bool, is_nullable: bool, - struct_mode: StructMode, - decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { let field = match data_type { DataType::List(f) if !O::IS_LARGE => f, DataType::LargeList(f) if O::IS_LARGE => f, _ => unreachable!(), }; - let decoder = make_decoder( - field.data_type(), - field.is_nullable(), - field.metadata(), - coerce_primitive, - strict_mode, - struct_mode, - decoder_factory, - )?; + let decoder = ctx.make_decoder(field.data_type(), field.is_nullable(), field.metadata())?; Ok(Self { data_type: data_type.clone(), diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index e3fd6d6572be..8abc3916d2b3 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -15,17 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::StructMode; use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, make_decoder}; +use crate::reader::{ArrayDecoder, DecoderContext}; use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; use arrow_buffer::ArrowNativeType; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; -use super::DecoderFactory; - pub struct MapArrayDecoder { data_type: DataType, keys: Box, @@ -35,12 +32,9 @@ pub struct MapArrayDecoder { impl MapArrayDecoder { pub fn new( + ctx: &DecoderContext, data_type: &DataType, - coerce_primitive: bool, - strict_mode: bool, is_nullable: bool, - struct_mode: StructMode, - decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { let DataType::Map(f, false) = data_type else { return Err(ArrowError::NotYetImplemented( @@ -60,23 +54,15 @@ impl MapArrayDecoder { let (key_field, value_field) = (&fields[0], &fields[1]); - let keys = make_decoder( + let keys = ctx.make_decoder( key_field.data_type(), key_field.is_nullable(), key_field.metadata(), - coerce_primitive, - strict_mode, - struct_mode, - decoder_factory, )?; - let values = make_decoder( + let values = ctx.make_decoder( value_field.data_type(), value_field.is_nullable(), value_field.metadata(), - coerce_primitive, - strict_mode, - struct_mode, - decoder_factory, )?; Ok(Self { diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index cf57c3e37a24..443b1943e1ab 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -318,15 +318,14 @@ impl ReaderBuilder { (data_type, false, &empty_metadata) }; - let decoder = make_decoder( - data_type.as_ref(), - nullable, - metadata, - self.coerce_primitive, - self.strict_mode, - self.struct_mode, - self.decoder_factory.as_deref(), - )?; + let ctx = DecoderContext { + coerce_primitive: self.coerce_primitive, + strict_mode: self.strict_mode, + struct_mode: self.struct_mode, + decoder_factory: self.decoder_factory.as_deref(), + }; + + let decoder = ctx.make_decoder(data_type.as_ref(), nullable, metadata)?; let num_fields = self.schema.flattened_fields().len(); @@ -398,7 +397,8 @@ impl RecordBatchReader for Reader { /// # Example /// /// ``` -/// # use arrow_json::{ArrayDecoder, DecoderFactory, TapeElement, Tape, ReaderBuilder, StructMode}; +/// # use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; +/// # use arrow_json::{ReaderBuilder, TapeElement, Tape}; /// # use arrow_schema::ArrowError; /// # use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; /// # use arrow_array::cast::AsArray; @@ -429,12 +429,10 @@ impl RecordBatchReader for Reader { /// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { /// fn make_custom_decoder( /// &self, +/// _ctx: &DecoderContext, /// data_type: &DataType, /// _is_nullable: bool, /// _field_metadata: &HashMap, -/// _coerce_primitive: bool, -/// _strict_mode: bool, -/// _struct_mode: StructMode, /// ) -> Result>, ArrowError> { /// match data_type { /// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder))), @@ -448,7 +446,7 @@ impl RecordBatchReader for Reader { /// {"a": 12} /// "#; /// let fields = vec![Field::new("a", DataType::Utf8, true)]; -/// let batch = ReaderBuilder::new(Arc::new(Schema::new(fields.into()))) +/// let batch = ReaderBuilder::new(Arc::new(Schema::new(fields))) /// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory)) /// .build(json.as_bytes()) /// .unwrap() @@ -466,92 +464,11 @@ pub trait DecoderFactory: std::fmt::Debug + Send + Sync { /// This can be used to override how e.g. error in decoding are handled. fn make_custom_decoder( &self, + ctx: &DecoderContext, data_type: &DataType, is_nullable: bool, field_metadata: &HashMap, - coerce_primitive: bool, - strict_mode: bool, - struct_mode: StructMode, - factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError>; - - /// Create a decoder for a type, without allowing this factory to directly intercept it, - /// but still allowing the specified factory to intercept children of complex types. - /// - /// Solves the delegation pattern: When a factory intercepts a type and wants to - /// delegate to the default implementation, calling `make_decoder` would cause - /// infinite recursion (the factory intercepts its own `make_decoder` call). This method - /// skips the factory check at the current level but still passes the factory through so - /// child fields are customized. - /// - /// # Parameters - /// - /// * `decoder_factory` - The factory to use for child decoders of complex types - /// - /// # Example - /// - /// ``` - /// # use arrow_json::reader::{DecoderFactory, ArrayDecoder}; - /// # use arrow_json::StructMode; - /// # use arrow_schema::{DataType, ArrowError}; - /// # use std::collections::HashMap; - /// # - /// #[derive(Debug)] - /// struct NoOpFactory; - /// - /// impl DecoderFactory for NoOpFactory { - /// fn make_custom_decoder( - /// &self, - /// data_type: &DataType, - /// is_nullable: bool, - /// _field_metadata: &HashMap, - /// coerce_primitive: bool, - /// strict_mode: bool, - /// struct_mode: StructMode, - /// decoder_factory: Option<&dyn DecoderFactory>, - /// ) -> Result>, ArrowError> { - /// if matches!(data_type, DataType::Struct(_)) { - /// // Decide which factory children should use - /// let factory = decoder_factory.unwrap_or(self); - /// - /// // Bypass self-interception, children use the specified factory - /// let delegate = Self::make_delegate_decoder( - /// factory, - /// data_type, - /// is_nullable, - /// coerce_primitive, - /// strict_mode, - /// struct_mode, - /// )?; - /// - /// // In real usage: wrap the delegate with some custom behavior - /// Ok(Some(delegate)) - /// } else { - /// Ok(None) - /// } - /// } - /// } - /// ``` - fn make_delegate_decoder( - decoder_factory: &dyn DecoderFactory, - data_type: &DataType, - is_nullable: bool, - coerce_primitive: bool, - strict_mode: bool, - struct_mode: StructMode, - ) -> Result, ArrowError> - where - Self: Sized, - { - make_decoder_impl( - data_type, - is_nullable, - coerce_primitive, - strict_mode, - struct_mode, - Some(decoder_factory), - ) - } } /// A low-level interface for reading JSON data from a byte stream @@ -861,73 +778,126 @@ pub trait ArrayDecoder: Send { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result; } -macro_rules! primitive_decoder { - ($t:ty, $data_type:expr) => { - Ok(Box::new(PrimitiveArrayDecoder::<$t>::new($data_type))) - }; -} - -/// Creates an [`ArrayDecoder`] for decoding JSON values into Arrow arrays. -/// -/// This function is the primary entry point for constructing decoders. It first -/// attempts to use a custom [`DecoderFactory`] if provided, then falls back to -/// default decoder implementations based on the data type. +/// Context for decoder creation, containing configuration and factory reference. /// -/// Custom decoders can use this function to recursively create child decoders -/// for complex types by passing along the `decoder_factory` parameter. -/// -/// # Arguments -/// -/// * `data_type` - The Arrow data type to decode into -/// * `is_nullable` - Whether the field is nullable -/// * `field_metadata` - Schema metadata for the field -/// * `coerce_primitive` - Whether to coerce primitive types (e.g., string to number) -/// * `strict_mode` - Whether to validate struct fields strictly -/// * `struct_mode` - How to decode struct fields -/// * `decoder_factory` - Optional custom decoder factory for recursion -pub fn make_decoder( - data_type: &DataType, - is_nullable: bool, - field_metadata: &HashMap, +/// This context is passed through the decoder creation process and contains +/// all the configuration needed to create decoders recursively. +pub struct DecoderContext<'a> { + /// Whether to coerce primitives to strings coerce_primitive: bool, + /// Whether to validate struct fields strictly strict_mode: bool, + /// How to decode struct fields struct_mode: StructMode, - decoder_factory: Option<&dyn DecoderFactory>, -) -> Result, ArrowError> { - if let Some(factory) = decoder_factory { - if let Some(decoder) = factory.make_custom_decoder( - data_type, - is_nullable, - field_metadata, - coerce_primitive, - strict_mode, - struct_mode, - decoder_factory, - )? { - return Ok(decoder); + /// Optional custom decoder factory + decoder_factory: Option<&'a dyn DecoderFactory>, +} + +impl DecoderContext<'_> { + /// Returns whether to coerce primitive types (e.g., number to string) + pub fn coerce_primitive(&self) -> bool { + self.coerce_primitive + } + + /// Returns whether to validate struct fields strictly + pub fn strict_mode(&self) -> bool { + self.strict_mode + } + + /// Returns how to decode struct fields + pub fn struct_mode(&self) -> StructMode { + self.struct_mode + } + + /// Create a decoder for a type, allowing the factory to intercept it. + /// + /// This is the standard way to create child decoders from within a decoder + /// implementation. The factory (if present) will be given the opportunity + /// to intercept and customize the decoder. + pub fn make_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result, ArrowError> { + if let Some(factory) = self.decoder_factory { + if let Some(decoder) = + factory.make_custom_decoder(self, data_type, is_nullable, field_metadata)? + { + return Ok(decoder); + } } + + make_decoder(self, data_type, is_nullable) + } + + /// Create a decoder for a type without allowing the factory to intercept it directly, + /// but still allowing the factory to intercept children of complex types. + /// + /// This is used by custom factories when they want to delegate to the standard + /// decoder implementation while still customizing child decoders. + /// + /// # Example + /// + /// When a factory intercepts a type and wants to wrap or modify the standard decoder, + /// calling `ctx.make_decoder()` would cause infinite recursion (the factory would + /// intercept its own call). This method bypasses the factory check at the current + /// level but still passes the context through so child fields can be customized. + /// + /// ``` + /// # use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; + /// # use arrow_json::TapeElement; + /// # use arrow_schema::{DataType, ArrowError}; + /// # use std::collections::HashMap; + /// # + /// #[derive(Debug)] + /// struct StructWrapperFactory; + /// + /// impl DecoderFactory for StructWrapperFactory { + /// fn make_custom_decoder( + /// &self, + /// ctx: &DecoderContext, + /// data_type: &DataType, + /// is_nullable: bool, + /// _field_metadata: &HashMap, + /// ) -> Result>, ArrowError> { + /// if matches!(data_type, DataType::Struct(_)) { + /// // Get standard struct decoder, bypassing self-interception + /// let delegate = ctx.make_delegate_decoder(data_type, is_nullable)?; + /// + /// // In real usage: wrap delegate with custom behavior + /// Ok(Some(delegate)) + /// } else { + /// Ok(None) + /// } + /// } + /// } + /// ``` + pub fn make_delegate_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + ) -> Result, ArrowError> { + make_decoder(self, data_type, is_nullable) } +} - make_decoder_impl( - data_type, - is_nullable, - coerce_primitive, - strict_mode, - struct_mode, - decoder_factory, - ) +macro_rules! primitive_decoder { + ($t:ty, $data_type:expr) => { + Ok(Box::new(PrimitiveArrayDecoder::<$t>::new($data_type))) + }; } -/// Private implementation of decoder creation that skips factory checks. -/// Used by both `make_decoder` and `DecoderFactory::make_delegate_decoder`. -fn make_decoder_impl( +/// Private workhorse function for decoder creation. +/// +/// Constructs the appropriate decoder for the given data type without +/// checking the factory. All decoder construction logic lives here. +fn make_decoder( + ctx: &DecoderContext, data_type: &DataType, is_nullable: bool, - coerce_primitive: bool, - strict_mode: bool, - struct_mode: StructMode, - decoder_factory: Option<&dyn DecoderFactory>, ) -> Result, ArrowError> { + let coerce_primitive = ctx.coerce_primitive(); downcast_integer! { *data_type => (primitive_decoder, data_type), DataType::Null => Ok(Box::::default()), @@ -980,14 +950,14 @@ fn make_decoder_impl( DataType::Utf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), DataType::Utf8View => Ok(Box::new(StringViewArrayDecoder::new(coerce_primitive))), DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), - DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), - DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), - DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), + DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(ctx, data_type, is_nullable)?)), + DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(ctx, data_type, is_nullable)?)), + DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(ctx, data_type, is_nullable)?)), DataType::Binary => Ok(Box::new(BinaryArrayDecoder::::default())), DataType::LargeBinary => Ok(Box::new(BinaryArrayDecoder::::default())), DataType::FixedSizeBinary(len) => Ok(Box::new(FixedSizeBinaryArrayDecoder::new(len))), DataType::BinaryView => Ok(Box::new(BinaryViewDecoder::default())), - DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), + DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(ctx, data_type, is_nullable)?)), _ => Err(ArrowError::NotYetImplemented(format!("Support for {data_type} in JSON reader"))) } } @@ -3074,13 +3044,10 @@ mod tests { impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { fn make_custom_decoder( &self, + _ctx: &crate::reader::DecoderContext, data_type: &DataType, _is_nullable: bool, _field_metadata: &HashMap, - _coerce_primitive: bool, - _strict_mode: bool, - _struct_mode: StructMode, - _factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { match data_type { DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder))), diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 68e07b8bdf4e..968da5db2e9e 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -16,7 +16,7 @@ // under the License. use crate::reader::tape::{Tape, TapeElement}; -use crate::reader::{ArrayDecoder, StructMode, make_decoder}; +use crate::reader::{ArrayDecoder, DecoderContext, StructMode}; use arrow_array::builder::BooleanBufferBuilder; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; @@ -68,8 +68,6 @@ impl FieldTapePositions { } } -use super::DecoderFactory; - pub struct StructArrayDecoder { data_type: DataType, decoders: Vec>, @@ -82,13 +80,11 @@ pub struct StructArrayDecoder { impl StructArrayDecoder { pub fn new( + ctx: &DecoderContext, data_type: &DataType, - coerce_primitive: bool, - strict_mode: bool, is_nullable: bool, - struct_mode: StructMode, - decoder_factory: Option<&dyn DecoderFactory>, ) -> Result { + let struct_mode = ctx.struct_mode(); let fields = struct_fields(data_type); let mut decoders = Vec::with_capacity(fields.len()); @@ -96,14 +92,10 @@ impl StructArrayDecoder { // If this struct nullable, need to permit nullability in child array // StructArrayDecoder::decode verifies that if the child is not nullable // it doesn't contain any nulls not masked by its parent - let decoder = make_decoder( + let decoder = ctx.make_decoder( field.data_type(), field.is_nullable() || is_nullable, field.metadata(), - coerce_primitive, - strict_mode, - struct_mode, - decoder_factory, )?; decoders.push(decoder); } @@ -117,7 +109,7 @@ impl StructArrayDecoder { Ok(Self { data_type: data_type.clone(), decoders, - strict_mode, + strict_mode: ctx.strict_mode(), is_nullable, struct_mode, field_name_to_index, diff --git a/arrow-json/tests/custom_decoder_tests.rs b/arrow-json/tests/custom_decoder_tests.rs index b0f12e876080..df67b91ff4bc 100644 --- a/arrow-json/tests/custom_decoder_tests.rs +++ b/arrow-json/tests/custom_decoder_tests.rs @@ -29,8 +29,8 @@ use arrow_array::Array as _; use arrow_array::builder::StringBuilder; use arrow_array::cast::AsArray; use arrow_data::ArrayData; -use arrow_json::reader::{ArrayDecoder, DecoderFactory}; -use arrow_json::{ReaderBuilder, StructMode, Tape, TapeElement}; +use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; +use arrow_json::{ReaderBuilder, Tape, TapeElement}; use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, Schema}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -64,13 +64,10 @@ struct TypeBasedLenientStringFactory; impl DecoderFactory for TypeBasedLenientStringFactory { fn make_custom_decoder( &self, + _ctx: &DecoderContext, data_type: &DataType, _is_nullable: bool, _field_metadata: &HashMap, - _coerce_primitive: bool, - _strict_mode: bool, - _struct_mode: StructMode, - _factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { match data_type { DataType::Utf8 => Ok(Some(Box::new(LenientStringDecoder))), @@ -123,13 +120,10 @@ struct AnnotatedLenientStringFactory; impl DecoderFactory for AnnotatedLenientStringFactory { fn make_custom_decoder( &self, + _ctx: &DecoderContext, data_type: &DataType, _is_nullable: bool, field_metadata: &HashMap, - _coerce_primitive: bool, - _strict_mode: bool, - _struct_mode: StructMode, - _factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { let config = field_metadata .get("test:decoder:config") @@ -270,36 +264,21 @@ struct TypeBasedLenientStructFactory; impl DecoderFactory for TypeBasedLenientStructFactory { fn make_custom_decoder( &self, + ctx: &DecoderContext, data_type: &DataType, is_nullable: bool, _field_metadata: &HashMap, - coerce_primitive: bool, - strict_mode: bool, - struct_mode: StructMode, - factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { if !matches!(data_type, DataType::Struct(_)) { return Ok(None); } - // Create standard struct decoder (using make_delegate_decoder to avoid infinite recursion) - let primary = Self::make_delegate_decoder( - factory_override.unwrap_or(self), - data_type, - is_nullable, - coerce_primitive, - strict_mode, - struct_mode, - )?; - - // Create null decoder as fallback - let fallback = Box::new(NullDecoder { - data_type: data_type.clone(), - }); - + // Delegate to a standard struct decoder for objects, with a null decoder as fallback. Ok(Some(Box::new(InterleavedDecoder { - primary, - fallback, + primary: ctx.make_delegate_decoder(data_type, is_nullable)?, + fallback: Box::new(NullDecoder { + data_type: data_type.clone(), + }), filter: |elem| matches!(elem, TapeElement::StartObject(_)), }))) } @@ -394,31 +373,20 @@ struct TypeBasedQuirkyListFactory; impl DecoderFactory for TypeBasedQuirkyListFactory { fn make_custom_decoder( &self, + ctx: &DecoderContext, data_type: &DataType, is_nullable: bool, _field_metadata: &HashMap, - coerce_primitive: bool, - strict_mode: bool, - struct_mode: StructMode, - factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { let field = match data_type { DataType::List(f) if *f.data_type() == DataType::Utf8 => f.clone(), _ => return Ok(None), }; - // Primary: parse string representations - // Fallback: standard list decoder for normal JSON lists + // Intercept and attempt to parse strings; all others fall back to standard list decoder Ok(Some(Box::new(InterleavedDecoder { primary: Box::new(StringToListDecoder { field }), - fallback: Self::make_delegate_decoder( - factory_override.unwrap_or(self), - data_type, - is_nullable, - coerce_primitive, - strict_mode, - struct_mode, - )?, + fallback: ctx.make_delegate_decoder(data_type, is_nullable)?, filter: |elem| matches!(elem, TapeElement::String(_)), }))) } @@ -483,25 +451,16 @@ struct ComposedDecoderFactory { impl DecoderFactory for ComposedDecoderFactory { fn make_custom_decoder( &self, + ctx: &DecoderContext, data_type: &DataType, is_nullable: bool, field_metadata: &HashMap, - coerce_primitive: bool, - strict_mode: bool, - struct_mode: StructMode, - factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { // Try each child factory in order until one returns Some or Err for factory in &self.factories { - if let Some(decoder) = factory.make_custom_decoder( - data_type, - is_nullable, - field_metadata, - coerce_primitive, - strict_mode, - struct_mode, - Some(factory_override.unwrap_or(self)), - )? { + if let Some(decoder) = + factory.make_custom_decoder(ctx, data_type, is_nullable, field_metadata)? + { return Ok(Some(decoder)); } } @@ -671,13 +630,10 @@ impl PathBasedDecoderFactory { impl DecoderFactory for PathBasedDecoderFactory { fn make_custom_decoder( &self, + ctx: &DecoderContext, data_type: &DataType, is_nullable: bool, field_metadata: &HashMap, - coerce_primitive: bool, - strict_mode: bool, - struct_mode: StructMode, - factory_override: Option<&dyn DecoderFactory>, ) -> Result>, ArrowError> { // O(1) lookup using temporary DataType variant for identity comparison let key = DataTypeIdentity::DataType(data_type); @@ -686,15 +642,7 @@ impl DecoderFactory for PathBasedDecoderFactory { }; // Delegate to the route-specific factory - factory.make_custom_decoder( - data_type, - is_nullable, - field_metadata, - coerce_primitive, - strict_mode, - struct_mode, - Some(factory_override.unwrap_or(self)), - ) + factory.make_custom_decoder(ctx, data_type, is_nullable, field_metadata) } } diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs index cdb9a556d389..0e51cd370f69 100644 --- a/parquet-variant-compute/src/decoder.rs +++ b/parquet-variant-compute/src/decoder.rs @@ -18,15 +18,12 @@ use crate::{VariantArrayBuilder, VariantType}; use arrow_array::{Array, StructArray}; use arrow_data::ArrayData; -use arrow_json::{DecoderFactory, StructMode}; +use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory, Tape, TapeElement}; use arrow_schema::extension::ExtensionType; use arrow_schema::{ArrowError, DataType}; use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt}; use std::collections::HashMap; -use arrow_json::reader::ArrayDecoder; -use arrow_json::reader::{Tape, TapeElement}; - /// An [`ArrayDecoder`] implementation that decodes JSON values into a Variant array. /// /// This decoder converts JSON tape elements (parsed JSON tokens) into Parquet Variant @@ -73,12 +70,10 @@ pub struct VariantArrayDecoderFactory; impl DecoderFactory for VariantArrayDecoderFactory { fn make_custom_decoder( &self, + _ctx: &DecoderContext, data_type: &DataType, _is_nullable: bool, field_metadata: &HashMap, - _coerce_primitive: bool, - _strict_mode: bool, - _struct_mode: StructMode, ) -> Result>, ArrowError> { // Check if this is a Variant extension type using metadata let result = VariantType::try_from_parts(field_metadata, data_type); From 1994cd53488bfdc9730a0b808e2405fc1f64395f Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Mon, 26 Jan 2026 09:16:16 -0800 Subject: [PATCH 14/14] tidy up --- arrow-flight/tests/flight_sql_client_cli.rs | 2 +- arrow-json/src/reader/struct_array.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index 716c626d6adb..c161caae8ca4 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -48,7 +48,7 @@ const QUERY: &str = "SELECT * FROM table;"; /// Return a Command instance for running the `flight_sql_client` CLI fn flight_sql_client_cmd() -> Command { - Command::new(assert_cmd::cargo::cargo_bin("flight_sql_client")) + Command::new(assert_cmd::cargo::cargo_bin!("flight_sql_client")) } #[tokio::test] diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 968da5db2e9e..f8f64f8fe166 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -264,10 +264,10 @@ impl ArrayDecoder for StructArrayDecoder { } fn struct_fields(data_type: &DataType) -> &Fields { - let DataType::Struct(f) = data_type else { - unreachable!() - }; - f + match &data_type { + DataType::Struct(f) => f, + _ => unreachable!(), + } } fn build_field_index(fields: &Fields) -> Option> {