diff --git a/arrow-json/Cargo.toml b/arrow-json/Cargo.toml index 5fcde480eb6d..82a4b8b7a32d 100644 --- a/arrow-json/Cargo.toml +++ b/arrow-json/Cargo.toml @@ -54,6 +54,7 @@ ryu = "1.0" itoa = "1.0" [dev-dependencies] +arrow-select = { workspace = true } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } serde = { version = "1.0", default-features = false, features = ["derive"] } futures = "0.3" diff --git a/arrow-json/src/lib.rs b/arrow-json/src/lib.rs index 1b18e0094708..a53ca74fd618 100644 --- a/arrow-json/src/lib.rs +++ b/arrow-json/src/lib.rs @@ -86,7 +86,7 @@ pub mod reader; pub mod writer; -pub use self::reader::{Reader, ReaderBuilder}; +pub use self::reader::{ArrayDecoder, DecoderFactory, Reader, ReaderBuilder, Tape, TapeElement}; pub use self::writer::{ ArrayWriter, Encoder, EncoderFactory, EncoderOptions, LineDelimitedWriter, Writer, WriterBuilder, diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index d363b6be9780..3a39c40da1e6 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -42,7 +42,7 @@ impl ListArrayDecoder { DataType::LargeList(f) if O::IS_LARGE => f, _ => unreachable!(), }; - let decoder = ctx.make_decoder(field.data_type(), field.is_nullable())?; + let decoder = ctx.make_decoder(field.data_type(), field.is_nullable(), field.metadata())?; Ok(Self { data_type: data_type.clone(), diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index ff22b588c510..4c87a694e5c9 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -53,8 +53,16 @@ impl MapArrayDecoder { _ => unreachable!(), }; - let keys = ctx.make_decoder(fields[0].data_type(), fields[0].is_nullable())?; - let values = ctx.make_decoder(fields[1].data_type(), fields[1].is_nullable())?; + let keys = ctx.make_decoder( + fields[0].data_type(), + fields[0].is_nullable(), + fields[0].metadata(), + )?; + let values = ctx.make_decoder( + fields[1].data_type(), + fields[1].is_nullable(), + fields[1].metadata(), + )?; Ok(Self { data_type: data_type.clone(), diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 1a63ea75653f..db351df55eb6 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -138,6 +138,7 @@ use crate::reader::binary_array::{ BinaryArrayDecoder, BinaryViewDecoder, FixedSizeBinaryArrayDecoder, }; use std::borrow::Cow; +use std::collections::HashMap; use std::io::BufRead; use std::sync::Arc; @@ -150,6 +151,7 @@ use arrow_array::{RecordBatch, RecordBatchReader, StructArray, downcast_integer, use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; +pub use tape::*; use crate::reader::boolean_array::BooleanArrayDecoder; use crate::reader::decimal_array::DecimalArrayDecoder; @@ -160,7 +162,6 @@ use crate::reader::primitive_array::PrimitiveArrayDecoder; use crate::reader::string_array::StringArrayDecoder; use crate::reader::string_view_array::StringViewArrayDecoder; use crate::reader::struct_array::StructArrayDecoder; -use crate::reader::tape::{Tape, TapeDecoder}; use crate::reader::timestamp_array::TimestampArrayDecoder; mod binary_array; @@ -185,6 +186,7 @@ pub struct ReaderBuilder { strict_mode: bool, is_field: bool, struct_mode: StructMode, + decoder_factory: Option>, schema: SchemaRef, } @@ -206,6 +208,7 @@ impl ReaderBuilder { is_field: false, struct_mode: Default::default(), schema, + decoder_factory: None, } } @@ -219,7 +222,7 @@ impl ReaderBuilder { /// # use arrow_array::cast::AsArray; /// # use arrow_array::types::Int32Type; /// # use arrow_json::ReaderBuilder; - /// # use arrow_schema::{DataType, Field}; + /// # use arrow_schema::{DataType, Field, FieldRef}; /// // Root of JSON schema is a numeric type /// let data = "1\n2\n3\n"; /// let field = Arc::new(Field::new("int", DataType::Int32, true)); @@ -247,6 +250,7 @@ impl ReaderBuilder { is_field: true, struct_mode: Default::default(), schema: Arc::new(Schema::new([field.into()])), + decoder_factory: None, } } @@ -286,6 +290,14 @@ impl ReaderBuilder { } } + /// Set an optional hook for customizing decoding behavior. + pub fn with_decoder_factory(self, decoder_factory: Arc) -> Self { + Self { + decoder_factory: Some(decoder_factory), + ..self + } + } + /// Create a [`Reader`] with the provided [`BufRead`] pub fn build(self, reader: R) -> Result, ArrowError> { Ok(Reader { @@ -296,21 +308,23 @@ impl ReaderBuilder { /// Create a [`Decoder`] pub fn build_decoder(self) -> Result { - let (data_type, nullable) = if self.is_field { + let empty_metadata = HashMap::default(); + let (data_type, nullable, metadata) = if self.is_field { let field = &self.schema.fields[0]; let data_type = Cow::Borrowed(field.data_type()); - (data_type, field.is_nullable()) + (data_type, field.is_nullable(), field.metadata()) } else { let data_type = Cow::Owned(DataType::Struct(self.schema.fields.clone())); - (data_type, false) + (data_type, false, &empty_metadata) }; let ctx = DecoderContext { coerce_primitive: self.coerce_primitive, strict_mode: self.strict_mode, struct_mode: self.struct_mode, + decoder_factory: self.decoder_factory.as_deref(), }; - let decoder = ctx.make_decoder(data_type.as_ref(), nullable)?; + let decoder = ctx.make_decoder(data_type.as_ref(), nullable, metadata)?; let num_fields = self.schema.flattened_fields().len(); @@ -374,6 +388,88 @@ impl RecordBatchReader for Reader { } } +/// A trait to create custom decoders for specific data types. +/// +/// This allows overriding the default decoders for specific data types, +/// or adding new decoders for custom data types. +/// +/// # Example +/// +/// ``` +/// # use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; +/// # use arrow_json::{ReaderBuilder, TapeElement, Tape}; +/// # use arrow_schema::ArrowError; +/// # use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::Array; +/// # use arrow_array::builder::StringBuilder; +/// # use arrow_data::ArrayData; +/// # use std::collections::HashMap; +/// # use std::sync::Arc; +/// # +/// struct IncorrectStringAsNullDecoder; +/// +/// impl ArrayDecoder for IncorrectStringAsNullDecoder { +/// fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { +/// let mut builder = StringBuilder::new(); +/// for p in pos { +/// match tape.get(*p) { +/// TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), +/// _ => builder.append_null(), +/// } +/// } +/// Ok(builder.finish().into_data()) +/// } +/// } +/// +/// #[derive(Debug)] +/// struct IncorrectStringAsNullDecoderFactory; +/// +/// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { +/// fn make_custom_decoder( +/// &self, +/// _ctx: &DecoderContext, +/// data_type: &DataType, +/// _is_nullable: bool, +/// _field_metadata: &HashMap, +/// ) -> Result>, ArrowError> { +/// match data_type { +/// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder))), +/// _ => Ok(None), +/// } +/// } +/// } +/// +/// let json = r#" +/// {"a": "a"} +/// {"a": 12} +/// "#; +/// let fields = vec![Field::new("a", DataType::Utf8, true)]; +/// let batch = ReaderBuilder::new(Arc::new(Schema::new(fields))) +/// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory)) +/// .build(json.as_bytes()) +/// .unwrap() +/// .next() +/// .unwrap() +/// .unwrap(); +/// +/// let values = batch.column(0).as_string::(); +/// assert_eq!(values.len(), 2); +/// assert_eq!(values.value(0), "a"); +/// assert!(values.is_null(1)); +/// ``` +pub trait DecoderFactory: std::fmt::Debug + Send + Sync { + /// Make a decoder that overrides the default decoder for a specific data type. + /// This can be used to override how e.g. error in decoding are handled. + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError>; +} + /// A low-level interface for reading JSON data from a byte stream /// /// See [`Reader`] for a higher-level interface for interface with [`BufRead`] @@ -675,25 +771,28 @@ impl Decoder { } } -trait ArrayDecoder: Send { +/// A trait to decode JSON values into arrow arrays +pub trait ArrayDecoder: Send { /// Decode elements from `tape` starting at the indexes contained in `pos` fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result; } -/// Context for decoder creation, containing configuration. +/// Context for decoder creation, containing configuration and factory reference. /// /// This context is passed through the decoder creation process and contains /// all the configuration needed to create decoders recursively. -pub struct DecoderContext { +pub struct DecoderContext<'a> { /// Whether to coerce primitives to strings coerce_primitive: bool, /// Whether to validate struct fields strictly strict_mode: bool, /// How to decode struct fields struct_mode: StructMode, + /// Optional custom decoder factory + decoder_factory: Option<&'a dyn DecoderFactory>, } -impl DecoderContext { +impl DecoderContext<'_> { /// Returns whether to coerce primitive types (e.g., number to string) pub fn coerce_primitive(&self) -> bool { self.coerce_primitive @@ -709,11 +808,71 @@ impl DecoderContext { self.struct_mode } - /// Create a decoder for a type. + /// Create a decoder for a type, allowing the factory to intercept it. /// /// This is the standard way to create child decoders from within a decoder - /// implementation. - fn make_decoder( + /// implementation. The factory (if present) will be given the opportunity + /// to intercept and customize the decoder. + pub fn make_decoder( + &self, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result, ArrowError> { + if let Some(factory) = self.decoder_factory { + if let Some(decoder) = + factory.make_custom_decoder(self, data_type, is_nullable, field_metadata)? + { + return Ok(decoder); + } + } + + make_decoder(self, data_type, is_nullable) + } + + /// Create a decoder for a type without allowing the factory to intercept it directly, + /// but still allowing the factory to intercept children of complex types. + /// + /// This is used by custom factories when they want to delegate to the standard + /// decoder implementation while still customizing child decoders. + /// + /// # Example + /// + /// When a factory intercepts a type and wants to wrap or modify the standard decoder, + /// calling `ctx.make_decoder()` would cause infinite recursion (the factory would + /// intercept its own call). This method bypasses the factory check at the current + /// level but still passes the context through so child fields can be customized. + /// + /// ``` + /// # use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; + /// # use arrow_json::TapeElement; + /// # use arrow_schema::{DataType, ArrowError}; + /// # use std::collections::HashMap; + /// # + /// #[derive(Debug)] + /// struct StructWrapperFactory; + /// + /// impl DecoderFactory for StructWrapperFactory { + /// fn make_custom_decoder( + /// &self, + /// ctx: &DecoderContext, + /// data_type: &DataType, + /// is_nullable: bool, + /// _field_metadata: &HashMap, + /// ) -> Result>, ArrowError> { + /// if matches!(data_type, DataType::Struct(_)) { + /// // Get standard struct decoder, bypassing self-interception + /// let delegate = ctx.make_delegate_decoder(data_type, is_nullable)?; + /// + /// // In real usage: wrap delegate with custom behavior + /// Ok(Some(delegate)) + /// } else { + /// Ok(None) + /// } + /// } + /// } + /// ``` + pub fn make_delegate_decoder( &self, data_type: &DataType, is_nullable: bool, @@ -728,6 +887,10 @@ macro_rules! primitive_decoder { }; } +/// Private workhorse function for decoder creation. +/// +/// Constructs the appropriate decoder for the given data type without +/// checking the factory. All decoder construction logic lives here. fn make_decoder( ctx: &DecoderContext, data_type: &DataType, @@ -2857,4 +3020,67 @@ mod tests { "Json error: whilst decoding field 'a': failed to parse \"a\" as Int32".to_owned() ); } + + #[test] + fn test_decoder_factory() { + use arrow_array::builder; + + struct AlwaysNullStringArrayDecoder; + + impl ArrayDecoder for AlwaysNullStringArrayDecoder { + fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut builder = builder::StringBuilder::new(); + for _ in pos { + builder.append_null(); + } + Ok(builder.finish().into_data()) + } + } + + #[derive(Debug)] + struct AlwaysNullStringArrayDecoderFactory; + + impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { + fn make_custom_decoder( + &self, + _ctx: &crate::reader::DecoderContext, + data_type: &DataType, + _is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + match data_type { + DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder))), + _ => Ok(None), + } + } + } + + let buf = r#" + {"a": "1", "b": 2} + {"a": "hello", "b": 23} + "#; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + ])); + + let batches = ReaderBuilder::new(schema.clone()) + .with_batch_size(2) + .with_decoder_factory(Arc::new(AlwaysNullStringArrayDecoderFactory)) + .build(Cursor::new(buf.as_bytes())) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(batches.len(), 1); + + let col1 = batches[0].column(0).as_string::(); + assert_eq!(col1.null_count(), 2); + assert!(col1.is_null(0)); + assert!(col1.is_null(1)); + + let col2 = batches[0].column(1).as_primitive::(); + assert_eq!(col2.value(0), 2); + assert_eq!(col2.value(1), 23); + } } diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 9191afb8e639..18c5eea9cec6 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -92,7 +92,7 @@ impl StructArrayDecoder { // StructArrayDecoder::decode verifies that if the child is not nullable // it doesn't contain any nulls not masked by its parent let nullable = f.is_nullable() || is_nullable; - ctx.make_decoder(f.data_type(), nullable) + ctx.make_decoder(f.data_type(), nullable, f.metadata()) }) .collect::, ArrowError>>()?; diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 89ee3f778765..fcab173ef110 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -338,6 +338,7 @@ impl TapeDecoder { } } + /// Decodes JSON data from the provided buffer, returning the number of bytes consumed pub fn decode(&mut self, buf: &[u8]) -> Result { let mut iter = BufIter::new(buf); diff --git a/arrow-json/tests/custom_decoder_tests.rs b/arrow-json/tests/custom_decoder_tests.rs new file mode 100644 index 000000000000..df67b91ff4bc --- /dev/null +++ b/arrow-json/tests/custom_decoder_tests.rs @@ -0,0 +1,836 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for custom decoder functionality +//! +//! This test suite demonstrates various patterns for customizing JSON decoding: +//! 1. Type-based routing - customize all fields of a given type +//! 2. Annotation-based routing - customize specific fields marked with metadata +//! 3. Type-specific behavior - custom parsing logic for specific types +//! 4. Composition - combining multiple custom factories +//! 5. Delegation with interleaving - wrapping standard decoders when direct access not possible +//! 6. Path-based routing - customize specific paths in the schema tree + +use arrow_array::Array as _; +use arrow_array::builder::StringBuilder; +use arrow_array::cast::AsArray; +use arrow_data::ArrayData; +use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory}; +use arrow_json::{ReaderBuilder, Tape, TapeElement}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, Schema}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::io::Cursor; +use std::sync::Arc; + +// ============================================================================ +// Test 1: Type-based lenient string decoder +// ============================================================================ + +/// A string decoder that converts type mismatches to NULL instead of erroring +struct LenientStringDecoder; + +impl ArrayDecoder for LenientStringDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut builder = StringBuilder::new(); + for p in pos { + match tape.get(*p) { + TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), + _ => builder.append_null(), + } + } + Ok(builder.finish().into_data()) + } +} + +/// A factory that applies the LenientStringDecoder to ALL Utf8 fields (type-based routing) +#[derive(Debug)] +struct TypeBasedLenientStringFactory; + +impl DecoderFactory for TypeBasedLenientStringFactory { + fn make_custom_decoder( + &self, + _ctx: &DecoderContext, + data_type: &DataType, + _is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + match data_type { + DataType::Utf8 => Ok(Some(Box::new(LenientStringDecoder))), + _ => Ok(None), + } + } +} + +#[test] +fn test_type_based_lenient_strings() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("email", DataType::Utf8, true), + ])); + + // JSON with type mismatches in BOTH string fields + let json = r#" + {"id": 1, "name": "Alice", "email": "alice@example.com"} + {"id": 2, "name": 42, "email": "bob@example.com"} + {"id": 3, "name": "Charlie", "email": true} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedLenientStringFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL + assert_eq!(names.value(2), "Charlie"); + + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + assert!(emails.is_null(2)); // true -> NULL +} + +// ============================================================================ +// Test 2: Annotation-based lenient string decoder +// ============================================================================ + +/// A factory that applies LenientStringDecoder only to annotated Utf8 fields +#[derive(Debug)] +struct AnnotatedLenientStringFactory; + +impl DecoderFactory for AnnotatedLenientStringFactory { + fn make_custom_decoder( + &self, + _ctx: &DecoderContext, + data_type: &DataType, + _is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + let config = field_metadata + .get("test:decoder:config") + .map(|s| s.as_str()); + match data_type { + DataType::Utf8 if config == Some("lenient") => Ok(Some(Box::new(LenientStringDecoder))), + _ => Ok(None), + } + } +} + +#[test] +fn test_annotation_based_lenient_strings() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + // ONLY this field is marked lenient + Field::new("name", DataType::Utf8, true).with_metadata(HashMap::from([( + "test:decoder:config".to_string(), + "lenient".to_string(), + )])), + // This field is NOT marked, should use standard decoder + Field::new("email", DataType::Utf8, false), + ])); + + // JSON with type mismatches in BOTH string fields + let json_valid = r#" + {"id": 1, "name": "Alice", "email": "alice@example.com"} + {"id": 2, "name": 42, "email": "bob@example.com"} + "#; + + let reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(Arc::new(AnnotatedLenientStringFactory)) + .build(Cursor::new(json_valid.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL (lenient behavior) + + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + + // Negative test: email field without annotation should error on type mismatch + let json_invalid = r#" + {"id": 1, "name": "Alice", "email": true} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(AnnotatedLenientStringFactory)) + .build(Cursor::new(json_invalid.as_bytes())) + .unwrap(); + + let result = reader.into_iter().next().unwrap(); + assert!(result.is_err()); // email field errors on type mismatch +} + +// ============================================================================ +// Shared helpers for interleaved decoding pattern +// ============================================================================ + +/// A general-purpose interleaved decoder that routes positions to different decoders +/// based on a filter predicate, then interleaves the results back to original order. +/// +/// This pattern is useful when you want to customize behavior but need to delegate +/// to standard decoders (which you can't directly access the builders of). +/// +/// Note: This example uses `Fn(TapeElement) -> bool` for simplicity. A production +/// implementation might use `Fn(&Tape, u32) -> bool` to support filters that need +/// to examine list/object contents before routing (e.g., routing based on list length +/// or presence of discriminator fields). Downside is more complexity in simple cases. +struct InterleavedDecoder { + primary: Box, + fallback: Box, + filter: F, +} + +impl bool + Send> ArrayDecoder for InterleavedDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + use arrow_select::interleave::interleave; + + // Partition positions based on filter + let mut primary_pos = Vec::new(); + let mut fallback_pos = Vec::new(); + let mut indices = Vec::with_capacity(pos.len()); + + for &p in pos { + if (self.filter)(tape.get(p)) { + indices.push((0, primary_pos.len())); + primary_pos.push(p); + } else { + indices.push((1, fallback_pos.len())); + fallback_pos.push(p); + } + } + + // Decode both parts + let primary = self.primary.decode(tape, &primary_pos)?; + let fallback = self.fallback.decode(tape, &fallback_pos)?; + + // Convert to arrays for interleaving + let primary = arrow_array::make_array(primary); + let fallback = arrow_array::make_array(fallback); + + // Interleave back to original order + let result = interleave(&[primary.as_ref(), fallback.as_ref()], &indices)?; + Ok(result.into_data()) + } +} + +/// A decoder that always produces NULL values (useful as fallback in InterleavedDecoder) +struct NullDecoder { + data_type: DataType, +} + +impl ArrayDecoder for NullDecoder { + fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result { + Ok(arrow_array::new_null_array(&self.data_type, pos.len()).into_data()) + } +} + +// ============================================================================ +// Test 3: Lenient struct decoder (introduces InterleavedDecoder pattern) +// ============================================================================ + +/// A factory that makes ALL Struct fields lenient (type-based routing) +/// +/// This demonstrates the InterleavedDecoder pattern: we want to customize struct +/// handling but can't directly access the StructArrayDecoder's internal builder +/// and -- unlike string decoding -- the logic is too complex to replicate easily. +/// +/// Solution: partition positions and delegate to standard decoder + null decoder. +#[derive(Debug)] +struct TypeBasedLenientStructFactory; + +impl DecoderFactory for TypeBasedLenientStructFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + if !matches!(data_type, DataType::Struct(_)) { + return Ok(None); + } + + // Delegate to a standard struct decoder for objects, with a null decoder as fallback. + Ok(Some(Box::new(InterleavedDecoder { + primary: ctx.make_delegate_decoder(data_type, is_nullable)?, + fallback: Box::new(NullDecoder { + data_type: data_type.clone(), + }), + filter: |elem| matches!(elem, TapeElement::StartObject(_)), + }))) + } +} + +#[test] +fn test_type_based_lenient_structs() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "person", + DataType::Struct(vec![Field::new("name", DataType::Utf8, false)].into()), + true, + ), + ])); + + let json = r#" + {"id": 1, "person": {"name": "Alice"}} + {"id": 2, "person": "not a struct"} + {"id": 3, "person": 42} + {"id": 4, "person": null} + "#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedLenientStructFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + let person = batch.column(1).as_struct(); + assert!(!person.is_null(0)); // Valid struct + assert!(person.is_null(1)); // "not a struct" -> NULL + assert!(person.is_null(2)); // 42 -> NULL + assert!(person.is_null(3)); // null -> NULL + + // Verify first row decoded correctly + let names = person.column(0).as_string::(); + assert_eq!(names.value(0), "Alice"); +} + +// ============================================================================ +// Test 4: Quirky string list decoder (reinforces InterleavedDecoder pattern) +// ============================================================================ + +/// Parse a string like "[a, b, c]" into an iterator of strings, e.g. "a", "b", "c" +fn parse_list_string(s: &str) -> Result + '_, ArrowError> { + let trimmed = s.trim(); + if !trimmed.starts_with('[') || !trimmed.ends_with(']') { + return Err(ArrowError::JsonError(format!( + "Failed to parse list string: {}", + s + ))); + } + let inner = &trimmed[1..trimmed.len() - 1]; + Ok(inner.split(',').map(|s| s.trim())) +} + +/// A decoder that ONLY handles string representations like "[a, b, c]" and parses them as lists. +/// Designed to be used with InterleavedDecoder (standard list decoder handles normal lists). +struct StringToListDecoder { + field: FieldRef, +} + +impl ArrayDecoder for StringToListDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + // Use with_field() to ensure the builder respects the schema's field (including nullability) + let mut builder = arrow_array::builder::ListBuilder::new(StringBuilder::new()) + .with_field(self.field.clone()); + + for p in pos { + let TapeElement::String(s) = tape.get(*p) else { + unreachable!("InterleavedDecoder filter should only route String elements here"); + }; + + // Parse string representation like "[x, y, z]" + for item in parse_list_string(tape.get_string(s))? { + builder.values().append_value(item); + } + builder.append(true); + } + + Ok(builder.finish().into_data()) + } +} + +/// A factory that makes ALL List fields quirky (type-based routing) +/// +/// Uses InterleavedDecoder to combine string parsing with standard list decoding. +#[derive(Debug)] +struct TypeBasedQuirkyListFactory; + +impl DecoderFactory for TypeBasedQuirkyListFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + _field_metadata: &HashMap, + ) -> Result>, ArrowError> { + let field = match data_type { + DataType::List(f) if *f.data_type() == DataType::Utf8 => f.clone(), + _ => return Ok(None), + }; + + // Intercept and attempt to parse strings; all others fall back to standard list decoder + Ok(Some(Box::new(InterleavedDecoder { + primary: Box::new(StringToListDecoder { field }), + fallback: ctx.make_delegate_decoder(data_type, is_nullable)?, + filter: |elem| matches!(elem, TapeElement::String(_)), + }))) + } +} + +#[test] +fn test_type_based_quirky_lists() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + ])); + + let json = r#" +{"id": 1, "tags": ["a", "b", "c"]} +{"id": 2, "tags": "[x, y, z]"} +{"id": 3, "tags": null} +"#; + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(Arc::new(TypeBasedQuirkyListFactory)) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + let tags = batch.column(1).as_list::(); + + // First row: normal JSON list + let row0 = tags.value(0); + let row0 = row0.as_string::(); + assert_eq!(row0.len(), 3); + assert_eq!(row0.value(0), "a"); + assert_eq!(row0.value(1), "b"); + assert_eq!(row0.value(2), "c"); + + // Second row: parsed from string representation + let row1 = tags.value(1); + let row1 = row1.as_string::(); + assert_eq!(row1.len(), 3); + assert_eq!(row1.value(0), "x"); + assert_eq!(row1.value(1), "y"); + assert_eq!(row1.value(2), "z"); + + // Third row: null + assert!(tags.is_null(2)); +} + +// ============================================================================ +// Test 5: Composition - combining multiple type-based factories +// ============================================================================ + +/// A factory that tries multiple child factories in sequence +#[derive(Debug)] +struct ComposedDecoderFactory { + factories: Vec>, +} + +impl DecoderFactory for ComposedDecoderFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + // Try each child factory in order until one returns Some or Err + for factory in &self.factories { + if let Some(decoder) = + factory.make_custom_decoder(ctx, data_type, is_nullable, field_metadata)? + { + return Ok(Some(decoder)); + } + } + Ok(None) + } +} + +#[test] +fn test_composed_type_based_factories() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + ])); + + let json = r#" +{"id": 1, "name": "Alice", "tags": ["a", "b"]} +{"id": 2, "name": 42, "tags": "[x, y]"} +{"id": 3, "name": "Bob", "tags": null} +"#; + + // Compose two type-based factories + let factories = vec![ + Arc::new(TypeBasedLenientStringFactory) as _, + Arc::new(TypeBasedQuirkyListFactory) as _, + ]; + let factory = Arc::new(ComposedDecoderFactory { factories }); + + let reader = ReaderBuilder::new(schema) + .with_decoder_factory(factory) + .build(Cursor::new(json.as_bytes())) + .unwrap(); + + let batch = reader.into_iter().next().unwrap().unwrap(); + + // Verify lenient string behavior + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert!(names.is_null(1)); // 42 -> NULL + assert_eq!(names.value(2), "Bob"); + + // Verify quirky list behavior + let tags = batch.column(2).as_list::(); + + // Row 0: normal JSON list + let row0 = tags.value(0); + let row0 = row0.as_string::(); + assert_eq!(row0.len(), 2); + assert_eq!(row0.value(0), "a"); + assert_eq!(row0.value(1), "b"); + + // Row 1: parsed from string representation + let row1 = tags.value(1); + let row1 = row1.as_string::(); + assert_eq!(row1.len(), 2); + assert_eq!(row1.value(0), "x"); + assert_eq!(row1.value(1), "y"); + + // Row 2: null + assert!(tags.is_null(2)); +} + +// ============================================================================ +// Test 6: Path-based routing (pointer-identity-based) +// ============================================================================ + +/// Identity wrapper for DataType that uses pointer equality for HashMap lookup. +/// +/// Supports two modes: +/// - `FieldRef` variant: Stores ownership of a Field (for HashMap storage) +/// - `DataType` variant: Borrows a DataType temporarily (for HashMap lookup) +/// +/// Both variants compare by pointer identity of the DataType they reference. +/// +/// Safety: Pointer-identity comparison relies on DataType stability guarantees: +/// - We store the owning FieldRef (Arc) which keeps the Field alive +/// - We never call any potentially-mutating methods such as `Arc::get_mut` or `Arc::make_mut` +/// - We never share a reference to the FieldRef that could allow others to mutate it +/// - Our FieldRef ensures that anyone else who might attempt potentially-mutating operations +/// of the same Field through their own FieldRef will fail because `!Arc::is_unique()` +/// - The &DataType returned by `Field::data_type` is stable -- no interior mutability +/// such as `Mutex>` that could move it to a new memory location +/// +/// NOTE: We never dereference the raw pointer values used for comparison. A violation +/// of the above would only produce incorrect HashMap lookups (false positives/negatives). +#[derive(Debug)] +enum DataTypeIdentity<'a> { + FieldRef(FieldRef), + DataType(&'a DataType), +} + +impl<'a> DataTypeIdentity<'a> { + /// Extract the raw DataType pointer for identity comparison. + fn as_ptr(&self) -> *const DataType { + match self { + DataTypeIdentity::FieldRef(f) => f.data_type(), + DataTypeIdentity::DataType(dt) => *dt, + } + } +} + +impl<'a> Hash for DataTypeIdentity<'a> { + fn hash(&self, state: &mut H) { + self.as_ptr().hash(state); + } +} + +impl<'a> PartialEq for DataTypeIdentity<'a> { + fn eq(&self, other: &Self) -> bool { + self.as_ptr() == other.as_ptr() + } +} + +impl<'a> Eq for DataTypeIdentity<'a> {} + +/// A factory that routes to custom decoders based on specific field paths. +/// +/// This allows fine-grained control: customize specific fields by name without +/// polluting the schema with metadata or affecting all fields of a given type. +#[derive(Debug)] +struct PathBasedDecoderFactory { + // Maps DataTypeIdentity::FieldRef to factory + routes: HashMap, Arc>, +} + +impl PathBasedDecoderFactory { + /// Create a new path-based factory by mapping field paths to factories. + /// + /// Takes a reference to the schema fields and a map of paths to factories. + /// Paths can be nested using dot notation: "field", "struct.nested", "struct.deep.nested" + fn new(fields: &Fields, path_routes: HashMap<&str, Arc>) -> Self { + // Walk the fields and associate DataTypeIdentity::FieldRef with factory for O(1) lookup + let mut routes = HashMap::new(); + for (path, factory) in path_routes { + let parts: Vec<&str> = path.split('.').collect(); + if let Some(field) = Self::find_field_by_path(fields, &parts) { + routes.insert(DataTypeIdentity::FieldRef(field), factory); + } + } + + Self { routes } + } + + /// Recursively find a Field by following a path of field names. + fn find_field_by_path(fields: &Fields, path: &[&str]) -> Option { + let (first, rest) = path.split_first()?; + let field = fields.iter().find(|f| f.name() == *first)?; + + if rest.is_empty() { + // End of path - return this field + return Some(field.clone()); + } + + // Path continues - attempt to recurse into a nested struct + let DataType::Struct(children) = field.data_type() else { + return None; + }; + + Self::find_field_by_path(children, rest) + } +} + +impl DecoderFactory for PathBasedDecoderFactory { + fn make_custom_decoder( + &self, + ctx: &DecoderContext, + data_type: &DataType, + is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + // O(1) lookup using temporary DataType variant for identity comparison + let key = DataTypeIdentity::DataType(data_type); + let Some(factory) = self.routes.get(&key) else { + return Ok(None); + }; + + // Delegate to the route-specific factory + factory.make_custom_decoder(ctx, data_type, is_nullable, field_metadata) + } +} + +#[test] +fn test_path_based_routing() { + // Create schema with both flat and nested String fields + let metadata_fields = Fields::from(vec![ + Field::new("source", DataType::Utf8, true), + Field::new("comment", DataType::Utf8, true), // Nested path: metadata.comment + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("email", DataType::Utf8, true), + Field::new("metadata", DataType::Struct(metadata_fields.clone()), true), + ])); + + // Create a lenient string factory (will be applied to specific paths only) + let lenient_factory = Arc::new(TypeBasedLenientStringFactory); + + // Build routes: ONLY the nested "metadata.comment" field gets lenient handling + // Demonstrates nested path routing with dot notation + let mut path_routes = HashMap::new(); + path_routes.insert( + "metadata.comment", + lenient_factory.clone() as Arc, + ); + + let factory = Arc::new(PathBasedDecoderFactory::new(schema.fields(), path_routes)); + + // JSON with type mismatches in multiple string fields + let json = r#" +{"id": 1, "name": "Alice", "email": "alice@example.com", "metadata": {"source": "web", "comment": "Good"}} +{"id": 2, "name": 42, "email": "bob@example.com", "metadata": {"source": "api", "comment": 100}} +{"id": 3, "name": "Charlie", "email": 999, "metadata": {"source": "mobile", "comment": "Excellent"}} +"#; + + let mut reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(factory) + .build(Cursor::new(json)) + .unwrap(); + + // The decode should FAIL because "name" and "email" don't have lenient handling + // Only "metadata.comment" is lenient, but "name" has a type mismatch in row 2 + let result = reader.next().unwrap(); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("name") || err.contains("expected string")); + + // Now let's test with data that only has issues in the lenient nested field + let json_lenient_only = r#" +{"id": 1, "name": "Alice", "email": "alice@example.com", "metadata": {"source": "web", "comment": "Good"}} +{"id": 2, "name": "Bob", "email": "bob@example.com", "metadata": {"source": "api", "comment": 100}} +{"id": 3, "name": "Charlie", "email": "charlie@example.com", "metadata": {"source": "mobile", "comment": true}} +"#; + + // Rebuild routes for the second test + let mut path_routes2 = HashMap::new(); + path_routes2.insert( + "metadata.comment", + lenient_factory as Arc, + ); + + let mut reader = ReaderBuilder::new(schema.clone()) + .with_decoder_factory(Arc::new(PathBasedDecoderFactory::new( + schema.fields(), + path_routes2, + ))) + .build(Cursor::new(json_lenient_only)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + // Verify all fields + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 4); + + // ID column: all valid + let ids = batch + .column(0) + .as_primitive::(); + assert_eq!(ids.value(0), 1); + assert_eq!(ids.value(1), 2); + assert_eq!(ids.value(2), 3); + + // Name column: all valid (no type mismatches) + let names = batch.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(names.value(1), "Bob"); + assert_eq!(names.value(2), "Charlie"); + + // Email column: all valid (no type mismatches) + let emails = batch.column(2).as_string::(); + assert_eq!(emails.value(0), "alice@example.com"); + assert_eq!(emails.value(1), "bob@example.com"); + assert_eq!(emails.value(2), "charlie@example.com"); + + // Metadata struct column + let metadata = batch.column(3).as_struct(); + + // metadata.source: all valid (no lenient handling) + let sources = metadata.column(0).as_string::(); + assert_eq!(sources.value(0), "web"); + assert_eq!(sources.value(1), "api"); + assert_eq!(sources.value(2), "mobile"); + + // metadata.comment: LENIENT - non-strings become NULL + let comments = metadata.column(1).as_string::(); + assert_eq!(comments.value(0), "Good"); + assert!(comments.is_null(1)); // 100 -> NULL + assert!(comments.is_null(2)); // true -> NULL +} + +// ============================================================================ +// Test 7: Recursive factory propagation +// ============================================================================ + +#[test] +fn test_recursive_factory_propagation() { + // Schema with nested struct containing string fields + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "person", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), // Make nullable + Field::new("email", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ])); + + let json = r#" +{"id": 1, "person": {"name": "Alice", "email": "alice@example.com"}} +{"id": 2, "person": 42} +{"id": 3, "person": {"name": 123, "email": "charlie@example.com"}} +{"id": 4, "person": {"name": "Dave", "email": true}} +"#; + + // Compose lenient struct + lenient string factories + // This tests that the factory propagates through struct decoder creation + let factories = vec![ + Arc::new(TypeBasedLenientStructFactory) as Arc, + Arc::new(TypeBasedLenientStringFactory) as Arc, + ]; + let factory = Arc::new(ComposedDecoderFactory { factories }); + + let mut reader = ReaderBuilder::new(schema) + .with_decoder_factory(factory) + .build(Cursor::new(json)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + // ID column: all valid + let ids = batch + .column(0) + .as_primitive::(); + assert_eq!(ids.value(0), 1); + assert_eq!(ids.value(1), 2); + assert_eq!(ids.value(2), 3); + assert_eq!(ids.value(3), 4); + + // Person struct column + let person = batch.column(1).as_struct(); + + // Row 0: Valid struct with valid strings + assert!(!person.is_null(0)); + let names = person.column(0).as_string::(); + let emails = person.column(1).as_string::(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(emails.value(0), "alice@example.com"); + + // Row 1: Not a struct (42) -> NULL from struct factory + assert!(person.is_null(1)); + + // Row 2: Valid struct but name has type mismatch (123) + // This tests that the string factory was invoked for the nested field! + assert!(!person.is_null(2)); + assert!(names.is_null(2)); // 123 -> NULL from string factory + assert_eq!(emails.value(2), "charlie@example.com"); + + // Row 3: Valid struct but email has type mismatch (true) + assert!(!person.is_null(3)); + assert_eq!(names.value(3), "Dave"); + assert!(emails.is_null(3)); // true -> NULL from string factory +} diff --git a/parquet-variant-compute/Cargo.toml b/parquet-variant-compute/Cargo.toml index 85d66a9cf706..8f15f99aa58d 100644 --- a/parquet-variant-compute/Cargo.toml +++ b/parquet-variant-compute/Cargo.toml @@ -30,9 +30,13 @@ rust-version = { workspace = true } [dependencies] arrow = { workspace = true , features = ["canonical_extension_types"]} +arrow-array = { workspace = true } +arrow-data = { workspace = true } +arrow-json = { workspace = true } arrow-schema = { workspace = true } half = { version = "2.1", default-features = false } indexmap = "2.10.0" +lexical-core = { version = "1.0", default-features = false} parquet-variant = { workspace = true } parquet-variant-json = { workspace = true } chrono = { workspace = true } diff --git a/parquet-variant-compute/src/decoder.rs b/parquet-variant-compute/src/decoder.rs new file mode 100644 index 000000000000..76b5bb3245cb --- /dev/null +++ b/parquet-variant-compute/src/decoder.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{VariantArrayBuilder, VariantType}; +use arrow_array::{Array, StructArray}; +use arrow_data::ArrayData; +use arrow_json::reader::{ArrayDecoder, DecoderContext, DecoderFactory, Tape, TapeElement}; +use arrow_schema::extension::ExtensionType; +use arrow_schema::{ArrowError, DataType}; +use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt}; +use std::collections::HashMap; + +/// An [`ArrayDecoder`] implementation that decodes JSON values into a Variant array. +/// +/// This decoder converts JSON tape elements (parsed JSON tokens) into Parquet Variant +/// format, preserving the full structure of arbitrary JSON including nested objects, +/// arrays, and primitive types. +/// +/// This decoder is typically used indirectly via [`VariantArrayDecoderFactory`] when +/// reading JSON data into Variant columns. +#[derive(Default)] +pub struct VariantArrayDecoder; + +impl ArrayDecoder for VariantArrayDecoder { + fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut array_builder = VariantArrayBuilder::new(pos.len()); + for p in pos { + variant_from_tape_element(&mut array_builder, *p, tape)?; + } + let variant_struct_array = StructArray::from(array_builder.build()); + Ok(variant_struct_array.into_data()) + } +} + +/// A [`DecoderFactory`] that integrates with the Arrow JSON reader to automatically +/// decode JSON values into Variant arrays when the target field is registered as a +/// [`VariantType`] extension type. +/// +/// # Example +/// +/// ```ignore +/// use arrow_json::reader::ReaderBuilder; +/// use arrow_json::StructMode; +/// use std::sync::Arc; +/// +/// let builder = ReaderBuilder::new(Arc::new(schema)); +/// let reader = builder +/// .with_struct_mode(StructMode::ObjectOnly) +/// .with_decoder_factory(Arc::new(VariantArrayDecoderFactory)) +/// .build(json_input)?; +/// ``` +#[derive(Default, Debug)] +#[allow(unused)] +pub struct VariantArrayDecoderFactory; + +impl DecoderFactory for VariantArrayDecoderFactory { + fn make_custom_decoder( + &self, + _ctx: &DecoderContext, + data_type: &DataType, + _is_nullable: bool, + field_metadata: &HashMap, + ) -> Result>, ArrowError> { + // Check if this is a Variant extension type using metadata + let result = VariantType::try_new_from_field_metadata(data_type, field_metadata); + Ok(result.ok().map(|_| Box::new(VariantArrayDecoder) as _)) + } +} + +fn variant_from_tape_element( + builder: &mut impl VariantBuilderExt, + mut p: u32, + tape: &Tape, +) -> Result { + match tape.get(p) { + TapeElement::StartObject(end_idx) => { + let mut object_builder = builder.try_new_object()?; + p += 1; + while p < end_idx { + // Read field name + let field_name = match tape.get(p) { + TapeElement::String(s) => tape.get_string(s), + _ => return Err(tape.error(p, "field name")), + }; + + let mut field_builder = ObjectFieldBuilder::new(field_name, &mut object_builder); + p = tape.next(p, "field value")?; + p = variant_from_tape_element(&mut field_builder, p, tape)?; + } + object_builder.finish(); + } + TapeElement::EndObject(_u32) => { + return Err(ArrowError::JsonError( + "unexpected end of object".to_string(), + )); + } + TapeElement::StartList(end_idx) => { + let mut list_builder = builder.try_new_list()?; + p += 1; + while p < end_idx { + p = variant_from_tape_element(&mut list_builder, p, tape)?; + } + list_builder.finish(); + } + TapeElement::EndList(_u32) => { + return Err(ArrowError::JsonError("unexpected end of list".to_string())); + } + TapeElement::String(idx) => builder.append_value(tape.get_string(idx)), + TapeElement::Number(idx) => { + let s = tape.get_string(idx); + builder.append_value(parse_number(s)?) + } + TapeElement::I64(i) => { + return Err(ArrowError::JsonError(format!( + "I64 tape element not supported: {i}" + ))); + } + TapeElement::I32(i) => { + return Err(ArrowError::JsonError(format!( + "I32 tape element not supported: {i}" + ))); + } + TapeElement::F64(f) => { + return Err(ArrowError::JsonError(format!( + "F64 tape element not supported: {f}" + ))); + } + TapeElement::F32(f) => { + return Err(ArrowError::JsonError(format!( + "F32 tape element not supported: {f}" + ))); + } + TapeElement::True => builder.append_value(true), + TapeElement::False => builder.append_value(false), + TapeElement::Null => builder.append_value(Variant::Null), + } + p += 1; + Ok(p) +} + +fn parse_number<'a, 'b>(s: &'a str) -> Result, ArrowError> { + if let Ok(v) = lexical_core::parse(s.as_bytes()) { + return Ok(Variant::Int64(v)); + } + + match lexical_core::parse(s.as_bytes()) { + Ok(v) => Ok(Variant::Double(v)), + Err(_) => Err(ArrowError::JsonError(format!( + "failed to parse {s} as number" + ))), + } +} + +#[cfg(test)] +mod tests { + use crate::VariantArray; + + use super::*; + use arrow_array::Int32Array; + use arrow_json::StructMode; + use arrow_json::reader::ReaderBuilder; + use arrow_schema::{DataType, Field, Schema}; + use parquet_variant::VariantBuilder; + use std::io::Cursor; + use std::sync::Arc; + + #[test] + fn test_variant() { + let do_test = |json_input: &str, ids: Vec, variants: Vec>| { + let variant_array = VariantArrayBuilder::new(0).build(); + + let struct_field = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + // call VariantArray::field to get the correct Field + variant_array.field("var"), + ]); + + let builder = ReaderBuilder::new(Arc::new(struct_field.clone())); + let result = builder + .with_struct_mode(StructMode::ObjectOnly) + .with_decoder_factory(Arc::new(VariantArrayDecoderFactory)) + .build(Cursor::new(json_input.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + + assert_eq!(result.num_columns(), 2); + let int_array = arrow_array::array::Int32Array::from(ids); + assert_eq!( + result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(), + &int_array + ); + + let result_variant_array: VariantArray = + VariantArray::try_new(result.column(1)).unwrap(); + let values = result_variant_array.iter().collect::>(); + + assert_eq!(values, variants); + }; + + do_test( + "{\"id\": 1, \"var\": \"a\"}\n{\"id\": 2, \"var\": \"b\"}", + vec![1, 2], + vec![Some(Variant::from("a")), Some(Variant::from("b"))], + ); + + let mut builder = VariantBuilder::new(); + let mut object_builder = builder.new_object(); + object_builder.insert("int64", Variant::Int64(1)); + object_builder.insert("double", Variant::Double(1.0)); + object_builder.insert("null", Variant::Null); + object_builder.insert("true", Variant::BooleanTrue); + object_builder.insert("false", Variant::BooleanFalse); + object_builder.insert("string", Variant::from("a")); + object_builder.finish(); + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": {\"int64\": 1, \"double\": 1.0, \"null\": null, \"true\": true, \"false\": false, \"string\": \"a\"}}", + vec![1], + vec![Some(variant)], + ); + + // nested structs + let mut builder = VariantBuilder::new(); + let mut object_builder = builder.new_object(); + { + let mut list_builder = object_builder.new_list("somelist"); + { + let mut nested_object_builder = list_builder.new_object(); + nested_object_builder.insert("num", Variant::Int64(2)); + nested_object_builder.finish(); + } + { + let mut nested_object_builder = list_builder.new_object(); + nested_object_builder.insert("num", Variant::Int64(3)); + nested_object_builder.finish(); + } + list_builder.finish(); + object_builder.insert("scalar", Variant::from("a")); + } + object_builder.finish(); + + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": {\"somelist\": [{\"num\": 2}, {\"num\": 3}], \"scalar\": \"a\"}}", + vec![1], + vec![Some(variant)], + ); + + let mut builder = VariantBuilder::new(); + let mut list_builder = builder.new_list(); + list_builder.append_value(Variant::Int64(1000000000000)); + list_builder.append_value(Variant::Double(std::f64::consts::E)); + list_builder.finish(); + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + do_test( + "{\"id\": 1, \"var\": [1000000000000, 2.718281828459045]}", + vec![1], + vec![Some(variant)], + ); + } +} diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index b05d0e023653..d33eece28d9b 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -41,6 +41,7 @@ mod arrow_to_variant; mod cast_to_variant; +mod decoder; mod from_json; mod shred_variant; mod to_json; @@ -55,6 +56,7 @@ pub use variant_array::{BorrowedShreddingState, ShreddingState, VariantArray, Va pub use variant_array_builder::{VariantArrayBuilder, VariantValueArrayBuilder}; pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options}; +pub use decoder::VariantArrayDecoderFactory; pub use from_json::json_to_variant; pub use shred_variant::{IntoShreddingField, ShreddedSchemaBuilder, shred_variant}; pub use to_json::variant_to_json;