diff --git a/parquet-avro/pom.xml b/parquet-avro/pom.xml index 43ed92c539..27cabb757f 100644 --- a/parquet-avro/pom.xml +++ b/parquet-avro/pom.xml @@ -48,6 +48,11 @@ parquet-common ${project.version} + + org.apache.parquet + parquet-variant + ${project.version} + org.apache.avro avro diff --git a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java index 4594490858..830bc2f91b 100644 --- a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java +++ b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java @@ -31,8 +31,10 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.PrimitiveConverter; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.PrimitiveStringifier; import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.variant.VariantColumnConverter; public class AvroConverters { @@ -363,4 +365,26 @@ public String convert(Binary binary) { return stringifier.stringify(binary); } } + + static final class FieldVariantConverter extends VariantColumnConverter { + protected final ParentValueContainer parent; + private final Schema avroSchema; + private final GenericData model; + + public FieldVariantConverter( + ParentValueContainer parent, GroupType schema, Schema avroSchema, GenericData model) { + super(schema); + this.avroSchema = avroSchema; + this.model = model; + this.parent = parent; + } + + @Override + public void addVariant(Binary value, Binary metadata) { + T currentRecord = (T) model.newRecord(null, avroSchema); + model.setField(currentRecord, "metadata", 0, metadata.toByteBuffer()); + model.setField(currentRecord, "value", 1, value.toByteBuffer()); + parent.add(currentRecord); + } + } } diff --git a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroIndexedRecordConverter.java b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroIndexedRecordConverter.java index ff77d44408..e7989b1fa9 100644 --- a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroIndexedRecordConverter.java +++ b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroIndexedRecordConverter.java @@ -35,6 +35,7 @@ import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.PrimitiveConverter; import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Type; @@ -168,7 +169,12 @@ private static Converter newConverter(Schema schema, Type type, GenericData mode case MAP: return new MapConverter(parent, type.asGroupType(), schema, model); case RECORD: - return new AvroIndexedRecordConverter<>(parent, type.asGroupType(), schema, model); + if (type.getLogicalTypeAnnotation() + instanceof LogicalTypeAnnotation.VariantLogicalTypeAnnotation) { + return new AvroConverters.FieldVariantConverter(parent, type.asGroupType(), schema, model); + } else { + return new AvroIndexedRecordConverter<>(parent, type.asGroupType(), schema, model); + } case STRING: return new AvroConverters.FieldStringConverter(parent); case UNION: diff --git a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java index a98deabf6f..d325d65eaa 100644 --- a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java +++ b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java @@ -62,6 +62,7 @@ import org.apache.parquet.io.api.Converter; import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Type; import org.slf4j.Logger; @@ -383,7 +384,12 @@ private static Converter newConverter( } return newStringConverter(schema, model, parent); case RECORD: - return new AvroRecordConverter(parent, type.asGroupType(), schema, model); + if (type.getLogicalTypeAnnotation() + instanceof LogicalTypeAnnotation.VariantLogicalTypeAnnotation) { + return new AvroConverters.FieldVariantConverter(parent, type.asGroupType(), schema, model); + } else { + return new AvroRecordConverter(parent, type.asGroupType(), schema, model); + } case ENUM: return new AvroConverters.FieldEnumConverter(parent, schema, model); case ARRAY: diff --git a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroSchemaConverter.java b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroSchemaConverter.java index 9632fc1754..5563b4194c 100644 --- a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroSchemaConverter.java +++ b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroSchemaConverter.java @@ -466,6 +466,21 @@ public Optional visit( LogicalTypeAnnotation.EnumLogicalTypeAnnotation enumLogicalType) { return of(Schema.create(Schema.Type.STRING)); } + + @Override + public Optional visit( + LogicalTypeAnnotation.VariantLogicalTypeAnnotation variantLogicalType) { + String name = parquetGroupType.getName(); + SchemaBuilder.FieldAssembler builder = SchemaBuilder.builder( + namespace(name, names)) + .record(name) + .fields(); + builder.name("metadata") + .type(Schema.create(Schema.Type.BYTES)) + .noDefault(); + builder.name("value").type().optional().type(Schema.create(Schema.Type.BYTES)); + return of(builder.endRecord()); + } }) .orElseThrow( () -> new UnsupportedOperationException("Cannot convert Parquet type " + parquetType)); diff --git a/parquet-avro/src/test/java/org/apache/parquet/avro/TestVariant.java b/parquet-avro/src/test/java/org/apache/parquet/avro/TestVariant.java new file mode 100644 index 0000000000..a359643ee3 --- /dev/null +++ b/parquet-avro/src/test/java/org/apache/parquet/avro/TestVariant.java @@ -0,0 +1,2343 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.avro; + +import static org.apache.parquet.avro.AvroTestUtil.*; +import static org.junit.Assert.*; + +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.*; +import java.util.concurrent.Callable; +import java.util.function.Consumer; +import com.google.common.collect.ImmutableMap; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.IndexedRecord; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.DirectWriterTest; +import org.apache.parquet.Preconditions; +import org.apache.parquet.conf.ParquetConfiguration; +import org.apache.parquet.conf.PlainParquetConfiguration; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.io.ParquetDecodingException; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.io.api.RecordConsumer; +import org.apache.parquet.schema.*; +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; +import org.apache.parquet.variant.Variant; +import org.apache.parquet.variant.VariantBuilder; +import org.apache.parquet.variant.VariantDuplicateKeyException; +import org.apache.parquet.variant.VariantUtil; +import org.junit.Test; + +public class TestVariant extends DirectWriterTest { + + private static final LogicalTypeAnnotation STRING = LogicalTypeAnnotation.stringType(); + + // Construct a variant, and return the value binary, dropping metadata. + private static Variant fullVariant(Consumer appendValue) { + VariantBuilder builder = new VariantBuilder(false); + appendValue.accept(builder); + return builder.result(); + } + + // Return only the byte[], which is usually all we want. + private static byte[] variant(Consumer appendValue) { + return fullVariant(appendValue).getValue(); + } + + // Returns a value based on building with fixed metadata. + private static byte[] variant(byte[] metadata, Consumer appendValue) { + VariantBuilder builder = new VariantBuilder(false); + builder.setFixedMetadata(VariantUtil.getMetadataMap(metadata)); + appendValue.accept(builder); + return builder.valueWithoutMetadata(); + } + + private static byte[] variant(int val) { + return variant(b -> b.appendLong(val)); + } + + private static byte[] variant(long val) { + return variant(b -> b.appendLong(val)); + } + + private static byte[] variant(String s) { + return variant(b -> b.appendString(s)); + } + + private static class PrimitiveCase { + Object avroValue; + byte[] value; + + PrimitiveCase(Object avroValue, byte[] value) { + this.avroValue = avroValue; + this.value = value; + } + } + + // Convert a string to a Decimal that can be written using Avro. + private static Object avroDecimalValue(String s) { + BigDecimal v = new BigDecimal(s); + int precision = v.precision(); + if (precision <= 9) { + return v.unscaledValue().intValueExact(); + } else if (precision <= 18) { + return v.unscaledValue().longValueExact(); + } else { + return v.unscaledValue().toByteArray(); + } + } + + private static final PrimitiveCase[] PRIMITIVES = + new PrimitiveCase[] { + new PrimitiveCase(null, variant(b -> b.appendNull())), + new PrimitiveCase(true, variant(b -> b.appendBoolean(true))), + new PrimitiveCase(false, variant(b -> b.appendBoolean(false))), + // TODO: fix types + new PrimitiveCase(34, variant(b -> b.appendLong(34))), + new PrimitiveCase(-34, variant(b -> b.appendLong(-34))), + new PrimitiveCase(1234, variant(b -> b.appendLong(1234))), + new PrimitiveCase(-1234, variant(b -> b.appendLong(-1234))), + new PrimitiveCase(12345, variant(b -> b.appendLong(12345))), + new PrimitiveCase(-12345, variant(b -> b.appendLong(-12345))), + new PrimitiveCase(9876543210L, variant(b -> b.appendLong(9876543210L))), + new PrimitiveCase(-9876543210L, variant(b -> b.appendLong(-9876543210L))), + new PrimitiveCase(10.11F, variant(b -> b.appendFloat(10.11F))), + new PrimitiveCase(-10.11F, variant(b -> b.appendFloat(-10.11F))), + new PrimitiveCase(14.3D, variant(b -> b.appendDouble(14.3D))), + new PrimitiveCase(-14.3D, variant(b -> b.appendDouble(-14.3D))), + // Dates and timestamps aren't very interesting in Variant tests, since they are passed + // to and from the API as integers. So just test arbitrary integer values. + new PrimitiveCase(12345, variant(b -> b.appendDate(12345))), + new PrimitiveCase(-12345, variant(b -> b.appendDate(-12345))), + new PrimitiveCase(9876543210L, variant(b -> b.appendTimestamp(9876543210L))), + new PrimitiveCase(-9876543210L, variant(b -> b.appendTimestamp(-9876543210L))), + new PrimitiveCase(9876543210L, variant(b -> b.appendTimestampNtz(9876543210L))), + new PrimitiveCase(-9876543210L, variant(b -> b.appendTimestampNtz(-9876543210L))), + new PrimitiveCase(9876543210L, variant(b -> b.appendTimestampNanos(9876543210L))), + new PrimitiveCase(-9876543210L, variant(b -> b.appendTimestampNanos(-9876543210L))), + new PrimitiveCase(9876543210L, variant(b -> b.appendTimestampNanosNtz(9876543210L))), + new PrimitiveCase(-9876543210L, variant(b -> b.appendTimestampNanosNtz(-9876543210L))), + new PrimitiveCase(avroDecimalValue("123456.7890"), variant(b -> b.appendDecimal(new BigDecimal("123456.7890")))), // decimal4 + new PrimitiveCase(avroDecimalValue("-123456.7890"), variant(b -> b.appendDecimal(new BigDecimal("-123456.7890")))), // decimal4 + new PrimitiveCase(avroDecimalValue("1234567890.987654321"), variant(b -> b.appendDecimal(new BigDecimal("1234567890.987654321")))), // decimal8 + new PrimitiveCase(avroDecimalValue("-1234567890.987654321"), variant(b -> b.appendDecimal(new BigDecimal("-1234567890.987654321")))), // decimal8 + new PrimitiveCase(avroDecimalValue("9876543210.123456789"), variant(b -> b.appendDecimal(new BigDecimal("9876543210.123456789")))), // decimal16 + new PrimitiveCase(avroDecimalValue("-9876543210.123456789"), variant(b -> b.appendDecimal(new BigDecimal("-9876543210.123456789")))), // decimal16 + new PrimitiveCase(new byte[] {0x0a, 0x0b, 0x0c, 0x0d}, variant(b -> b.appendBinary(new byte[] {0x0a, 0x0b, 0x0c, 0x0d}))), + new PrimitiveCase("parquet", variant(b -> b.appendString("parquet"))), + new PrimitiveCase(UUID.fromString("f24f9b64-81fa-49d1-b74e-8c09a6e31c56"), variant(b -> b.appendUUID(UUID.fromString("f24f9b64-81fa-49d1-b74e-8c09a6e31c56")))) + }; + + private byte[] EMPTY_METADATA = fullVariant(b -> b.appendNull()).getMetadata(); + private byte[] NULL_VALUE = PRIMITIVES[0].value; + + private byte[] TEST_METADATA; + private byte[] TEST_OBJECT; + + private static class TestCase { + byte[] value; + byte[] metadata; + + public TestCase(byte[] value, byte[] metadata) { + this.value = value; + this.metadata = metadata; + } + } + + private ArrayList testCases; + + public TestVariant() throws Exception { + TEST_METADATA = VariantBuilder.parseJson( + "{\"a\": 0, \"b\": 0, \"c\": 0, \"d\": 0, \"e\": 0}").getMetadata(); + + TEST_OBJECT = variant(b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendNull(); + entries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, entries); + }); + + testCases = new ArrayList<>(); + for (PrimitiveCase p : PRIMITIVES) { + testCases.add(new TestCase(p.value, EMPTY_METADATA)); + } + testCases.add(new TestCase(TEST_OBJECT, TEST_METADATA)); + } + + @Test + public void testUnshredded() throws Exception { + // Unshredded Variant should produce exactly the same value and metadata. + Variant testValue = VariantBuilder.parseJson("{\"a\": 123, \"b\": [\"a\", 2, true, null]}"); + Binary expectedValue = Binary.fromConstantByteArray(testValue.getValue()); + Binary expectedMetadata = Binary.fromConstantByteArray(testValue.getMetadata()); + Path test = writeDirect( + "message VariantMessage {" + " required group v (VARIANT(1)) {" + + " required binary value;" + + " required binary metadata;" + + " }" + + "}", + rc -> { + rc.startMessage(); + rc.startField("v", 0); + + rc.startGroup(); + rc.startField("value", 0); + rc.addBinary(expectedValue); + rc.endField("value", 0); + rc.startField("metadata", 1); + rc.addBinary(expectedMetadata); + rc.endField("metadata", 1); + rc.endGroup(); + + rc.endField("v", 0); + rc.endMessage(); + }); + + Schema variantSchema = record( + "v", + field("metadata", Schema.create(Schema.Type.BYTES)), + optionalField("value", Schema.create(Schema.Type.BYTES))); + Schema expectedSchema = record("VariantMessage", field("v", variantSchema)); + + GenericRecord expectedRecord = instance( + expectedSchema, + "v", + instance( + variantSchema, + "metadata", + expectedMetadata.toByteBuffer(), + "value", + expectedValue.toByteBuffer())); + + // both should behave the same way + assertReaderContains(new AvroParquetReader(new Configuration(), test), expectedSchema, expectedRecord); + } + + /** + * Construct a Variant with a single scalar value, and write the same value to the typed_value column + * of a shredded file, verifying that the reconstructed value is bit-for-bit identical to the original value. + * and a lambda to append the same corresponding value to the + * @param type Type of the shredded field. E.g. int64" + * @param annotation Logical annotation of the shredded field, or empty string if none. E.g. "UTF8" + * @param appendValue Lambda to append a value to a VariantBuilder + * @param addValue Lambda to append the logically equivalent value to a RecordConsumer + * @throws Exception + */ + public void runOneScalarTest( + String type, String annotation, Consumer appendValue, Consumer addValue) + throws Exception { + VariantBuilder builder = new VariantBuilder(false); + appendValue.accept(builder); + Variant testValue = builder.result(); + Binary expectedValue = Binary.fromConstantByteArray(testValue.getValue()); + Binary expectedMetadata = Binary.fromConstantByteArray(testValue.getMetadata()); + Path test = writeDirect( + "message VariantMessage {" + " required group v (VARIANT(1)) {" + + " optional binary value;" + + " required binary metadata;" + + " optional " + type + " typed_value " + annotation + ";" + + " }" + + "}", + rc -> { + rc.startMessage(); + rc.startField("v", 0); + + rc.startGroup(); + rc.startField("typed_value", 2); + addValue.accept(rc); + rc.endField("typed_value", 2); + rc.startField("metadata", 1); + rc.addBinary(expectedMetadata); + rc.endField("metadata", 1); + rc.endGroup(); + + rc.endField("v", 0); + rc.endMessage(); + }); + + Schema variantSchema = record( + "v", + field("metadata", Schema.create(Schema.Type.BYTES)), + optionalField("value", Schema.create(Schema.Type.BYTES))); + Schema expectedSchema = record("VariantMessage", field("v", variantSchema)); + + GenericRecord expectedRecord = instance( + expectedSchema, + "v", + instance( + variantSchema, + "metadata", + expectedMetadata.toByteBuffer(), + "value", + expectedValue.toByteBuffer())); + + // both should behave the same way + assertReaderContains(new AvroParquetReader(new Configuration(), test), expectedSchema, expectedRecord); + } + + @Test + public void testShreddedScalar() throws Exception { + runOneScalarTest("boolean", "", b -> b.appendBoolean(true), rc -> rc.addBoolean(true)); + // Test true and false, since they have different types in Variant. + runOneScalarTest("boolean", "", b -> b.appendBoolean(false), rc -> rc.addBoolean(false)); + runOneScalarTest("boolean", "", b -> b.appendBoolean(false), rc -> rc.addBoolean(false)); + runOneScalarTest("int32", "(INT_8)", b -> b.appendLong(123), rc -> rc.addInteger(123)); + runOneScalarTest("int32", "(INT_16)", b -> b.appendLong(-12345), rc -> rc.addInteger(-12345)); + runOneScalarTest("int32", "(INT_32)", b -> b.appendLong(1234567890), rc -> rc.addInteger(1234567890)); + runOneScalarTest("int64", "", b -> b.appendLong(1234567890123L), rc -> rc.addLong(1234567890123L)); + runOneScalarTest("double", "", b -> b.appendDouble(1.2e34), rc -> rc.addDouble(1.2e34)); + runOneScalarTest("float", "", b -> b.appendFloat(1.2e34f), rc -> rc.addFloat(1.2e34f)); + runOneScalarTest( + "int32", "(DECIMAL(9, 2))", b -> b.appendDecimal(new BigDecimal("1.23")), rc -> rc.addInteger(123)); + runOneScalarTest( + "int64", + "(DECIMAL(18, 5))", + b -> b.appendDecimal(new BigDecimal("123456789.12345")), + rc -> rc.addLong(12345678912345L)); + BigDecimal decimalVal = new BigDecimal("0.12345678901234567890123456789012345678"); + runOneScalarTest( + "fixed_len_byte_array(16)", + "(DECIMAL(38, 38))", + b -> b.appendDecimal(decimalVal), + rc -> rc.addBinary( + Binary.fromConstantByteArray(decimalVal.unscaledValue().toByteArray()))); + // Verify that the parquet type's scale is used when shredding, and not the scale implied by the value. + runOneScalarTest( + "int32", + "(DECIMAL(9, 2))", + b -> b.appendDecimal(new BigDecimal("1.2").setScale(2)), + rc -> rc.addInteger(120)); + runOneScalarTest( + "int64", + "(DECIMAL(18, 5))", + b -> b.appendDecimal(new BigDecimal("123456789").setScale(5)), + rc -> rc.addLong(12345678900000L)); + BigDecimal decimalVal2 = new BigDecimal("9.12345678901234567890123456789").setScale(37); + runOneScalarTest( + "fixed_len_byte_array(16)", + "(DECIMAL(38, 37))", + b -> b.appendDecimal(decimalVal2), + rc -> rc.addBinary( + Binary.fromConstantByteArray(decimalVal2.unscaledValue().toByteArray()))); + runOneScalarTest("int64", "(TIMESTAMP(MICROS, true))", b -> b.appendTimestamp(123), rc -> rc.addLong(123)); + runOneScalarTest("int64", "(TIMESTAMP(MICROS, false))", b -> b.appendTimestampNtz(123), rc -> rc.addLong(123)); + runOneScalarTest( + "binary", + "", + b -> b.appendBinary(Binary.fromString("hello").toByteBuffer().array()), + rc -> rc.addBinary(Binary.fromString("hello"))); + runOneScalarTest( + "binary", "(UTF8)", b -> b.appendString("hello"), rc -> rc.addBinary(Binary.fromString("hello"))); + runOneScalarTest("int64", "(TIME(MICROS, false))", b -> b.appendTime(123), rc -> rc.addLong(123)); + runOneScalarTest("int64", "(TIMESTAMP(NANOS, true))", b -> b.appendTimestampNanos(123), rc -> rc.addLong(123)); + runOneScalarTest( + "int64", "(TIMESTAMP(NANOS, false))", b -> b.appendTimestampNanosNtz(123), rc -> rc.addLong(123)); + UUID uuid = UUID.randomUUID(); + byte[] uuidBytes = new byte[16]; + ByteBuffer bb = ByteBuffer.wrap(uuidBytes, 0, 16).order(ByteOrder.BIG_ENDIAN); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + runOneScalarTest( + "fixed_len_byte_array(16)", + "(UUID)", + b -> b.appendUUID(uuid), + rc -> rc.addBinary(Binary.fromConstantByteArray(uuidBytes))); + } + + @Test + public void testArray() throws Exception { + Variant testValue = VariantBuilder.parseJson("[123, \"hello\", 456]"); + // The string value will be stored in value, not typed_value, so we need to write its binary representation + // to parquet. + Variant stringVal = testValue.getElementAtIndex(1); + byte[] stringValue = stringVal.getValue(); + + Binary expectedValue = Binary.fromConstantByteArray(testValue.getValue()); + Binary expectedMetadata = Binary.fromConstantByteArray(testValue.getMetadata()); + Path test = writeDirect( + "message VariantMessage {" + " required group v (VARIANT(1)) {" + + " required binary metadata;" + + " optional binary value;" + + " optional group typed_value (LIST) {" + + " repeated group list {" + + " required group element {" + + " optional int64 typed_value;" + + " optional binary value;" + + " }" + + " }" + + " }" + + " }" + + "}", + rc -> { + rc.startMessage(); + rc.startField("v", 0); + rc.startGroup(); + rc.startField("typed_value", 2); + rc.startGroup(); + rc.startField("list", 0); + + rc.startGroup(); + rc.startField("element", 0); + rc.startGroup(); + rc.startField("typed_value", 0); + rc.addLong(123); + rc.endField("typed_value", 0); + rc.endGroup(); + rc.endField("element", 0); + rc.endGroup(); + + rc.startGroup(); + rc.startField("element", 0); + rc.startGroup(); + rc.startField("value", 1); + rc.addBinary(Binary.fromConstantByteArray(stringValue)); + rc.endField("value", 1); + rc.endGroup(); + rc.endField("element", 0); + rc.endGroup(); + + rc.startGroup(); + rc.startField("element", 0); + rc.startGroup(); + rc.startField("typed_value", 0); + rc.addLong(456); + rc.endField("typed_value", 0); + rc.endGroup(); + rc.endField("element", 0); + rc.endGroup(); + + rc.endField("list", 0); + rc.endGroup(); + rc.endField("typed_value", 2); + + rc.startField("metadata", 0); + rc.addBinary(expectedMetadata); + rc.endField("metadata", 0); + + rc.endGroup(); + rc.endField("v", 0); + rc.endMessage(); + }); + + Schema variantSchema = record( + "v", + field("metadata", Schema.create(Schema.Type.BYTES)), + optionalField("value", Schema.create(Schema.Type.BYTES))); + Schema expectedSchema = record("VariantMessage", field("v", variantSchema)); + + GenericRecord expectedRecord = instance( + expectedSchema, + "v", + instance( + variantSchema, + "metadata", + expectedMetadata.toByteBuffer(), + "value", + expectedValue.toByteBuffer())); + + // both should behave the same way + assertReaderContains(new AvroParquetReader(new Configuration(), test), expectedSchema, expectedRecord); + } + + @Test + public void testObject() throws Exception { + Variant testValue = VariantBuilder.parseJson("{\"a\": 123, \"b\": \"string_val\", \"c\": 456}"); + // Column c will be omitted from the schema and stored in the value column as the object {c: 456}. + // It's a bit tricky to construct this, since we need to ensure that it ends up with the same ID in metadata. + // The value below should do the trick, but it makes assumptions about parseJson behavior that is not guaranteed, + // so is a bit fragile. + Variant cValue = VariantBuilder.parseJson("{\"x\": 1, \"dummy\": {\"c\": 456}}").getFieldByKey("dummy"); + + Binary expectedValue = Binary.fromConstantByteArray(testValue.getValue()); + Binary expectedMetadata = Binary.fromConstantByteArray(testValue.getMetadata()); + Path test = writeDirect( + "message VariantMessage {" + " required group v (VARIANT(1)) {" + + " required binary metadata;" + + " optional binary value;" + + " optional group typed_value {" + + " required group a {" + + " optional int64 typed_value;" + + " optional binary value;" + + " }" + + " required group b {" + + " optional binary typed_value (UTF8);" + + " optional binary value;" + + " }" + + " required group missing {" + + " optional int64 typed_value;" + + " optional binary value;" + + " }" + + " }" + + " }" + + "}", + rc -> { + rc.startMessage(); + rc.startField("v", 0); + rc.startGroup(); + rc.startField("typed_value", 2); + rc.startGroup(); + + rc.startField("a", 0); + rc.startGroup(); + rc.startField("typed_value", 0); + rc.addLong(123); + rc.endField("typed_value", 0); + rc.endGroup(); + rc.endField("a", 0); + + rc.startField("b", 1); + rc.startGroup(); + rc.startField("typed_value", 0); + rc.addBinary(Binary.fromString("string_val")); + rc.endField("typed_value", 0); + rc.endGroup(); + rc.endField("b", 1); + + // Spec requires missing fields to be non-null. They are identified as missing by + // not having a non-null value or typed_value. + rc.startField("missing", 2); + rc.startGroup(); + rc.endGroup(); + rc.endField("missing", 2); + + rc.endGroup(); + rc.endField("typed_value", 2); + + rc.startField("value", 1); + rc.addBinary(Binary.fromConstantByteArray(cValue.getValue())); + rc.endField("value", 1); + + rc.startField("metadata", 0); + rc.addBinary(expectedMetadata); + rc.endField("metadata", 0); + + rc.endGroup(); + rc.endField("v", 0); + rc.endMessage(); + }); + + Schema variantSchema = record( + "v", + field("metadata", Schema.create(Schema.Type.BYTES)), + optionalField("value", Schema.create(Schema.Type.BYTES))); + Schema expectedSchema = record("VariantMessage", field("v", variantSchema)); + + GenericRecord expectedRecord = instance( + expectedSchema, + "v", + instance( + variantSchema, + "metadata", + expectedMetadata.toByteBuffer(), + "value", + expectedValue.toByteBuffer())); + + // both should behave the same way + assertReaderContains(new AvroParquetReader(new Configuration(), test), expectedSchema, expectedRecord); + } + + public void assertReaderContains( + AvroParquetReader reader, Schema expectedSchema, T... expectedRecords) throws IOException { + for (T expectedRecord : expectedRecords) { + T actualRecord = reader.read(); + assertEquals("Should match expected schema", expectedSchema, actualRecord.getSchema()); + assertEquals("Should match the expected record", expectedRecord, actualRecord); + } + assertNull( + "Should only contain " + expectedRecords.length + " record" + (expectedRecords.length == 1 ? "" : "s"), + reader.read()); + } + + // We need to store two copies of the schema: one without the Variant type annotation that is used to construct the + // Avro schema for writing, and one with type annotation that is used in the actual written parquet schema, and when + // reading. + private static class TestSchema { + MessageType parquetSchema; + MessageType unannotatedParquetSchema; + GroupType variantType; + GroupType unannotatedVariantType; + + TestSchema(GroupType variantType, GroupType unannotatedVariantType) { + this.variantType = variantType; + this.unannotatedVariantType = unannotatedVariantType; + this.parquetSchema = parquetSchema(variantType); + this.unannotatedParquetSchema = parquetSchema(unannotatedVariantType); + } + + TestSchema(Type shreddedType) { + variantType = variant("var", 2, shreddedType); + unannotatedVariantType = unannotatedVariant("var", 2, shreddedType); + parquetSchema = parquetSchema(variantType); + unannotatedParquetSchema = parquetSchema(unannotatedVariantType); + } + + TestSchema() { + variantType = variant("var", 2); + unannotatedVariantType = unannotatedVariant("var", 2); + parquetSchema = parquetSchema(variantType); + unannotatedParquetSchema = parquetSchema(unannotatedVariantType); + } + } + + // The remaining tests in this file are based on Iceberg's TestVariantReaders suite. + @Test + public void testUnshreddedVariants() throws Exception { + for (TestCase t : testCases) { + TestSchema schema = new TestSchema(); + + GenericRecord variant = recordFromMap(schema.unannotatedVariantType, + ImmutableMap.of("metadata", t.metadata, "value", t.value)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(t.metadata, t.value, actualVariant); + } + } + + @Test + public void testUnshreddedVariantsWithShreddingSchema() throws Exception { + for (TestCase t : testCases) { + // Parquet schema has a shredded field, but it is unused by the data. + TestSchema schema = new TestSchema(shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + + GenericRecord variant = recordFromMap(schema.unannotatedVariantType, + ImmutableMap.of("metadata", t.metadata, "value", t.value)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(t.metadata, t.value, actualVariant); + } + } + + @Test + public void testShreddedVariantPrimitives() throws IOException { + for (PrimitiveCase p : PRIMITIVES) { + if (p.avroValue == null) { + // null isn't a valid type for shredding. + continue; + } + TestSchema schema = new TestSchema(shreddedType(new Variant(p.value, EMPTY_METADATA))); + + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + EMPTY_METADATA, + "typed_value", + p.avroValue)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, p.value, actualVariant); + } + } + + @Test + public void testNullValueAndNullTypedValue() throws IOException { + TestSchema schema = new TestSchema(shreddedPrimitive(PrimitiveTypeName.INT32)); + + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, NULL_VALUE, actualVariant); + } + + @Test + public void testMissingValueColumn() throws IOException { + GroupType variantType = + Types.buildGroup(Type.Repetition.REQUIRED) + .as(LogicalTypeAnnotation.variantType((byte) 1)) + .id(2) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .addField(shreddedPrimitive(PrimitiveTypeName.INT32)) + .named("var"); + + GroupType unannotatedVariantType = Types.buildGroup(Type.Repetition.REQUIRED) + .id(2) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .addField(shreddedPrimitive(PrimitiveTypeName.INT32)) + .named("var"); + + TestSchema schema = new TestSchema(variantType, unannotatedVariantType); + + GenericRecord variant = + recordFromMap(unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", 34)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, variant(34), actualVariant); + } + + @Test + public void testValueAndTypedValueConflict() throws IOException { + TestSchema schema = new TestSchema(shreddedPrimitive(PrimitiveTypeName.INT32)); + + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + EMPTY_METADATA, + "value", + variant("str"), + "typed_value", + 34)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + assertThrows(() -> writeAndRead(schema, record), IllegalArgumentException.class, + "Invalid variant, conflicting value and typed_value"); + } + + @Test + public void testUnsignedInteger() { + TestSchema schema = new TestSchema(shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, false))); + + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + assertThrows(() -> writeAndRead(schema, record), + UnsupportedOperationException.class, "Unsupported shredded value type: INTEGER(32,false)"); + } + + @Test + public void testFixedLengthByteArray() { + Type shreddedType = Types.optional(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY).length(4).named("typed_value"); + TestSchema schema = new TestSchema(shreddedType); + + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + assertThrows(() -> writeAndRead(schema, record), + UnsupportedOperationException.class, + "Unsupported shredded value type: optional fixed_len_byte_array(4) typed_value"); + } + + @Test + public void testShreddedObject() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("value", NULL_VALUE)); + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("typed_value", "")); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + Variant expectedValue = VariantBuilder.parseJson( + "{\"a\": null, \"b\": \"\"}"); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue.getValue(), actualVariant); + } + + @Test + public void testShreddedObjectMissingValueColumn() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = Types.buildGroup(Type.Repetition.REQUIRED) + .id(2) + .as(LogicalTypeAnnotation.variantType((byte) 1)) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .addField(objectFields) + .named("var"); + + GroupType unannotatedVariantType = Types.buildGroup(Type.Repetition.REQUIRED) + .id(2) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .addField(objectFields) + .named("var"); + + TestSchema schema = new TestSchema(variantType, unannotatedVariantType); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("value", + variant(1234))); + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + Variant expectedValue = VariantBuilder.parseJson( + "{\"a\": 1234, \"b\": \"iceberg\"}"); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue.getValue(), actualVariant); + } + + @Test + public void testShreddedObjectMissingField() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, + ImmutableMap.of("value", variant(b -> b.appendBoolean(false)))); + // value and typed_value are null, but a struct for b is required + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of()); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + Variant expectedValue = VariantBuilder.parseJson( + "{\"a\": false}"); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue.getValue(), actualVariant); + } + + @Test + public void testEmptyShreddedObject() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of()); // missing + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of()); // missing + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + Variant expectedValue = VariantBuilder.parseJson("{}"); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue.getValue(), actualVariant); + } + + @Test + public void testShreddedObjectMissingFieldValueColumn() throws IOException { + // field groups do not have value + GroupType fieldA = + Types.buildGroup(Type.Repetition.REQUIRED) + .addField(shreddedPrimitive(PrimitiveTypeName.INT32)) + .named("a"); + GroupType fieldB = + Types.buildGroup(Type.Repetition.REQUIRED) + .addField(shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)) + .named("b"); + GroupType objectFields = + Types.buildGroup(Type.Repetition.OPTIONAL).addFields(fieldA, fieldB).named("typed_value"); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of()); // typed_value=null + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("b", 1, 0)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + + + @Test + public void testShreddedObjectMissingTypedValue() throws IOException { + // field groups do not have typed_value + GroupType fieldA = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .named("a"); + GroupType fieldB = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .named("b"); + GroupType objectFields = + Types.buildGroup(Type.Repetition.OPTIONAL).addFields(fieldA, fieldB).named("typed_value"); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of()); // value=null + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("value", variant("iceberg"))); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("b", 1, 0)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + + @Test + public void testShreddedObjectWithinShreddedObject() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType innerFields = objectFields(fieldA, fieldB); + GroupType fieldC = shreddedField("c", innerFields); + GroupType fieldD = shreddedField("d", shreddedPrimitive(PrimitiveTypeName.DOUBLE)); + GroupType outerFields = objectFields(fieldC, fieldD); + TestSchema schema = new TestSchema(outerFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("typed_value", 34)); + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord inner = recordFromMap(innerFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord recordC = recordFromMap(fieldC, ImmutableMap.of("typed_value", inner)); + GenericRecord recordD = recordFromMap(fieldD, ImmutableMap.of("typed_value", -0.0D)); + GenericRecord outer = recordFromMap(outerFields, ImmutableMap.of("c", recordC, "d", recordD)); + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", outer)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList outerEntries = new ArrayList<>(); + + outerEntries.add(new VariantBuilder.FieldEntry("c", 2, b.getWritePos() - startWritePos)); + ArrayList innerEntries = new ArrayList<>(); + innerEntries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendLong(34); + innerEntries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, innerEntries); + + outerEntries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendDouble(-0.0D); + b.finishWritingObject(startWritePos, outerEntries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + + @Test + public void testShreddedObjectWithOptionalFieldStructs() throws IOException { + // fields use an incorrect OPTIONAL struct of value and typed_value to test definition levels + GroupType fieldA = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.INT32)) + .named("a"); + GroupType fieldB = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)) + .named("b"); + GroupType fieldC = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.DOUBLE)) + .named("c"); + GroupType fieldD = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.BOOLEAN)) + .named("d"); + GroupType objectFields = + Types.buildGroup(Type.Repetition.OPTIONAL) + .addFields(fieldA, fieldB, fieldC, fieldD) + .named("typed_value"); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("value", variant(34))); + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord recordC = recordFromMap(fieldC, ImmutableMap.of()); // c.value and c.typed_value are missing + GenericRecord fields = + recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB, "c", recordC)); // d is missing + GenericRecord variant = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendLong(34); + entries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + + @Test + public void testPartiallyShreddedObjectOutOfOrder() throws IOException { + // The schema is not in alphabetical order, and the unshredded field is also not. + // The resulting object should be logically the same (i.e. the offset list must be in + // alphabetical order), but the layout of the values in the binary may differ. + GroupType fieldD = shreddedField("d", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldD, fieldA); + TestSchema schema = new TestSchema(objectFields); + + byte[] baseObject = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendDate(12345); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord recordA = recordFromMap(fieldD, ImmutableMap.of("value", NULL_VALUE)); + GenericRecord recordB = recordFromMap(fieldA, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("d", recordA, "a", recordB)); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + TEST_METADATA, + "value", + baseObject, + "typed_value", + fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + entries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendDate(12345); + entries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendNull(); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + + @Test + public void testPartiallyShreddedObject() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + byte[] baseObject = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendDate(12345); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("value", NULL_VALUE)); + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + TEST_METADATA, + "value", + baseObject, + "typed_value", + fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendNull(); + entries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + entries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendDate(12345); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + @Test + public void testPartiallyShreddedObjectFieldConflict() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + byte[] baseObject = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendDate(12345); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("value", NULL_VALUE)); + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + TEST_METADATA, + "value", + baseObject, + "typed_value", + fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + // The reader is expected to ignore fields in value that are present in the typed_value schema. + // This matches Iceberg behaviour, but we could also consider failing with an error. + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendNull(); + entries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + + @Test + public void testPartiallyShreddedObjectMissingFieldConflict() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + byte[] baseObject = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendDate(12345); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("value", NULL_VALUE)); + // value and typed_value are null, but a struct for b is required + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of()); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + TEST_METADATA, + "value", + baseObject, + "typed_value", + fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + + assertEquals(actual.get("id"), 1); + + // The reader is expected to ignore fields in value that are present in the typed_value schema, even if they are + // missing in typed_value. + byte[] expectedValue = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendNull(); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, expectedValue, actualVariant); + } + + @Test + public void testNonObjectWithNullShreddedFields() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of("metadata", TEST_METADATA, "value", variant(34))); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(TEST_METADATA, variant(34), actualVariant); + } + + @Test + public void testNonObjectWithNonNullShreddedFields() { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of("value", NULL_VALUE)); + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of("value", variant(9876543210L))); + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + TEST_METADATA, + "value", + variant(34), + "typed_value", + fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + assertThrows(() -> writeAndRead(schema, record), + IllegalArgumentException.class, + "Invalid variant, conflicting value and typed_value"); + } + + @Test + public void testEmptyPartiallyShreddedObjectConflict() { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + TestSchema schema = new TestSchema(objectFields); + + GenericRecord recordA = recordFromMap(fieldA, ImmutableMap.of()); // missing + GenericRecord recordB = recordFromMap(fieldB, ImmutableMap.of()); // missing + GenericRecord fields = recordFromMap(objectFields, ImmutableMap.of("a", recordA, "b", recordB)); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + TEST_METADATA, + "value", + NULL_VALUE, // conflicting non-object + "typed_value", + fields)); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + assertThrows(() -> writeAndRead(schema, record), + IllegalArgumentException.class, + "Invalid variant, conflicting value and typed_value"); + } + + @Test + public void testMixedRecords() throws IOException { + // tests multiple rows to check that Parquet columns are correctly advanced + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType innerFields = objectFields(fieldA, fieldB); + GroupType fieldC = shreddedField("c", innerFields); + GroupType fieldD = shreddedField("d", shreddedPrimitive(PrimitiveTypeName.DOUBLE)); + GroupType outerFields = objectFields(fieldC, fieldD); + TestSchema schema = new TestSchema(outerFields); + + GenericRecord zero = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 0)); + + GenericRecord a1 = recordFromMap(fieldA, ImmutableMap.of()); // missing + GenericRecord b1 = recordFromMap(fieldB, ImmutableMap.of("typed_value", "iceberg")); + GenericRecord inner1 = recordFromMap(innerFields, ImmutableMap.of("a", a1, "b", b1)); + GenericRecord c1 = recordFromMap(fieldC, ImmutableMap.of("typed_value", inner1)); + GenericRecord d1 = recordFromMap(fieldD, ImmutableMap.of()); // missing + GenericRecord outer1 = recordFromMap(outerFields, ImmutableMap.of("c", c1, "d", d1)); + GenericRecord variant1 = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", outer1)); + GenericRecord one = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant1)); + + byte[] expectedOne = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList outerEntries = new ArrayList<>(); + outerEntries.add(new VariantBuilder.FieldEntry("c", 2, b.getWritePos() - startWritePos)); + ArrayList innerEntries = new ArrayList<>(); + innerEntries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, innerEntries); + b.finishWritingObject(startWritePos, outerEntries); + }); + + GenericRecord c2 = recordFromMap(fieldC, ImmutableMap.of("value", variant(8))); + GenericRecord d2 = recordFromMap(fieldD, ImmutableMap.of("typed_value", -0.0D)); + GenericRecord outer2 = recordFromMap(outerFields, ImmutableMap.of("c", c2, "d", d2)); + GenericRecord variant2 = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", outer2)); + GenericRecord two = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 2, "var", variant2)); + + byte[] expectedTwo = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList outerEntries = new ArrayList<>(); + outerEntries.add(new VariantBuilder.FieldEntry("c", 2, b.getWritePos() - startWritePos)); + b.appendLong(8); + outerEntries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendDouble(-0.0D); + b.finishWritingObject(startWritePos, outerEntries); + }); + + GenericRecord a3 = recordFromMap(fieldA, ImmutableMap.of("typed_value", 34)); + GenericRecord b3 = recordFromMap(fieldB, ImmutableMap.of("value", variant(""))); + GenericRecord inner3 = recordFromMap(innerFields, ImmutableMap.of("a", a3, "b", b3)); + GenericRecord c3 = recordFromMap(fieldC, ImmutableMap.of("typed_value", inner3)); + GenericRecord d3 = recordFromMap(fieldD, ImmutableMap.of("typed_value", 0.0D)); + GenericRecord outer3 = recordFromMap(outerFields, ImmutableMap.of("c", c3, "d", d3)); + GenericRecord variant3 = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", outer3)); + GenericRecord three = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 3, "var", variant3)); + + byte[] expectedThree = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList outerEntries = new ArrayList<>(); + outerEntries.add(new VariantBuilder.FieldEntry("c", 2, b.getWritePos() - startWritePos)); + ArrayList innerEntries = new ArrayList<>(); + innerEntries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendLong(34); + innerEntries.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString(""); + b.finishWritingObject(startWritePos, innerEntries); + outerEntries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendDouble(0.0D); + b.finishWritingObject(startWritePos, outerEntries); + }); + + List records = writeAndRead(schema, Arrays.asList(zero, one, two, three)); + assertEquals(records.size(), 4); + + GenericRecord actualZero = records.get(0); + assertEquals(actualZero.get("id"), 0); + assertEquals(actualZero.get("var"), null); + + GenericRecord actualOne = records.get(1); + assertEquals(actualOne.get("id"), 1); + GenericRecord actualOneVariant = (GenericRecord) actualOne.get("var"); + assertEquivalent(TEST_METADATA, expectedOne, actualOneVariant); + + GenericRecord actualTwo = records.get(2); + assertEquals(actualTwo.get("id"), 2); + GenericRecord actualTwoVariant = (GenericRecord) actualTwo.get("var"); + assertEquivalent(TEST_METADATA, expectedTwo, actualTwoVariant); + + GenericRecord actualThree = records.get(3); + assertEquals(actualThree.get("id"), 3); + GenericRecord actualThreeVariant = (GenericRecord) actualThree.get("var"); + assertEquivalent(TEST_METADATA, expectedThree, actualThreeVariant); + } + + @Test + public void testSimpleArray() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = element(shreddedType); + TestSchema schema = new TestSchema(list(elementType)); + + List arr = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", "comedy")), + recordFromMap(elementType, ImmutableMap.of("typed_value", "drama"))); + + GenericRecord var = + recordFromMap( + schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", arr)); + GenericRecord row = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var)); + + byte[] expected = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + entries.add(b.getWritePos() - startWritePos); + b.appendString("comedy"); + entries.add(b.getWritePos() - startWritePos); + b.appendString("drama"); + b.finishWritingArray(startWritePos, entries); + }); + + GenericRecord actual = writeAndRead(schema, row); + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, expected, actualVariant); + } + + @Test + public void testNullArray() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + TestSchema schema = new TestSchema(list(element(shreddedType))); + + GenericRecord var = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", + EMPTY_METADATA, + "value", + NULL_VALUE)); + GenericRecord row = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var)); + + GenericRecord actual = writeAndRead(schema, row); + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, NULL_VALUE, actualVariant); + } + + @Test + public void testEmptyArray() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + TestSchema schema = new TestSchema(list(element(shreddedType))); + + List arr = Arrays.asList(); + GenericRecord var = + recordFromMap( + schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", arr)); + GenericRecord row = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var)); + + GenericRecord actual = writeAndRead(schema, row); + + byte[] emptyArray = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + b.finishWritingArray(startWritePos, entries); + }); + + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, emptyArray, actualVariant); + } + + @Test + public void testArrayWithNull() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = element(shreddedType); + TestSchema schema = new TestSchema(list(elementType)); + + List arr = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", "comedy")), + recordFromMap(elementType, ImmutableMap.of("value", NULL_VALUE)), + recordFromMap(elementType, ImmutableMap.of("typed_value", "drama"))); + + GenericRecord var = + recordFromMap( + schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", arr)); + GenericRecord row = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var)); + + byte[] expected = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + entries.add(b.getWritePos() - startWritePos); + b.appendString("comedy"); + entries.add(b.getWritePos() - startWritePos); + b.appendNull(); + entries.add(b.getWritePos() - startWritePos); + b.appendString("drama"); + b.finishWritingArray(startWritePos, entries); + }); + + GenericRecord actual = writeAndRead(schema, row); + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, expected, actualVariant); + } + + @Test + public void testNestedArray() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = element(shreddedType); + GroupType outerElementType = element(list(elementType)); + TestSchema schema = new TestSchema(list(outerElementType)); + + List inner1 = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", "comedy")), + recordFromMap(elementType, ImmutableMap.of("typed_value", "drama"))); + List outer1 = + Arrays.asList( + recordFromMap(outerElementType, ImmutableMap.of("typed_value", inner1)), + recordFromMap(outerElementType, ImmutableMap.of("typed_value", Arrays.asList()))); + GenericRecord var = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", outer1)); + GenericRecord row = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var)); + + byte[] expected = variant(EMPTY_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList outerEntries= new ArrayList<>(); + outerEntries.add(b.getWritePos() - startWritePos); + ArrayList entries1 = new ArrayList<>(); + entries1.add(b.getWritePos() - startWritePos); + b.appendString("comedy"); + entries1.add(b.getWritePos() - startWritePos); + b.appendString("drama"); + b.finishWritingArray(startWritePos, entries1); + outerEntries.add(b.getWritePos() - startWritePos); + ArrayList entries2 = new ArrayList<>(); + b.finishWritingArray(b.getWritePos(), entries2); + b.finishWritingArray(startWritePos, outerEntries); + }); + + GenericRecord actual = writeAndRead(schema, row); + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, expected, actualVariant); + } + + @Test + public void testArrayWithNestedObject() throws IOException { + GroupType fieldA = shreddedField("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = shreddedField("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType shreddedFields = objectFields(fieldA, fieldB); + GroupType elementType = element(shreddedFields); + GroupType listType = list(elementType); + TestSchema schema = new TestSchema(listType); + + // Row 1 with nested fully shredded object + GenericRecord shredded1 = + recordFromMap( + shreddedFields, + ImmutableMap.of( + "a", + recordFromMap(fieldA, ImmutableMap.of("typed_value", 1)), + "b", + recordFromMap(fieldB, ImmutableMap.of("typed_value", "comedy")))); + GenericRecord shredded2 = + recordFromMap( + shreddedFields, + ImmutableMap.of( + "a", + recordFromMap(fieldA, ImmutableMap.of("typed_value", 2)), + "b", + recordFromMap(fieldB, ImmutableMap.of("typed_value", "drama")))); + List arr1 = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", shredded1)), + recordFromMap(elementType, ImmutableMap.of("typed_value", shredded2))); + GenericRecord var1 = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", arr1)); + GenericRecord row1 = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var1)); + + byte[] expected1 = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList arrayEntries = new ArrayList<>(); + arrayEntries.add(b.getWritePos() - startWritePos); + ArrayList objEntries1 = new ArrayList<>(); + objEntries1.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendLong(1); + objEntries1.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString("comedy"); + b.finishWritingObject(startWritePos, objEntries1); + + int startWritePos2 = b.getWritePos(); + arrayEntries.add(b.getWritePos() - startWritePos); + ArrayList objEntries2 = new ArrayList<>(); + objEntries2.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos2)); + b.appendLong(2); + objEntries2.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos2)); + b.appendString("drama"); + b.finishWritingObject(startWritePos2, objEntries2); + + b.finishWritingArray(startWritePos, arrayEntries); + }); + + // Row 2 with nested partially shredded object + GenericRecord shredded3 = + recordFromMap( + shreddedFields, + ImmutableMap.of( + "a", + recordFromMap(fieldA, ImmutableMap.of("typed_value", 3)), + "b", + recordFromMap(fieldB, ImmutableMap.of("typed_value", "action")))); + + byte[] baseObject3 = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("c", 2, b.getWritePos() - startWritePos)); + b.appendString("str"); + b.finishWritingObject(startWritePos, entries); + }); + + GenericRecord shredded4 = + recordFromMap( + shreddedFields, + ImmutableMap.of( + "a", + recordFromMap(fieldA, ImmutableMap.of("typed_value", 4)), + "b", + recordFromMap(fieldB, ImmutableMap.of("typed_value", "horror")))); + + byte[] baseObject4 = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendDate(12345); + b.finishWritingObject(startWritePos, entries); + }); + + List arr2 = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("value", baseObject3, "typed_value", shredded3)), + recordFromMap(elementType, ImmutableMap.of("value", baseObject4, "typed_value", shredded4))); + GenericRecord var2 = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", arr2)); + GenericRecord row2 = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 2, "var", var2)); + + byte[] expected2 = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList arrayEntries = new ArrayList<>(); + arrayEntries.add(b.getWritePos() - startWritePos); + ArrayList objEntries1 = new ArrayList<>(); + objEntries1.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendLong(3); + objEntries1.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos)); + b.appendString("action"); + objEntries1.add(new VariantBuilder.FieldEntry("c", 2, b.getWritePos() - startWritePos)); + b.appendString("str"); + b.finishWritingObject(startWritePos, objEntries1); + + int startWritePos2 = b.getWritePos(); + arrayEntries.add(b.getWritePos() - startWritePos); + ArrayList objEntries2 = new ArrayList<>(); + objEntries2.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos2)); + b.appendLong(4); + objEntries2.add(new VariantBuilder.FieldEntry("b", 1, b.getWritePos() - startWritePos2)); + b.appendString("horror"); + objEntries2.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos2)); + b.appendDate(12345); + b.finishWritingObject(startWritePos2, objEntries2); + + b.finishWritingArray(startWritePos, arrayEntries); + }); + + + // verify + List actual = writeAndRead(schema, Arrays.asList(row1, row2)); + GenericRecord actual1 = actual.get(0); + assertEquals(actual1.get("id"), 1); + GenericRecord actualVariant1 = (GenericRecord) actual1.get("var"); + assertEquivalent(TEST_METADATA, expected1, actualVariant1); + + GenericRecord actual2 = actual.get(1); + assertEquals(actual2.get("id"), 2); + GenericRecord actualVariant2 = (GenericRecord) actual2.get("var"); + assertEquivalent(TEST_METADATA, expected2, actualVariant2); + } + + @Test + public void testArrayWithNonArray() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = element(shreddedType); + TestSchema schema = new TestSchema(list(elementType)); + + List arr1 = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", "comedy")), + recordFromMap(elementType, ImmutableMap.of("typed_value", "drama"))); + GenericRecord var1 = + recordFromMap( + schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", arr1)); + GenericRecord row1 = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var1)); + + byte[] expectedArray1 = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + entries.add(b.getWritePos() - startWritePos); + b.appendString("comedy"); + entries.add(b.getWritePos() - startWritePos); + b.appendString("drama"); + b.finishWritingArray(startWritePos, entries); + }); + + GenericRecord var2 = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of( + "metadata", EMPTY_METADATA, "value", variant(34))); + GenericRecord row2 = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 2, "var", var2)); + + byte[] expectedValue2 = variant(34); + + GenericRecord var3 = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "value", TEST_OBJECT)); + GenericRecord row3 = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 3, "var", var3)); + + byte[] expectedObject3 = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries = new ArrayList<>(); + entries.add(new VariantBuilder.FieldEntry("a", 0, b.getWritePos() - startWritePos)); + b.appendNull(); + entries.add(new VariantBuilder.FieldEntry("d", 3, b.getWritePos() - startWritePos)); + b.appendString("iceberg"); + b.finishWritingObject(startWritePos, entries); + }); + + // Test array is read properly after a non-array + List arr4 = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", "action")), + recordFromMap(elementType, ImmutableMap.of("typed_value", "horror"))); + GenericRecord var4 = + recordFromMap(schema.unannotatedVariantType, ImmutableMap.of("metadata", TEST_METADATA, "typed_value", arr4)); + GenericRecord row4 = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 4, "var", var4)); + + byte[] expectedArray4 = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + entries.add(b.getWritePos() - startWritePos); + b.appendString("action"); + entries.add(b.getWritePos() - startWritePos); + b.appendString("horror"); + b.finishWritingArray(startWritePos, entries); + }); + + List actual = writeAndRead(schema, Arrays.asList(row1, row2, row3, row4)); + GenericRecord actual1 = actual.get(0); + assertEquals(actual1.get("id"), 1); + GenericRecord actualVariant1 = (GenericRecord) actual1.get("var"); + assertEquivalent(EMPTY_METADATA, expectedArray1, actualVariant1); + + GenericRecord actual2 = actual.get(1); + assertEquals(actual2.get("id"), 2); + GenericRecord actualVariant2 = (GenericRecord) actual2.get("var"); + assertEquivalent(EMPTY_METADATA, expectedValue2, actualVariant2); + + GenericRecord actual3 = actual.get(2); + assertEquals(actual3.get("id"), 3); + GenericRecord actualVariant3 = (GenericRecord) actual3.get("var"); + assertEquivalent(TEST_METADATA, expectedObject3, actualVariant3); + + GenericRecord actual4 = actual.get(3); + assertEquals(actual4.get("id"), 4); + GenericRecord actualVariant4 = (GenericRecord) actual4.get("var"); + assertEquivalent(TEST_METADATA, expectedArray4, actualVariant4); + } + + @Test + public void testArrayMissingValueColumn() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = element(shreddedType); + GroupType unannotatedVariantType = + Types.buildGroup(Type.Repetition.OPTIONAL) + .id(2) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .addField(list(elementType)) + .named("var"); + + GroupType variantType = + Types.buildGroup(Type.Repetition.OPTIONAL) + .id(2) + .as(LogicalTypeAnnotation.variantType((byte) 1)) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .addField(list(elementType)) + .named("var"); + + TestSchema schema = new TestSchema(variantType, unannotatedVariantType); + + List arr = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", "comedy")), + recordFromMap(elementType, ImmutableMap.of("typed_value", "drama"))); + GenericRecord var = + recordFromMap( + schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", arr)); + GenericRecord row = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var)); + + byte[] expectedArray = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + entries.add(b.getWritePos() - startWritePos); + b.appendString("comedy"); + entries.add(b.getWritePos() - startWritePos); + b.appendString("drama"); + b.finishWritingArray(startWritePos, entries); + }); + + GenericRecord actual = writeAndRead(schema, row); + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, expectedArray, actualVariant); + } + + @Test + public void testArrayMissingElementValueColumn() throws IOException { + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = + Types.buildGroup(Type.Repetition.REQUIRED).addField(shreddedType).named("element"); + + TestSchema schema = new TestSchema(list(elementType)); + + List arr = + Arrays.asList( + recordFromMap(elementType, ImmutableMap.of("typed_value", "comedy")), + recordFromMap(elementType, ImmutableMap.of("typed_value", "drama"))); + GenericRecord var = + recordFromMap( + schema.unannotatedVariantType, ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", arr)); + GenericRecord row = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", var)); + + byte[] expectedArray = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + entries.add(b.getWritePos() - startWritePos); + b.appendString("comedy"); + entries.add(b.getWritePos() - startWritePos); + b.appendString("drama"); + b.finishWritingArray(startWritePos, entries); + }); + + GenericRecord actual = writeAndRead(schema, row); + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, expectedArray, actualVariant); + } + + @Test + public void testArrayWithElementNullValueAndNullTypedValue() throws IOException { + // Test the invalid case that both value and typed_value of an element are null + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = element(shreddedType); + + TestSchema schema = new TestSchema(list(elementType)); + + GenericRecord element = recordFromMap(elementType, ImmutableMap.of()); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", Arrays.asList(element))); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + byte[] expectedArray = variant(TEST_METADATA, b -> { + int startWritePos = b.getWritePos(); + ArrayList entries= new ArrayList<>(); + entries.add(b.getWritePos() - startWritePos); + b.appendNull(); + b.finishWritingArray(startWritePos, entries); + }); + + GenericRecord actual = writeAndRead(schema, record); + assertEquals(actual.get("id"), 1); + GenericRecord actualVariant = (GenericRecord) actual.get("var"); + assertEquivalent(EMPTY_METADATA, expectedArray, actualVariant); + } + + @Test + public void testArrayWithElementValueTypedValueConflict() { + // Test the invalid case that both value and typed_value of an element are not null + Type shreddedType = shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + GroupType elementType = element(shreddedType); + TestSchema schema = new TestSchema(list(elementType)); + + GenericRecord element = + recordFromMap(elementType, ImmutableMap.of("value", variant(3), "typed_value", "comedy")); + GenericRecord variant = + recordFromMap( + schema.unannotatedVariantType, + ImmutableMap.of("metadata", EMPTY_METADATA, "typed_value", Arrays.asList(element))); + GenericRecord record = recordFromMap(schema.unannotatedParquetSchema, ImmutableMap.of("id", 1, "var", variant)); + + assertThrows(() -> writeAndRead(schema, record), + IllegalArgumentException.class, + "Invalid variant, conflicting value and typed_value"); + } + + /** + * This is a custom Parquet writer builder that injects a specific Parquet schema and then uses + * the Avro object model. This ensures that the Parquet file's schema is exactly what was passed. + */ + private static class TestWriterBuilder + extends ParquetWriter.Builder { + private TestSchema schema = null; + + protected TestWriterBuilder(Path path) { + super(path); + } + + TestWriterBuilder withFileType(TestSchema schema) { + this.schema = schema; + return self(); + } + + @Override + protected TestWriterBuilder self() { + return this; + } + + @Override + protected WriteSupport getWriteSupport(Configuration conf) { + return new AvroWriteSupport<>(schema.parquetSchema, avroSchema(schema.unannotatedParquetSchema), GenericData.get()); + } + } + + GenericRecord writeAndRead(TestSchema testSchema, GenericRecord record) + throws IOException { + List result = writeAndRead(testSchema, Arrays.asList(record)); + assert(result.size() == 1); + return result.get(0); + } + + List writeAndRead(TestSchema testSchema, List records) + throws IOException { + // Copied from TestSpecificReadWrite.java. Why does it do these weird things? + File tmp = File.createTempFile(getClass().getSimpleName(), ".tmp"); + tmp.deleteOnExit(); + tmp.delete(); + Path path = new Path(tmp.getPath()); + + try (ParquetWriter writer = + new TestWriterBuilder(path).withFileType(testSchema).withConf(CONF).build()) { + for (GenericRecord record : records) { + writer.write(record); + } + } + + Configuration conf = new Configuration(); + // We need to set an explicit read schema, because Avro wrote the shredding schema as the Avro schema in the + // write, and it will use that by default. If we write using a proper shredding writer, the Avro schema + // should just contain a record, and we won't need this. + AvroReadSupport.setAvroReadSchema(conf, avroSchema(testSchema.parquetSchema)); + AvroParquetReader reader = new AvroParquetReader(conf, path); + + ArrayList result = new ArrayList<>(); + GenericRecord next = reader.read(); + while (next != null) { + result.add(next); + next = reader.read(); + } + return result; + } + + private static MessageType parquetSchema(Type variantType) { + return Types.buildMessage() + .required(PrimitiveTypeName.INT32) + .id(1) + .named("id") + .addField(variantType) + .named("table"); + } + + private static Type shreddedType(Variant value) { + switch (value.getType()) { + case BOOLEAN: + return shreddedPrimitive(PrimitiveTypeName.BOOLEAN); + case BYTE: + return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8)); + case SHORT: + return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16)); + case INT: + return shreddedPrimitive(PrimitiveTypeName.INT32); + case LONG: + return shreddedPrimitive(PrimitiveTypeName.INT64); + case FLOAT: + return shreddedPrimitive(PrimitiveTypeName.FLOAT); + case DOUBLE: + return shreddedPrimitive(PrimitiveTypeName.DOUBLE); + case DECIMAL: + int precision = value.getDecimal().precision(); + int scale = value.getDecimal().scale (); + if (precision <= 9) { + return shreddedPrimitive( + PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(scale, 9)); + } else if (precision <= 18) { + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.decimalType(scale, 18)); + } else { + return shreddedPrimitive( + PrimitiveTypeName.BINARY, LogicalTypeAnnotation.decimalType(scale, 38)); + } + case DATE: + return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.dateType()); + case TIMESTAMP: + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS)); + case TIMESTAMP_NTZ: + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(false, TimeUnit.MICROS)); + case BINARY: + return shreddedPrimitive(PrimitiveTypeName.BINARY); + case STRING: + return shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + case TIME: + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.timeType(false, TimeUnit.MICROS)); + case TIMESTAMP_NANOS: + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(true, TimeUnit.NANOS)); + case TIMESTAMP_NANOS_NTZ: + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(false, TimeUnit.NANOS)); + case UUID: + return shreddedPrimitive( + PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, LogicalTypeAnnotation.uuidType()); + } + + throw new UnsupportedOperationException("Unsupported shredding type: " + value.getType()); + } + + private static Object toAvroValue(Variant v) { + switch (v.getType()) { + case BOOLEAN: + return v.getBoolean(); + case BYTE: + return v.getByte(); + case SHORT: + return v.getShort(); + case INT: + return v.getInt(); + case LONG: + return v.getLong(); + case FLOAT: + return v.getFloat(); + case DOUBLE: + return v.getDouble(); + case DECIMAL: + int precision = v.getDecimal().precision(); + int scale = v.getDecimal().scale (); + if (precision <= 9) { + return v.getDecimal().unscaledValue().intValueExact(); + } else if (precision <= 18) { + return v.getDecimal().unscaledValue().longValueExact(); + } else { + return v.getDecimal().unscaledValue().toByteArray(); + } + case DATE: + return v.getInt(); + case TIMESTAMP: + return v.getLong(); + case TIMESTAMP_NTZ: + return v.getLong(); + case BINARY: + return v.getBinary(); + case STRING: + return v.getString(); + case TIME: + return v.getLong(); + case TIMESTAMP_NANOS: + return v.getLong(); + case TIMESTAMP_NANOS_NTZ: + return v.getLong(); + case UUID: + return v.getUUID(); + default: + throw new UnsupportedOperationException("Unsupported shredding type: " + v.getType()); + } + } + + private static GroupType variant(String name, int fieldId) { + return Types.buildGroup(Type.Repetition.REQUIRED) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType((byte) 1)) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + + private static GroupType unannotatedVariant(String name, int fieldId) { + return Types.buildGroup(Type.Repetition.REQUIRED) + .id(fieldId) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + + private static GroupType variant(String name, int fieldId, Type shreddedType) { + checkShreddedType(shreddedType); + return Types.buildGroup(Type.Repetition.OPTIONAL) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType((byte) 1)) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + // Shredding schema with no Variant logical annotation. Needed in order to construct the Avro schema. + private static GroupType unannotatedVariant(String name, int fieldId, Type shreddedType) { + checkShreddedType(shreddedType); + return Types.buildGroup(Type.Repetition.OPTIONAL) + .id(fieldId) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static void checkField(GroupType fieldType) { + Preconditions.checkArgument( + fieldType.isRepetition(Type.Repetition.REQUIRED), + "Invalid field type repetition: %s should be REQUIRED", + fieldType.getRepetition()); + } + + private static GroupType objectFields(GroupType... fields) { + for (GroupType fieldType : fields) { + checkField(fieldType); + } + + return Types.buildGroup(Type.Repetition.OPTIONAL).addFields(fields).named("typed_value"); + } + private static Type shreddedPrimitive(PrimitiveTypeName primitive) { + return Types.optional(primitive).named("typed_value"); + } + + private static Type shreddedPrimitive( + PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { + return Types.optional(primitive).as(annotation).named("typed_value"); + } + + /** Creates an Avro record from a map of field name to value. */ + private static GenericRecord recordFromMap(GroupType type, Map fields) { + GenericRecord record = new GenericData.Record(avroSchema(type)); + for (Map.Entry entry : fields.entrySet()) { + record.put(entry.getKey(), entry.getValue()); + } + return record; + } + + // Required configuration to convert between Avro and Parquet schemas with 3-level list structure + private static final ParquetConfiguration CONF = + new PlainParquetConfiguration( + ImmutableMap.of( + AvroWriteSupport.WRITE_OLD_LIST_STRUCTURE, + "false", + AvroSchemaConverter.ADD_LIST_ELEMENT_RECORDS, + "false")); + + + private static GroupType shreddedField(String name, Type shreddedType) { + checkShreddedType(shreddedType); + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static GroupType element(Type shreddedType) { + return shreddedField("element", shreddedType); + } + + private static GroupType list(GroupType elementType) { + return Types.optionalList().element(elementType).named("typed_value"); + } + + private static void checkListType(GroupType listType) { + // Check the list is a 3-level structure + Preconditions.checkArgument( + listType.getFieldCount() == 1 + && listType.getFields().get(0).isRepetition(Type.Repetition.REPEATED), + "Invalid list type: does not contain single repeated field: %s", + listType); + + GroupType repeated = listType.getFields().get(0).asGroupType(); + Preconditions.checkArgument( + repeated.getFieldCount() == 1 + && repeated.getFields().get(0).isRepetition(Type.Repetition.REQUIRED), + "Invalid list type: does not contain single required subfield: %s", + listType); + } + + private static org.apache.avro.Schema avroSchema(GroupType schema) { + if (schema instanceof MessageType) { + return new AvroSchemaConverter(CONF).convert((MessageType) schema); + + } else { + MessageType wrapped = Types.buildMessage().addField(schema).named("table"); + org.apache.avro.Schema avro = + new AvroSchemaConverter(CONF).convert(wrapped).getFields().get(0).schema(); + switch (avro.getType()) { + case RECORD: + return avro; + case UNION: + return avro.getTypes().get(1); + } + + throw new IllegalArgumentException("Invalid converted type: " + avro); + } + } + + private static void checkShreddedType(Type shreddedType) { + Preconditions.checkArgument( + shreddedType.getName().equals("typed_value"), + "Invalid shredded type name: %s should be typed_value", + shreddedType.getName()); + Preconditions.checkArgument( + shreddedType.isRepetition(Type.Repetition.OPTIONAL), + "Invalid shredded type repetition: %s should be OPTIONAL", + shreddedType.getRepetition()); + } + + // Check for the given excpetion with message, possibly wrapped by a ParquetDecodingException + void assertThrows(Callable callable, Class exception, String msg) { + try { + callable.call(); + fail("No exception was thrown. Expected: " + exception.getName()); + } catch (Exception actual) { + try { + if (actual.getClass().equals(ParquetDecodingException.class)) { + assertTrue(actual.getCause().getMessage().contains(msg)); + assertEquals(actual.getCause().getClass(), exception); + } else { + assertTrue(actual.getMessage().contains(msg)); + assertEquals(actual.getClass(), exception); + } + } catch (AssertionError e) { + e.addSuppressed(actual); + throw e; + } + } + } + + // Assert that metadata contains identical bytes to expected, and value is logically equivalent. + // E.g. object fields may be ordered differently in the binary. + void assertEquivalent(byte[] expectedMetadata, byte[] expectedValue, GenericRecord actual) { + assertEquals(ByteBuffer.wrap(expectedMetadata), (ByteBuffer) actual.get("metadata")); + assertEquals(ByteBuffer.wrap(expectedMetadata), (ByteBuffer) actual.get("metadata")); + assertEquivalent(new Variant(expectedValue, expectedMetadata), + new Variant(((ByteBuffer) actual.get("value")).array(), expectedMetadata)); + } + + void assertEquivalent(Variant expected, Variant actual) { + assertEquals(expected.getType(), actual.getType()); + switch (expected.getType()) { + case STRING: + // Short strings may use the compact or extended representation. + assertEquals(expected.getString(), actual.getString()); + break; + case ARRAY: + assertEquals(expected.numArrayElements(), actual.numArrayElements()); + for (int i = 0; i < expected.numArrayElements(); ++i) { + assertEquivalent(expected.getElementAtIndex(i), actual.getElementAtIndex(i)); + } + break; + case OBJECT: + assertEquals(expected.numObjectElements(), actual.numObjectElements()); + for (int i = 0; i < expected.numObjectElements(); ++i) { + Variant.ObjectField expectedField = expected.getFieldAtIndex(i); + Variant.ObjectField actualField = actual.getFieldAtIndex(i); + assertEquals(expectedField.key, actualField.key); + assertEquivalent(expectedField.value, actualField.value); + } + break; + default: + // All other types have a single representation, and must be bit-for-bit identical. + assertArrayEquals(expected.getValue(), actual.getValue()); + } + } +} diff --git a/parquet-column/src/main/java/org/apache/parquet/column/ParquetProperties.java b/parquet-column/src/main/java/org/apache/parquet/column/ParquetProperties.java index 9aaef4b3cf..cb5931581f 100644 --- a/parquet-column/src/main/java/org/apache/parquet/column/ParquetProperties.java +++ b/parquet-column/src/main/java/org/apache/parquet/column/ParquetProperties.java @@ -708,6 +708,7 @@ public Builder withStatisticsEnabled(String columnPath, boolean enabled) { } public Builder withStatisticsEnabled(boolean enabled) { + this.statistics.withDefaultValue(enabled); this.statisticsEnabled = enabled; return this; } diff --git a/parquet-column/src/main/java/org/apache/parquet/schema/LogicalTypeAnnotation.java b/parquet-column/src/main/java/org/apache/parquet/schema/LogicalTypeAnnotation.java index 78b0f9a0c1..749beaa95e 100644 --- a/parquet-column/src/main/java/org/apache/parquet/schema/LogicalTypeAnnotation.java +++ b/parquet-column/src/main/java/org/apache/parquet/schema/LogicalTypeAnnotation.java @@ -56,6 +56,14 @@ protected LogicalTypeAnnotation fromString(List params) { return listType(); } }, + VARIANT { + @Override + protected LogicalTypeAnnotation fromString(List params) { + Preconditions.checkArgument( + params.size() == 1, "Expecting only spec version for variant annotation args: %s", params); + return variantType(Byte.parseByte(params.get(0))); + } + }, STRING { @Override protected LogicalTypeAnnotation fromString(List params) { @@ -269,6 +277,10 @@ public static ListLogicalTypeAnnotation listType() { return ListLogicalTypeAnnotation.INSTANCE; } + public static VariantLogicalTypeAnnotation variantType(byte specVersion) { + return new VariantLogicalTypeAnnotation(specVersion); + } + public static EnumLogicalTypeAnnotation enumType() { return EnumLogicalTypeAnnotation.INSTANCE; } @@ -1128,6 +1140,49 @@ public int hashCode() { } } + public static class VariantLogicalTypeAnnotation extends LogicalTypeAnnotation { + private byte specVersion; + + private VariantLogicalTypeAnnotation(byte specVersion) { + this.specVersion = specVersion; + } + + @Override + public OriginalType toOriginalType() { + // No OriginalType for Variant + return null; + } + + @Override + public Optional accept(LogicalTypeAnnotationVisitor logicalTypeAnnotationVisitor) { + return logicalTypeAnnotationVisitor.visit(this); + } + + @Override + LogicalTypeToken getType() { + return LogicalTypeToken.VARIANT; + } + + public byte getSpecVersion() { + return this.specVersion; + } + + @Override + protected String typeParametersAsString() { + return "(" + specVersion + ")"; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof VariantLogicalTypeAnnotation)) { + return false; + } + + VariantLogicalTypeAnnotation other = (VariantLogicalTypeAnnotation) obj; + return specVersion == other.specVersion; + } + } + /** * Implement this interface to visit a logical type annotation in the schema. * The default implementation for each logical type specific visitor method is empty. @@ -1152,6 +1207,10 @@ default Optional visit(ListLogicalTypeAnnotation listLogicalType) { return empty(); } + default Optional visit(VariantLogicalTypeAnnotation variantLogicalType) { + return empty(); + } + default Optional visit(EnumLogicalTypeAnnotation enumLogicalType) { return empty(); } diff --git a/parquet-column/src/main/java/org/apache/parquet/schema/MessageTypeParser.java b/parquet-column/src/main/java/org/apache/parquet/schema/MessageTypeParser.java index 2e6cb20963..c7f4b900b8 100644 --- a/parquet-column/src/main/java/org/apache/parquet/schema/MessageTypeParser.java +++ b/parquet-column/src/main/java/org/apache/parquet/schema/MessageTypeParser.java @@ -118,12 +118,36 @@ private static void addGroupType(Tokenizer st, Repetition r, GroupBuilder bui String name = st.nextToken(); // Read annotation, if any. + String annotation = null; t = st.nextToken(); - OriginalType originalType = null; if (t.equalsIgnoreCase("(")) { - originalType = OriginalType.valueOf(st.nextToken()); - childBuilder.as(originalType); - check(st.nextToken(), ")", "original type ended by )", st); + t = st.nextToken(); + if (isLogicalType(t)) { + LogicalTypeAnnotation.LogicalTypeToken logicalType = LogicalTypeAnnotation.LogicalTypeToken.valueOf(t); + t = st.nextToken(); + List tokens = new ArrayList<>(); + if ("(".equals(t)) { + while (!")".equals(t)) { + if (!(",".equals(t) || "(".equals(t) || ")".equals(t))) { + tokens.add(t); + } + t = st.nextToken(); + } + t = st.nextToken(); + } + + LogicalTypeAnnotation logicalTypeAnnotation = logicalType.fromString(tokens); + childBuilder.as(logicalTypeAnnotation); + annotation = logicalTypeAnnotation.toString(); + } else { + // Try to parse as OriginalType + OriginalType originalType = OriginalType.valueOf(t); + childBuilder.as(originalType); + annotation = originalType.toString(); + t = st.nextToken(); + } + + check(t, ")", "logical type ended by )", st); t = st.nextToken(); } if (t.equals("=")) { @@ -134,7 +158,7 @@ private static void addGroupType(Tokenizer st, Repetition r, GroupBuilder bui addGroupTypeFields(t, st, childBuilder); } catch (IllegalArgumentException e) { throw new IllegalArgumentException( - "problem reading type: type = group, name = " + name + ", original type = " + originalType, e); + "problem reading type: type = group, name = " + name + ", annotation = " + annotation, e); } childBuilder.named(name); diff --git a/parquet-column/src/test/java/org/apache/parquet/parser/TestParquetParser.java b/parquet-column/src/test/java/org/apache/parquet/parser/TestParquetParser.java index 04b4a9432a..5172e788b5 100644 --- a/parquet-column/src/test/java/org/apache/parquet/parser/TestParquetParser.java +++ b/parquet-column/src/test/java/org/apache/parquet/parser/TestParquetParser.java @@ -55,6 +55,7 @@ import static org.junit.Assert.assertEquals; import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.MessageTypeParser; import org.apache.parquet.schema.OriginalType; @@ -447,4 +448,30 @@ public void testEmbeddedAnnotations() { MessageType reparsed = MessageTypeParser.parseMessageType(parsed.toString()); assertEquals(expected, reparsed); } + + @Test + public void testVARIANTAnnotation() { + String message = "message Message {\n" + + " required group aVariant (VARIANT(2)) {\n" + + " required binary metadata;\n" + + " required binary value;\n" + + " }\n" + + "}\n"; + + MessageType expected = buildMessage() + .requiredGroup() + .as(LogicalTypeAnnotation.variantType((byte) 2)) + .required(BINARY) + .named("metadata") + .required(BINARY) + .named("value") + .named("aVariant") + .named("Message"); + + MessageType parsed = parseMessageType(message); + + assertEquals(expected, parsed); + MessageType reparsed = parseMessageType(parsed.toString()); + assertEquals(expected, reparsed); + } } diff --git a/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuilders.java b/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuilders.java index 579077897f..71886d1208 100644 --- a/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuilders.java +++ b/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuilders.java @@ -50,6 +50,7 @@ import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; import static org.apache.parquet.schema.Type.Repetition.REQUIRED; +import static org.junit.Assert.assertEquals; import java.util.ArrayList; import java.util.List; @@ -1414,6 +1415,52 @@ public void testTimestampLogicalTypeWithUTCParameter() { Assert.assertEquals(nonUtcMicrosExpected, nonUtcMicrosActual); } + @Test + public void testVariantLogicalType() { + byte specVersion = 1; + String name = "variant_field"; + GroupType variantExpected = new GroupType( + REQUIRED, + name, + LogicalTypeAnnotation.variantType(specVersion), + new PrimitiveType(REQUIRED, BINARY, "metadata"), + new PrimitiveType(REQUIRED, BINARY, "value")); + + GroupType variantActual = Types.buildGroup(REQUIRED) + .addFields( + Types.required(BINARY).named("metadata"), + Types.required(BINARY).named("value")) + .as(LogicalTypeAnnotation.variantType(specVersion)) + .named(name); + + assertEquals(variantExpected, variantActual); + } + + @Test + public void testVariantLogicalTypeWithShredded() { + byte specVersion = 1; + String name = "variant_field"; + GroupType variantExpected = new GroupType( + REQUIRED, + name, + LogicalTypeAnnotation.variantType(specVersion), + new PrimitiveType(REQUIRED, BINARY, "metadata"), + new PrimitiveType(OPTIONAL, BINARY, "value"), + new PrimitiveType(OPTIONAL, BINARY, "typed_value", LogicalTypeAnnotation.stringType())); + + GroupType variantActual = Types.buildGroup(REQUIRED) + .addFields( + Types.required(BINARY).named("metadata"), + Types.optional(BINARY).named("value"), + Types.optional(BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("typed_value")) + .as(LogicalTypeAnnotation.variantType(specVersion)) + .named(name); + + assertEquals(variantExpected, variantActual); + } + @Test(expected = IllegalArgumentException.class) public void testDecimalLogicalTypeWithDeprecatedScaleMismatch() { Types.required(BINARY) diff --git a/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuildersWithLogicalTypes.java b/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuildersWithLogicalTypes.java index 54853e8138..61fe3065e1 100644 --- a/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuildersWithLogicalTypes.java +++ b/parquet-column/src/test/java/org/apache/parquet/schema/TestTypeBuildersWithLogicalTypes.java @@ -41,6 +41,8 @@ import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT96; import static org.apache.parquet.schema.Type.Repetition.REQUIRED; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import java.util.concurrent.Callable; import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; @@ -473,6 +475,59 @@ public void testFloat16LogicalType() { .toString()); } + @Test + public void testVariantLogicalType() { + byte specVersion = 1; + String name = "variant_field"; + GroupType variant = new GroupType( + REQUIRED, + name, + LogicalTypeAnnotation.variantType(specVersion), + Types.required(BINARY).named("metadata"), + Types.required(BINARY).named("value")); + + assertEquals( + "required group variant_field (VARIANT(1)) {\n" + + " required binary metadata;\n" + + " required binary value;\n" + + "}", + variant.toString()); + + LogicalTypeAnnotation annotation = variant.getLogicalTypeAnnotation(); + assertEquals(LogicalTypeAnnotation.LogicalTypeToken.VARIANT, annotation.getType()); + assertNull(annotation.toOriginalType()); + assertTrue(annotation instanceof LogicalTypeAnnotation.VariantLogicalTypeAnnotation); + assertEquals(specVersion, ((LogicalTypeAnnotation.VariantLogicalTypeAnnotation) annotation).getSpecVersion()); + } + + @Test + public void testVariantLogicalTypeWithShredded() { + byte specVersion = 1; + + String name = "variant_field"; + GroupType variant = new GroupType( + REQUIRED, + name, + LogicalTypeAnnotation.variantType(specVersion), + Types.required(BINARY).named("metadata"), + Types.optional(BINARY).named("value"), + Types.optional(BINARY).as(LogicalTypeAnnotation.stringType()).named("typed_value")); + + assertEquals( + "required group variant_field (VARIANT(1)) {\n" + + " required binary metadata;\n" + + " optional binary value;\n" + + " optional binary typed_value (STRING);\n" + + "}", + variant.toString()); + + LogicalTypeAnnotation annotation = variant.getLogicalTypeAnnotation(); + assertEquals(LogicalTypeAnnotation.LogicalTypeToken.VARIANT, annotation.getType()); + assertNull(annotation.toOriginalType()); + assertTrue(annotation instanceof LogicalTypeAnnotation.VariantLogicalTypeAnnotation); + assertEquals(specVersion, ((LogicalTypeAnnotation.VariantLogicalTypeAnnotation) annotation).getSpecVersion()); + } + /** * A convenience method to avoid a large number of @Test(expected=...) tests * diff --git a/parquet-format-structures/src/main/java/org/apache/parquet/format/LogicalTypes.java b/parquet-format-structures/src/main/java/org/apache/parquet/format/LogicalTypes.java index b2d70c9247..8956d3944e 100644 --- a/parquet-format-structures/src/main/java/org/apache/parquet/format/LogicalTypes.java +++ b/parquet-format-structures/src/main/java/org/apache/parquet/format/LogicalTypes.java @@ -32,6 +32,12 @@ public static LogicalType DECIMAL(int scale, int precision) { return LogicalType.DECIMAL(new DecimalType(scale, precision)); } + public static LogicalType VARIANT(byte specificationVersion) { + VariantType type = new VariantType(); + type.setSpecification_version(specificationVersion); + return LogicalType.VARIANT(type); + } + public static final LogicalType UTF8 = LogicalType.STRING(new StringType()); public static final LogicalType MAP = LogicalType.MAP(new MapType()); public static final LogicalType LIST = LogicalType.LIST(new ListType()); @@ -53,4 +59,5 @@ public static LogicalType DECIMAL(int scale, int precision) { public static final LogicalType JSON = LogicalType.JSON(new JsonType()); public static final LogicalType BSON = LogicalType.BSON(new BsonType()); public static final LogicalType FLOAT16 = LogicalType.FLOAT16(new Float16Type()); + public static final LogicalType UUID = LogicalType.UUID(new UUIDType()); } diff --git a/parquet-hadoop/pom.xml b/parquet-hadoop/pom.xml index d4aa4b42a7..adfebfbd05 100644 --- a/parquet-hadoop/pom.xml +++ b/parquet-hadoop/pom.xml @@ -230,6 +230,14 @@ !aarch64 + + + jitpack.io + https://jitpack.io + Jitpack.io repository + + + com.github.rdblue diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java b/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java index 87797d1fa5..5759be234f 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java @@ -65,7 +65,6 @@ import org.apache.parquet.format.BloomFilterHash; import org.apache.parquet.format.BloomFilterHeader; import org.apache.parquet.format.BoundaryOrder; -import org.apache.parquet.format.BsonType; import org.apache.parquet.format.ColumnChunk; import org.apache.parquet.format.ColumnCryptoMetaData; import org.apache.parquet.format.ColumnIndex; @@ -75,25 +74,19 @@ import org.apache.parquet.format.ConvertedType; import org.apache.parquet.format.DataPageHeader; import org.apache.parquet.format.DataPageHeaderV2; -import org.apache.parquet.format.DateType; import org.apache.parquet.format.DecimalType; import org.apache.parquet.format.DictionaryPageHeader; import org.apache.parquet.format.Encoding; import org.apache.parquet.format.EncryptionWithColumnKey; -import org.apache.parquet.format.EnumType; import org.apache.parquet.format.FieldRepetitionType; import org.apache.parquet.format.FileMetaData; -import org.apache.parquet.format.Float16Type; import org.apache.parquet.format.IntType; -import org.apache.parquet.format.JsonType; import org.apache.parquet.format.KeyValue; -import org.apache.parquet.format.ListType; import org.apache.parquet.format.LogicalType; -import org.apache.parquet.format.MapType; +import org.apache.parquet.format.LogicalTypes; import org.apache.parquet.format.MicroSeconds; import org.apache.parquet.format.MilliSeconds; import org.apache.parquet.format.NanoSeconds; -import org.apache.parquet.format.NullType; import org.apache.parquet.format.OffsetIndex; import org.apache.parquet.format.PageEncodingStats; import org.apache.parquet.format.PageHeader; @@ -104,14 +97,13 @@ import org.apache.parquet.format.SizeStatistics; import org.apache.parquet.format.SplitBlockAlgorithm; import org.apache.parquet.format.Statistics; -import org.apache.parquet.format.StringType; import org.apache.parquet.format.TimeType; import org.apache.parquet.format.TimeUnit; import org.apache.parquet.format.TimestampType; import org.apache.parquet.format.Type; import org.apache.parquet.format.TypeDefinedOrder; -import org.apache.parquet.format.UUIDType; import org.apache.parquet.format.Uncompressed; +import org.apache.parquet.format.VariantType; import org.apache.parquet.format.XxHash; import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; @@ -449,33 +441,32 @@ private static class LogicalTypeConverterVisitor implements LogicalTypeAnnotation.LogicalTypeAnnotationVisitor { @Override public Optional visit(LogicalTypeAnnotation.StringLogicalTypeAnnotation stringLogicalType) { - return of(LogicalType.STRING(new StringType())); + return of(LogicalTypes.UTF8); } @Override public Optional visit(LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) { - return of(LogicalType.MAP(new MapType())); + return of(LogicalTypes.MAP); } @Override public Optional visit(LogicalTypeAnnotation.ListLogicalTypeAnnotation listLogicalType) { - return of(LogicalType.LIST(new ListType())); + return of(LogicalTypes.LIST); } @Override public Optional visit(LogicalTypeAnnotation.EnumLogicalTypeAnnotation enumLogicalType) { - return of(LogicalType.ENUM(new EnumType())); + return of(LogicalTypes.ENUM); } @Override public Optional visit(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalLogicalType) { - return of(LogicalType.DECIMAL( - new DecimalType(decimalLogicalType.getScale(), decimalLogicalType.getPrecision()))); + return of(LogicalTypes.DECIMAL(decimalLogicalType.getScale(), decimalLogicalType.getPrecision())); } @Override public Optional visit(LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) { - return of(LogicalType.DATE(new DateType())); + return of(LogicalTypes.DATE); } @Override @@ -497,32 +488,37 @@ public Optional visit(LogicalTypeAnnotation.IntLogicalTypeAnnotatio @Override public Optional visit(LogicalTypeAnnotation.JsonLogicalTypeAnnotation jsonLogicalType) { - return of(LogicalType.JSON(new JsonType())); + return of(LogicalTypes.JSON); } @Override public Optional visit(LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) { - return of(LogicalType.BSON(new BsonType())); + return of(LogicalTypes.BSON); } @Override public Optional visit(UUIDLogicalTypeAnnotation uuidLogicalType) { - return of(LogicalType.UUID(new UUIDType())); + return of(LogicalTypes.UUID); } @Override public Optional visit(LogicalTypeAnnotation.Float16LogicalTypeAnnotation float16LogicalType) { - return of(LogicalType.FLOAT16(new Float16Type())); + return of(LogicalTypes.FLOAT16); } @Override - public Optional visit(LogicalTypeAnnotation.UnknownLogicalTypeAnnotation intervalLogicalType) { - return of(LogicalType.UNKNOWN(new NullType())); + public Optional visit(LogicalTypeAnnotation.UnknownLogicalTypeAnnotation unknownLogicalType) { + return of(LogicalTypes.UNKNOWN); } @Override public Optional visit(LogicalTypeAnnotation.IntervalLogicalTypeAnnotation intervalLogicalType) { - return of(LogicalType.UNKNOWN(new NullType())); + return of(LogicalTypes.UNKNOWN); + } + + @Override + public Optional visit(LogicalTypeAnnotation.VariantLogicalTypeAnnotation variantLogicalType) { + return of(LogicalTypes.VARIANT(variantLogicalType.getSpecVersion())); } } @@ -1187,6 +1183,9 @@ LogicalTypeAnnotation getLogicalTypeAnnotation(LogicalType type) { return LogicalTypeAnnotation.uuidType(); case FLOAT16: return LogicalTypeAnnotation.float16Type(); + case VARIANT: + VariantType variant = type.getVARIANT(); + return LogicalTypeAnnotation.variantType(variant.getSpecification_version()); default: throw new RuntimeException("Unknown logical type " + type); } diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/format/converter/TestParquetMetadataConverter.java b/parquet-hadoop/src/test/java/org/apache/parquet/format/converter/TestParquetMetadataConverter.java index 6b3259070e..322d4c4abc 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/format/converter/TestParquetMetadataConverter.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/format/converter/TestParquetMetadataConverter.java @@ -41,6 +41,7 @@ import static org.apache.parquet.schema.LogicalTypeAnnotation.timeType; import static org.apache.parquet.schema.LogicalTypeAnnotation.timestampType; import static org.apache.parquet.schema.LogicalTypeAnnotation.uuidType; +import static org.apache.parquet.schema.LogicalTypeAnnotation.variantType; import static org.apache.parquet.schema.MessageTypeParser.parseMessageType; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -1589,6 +1590,28 @@ public void testMapConvertedTypeReadWrite() throws Exception { verifyMapMessageType(messageType, "map"); } + @Test + public void testVariantLogicalType() { + byte specVersion = 1; + MessageType expected = Types.buildMessage() + .requiredGroup() + .as(variantType(specVersion)) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveTypeName.BINARY) + .named("value") + .named("v") + .named("example"); + + ParquetMetadataConverter parquetMetadataConverter = new ParquetMetadataConverter(); + List parquetSchema = parquetMetadataConverter.toParquetSchema(expected); + MessageType schema = parquetMetadataConverter.fromParquetSchema(parquetSchema, null); + assertEquals(expected, schema); + LogicalTypeAnnotation logicalType = schema.getType("v").getLogicalTypeAnnotation(); + assertEquals(LogicalTypeAnnotation.variantType(specVersion), logicalType); + assertEquals(specVersion, ((LogicalTypeAnnotation.VariantLogicalTypeAnnotation) logicalType).getSpecVersion()); + } + private void verifyMapMessageType(final MessageType messageType, final String keyValueName) throws IOException { Path file = new Path(temporaryFolder.newFolder("verifyMapMessageType").getPath(), keyValueName + ".parquet"); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java index c8e8f71a91..739aa85d2c 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriter.java @@ -543,7 +543,7 @@ private void testParquetFileNumberOfBlocks( } @Test - public void testSizeStatisticsControl() throws Exception { + public void testSizeStatisticsAndStatisticsControl() throws Exception { MessageType schema = Types.buildMessage() .required(BINARY) .named("string_field") @@ -568,6 +568,7 @@ public void testSizeStatisticsControl() throws Exception { try (ParquetWriter writer = ExampleParquetWriter.builder(path) .withType(schema) .withSizeStatisticsEnabled(false) + .withStatisticsEnabled(false) // Disable column statistics globally .build()) { writer.write(group); } @@ -576,6 +577,7 @@ public void testSizeStatisticsControl() throws Exception { // Verify size statistics are disabled globally for (BlockMetaData block : reader.getFooter().getBlocks()) { for (ColumnChunkMetaData column : block.getColumns()) { + assertTrue(column.getStatistics().isEmpty()); // Make sure there is no column statistics assertNull(column.getSizeStatistics()); } } @@ -589,6 +591,7 @@ public void testSizeStatisticsControl() throws Exception { .withType(schema) .withSizeStatisticsEnabled(true) // enable globally .withSizeStatisticsEnabled("boolean_field", false) // disable for specific column + .withStatisticsEnabled("boolean_field", false) // disable column statistics .build()) { writer.write(group); } @@ -599,8 +602,10 @@ public void testSizeStatisticsControl() throws Exception { for (ColumnChunkMetaData column : block.getColumns()) { if (column.getPath().toDotString().equals("boolean_field")) { assertNull(column.getSizeStatistics()); + assertTrue(column.getStatistics().isEmpty()); } else { assertTrue(column.getSizeStatistics().isValid()); + assertFalse(column.getStatistics().isEmpty()); } } } diff --git a/parquet-protobuf/pom.xml b/parquet-protobuf/pom.xml index 155d19feb3..c58de2fe24 100644 --- a/parquet-protobuf/pom.xml +++ b/parquet-protobuf/pom.xml @@ -32,7 +32,7 @@ 4.4 3.25.6 - 2.51.0 + 2.54.1 1.4.4 diff --git a/parquet-variant/pom.xml b/parquet-variant/pom.xml new file mode 100644 index 0000000000..9d7d2a7ce9 --- /dev/null +++ b/parquet-variant/pom.xml @@ -0,0 +1,82 @@ + + + + org.apache.parquet + parquet + ../pom.xml + 1.16.0-SNAPSHOT + + + 4.0.0 + + parquet-variant + jar + + Apache Parquet Variant + https://parquet.apache.org + + + + + + + org.apache.parquet + parquet-jackson + ${project.version} + runtime + + + org.apache.parquet + parquet-column + ${project.version} + + + ${jackson.groupId} + jackson-core + ${jackson.version} + + + ${jackson.groupId} + jackson-databind + ${jackson-databind.version} + test + + + org.slf4j + slf4j-api + ${slf4j.version} + test + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + org.apache.maven.plugins + maven-shade-plugin + + + + + diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/MalformedVariantException.java b/parquet-variant/src/main/java/org/apache/parquet/variant/MalformedVariantException.java new file mode 100644 index 0000000000..3ecc707a11 --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/MalformedVariantException.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.variant; + +/** + * An exception indicating that the Variant is malformed. + */ +public class MalformedVariantException extends RuntimeException { + public MalformedVariantException(String message) { + super(message); + } +} diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/UnknownVariantTypeException.java b/parquet-variant/src/main/java/org/apache/parquet/variant/UnknownVariantTypeException.java new file mode 100644 index 0000000000..3cdacb5d99 --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/UnknownVariantTypeException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.variant; + +/** + * An exception indicating that the Variant contains an unknown type. + */ +public class UnknownVariantTypeException extends RuntimeException { + public final int typeId; + + /** + * @param typeId the type id that was unknown + */ + public UnknownVariantTypeException(int typeId) { + super("Unknown type in Variant. id: " + typeId); + this.typeId = typeId; + } + + /** + * @return the type id that was unknown + */ + public int typeId() { + return typeId; + } +} diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/Variant.java b/parquet-variant/src/main/java/org/apache/parquet/variant/Variant.java new file mode 100644 index 0000000000..38e505afb3 --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/Variant.java @@ -0,0 +1,587 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.variant; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import java.io.CharArrayWriter; +import java.io.IOException; +import java.math.BigDecimal; +import java.time.*; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; +import java.time.temporal.ChronoField; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Base64; +import java.util.Locale; +import java.util.UUID; + +/** + * This Variant class holds the Variant-encoded value and metadata binary values. + */ +public final class Variant { + final byte[] value; + final byte[] metadata; + /** + * The starting index into `value` where the variant value starts. This is used to avoid copying + * the value binary when reading a sub-variant in the array/object element. + */ + final int pos; + + /** + * The threshold to switch from linear search to binary search when looking up a field by key in + * an object. This is a performance optimization to avoid the overhead of binary search for a + * short list. + */ + static final int BINARY_SEARCH_THRESHOLD = 32; + + static final ZoneId UTC = ZoneId.of("UTC"); + + public Variant(byte[] value, byte[] metadata) { + this(value, metadata, 0); + } + + Variant(byte[] value, byte[] metadata, int pos) { + this.value = value; + this.metadata = metadata; + this.pos = pos; + // There is currently only one allowed version. + if (metadata.length < 1 || (metadata[0] & VariantUtil.VERSION_MASK) != VariantUtil.VERSION) { + throw new MalformedVariantException(String.format( + "Unsupported variant metadata version: %02X", metadata[0] & VariantUtil.VERSION_MASK)); + } + } + + public byte[] getValue() { + if (pos == 0) { + // Position 0 means the entire value is used. Return the original value. + return value; + } + int size = VariantUtil.valueSize(value, pos); + VariantUtil.checkIndex(pos + size - 1, value.length); + return Arrays.copyOfRange(value, pos, pos + size); + } + + public byte[] getMetadata() { + return metadata; + } + + /** + * @return the boolean value + */ + public boolean getBoolean() { + return VariantUtil.getBoolean(value, pos); + } + + /** + * @return the byte value + */ + public byte getByte() { + long longValue = VariantUtil.getLong(value, pos); + if (longValue < Byte.MIN_VALUE || longValue > Byte.MAX_VALUE) { + throw new IllegalStateException("Value out of range for byte: " + longValue); + } + return (byte) longValue; + } + + /** + * @return the short value + */ + public short getShort() { + long longValue = VariantUtil.getLong(value, pos); + if (longValue < Short.MIN_VALUE || longValue > Short.MAX_VALUE) { + throw new IllegalStateException("Value out of range for short: " + longValue); + } + return (short) longValue; + } + + /** + * @return the int value + */ + public int getInt() { + long longValue = VariantUtil.getLong(value, pos); + if (longValue < Integer.MIN_VALUE || longValue > Integer.MAX_VALUE) { + throw new IllegalStateException("Value out of range for int: " + longValue); + } + return (int) longValue; + } + + /** + * @return the long value + */ + public long getLong() { + return VariantUtil.getLong(value, pos); + } + + /** + * @return the double value + */ + public double getDouble() { + return VariantUtil.getDouble(value, pos); + } + + /** + * @return the decimal value + */ + public BigDecimal getDecimal() { + return VariantUtil.getDecimal(value, pos); + } + + /** + * @return the float value + */ + public float getFloat() { + return VariantUtil.getFloat(value, pos); + } + + /** + * @return the binary value + */ + public byte[] getBinary() { + return VariantUtil.getBinary(value, pos); + } + + /** + * @return the UUID value + */ + public UUID getUUID() { + return VariantUtil.getUUID(value, pos); + } + + /** + * @return the string value + */ + public String getString() { + return VariantUtil.getString(value, pos); + } + + /** + * @return the primitive type id from a variant value + */ + public int getPrimitiveTypeId() { + return VariantUtil.getPrimitiveTypeId(value, pos); + } + + /** + * @return the type of the variant value + */ + public VariantUtil.Type getType() { + return VariantUtil.getType(value, pos); + } + + /** + * @return the number of object fields in the variant. `getType()` must be `Type.OBJECT`. + */ + public int numObjectElements() { + return VariantUtil.handleObject(value, pos, (info) -> info.numElements); + } + + /** + * Returns the object field Variant value whose key is equal to `key`. + * Return null if the key is not found. `getType()` must be `Type.OBJECT`. + * @param key the key to look up + * @return the field value whose key is equal to `key`, or null if key is not found + */ + public Variant getFieldByKey(String key) { + return VariantUtil.handleObject(value, pos, (info) -> { + // Use linear search for a short list. Switch to binary search when the length reaches + // `BINARY_SEARCH_THRESHOLD`. + if (info.numElements < BINARY_SEARCH_THRESHOLD) { + for (int i = 0; i < info.numElements; ++i) { + ObjectField field = getFieldAtIndex( + i, + value, + metadata, + info.idSize, + info.offsetSize, + info.idStart, + info.offsetStart, + info.dataStart); + if (field.key.equals(key)) { + return field.value; + } + } + } else { + int low = 0; + int high = info.numElements - 1; + while (low <= high) { + // Use unsigned right shift to compute the middle of `low` and `high`. This is not only a + // performance optimization, because it can properly handle the case where `low + high` + // overflows int. + int mid = (low + high) >>> 1; + ObjectField field = getFieldAtIndex( + mid, + value, + metadata, + info.idSize, + info.offsetSize, + info.idStart, + info.offsetStart, + info.dataStart); + int cmp = field.key.compareTo(key); + if (cmp < 0) { + low = mid + 1; + } else if (cmp > 0) { + high = mid - 1; + } else { + return field.value; + } + } + } + return null; + }); + } + + /** + * A field in a Variant object. + */ + public static final class ObjectField { + public final String key; + public final Variant value; + + public ObjectField(String key, Variant value) { + this.key = key; + this.value = value; + } + } + + /** + * Returns the ObjectField at the `index` slot. Return null if `index` is out of the bound of + * `[0, objectSize())`. `getType()` must be `Type.OBJECT`. + * @param index the index of the object field to get + * @return the ObjectField at the `index` slot, or null if `index` is out of bounds + */ + public ObjectField getFieldAtIndex(int index) { + return VariantUtil.handleObject(value, pos, (info) -> { + if (index < 0 || index >= info.numElements) { + return null; + } + return getFieldAtIndex( + index, + value, + metadata, + info.idSize, + info.offsetSize, + info.idStart, + info.offsetStart, + info.dataStart); + }); + } + + private static ObjectField getFieldAtIndex( + int index, + byte[] value, + byte[] metadata, + int idSize, + int offsetSize, + int idStart, + int offsetStart, + int dataStart) { + int id = VariantUtil.readUnsigned(value, idStart + idSize * index, idSize); + int offset = VariantUtil.readUnsigned(value, offsetStart + offsetSize * index, offsetSize); + String key = VariantUtil.getMetadataKey(metadata, id); + Variant v = new Variant(value, metadata, dataStart + offset); + return new ObjectField(key, v); + } + + /** + * @return the number of array elements. `getType()` must be `Type.ARRAY`. + */ + public int numArrayElements() { + return VariantUtil.handleArray(value, pos, (info) -> info.numElements); + } + + /** + * Returns the array element Variant value at the `index` slot. Returns null if `index` is + * out of the bound of `[0, arraySize())`. `getType()` must be `Type.ARRAY`. + * @param index the index of the array element to get + * @return the array element Variant at the `index` slot, or null if `index` is out of bounds + */ + public Variant getElementAtIndex(int index) { + return VariantUtil.handleArray(value, pos, (info) -> { + if (index < 0 || index >= info.numElements) { + return null; + } + return getElementAtIndex(index, value, metadata, info.offsetSize, info.offsetStart, info.dataStart); + }); + } + + private static Variant getElementAtIndex( + int index, byte[] value, byte[] metadata, int offsetSize, int offsetStart, int dataStart) { + int offset = VariantUtil.readUnsigned(value, offsetStart + offsetSize * index, offsetSize); + return new Variant(value, metadata, dataStart + offset); + } + + /** + * @return the JSON representation of the variant + * @throws MalformedVariantException if the variant is malformed + */ + public String toJson() { + return toJson(UTC, false); + } + + /** + * @param zoneId The ZoneId to use for formatting timestamps + * @return the JSON representation of the variant + * @throws MalformedVariantException if the variant is malformed + */ + public String toJson(ZoneId zoneId) { + return toJson(zoneId, false); + } + + /** + * @param zoneId The ZoneId to use for formatting timestamps + * @param truncateTrailingZeros Whether to truncate trailing zeros in decimal values or timestamps + * @return the JSON representation of the variant + * @throws MalformedVariantException if the variant is malformed + */ + public String toJson(ZoneId zoneId, boolean truncateTrailingZeros) { + try (CharArrayWriter writer = new CharArrayWriter(); + JsonGenerator gen = new JsonFactory().createGenerator(writer)) { + toJsonImpl(value, metadata, pos, gen, zoneId, truncateTrailingZeros); + gen.flush(); + return writer.toString(); + } catch (IOException e) { + // TODO Fix exception + // throw new RuntimeIOException("Failed to convert variant to json", e); + throw new IllegalArgumentException("Failed to convert variant to json", e); + } + } + + /** The format for a timestamp without time zone. */ + private static final DateTimeFormatter TIMESTAMP_NTZ_FORMATTER = new DateTimeFormatterBuilder() + .append(DateTimeFormatter.ISO_LOCAL_DATE) + .appendLiteral('T') + .appendPattern("HH:mm:ss") + .appendFraction(ChronoField.MICRO_OF_SECOND, 6, 6, true) + .toFormatter(Locale.US); + + /** The format for a timestamp without time zone, with nanosecond precision. */ + private static final DateTimeFormatter TIMESTAMP_NANOS_NTZ_FORMATTER = new DateTimeFormatterBuilder() + .append(DateTimeFormatter.ISO_LOCAL_DATE) + .appendLiteral('T') + .appendPattern("HH:mm:ss") + .appendFraction(ChronoField.NANO_OF_SECOND, 9, 9, true) + .toFormatter(Locale.US); + + /** The format for a timestamp with time zone. */ + private static final DateTimeFormatter TIMESTAMP_FORMATTER = new DateTimeFormatterBuilder() + .append(TIMESTAMP_NTZ_FORMATTER) + .appendOffset("+HH:MM", "+00:00") + .toFormatter(Locale.US); + + /** The format for a timestamp with time zone, with nanosecond precision. */ + private static final DateTimeFormatter TIMESTAMP_NANOS_FORMATTER = new DateTimeFormatterBuilder() + .append(TIMESTAMP_NANOS_NTZ_FORMATTER) + .appendOffset("+HH:MM", "+00:00") + .toFormatter(Locale.US); + + /** The format for a time. */ + private static final DateTimeFormatter TIME_FORMATTER = new DateTimeFormatterBuilder() + .appendPattern("HH:mm:ss") + .appendFraction(ChronoField.MICRO_OF_SECOND, 6, 6, true) + .toFormatter(Locale.US); + + /** The format for a timestamp without time zone, truncating trailing microsecond zeros. */ + private static final DateTimeFormatter TIMESTAMP_NTZ_TRUNC_FORMATTER = new DateTimeFormatterBuilder() + .append(DateTimeFormatter.ISO_LOCAL_DATE) + .appendLiteral('T') + .appendPattern("HH:mm:ss") + .optionalStart() + .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) + .optionalEnd() + .toFormatter(Locale.US); + + /** + * The format for a timestamp without time zone, with nanosecond precision, truncating + * trailing nanosecond zeros. + */ + private static final DateTimeFormatter TIMESTAMP_NANOS_NTZ_TRUNC_FORMATTER = new DateTimeFormatterBuilder() + .append(DateTimeFormatter.ISO_LOCAL_DATE) + .appendLiteral('T') + .appendPattern("HH:mm:ss") + .optionalStart() + .appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true) + .optionalEnd() + .toFormatter(Locale.US); + + /** The format for a timestamp with time zone, truncating trailing microsecond zeros. */ + private static final DateTimeFormatter TIMESTAMP_TRUNC_FORMATTER = new DateTimeFormatterBuilder() + .append(TIMESTAMP_NTZ_TRUNC_FORMATTER) + .appendOffset("+HH:MM", "+00:00") + .toFormatter(Locale.US); + + /** + * The format for a timestamp with time zone, with nanosecond precision, truncating trailing + * nanosecond zeros. + */ + private static final DateTimeFormatter TIMESTAMP_NANOS_TRUNC_FORMATTER = new DateTimeFormatterBuilder() + .append(TIMESTAMP_NANOS_NTZ_TRUNC_FORMATTER) + .appendOffset("+HH:MM", "+00:00") + .toFormatter(Locale.US); + + /** The format for a time, truncating trailing microsecond zeros. */ + private static final DateTimeFormatter TIME_TRUNC_FORMATTER = new DateTimeFormatterBuilder() + .appendPattern("HH:mm:ss") + .optionalStart() + .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) + .optionalEnd() + .toFormatter(Locale.US); + + private static Instant microsToInstant(long microsSinceEpoch) { + return Instant.EPOCH.plus(microsSinceEpoch, ChronoUnit.MICROS); + } + + private static Instant nanosToInstant(long timestampNanos) { + return Instant.EPOCH.plus(timestampNanos, ChronoUnit.NANOS); + } + + private static void toJsonImpl( + byte[] value, byte[] metadata, int pos, JsonGenerator gen, ZoneId zoneId, boolean truncateTrailingZeros) + throws IOException { + switch (VariantUtil.getType(value, pos)) { + case OBJECT: + VariantUtil.handleObjectException(value, pos, (info) -> { + gen.writeStartObject(); + for (int i = 0; i < info.numElements; ++i) { + ObjectField field = getFieldAtIndex( + i, + value, + metadata, + info.idSize, + info.offsetSize, + info.idStart, + info.offsetStart, + info.dataStart); + gen.writeFieldName(field.key); + toJsonImpl( + field.value.value, + field.value.metadata, + field.value.pos, + gen, + zoneId, + truncateTrailingZeros); + } + gen.writeEndObject(); + return null; + }); + break; + case ARRAY: + VariantUtil.handleArrayException(value, pos, (info) -> { + gen.writeStartArray(); + for (int i = 0; i < info.numElements; ++i) { + Variant v = getElementAtIndex( + i, value, metadata, info.offsetSize, info.offsetStart, info.dataStart); + toJsonImpl(v.value, v.metadata, v.pos, gen, zoneId, truncateTrailingZeros); + } + gen.writeEndArray(); + return null; + }); + break; + case NULL: + gen.writeNull(); + break; + case BOOLEAN: + gen.writeBoolean(VariantUtil.getBoolean(value, pos)); + break; + case BYTE: + case SHORT: + case INT: + case LONG: + gen.writeNumber(VariantUtil.getLong(value, pos)); + break; + case STRING: + gen.writeString(VariantUtil.getString(value, pos)); + break; + case DOUBLE: + gen.writeNumber(VariantUtil.getDouble(value, pos)); + break; + case DECIMAL: + if (truncateTrailingZeros) { + gen.writeNumber(VariantUtil.getDecimal(value, pos) + .stripTrailingZeros() + .toPlainString()); + } else { + gen.writeNumber(VariantUtil.getDecimal(value, pos).toPlainString()); + } + break; + case DATE: + gen.writeString(LocalDate.ofEpochDay((int) VariantUtil.getLong(value, pos)) + .toString()); + break; + case TIMESTAMP: + if (truncateTrailingZeros) { + gen.writeString(TIMESTAMP_TRUNC_FORMATTER.format( + microsToInstant(VariantUtil.getLong(value, pos)).atZone(zoneId))); + } else { + gen.writeString(TIMESTAMP_FORMATTER.format( + microsToInstant(VariantUtil.getLong(value, pos)).atZone(zoneId))); + } + break; + case TIMESTAMP_NTZ: + if (truncateTrailingZeros) { + gen.writeString(TIMESTAMP_NTZ_TRUNC_FORMATTER.format( + microsToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC))); + } else { + gen.writeString(TIMESTAMP_NTZ_FORMATTER.format( + microsToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC))); + } + break; + case FLOAT: + gen.writeNumber(VariantUtil.getFloat(value, pos)); + break; + case BINARY: + gen.writeString(Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos))); + break; + case TIME: + if (truncateTrailingZeros) { + gen.writeString(TIME_TRUNC_FORMATTER.format( + LocalTime.ofNanoOfDay(VariantUtil.getLong(value, pos) * 1_000))); + } else { + gen.writeString( + TIME_FORMATTER.format(LocalTime.ofNanoOfDay(VariantUtil.getLong(value, pos) * 1_000))); + } + break; + case TIMESTAMP_NANOS: + if (truncateTrailingZeros) { + gen.writeString(TIMESTAMP_NANOS_TRUNC_FORMATTER.format( + nanosToInstant(VariantUtil.getLong(value, pos)).atZone(zoneId))); + } else { + gen.writeString(TIMESTAMP_NANOS_FORMATTER.format( + nanosToInstant(VariantUtil.getLong(value, pos)).atZone(zoneId))); + } + break; + case TIMESTAMP_NANOS_NTZ: + if (truncateTrailingZeros) { + gen.writeString(TIMESTAMP_NANOS_NTZ_TRUNC_FORMATTER.format( + nanosToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC))); + } else { + gen.writeString(TIMESTAMP_NANOS_NTZ_FORMATTER.format( + nanosToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC))); + } + break; + case UUID: + gen.writeString(VariantUtil.getUUID(value, pos).toString()); + break; + default: + throw new IllegalArgumentException("Unsupported type: " + VariantUtil.getType(value, pos)); + } + } +} diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java new file mode 100644 index 0000000000..bd959aabe1 --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java @@ -0,0 +1,717 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.variant; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.core.exc.InputCoercionException; +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; + +/** + * Builder for creating Variant value and metadata. + */ +public class VariantBuilder { + /** + * Creates a VariantBuilder. + * @param allowDuplicateKeys if true, only the last occurrence of a duplicate key will be kept. + * Otherwise, an exception will be thrown. + */ + public VariantBuilder(boolean allowDuplicateKeys) { + this(allowDuplicateKeys, VariantUtil.DEFAULT_SIZE_LIMIT); + } + + /** + * Creates a VariantBuilder. + * @param allowDuplicateKeys if true, only the last occurrence of a duplicate key will be kept. + * Otherwise, an exception will be thrown. + * @param sizeLimitBytes the maximum size (in bytes) of the resulting Variant value or metadata + */ + public VariantBuilder(boolean allowDuplicateKeys, int sizeLimitBytes) { + fixedMetadata = false; + this.dictionary = new HashMap<>(); + this.dictionaryKeys = new ArrayList<>(); + this.allowDuplicateKeys = allowDuplicateKeys; + this.sizeLimitBytes = sizeLimitBytes; + } + + /** + * Set the metadata. May only be called if the builder has not yet added anything to the metadata. + * @param metadata + */ + public void setFixedMetadata(HashMap metadata) { + if (!this.dictionaryKeys.isEmpty()) { + throw new IllegalStateException("Cannot fix metadata once values have been added to it"); + } + this.dictionary = metadata; + this.fixedMetadata = true; + // We don't need the dictionaryKeys list when metadata is fixed, and setting to null ensures that we'll + // fail if we accidentally try to use it. However, uses should be guarded by a cleaner exception. + this.dictionaryKeys = null; + } + + /** + * Parse a JSON string as a Variant value. + * @param json the JSON string to parse + * @return the Variant value + * @throws IOException if any JSON parsing error happens + * @throws VariantSizeLimitException if the resulting variant value or metadata would exceed + * the size limit + */ + public static Variant parseJson(String json) throws IOException { + return parseJson(json, new VariantBuilder(false)); + } + + /** + * Parse a JSON string as a Variant value. + * @param json the JSON string to parse + * @param builder the VariantBuilder to use for building the Variant + * @return the Variant value + * @throws IOException if any JSON parsing error happens + * @throws VariantSizeLimitException if the resulting variant value or metadata would exceed + * the size limit + */ + public static Variant parseJson(String json, VariantBuilder builder) throws IOException { + try (JsonParser parser = new JsonFactory().createParser(json)) { + parser.nextToken(); + return parseJson(parser, builder); + } + } + + /** + * Parse a JSON parser as a Variant value. + * @param parser the JSON parser to use + * @param builder the VariantBuilder to use for building the Variant + * @return the Variant value + * @throws IOException if any JSON parsing error happens + * @throws VariantSizeLimitException if the resulting variant value or metadata would exceed + * the size limit + */ + public static Variant parseJson(JsonParser parser, VariantBuilder builder) throws IOException { + builder.buildFromJsonParser(parser); + return builder.result(); + } + + /** + * @return the Variant value + * @throws VariantSizeLimitException if the resulting variant value or metadata would exceed + * the size limit + */ + public Variant result() { + if (fixedMetadata) { + throw new IllegalStateException("Cannot reconstruct metadata when using fixed metadata"); + } + int numKeys = dictionaryKeys.size(); + // Use long to avoid overflow in accumulating lengths. + long dictionaryStringSize = 0; + for (byte[] key : dictionaryKeys) { + dictionaryStringSize += key.length; + } + // Determine the number of bytes required per offset entry. + // The largest offset is the one-past-the-end value, which is total string size. It's very + // unlikely that the number of keys could be larger, but incorporate that into the calculation + // in case of pathological data. + long maxSize = Math.max(dictionaryStringSize, numKeys); + if (maxSize > sizeLimitBytes) { + throw new VariantSizeLimitException(sizeLimitBytes, maxSize); + } + int offsetSize = getMinIntegerSize((int) maxSize); + + int offsetStart = 1 + offsetSize; + int stringStart = offsetStart + (numKeys + 1) * offsetSize; + long metadataSize = stringStart + dictionaryStringSize; + + if (metadataSize > sizeLimitBytes) { + throw new VariantSizeLimitException(sizeLimitBytes, metadataSize); + } + byte[] metadata = new byte[(int) metadataSize]; + int headerByte = VariantUtil.VERSION | ((offsetSize - 1) << 6); + VariantUtil.writeLong(metadata, 0, headerByte, 1); + VariantUtil.writeLong(metadata, 1, numKeys, offsetSize); + int currentOffset = 0; + for (int i = 0; i < numKeys; ++i) { + VariantUtil.writeLong(metadata, offsetStart + i * offsetSize, currentOffset, offsetSize); + byte[] key = dictionaryKeys.get(i); + System.arraycopy(key, 0, metadata, stringStart + currentOffset, key.length); + currentOffset += key.length; + } + VariantUtil.writeLong(metadata, offsetStart + numKeys * offsetSize, currentOffset, offsetSize); + return new Variant(Arrays.copyOfRange(writeBuffer, 0, writePos), metadata); + } + + // Return the variant value only, without metadata. + // Used in shredding to produce a final value, where all shredded values refer to a common + // metadata. It should be called instead of `result()` when fixedMetadata is true, although it is valid to + // call it if fixedMetadata is false. + public byte[] valueWithoutMetadata() { + return Arrays.copyOfRange(writeBuffer, 0, writePos); + } + + public void appendString(String str) { + byte[] text = str.getBytes(StandardCharsets.UTF_8); + boolean longStr = text.length > VariantUtil.MAX_SHORT_STR_SIZE; + checkCapacity((longStr ? 1 + VariantUtil.U32_SIZE : 1) + text.length); + if (longStr) { + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.LONG_STR); + VariantUtil.writeLong(writeBuffer, writePos, text.length, VariantUtil.U32_SIZE); + writePos += VariantUtil.U32_SIZE; + } else { + writeBuffer[writePos++] = VariantUtil.shortStrHeader(text.length); + } + System.arraycopy(text, 0, writeBuffer, writePos, text.length); + writePos += text.length; + } + + public void appendNull() { + checkCapacity(1); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.NULL); + } + + public void appendBoolean(boolean b) { + checkCapacity(1); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(b ? VariantUtil.TRUE : VariantUtil.FALSE); + } + + /** + * Appends a long value to the variant builder. The actual encoded integer type depends on the + * value range of the long value. + * @param l the long value to append + */ + public void appendLong(long l) { + if (l == (byte) l) { + checkCapacity(1 + 1); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.INT8); + VariantUtil.writeLong(writeBuffer, writePos, l, 1); + writePos += 1; + } else if (l == (short) l) { + checkCapacity(1 + 2); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.INT16); + VariantUtil.writeLong(writeBuffer, writePos, l, 2); + writePos += 2; + } else if (l == (int) l) { + checkCapacity(1 + 4); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.INT32); + VariantUtil.writeLong(writeBuffer, writePos, l, 4); + writePos += 4; + } else { + checkCapacity(1 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.INT64); + VariantUtil.writeLong(writeBuffer, writePos, l, 8); + writePos += 8; + } + } + + public void appendDouble(double d) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.DOUBLE); + VariantUtil.writeLong(writeBuffer, writePos, Double.doubleToLongBits(d), 8); + writePos += 8; + } + + /** + * Appends a decimal value to the variant builder. The actual encoded decimal type depends on the + * precision and scale of the decimal value. + * @param d the decimal value to append + */ + public void appendDecimal(BigDecimal d) { + BigInteger unscaled = d.unscaledValue(); + if (d.scale() <= VariantUtil.MAX_DECIMAL4_PRECISION && d.precision() <= VariantUtil.MAX_DECIMAL4_PRECISION) { + checkCapacity(2 + 4); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.DECIMAL4); + writeBuffer[writePos++] = (byte) d.scale(); + VariantUtil.writeLong(writeBuffer, writePos, unscaled.intValueExact(), 4); + writePos += 4; + } else if (d.scale() <= VariantUtil.MAX_DECIMAL8_PRECISION + && d.precision() <= VariantUtil.MAX_DECIMAL8_PRECISION) { + checkCapacity(2 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.DECIMAL8); + writeBuffer[writePos++] = (byte) d.scale(); + VariantUtil.writeLong(writeBuffer, writePos, unscaled.longValueExact(), 8); + writePos += 8; + } else { + assert d.scale() <= VariantUtil.MAX_DECIMAL16_PRECISION + && d.precision() <= VariantUtil.MAX_DECIMAL16_PRECISION; + checkCapacity(2 + 16); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.DECIMAL16); + writeBuffer[writePos++] = (byte) d.scale(); + // `toByteArray` returns a big-endian representation. We need to copy it reversely and sign + // extend it to 16 bytes. + byte[] bytes = unscaled.toByteArray(); + for (int i = 0; i < bytes.length; ++i) { + writeBuffer[writePos + i] = bytes[bytes.length - 1 - i]; + } + byte sign = (byte) (bytes[0] < 0 ? -1 : 0); + for (int i = bytes.length; i < 16; ++i) { + writeBuffer[writePos + i] = sign; + } + writePos += 16; + } + } + + public void appendDate(int daysSinceEpoch) { + checkCapacity(1 + 4); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.DATE); + VariantUtil.writeLong(writeBuffer, writePos, daysSinceEpoch, 4); + writePos += 4; + } + + public void appendTimestamp(long microsSinceEpoch) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.TIMESTAMP); + VariantUtil.writeLong(writeBuffer, writePos, microsSinceEpoch, 8); + writePos += 8; + } + + public void appendTimestampNtz(long microsSinceEpoch) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.TIMESTAMP_NTZ); + VariantUtil.writeLong(writeBuffer, writePos, microsSinceEpoch, 8); + writePos += 8; + } + + public void appendTime(long microsSinceMidnight) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.TIME); + VariantUtil.writeLong(writeBuffer, writePos, microsSinceMidnight, 8); + writePos += 8; + } + + public void appendTimestampNanos(long nanosSinceEpoch) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.TIMESTAMP_NANOS); + VariantUtil.writeLong(writeBuffer, writePos, nanosSinceEpoch, 8); + writePos += 8; + } + + public void appendTimestampNanosNtz(long nanosSinceEpoch) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.TIMESTAMP_NANOS_NTZ); + VariantUtil.writeLong(writeBuffer, writePos, nanosSinceEpoch, 8); + writePos += 8; + } + + public void appendFloat(float f) { + checkCapacity(1 + 4); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.FLOAT); + VariantUtil.writeLong(writeBuffer, writePos, Float.floatToIntBits(f), 8); + writePos += 4; + } + + public void appendBinary(byte[] binary) { + checkCapacity(1 + VariantUtil.U32_SIZE + binary.length); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.BINARY); + VariantUtil.writeLong(writeBuffer, writePos, binary.length, VariantUtil.U32_SIZE); + writePos += VariantUtil.U32_SIZE; + System.arraycopy(binary, 0, writeBuffer, writePos, binary.length); + writePos += binary.length; + } + + public void appendUUID(java.util.UUID uuid) { + checkCapacity(1 + VariantUtil.UUID_SIZE); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.UUID); + + ByteBuffer bb = + ByteBuffer.wrap(writeBuffer, writePos, VariantUtil.UUID_SIZE).order(ByteOrder.BIG_ENDIAN); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + writePos += VariantUtil.UUID_SIZE; + } + + // Append raw bytes, already in the form required for storage in Variant. + public void appendUUIDBytes(byte[] bytes) { + checkCapacity(1 + VariantUtil.UUID_SIZE); + writeBuffer[writePos++] = VariantUtil.primitiveHeader(VariantUtil.UUID); + // TODO Throw a better exception if this is violated. + assert (bytes.length == VariantUtil.UUID_SIZE); + System.arraycopy(bytes, 0, writeBuffer, writePos, bytes.length); + writePos += bytes.length; + } + + /** + * Adds a key to the Variant dictionary. If the key already exists, the dictionary is unmodified. + * @param key the key to add + * @return the id of the key + */ + public int addKey(String key) { + return dictionary.computeIfAbsent(key, newKey -> { + if (fixedMetadata) { + // TODO: Better exception. + throw new IllegalArgumentException("Value in shredding refers to non-existent metadata string"); + } + int id = dictionaryKeys.size(); + dictionaryKeys.add(newKey.getBytes(StandardCharsets.UTF_8)); + return id; + }); + } + + /** + * @return the current write position of the variant builder + */ + public int getWritePos() { + return writePos; + } + + /** + * Finish writing a Variant object after all of its fields have already been written. The process + * is as follows: + * 1. The caller calls `getWritePos()` before writing any fields to obtain the `start` parameter. + * 2. The caller appends all the object fields to the builder. In the meantime, it should maintain + * the `fields` parameter. Before appending each field, it should append an entry to `fields` to + * record the offset of the field. The offset is computed as `getWritePos() - start`. + * 3. The caller calls `finishWritingObject` to finish writing the Variant object. + * + * This method will sort the fields by key. If there are duplicate field keys: + * - when `allowDuplicateKeys` is true, the field with the greatest offset value (the last + * appended one) is kept. + * - otherwise, throw an exception. + * @param start the start position of the object in the write buffer + * @param fields the list of `FieldEntry` in the object + * @throws VariantDuplicateKeyException if there are duplicate keys and `allowDuplicateKeys` is + * false + */ + public void finishWritingObject(int start, ArrayList fields) { + int size = fields.size(); + Collections.sort(fields); + int maxId = size == 0 ? 0 : fields.get(0).id; + if (allowDuplicateKeys) { + int distinctPos = 0; + // Maintain a list of distinct keys in-place. + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + if (fields.get(i).id == fields.get(i - 1).id) { + // Found a duplicate key. Keep the field with the greater offset. + if (fields.get(distinctPos).offset < fields.get(i).offset) { + fields.set(distinctPos, fields.get(distinctPos).withNewOffset(fields.get(i).offset)); + } + } else { + // Found a distinct key. Add the field to the list. + ++distinctPos; + fields.set(distinctPos, fields.get(i)); + } + } + if (distinctPos + 1 < fields.size()) { + size = distinctPos + 1; + // Resize `fields` to `size`. + fields.subList(size, fields.size()).clear(); + // Sort the fields by offsets so that we can move the value data of each field to the new + // offset without overwriting the fields after it. + fields.sort(Comparator.comparingInt(f -> f.offset)); + int currentOffset = 0; + for (int i = 0; i < size; ++i) { + int oldOffset = fields.get(i).offset; + int fieldSize = VariantUtil.valueSize(writeBuffer, start + oldOffset); + System.arraycopy(writeBuffer, start + oldOffset, writeBuffer, start + currentOffset, fieldSize); + fields.set(i, fields.get(i).withNewOffset(currentOffset)); + currentOffset += fieldSize; + } + writePos = start + currentOffset; + // Change back to the sort order by field keys, required by the Variant specification. + Collections.sort(fields); + } + } else { + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + String key = fields.get(i).key; + if (key.equals(fields.get(i - 1).key)) { + throw new VariantDuplicateKeyException(key); + } + } + } + int dataSize = writePos - start; + boolean largeSize = size > VariantUtil.U8_MAX; + int sizeBytes = largeSize ? VariantUtil.U32_SIZE : 1; + int idSize = getMinIntegerSize(maxId); + int offsetSize = getMinIntegerSize(dataSize); + // The space for header byte, object size, id list, and offset list. + int headerSize = 1 + sizeBytes + size * idSize + (size + 1) * offsetSize; + checkCapacity(headerSize); + // Shift the just-written field data to make room for the object header section. + System.arraycopy(writeBuffer, start, writeBuffer, start + headerSize, dataSize); + writePos += headerSize; + writeBuffer[start] = VariantUtil.objectHeader(largeSize, idSize, offsetSize); + VariantUtil.writeLong(writeBuffer, start + 1, size, sizeBytes); + int idStart = start + 1 + sizeBytes; + int offsetStart = idStart + size * idSize; + for (int i = 0; i < size; ++i) { + VariantUtil.writeLong(writeBuffer, idStart + i * idSize, fields.get(i).id, idSize); + VariantUtil.writeLong(writeBuffer, offsetStart + i * offsetSize, fields.get(i).offset, offsetSize); + } + VariantUtil.writeLong(writeBuffer, offsetStart + size * offsetSize, dataSize, offsetSize); + } + + /** + * Finish writing a Variant array after all of its elements have already been written. The process + * is similar to that of `finishWritingObject`. + * @param start the start position of the array in the write buffer + * @param offsets the list of offsets of the array elements + */ + public void finishWritingArray(int start, ArrayList offsets) { + int dataSize = writePos - start; + int size = offsets.size(); + boolean largeSize = size > VariantUtil.U8_MAX; + int sizeBytes = largeSize ? VariantUtil.U32_SIZE : 1; + int offsetSize = getMinIntegerSize(dataSize); + // The space for header byte, object size, and offset list. + int headerSize = 1 + sizeBytes + (size + 1) * offsetSize; + checkCapacity(headerSize); + // Shift the just-written field data to make room for the header section. + System.arraycopy(writeBuffer, start, writeBuffer, start + headerSize, dataSize); + writePos += headerSize; + writeBuffer[start] = VariantUtil.arrayHeader(largeSize, offsetSize); + VariantUtil.writeLong(writeBuffer, start + 1, size, sizeBytes); + int offsetStart = start + 1 + sizeBytes; + for (int i = 0; i < size; ++i) { + VariantUtil.writeLong(writeBuffer, offsetStart + i * offsetSize, offsets.get(i), offsetSize); + } + VariantUtil.writeLong(writeBuffer, offsetStart + size * offsetSize, dataSize, offsetSize); + } + + /** + * Appends a Variant value to the Variant builder. The input Variant keys must be inserted into + * the builder dictionary and rebuilt with new field ids. For scalar values in the input + * Variant, we can directly copy the binary slice. + * @param v the Variant value to append + */ + public void appendVariant(Variant v) { + appendVariantImpl(v.value, v.metadata, v.pos); + } + + private void appendVariantImpl(byte[] value, byte[] metadata, int pos) { + VariantUtil.checkIndex(pos, value.length); + int basicType = value[pos] & VariantUtil.BASIC_TYPE_MASK; + switch (basicType) { + case VariantUtil.OBJECT: + VariantUtil.handleObject(value, pos, (info) -> { + ArrayList fields = new ArrayList<>(info.numElements); + int start = writePos; + for (int i = 0; i < info.numElements; ++i) { + int id = VariantUtil.readUnsigned(value, info.idStart + info.idSize * i, info.idSize); + int offset = VariantUtil.readUnsigned( + value, info.offsetStart + info.offsetSize * i, info.offsetSize); + int elementPos = info.dataStart + offset; + String key = VariantUtil.getMetadataKey(metadata, id); + int newId = addKey(key); + fields.add(new FieldEntry(key, newId, writePos - start)); + appendVariantImpl(value, metadata, elementPos); + } + finishWritingObject(start, fields); + return null; + }); + break; + case VariantUtil.ARRAY: + VariantUtil.handleArray(value, pos, (info) -> { + ArrayList offsets = new ArrayList<>(info.numElements); + int start = writePos; + for (int i = 0; i < info.numElements; ++i) { + int offset = VariantUtil.readUnsigned( + value, info.offsetStart + info.offsetSize * i, info.offsetSize); + int elementPos = info.dataStart + offset; + offsets.add(writePos - start); + appendVariantImpl(value, metadata, elementPos); + } + finishWritingArray(start, offsets); + return null; + }); + break; + default: + shallowAppendVariant(value, pos); + break; + } + } + + public void shallowAppendVariant(byte[] value, int pos) { + int size = VariantUtil.valueSize(value, pos); + VariantUtil.checkIndex(pos + size - 1, value.length); + checkCapacity(size); + System.arraycopy(value, pos, writeBuffer, writePos, size); + writePos += size; + } + + private void checkCapacity(int additionalBytes) { + int requiredBytes = writePos + additionalBytes; + if (requiredBytes > writeBuffer.length) { + // Allocate a new buffer with a capacity of the next power of 2 of `requiredBytes`. + int newCapacity = Integer.highestOneBit(requiredBytes); + newCapacity = newCapacity < requiredBytes ? newCapacity * 2 : newCapacity; + if (newCapacity > sizeLimitBytes) { + throw new VariantSizeLimitException(sizeLimitBytes, newCapacity); + } + byte[] newValue = new byte[newCapacity]; + System.arraycopy(writeBuffer, 0, newValue, 0, writePos); + writeBuffer = newValue; + } + } + + /** + * Class to store the information of a Variant object field. We need to collect all fields of + * an object, sort them by their keys, and build the Variant object in sorted order. + */ + public static final class FieldEntry implements Comparable { + final String key; + final int id; + final int offset; + + public FieldEntry(String key, int id, int offset) { + this.key = key; + this.id = id; + this.offset = offset; + } + + FieldEntry withNewOffset(int newOffset) { + return new FieldEntry(key, id, newOffset); + } + + @Override + public int compareTo(FieldEntry other) { + return key.compareTo(other.key); + } + } + + private void buildFromJsonParser(JsonParser parser) throws IOException { + JsonToken token = parser.currentToken(); + if (token == null) { + throw new JsonParseException(parser, "Unexpected null token"); + } + switch (token) { + case START_OBJECT: { + ArrayList fields = new ArrayList<>(); + int start = writePos; + while (parser.nextToken() != JsonToken.END_OBJECT) { + String key = parser.currentName(); + parser.nextToken(); + int id = addKey(key); + fields.add(new FieldEntry(key, id, writePos - start)); + buildFromJsonParser(parser); + } + finishWritingObject(start, fields); + break; + } + case START_ARRAY: { + ArrayList offsets = new ArrayList<>(); + int start = writePos; + while (parser.nextToken() != JsonToken.END_ARRAY) { + offsets.add(writePos - start); + buildFromJsonParser(parser); + } + finishWritingArray(start, offsets); + break; + } + case VALUE_STRING: + appendString(parser.getText()); + break; + case VALUE_NUMBER_INT: + try { + appendLong(parser.getLongValue()); + } catch (InputCoercionException ignored) { + // If the value doesn't fit any integer type, parse it as decimal or floating instead. + parseAndAppendFloatingPoint(parser); + } + break; + case VALUE_NUMBER_FLOAT: + parseAndAppendFloatingPoint(parser); + break; + case VALUE_TRUE: + appendBoolean(true); + break; + case VALUE_FALSE: + appendBoolean(false); + break; + case VALUE_NULL: + appendNull(); + break; + default: + throw new JsonParseException(parser, "Unexpected token " + token); + } + } + + /** + * Returns the size (number of bytes) of the smallest unsigned integer type that can store + * `value`. It must be within `[0, U24_MAX]`. + * @param value the value to get the size for + * @return the size (number of bytes) of the smallest unsigned integer type that can store `value` + */ + private int getMinIntegerSize(int value) { + assert value >= 0 && value <= VariantUtil.U24_MAX; + if (value <= VariantUtil.U8_MAX) { + return VariantUtil.U8_SIZE; + } + if (value <= VariantUtil.U16_MAX) { + return VariantUtil.U16_SIZE; + } + return VariantUtil.U24_SIZE; + } + + /** + * Parse a JSON number as a floating point value. If the number can be parsed as a decimal, it + * will be appended as a decimal value. Otherwise, it will be appended as a double value. + * @param parser the JSON parser to use + */ + private void parseAndAppendFloatingPoint(JsonParser parser) throws IOException { + if (!tryParseDecimal(parser.getText())) { + appendDouble(parser.getDoubleValue()); + } + } + + /** + * Try to parse a JSON number as a decimal. The input must only use the decimal format + * (an integer value with an optional '.' in it) and must not use scientific notation. It also + * must fit into the precision limitation of decimal types. + * @param input the input string to parse as decimal + * @return whether the parsing succeeds + */ + private boolean tryParseDecimal(String input) { + for (int i = 0; i < input.length(); ++i) { + char ch = input.charAt(i); + if (ch != '-' && ch != '.' && !(ch >= '0' && ch <= '9')) { + return false; + } + } + BigDecimal d = new BigDecimal(input); + if (d.scale() <= VariantUtil.MAX_DECIMAL16_PRECISION && d.precision() <= VariantUtil.MAX_DECIMAL16_PRECISION) { + appendDecimal(d); + return true; + } + return false; + } + + /** The buffer for building the Variant value. The first `writePos` bytes have been written. */ + private byte[] writeBuffer = new byte[128]; + + private int writePos = 0; + /** The dictionary for mapping keys to monotonically increasing ids. */ + private HashMap dictionary; + /** The keys in the dictionary, in id order. */ + private ArrayList dictionaryKeys; + + private final boolean allowDuplicateKeys; + private final int sizeLimitBytes; + + // If true, metadata is provided in the constructor, and may not be modified - added values must refer to + // strings that already exist in metadata. + private boolean fixedMetadata; +} diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantColumnConverter.java b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantColumnConverter.java new file mode 100644 index 0000000000..0574e321ff --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantColumnConverter.java @@ -0,0 +1,799 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.variant; + +import static org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.MICROS; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; +import static org.apache.parquet.schema.Type.Repetition.REPEATED; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.*; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.io.api.Converter; +import org.apache.parquet.io.api.GroupConverter; +import org.apache.parquet.io.api.PrimitiveConverter; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * Stores the Variant builder and metadata used to rebuild a single Variant value from its shredded representation. + */ +class VariantBuilderHolder { + VariantBuilder builder = null; + Binary metadata = null; + // Maps metadata entries to their index in the metadata binary. + HashMap metadataMap = null; + + void startNewVariant() { + builder = new VariantBuilder(false); + } + + Binary getMetadata() { + return metadata; + } + + /** + * Sets the metadata. May only be called after startNewVariant. We allow the `value` column to + * be added to the builder before metadata has been set, since it does not depend on metadata, but + * typed_value (specifically, if typed_value is or contains an object) must be added after setting + * the metadata. + */ + void setMetadata(Binary metadata) { + // If the metadata hasn't changed, we don't need to rebuild the map. + // When metadata is dictionary encoded, we could consider keeping the map + // around for every dictionary value, but that could be expensive, and handling adjacent + // rows with identical metadata should be the most common case. + if (this.metadata != metadata) { + metadataMap = VariantUtil.getMetadataMap(metadata.getBytes()); + } + this.metadata = metadata; + builder.setFixedMetadata(metadataMap); + } +} + +interface VariantConverter { + void init(VariantBuilderHolder builderHolder); +} + +/** + * Converter for a shredded Variant containing a value and/or typed_value field: either a top-level + * Variant column, or a nested array element or object field. + * The top-level converter is handled by a subclass (VariantColumnConverter) that also reads + * metadata. + * + * All converters for a Variant column shared the same VariantBuilder, and append their results to + * it as values are read from Parquet. + * + * Values in `typed_value` are appended by the child converter. Values in `value` are stored by a + * child converter, but only appended when completing this group. Additionally, object fields are + * appended by the `typed_value` converter, but because residual values are stored in `value`, this + * converter is responsible for finalizing the object. + */ +class VariantElementConverter extends GroupConverter implements VariantConverter { + + // startWritePos has two uses: + // 1) If typed_value is an object, we gather fields from value and typed_value and write the final + // object in end(), so we need to remember the start position. + // 2) If this is the field of an object, we use startWritePos to tell our parent the field's + // offset within the encoded parent object. + private int startWritePos; + private boolean typedValueIsObject = false; + private int valueIdx = -1; + private int typedValueIdx = -1; + protected VariantBuilderHolder builder; + protected Converter[] converters; + + // The following are only used if this is an object field. + private String objectFieldName = null; + private int objectFieldId = -1; + private VariantObjectConverter parent = null; + + // Only used if typedValueIsObject is true. + private Set shreddedObjectKeys; + + @Override + public void init(VariantBuilderHolder builder) { + this.builder = builder; + for (Converter converter : converters) { + if (converter != null) { + ((VariantConverter) converter).init(builder); + } + } + } + + public VariantElementConverter(GroupType variantSchema, String objectFieldName, VariantObjectConverter parent) { + this(variantSchema); + this.objectFieldName = objectFieldName; + this.parent = parent; + } + + public VariantElementConverter(GroupType variantSchema) { + converters = new Converter[3]; + + List fields = variantSchema.getFields(); + + for (int i = 0; i < fields.size(); i++) { + Type field = fields.get(i); + String fieldName = field.getName(); + if (fieldName.equals("value")) { + this.valueIdx = i; + if (!field.isPrimitive() || field.asPrimitiveType().getPrimitiveTypeName() != BINARY) { + throw new IllegalArgumentException("Value must be a binary value"); + } + } else if (fieldName.equals("typed_value")) { + this.typedValueIdx = i; + } + } + + if (valueIdx >= 0) { + converters[valueIdx] = new VariantValueConverter(this); + } + + if (typedValueIdx >= 0) { + Converter typedConverter = null; + Type field = fields.get(typedValueIdx); + LogicalTypeAnnotation annotation = field.getLogicalTypeAnnotation(); + if (annotation instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) { + typedConverter = new VariantArrayConverter(field.asGroupType()); + } else if (!field.isPrimitive()) { + GroupType typed_value = field.asGroupType(); + typedConverter = new VariantObjectConverter(typed_value); + typedValueIsObject = true; + shreddedObjectKeys = new HashSet<>(); + for (Type f : typed_value.getFields()) { + shreddedObjectKeys.add(f.getName()); + } + } else { + typedConverter = VariantScalarConverter.create(field.asPrimitiveType()); + } + + assert (typedConverter != null); + converters[typedValueIdx] = typedConverter; + } + } + + @Override + public Converter getConverter(int fieldIndex) { + return converters[fieldIndex]; + } + + /** runtime calls **/ + + @Override + public void start() { + startWritePos = builder.builder.getWritePos(); + if (valueIdx >= 0) { + ((VariantValueConverter) converters[valueIdx]).reset(); + } + } + + @Override + public void end() { + VariantBuilder builder = this.builder.builder; + + Binary variantValue = null; + ArrayList fields = null; + if (typedValueIsObject) { + // Get the array that the child typed_value has been adding its fields to. We need to possibly add + // more values from the `value` field, then finalize. If the value was not an object, fields will be null. + fields = ((VariantObjectConverter) converters[typedValueIdx]).getFieldsAndReset(); + } + if (valueIdx >= 0) { + variantValue = ((VariantValueConverter) converters[valueIdx]).getValue(); + } + if (variantValue != null) { + // The first check guards against an invalid shredding where value and typed_value are both non-null, and + // typed_value is not an object. It is not sufficient, because a non-null but empty object in typed_value + // will leave the write position unchanged. + if (startWritePos == builder.getWritePos() && fields == null) { + // Nothing else was added. We can directly append this value. + builder.shallowAppendVariant( + variantValue.toByteBuffer().array(), + variantValue.toByteBuffer().position()); + } else { + // Both value and typed_value were non-null. This is only valid for an object. + byte[] value = variantValue.getBytes(); + int basicType = value[0] & VariantUtil.BASIC_TYPE_MASK; + if (basicType != VariantUtil.OBJECT || fields == null) { + throw new IllegalArgumentException("Invalid variant, conflicting value and typed_value"); + } + + // Copy needed to satisfy compiler due to lambda. + ArrayList finalFields = fields; + VariantUtil.handleObject(value, 0, (info) -> { + for (int i = 0; i < info.numElements; ++i) { + int id = VariantUtil.readUnsigned(value, info.idStart + info.idSize * i, info.idSize); + String key = VariantUtil.getMetadataKey(this.builder.getMetadata().getBytes(), id); + if (shreddedObjectKeys.contains(key)) { + // Skip any field ID that is also in the typed schema. This check is needed because readers with + // pushdown may not look at the value column, causing inconsistent results if a writer puth a given key + // only in the value column when it was present in the typed_value schema. + // Alternatively, we could fail at this point, since the shredding is invalid according to the spec. + continue; + } + int offset = VariantUtil.readUnsigned( + value, info.offsetStart + info.offsetSize * i, info.offsetSize); + int elementPos = info.dataStart + offset; + finalFields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos() - startWritePos)); + builder.shallowAppendVariant(value, elementPos); + } + return null; + }); + builder.finishWritingObject(startWritePos, finalFields); + } + } else if (typedValueIsObject && fields != null) { + // We wrote an object, and there's nothing left to append. + builder.finishWritingObject(startWritePos, fields); + } + + if (startWritePos == builder.getWritePos() && objectFieldName == null) { + // If startWritePos == builder.getWritePos(), and this is an array element or top-level field, the + // spec considers this invalid, but suggests writing a VariantNull to the resulting variant. + // We could also consider failing with an error. + builder.appendNull(); + } + + if (startWritePos != builder.getWritePos() && objectFieldName != null) { + if (objectFieldId == -1) { + // metadata isn't available in the constructor, so we look up the field lazily. + objectFieldId = builder.addKey(objectFieldName); + } + // Record that we added a field. + parent.addField(objectFieldName, objectFieldId, startWritePos); + } + } +} + +/** + * Converter for shredded Variant values. Connectors should implement the addVariant method, similar to + * the add* methods on PrimitiveConverter. + */ +public abstract class VariantColumnConverter extends VariantElementConverter { + + private int topLevelMetadataIdx = -1; + + public VariantColumnConverter(GroupType variantSchema) { + super(variantSchema); + + List fields = variantSchema.getFields(); + for (int i = 0; i < fields.size(); i++) { + Type field = fields.get(i); + String fieldName = field.getName(); + if (fieldName.equals("metadata")) { + this.topLevelMetadataIdx = i; + if (!field.isPrimitive() || field.asPrimitiveType().getPrimitiveTypeName() != BINARY) { + throw new IllegalArgumentException("Metadata must be a binary value"); + } + } + } + if (topLevelMetadataIdx < 0) { + throw new IllegalArgumentException("Metadata missing from schema"); + } + converters[topLevelMetadataIdx] = new VariantMetadataConverter(); + builder = new VariantBuilderHolder(); + init(builder); + } + + /** + * Set the final shredded value. + */ + public abstract void addVariant(Binary value, Binary metadata); + + /** + * called at the beginning of the group managed by this converter + */ + @Override + public void start() { + builder.startNewVariant(); + super.start(); + } + + /** + * call at the end of the group + */ + @Override + public void end() { + super.end(); + byte[] value = builder.builder.valueWithoutMetadata(); + addVariant(Binary.fromConstantByteArray(value), builder.getMetadata()); + } +} + +/** + * Converter for the metadata column. It sets the current metadata in the parent converter, + * so that it can be used by the typed_value converter on the same row. + */ +class VariantMetadataConverter extends PrimitiveConverter implements VariantConverter { + private VariantBuilderHolder builder; + Binary[] dict; + + public VariantMetadataConverter() { + dict = null; + } + + @Override + public void init(VariantBuilderHolder builderHolder) { + builder = builderHolder; + } + + @Override + public void addBinary(Binary value) { + builder.setMetadata(value); + } + + @Override + public boolean hasDictionarySupport() { + return true; + } + + @Override + public void setDictionary(Dictionary dictionary) { + dict = new Binary[dictionary.getMaxId() + 1]; + for (int i = 0; i <= dictionary.getMaxId(); i++) { + dict[i] = dictionary.decodeToBinary(i); + } + } + + @Override + public void addValueFromDictionary(int dictionaryId) { + builder.setMetadata(dict[dictionaryId]); + } +} + +// Converter for the `value` field. It does not append to VariantBuilder directly: it simply holds onto +// its value for the parent converter to append. +class VariantValueConverter extends PrimitiveConverter implements VariantConverter { + private VariantElementConverter parent; + Binary[] dict; + Binary currentValue; + + public VariantValueConverter(VariantElementConverter parent) { + this.parent = parent; + this.currentValue = null; + dict = null; + } + + @Override + public void init(VariantBuilderHolder builderHolder) {} + + void reset() { + currentValue = null; + } + + Binary getValue() { + return currentValue; + } + + @Override + public void addBinary(Binary value) { + currentValue = value; + } + + @Override + public boolean hasDictionarySupport() { + return true; + } + + @Override + public void setDictionary(Dictionary dictionary) { + dict = new Binary[dictionary.getMaxId() + 1]; + for (int i = 0; i <= dictionary.getMaxId(); i++) { + dict[i] = dictionary.decodeToBinary(i); + } + } + + @Override + public void addValueFromDictionary(int dictionaryId) { + currentValue = dict[dictionaryId]; + } +} + +// Base class for converting primitive typed_value fields. +class VariantScalarConverter extends PrimitiveConverter implements VariantConverter { + protected VariantBuilderHolder builder; + private GroupType scalarType; + + @Override + public void init(VariantBuilderHolder builderHolder) { + builder = builderHolder; + } + + // Return an appropriate converter for the given Parquet type. + static VariantScalarConverter create(PrimitiveType primitive) { + VariantScalarConverter typedConverter = null; + LogicalTypeAnnotation annotation = primitive.getLogicalTypeAnnotation(); + PrimitiveType.PrimitiveTypeName primitiveType = primitive.getPrimitiveTypeName(); + if (primitiveType == BOOLEAN) { + typedConverter = new VariantBooleanConverter(); + } else if (annotation instanceof LogicalTypeAnnotation.IntLogicalTypeAnnotation) { + LogicalTypeAnnotation.IntLogicalTypeAnnotation intAnnotation = + (LogicalTypeAnnotation.IntLogicalTypeAnnotation) annotation; + if (!intAnnotation.isSigned()) { + throw new UnsupportedOperationException("Unsupported shredded value type: " + + intAnnotation); + } + int width = intAnnotation.getBitWidth(); + if (width == 8) { + typedConverter = new VariantByteConverter(); + } else if (width == 16) { + typedConverter = new VariantShortConverter(); + } else if (width == 32) { + typedConverter = new VariantIntConverter(); + } else if (width == 64) { + typedConverter = new VariantLongConverter(); + } else { + throw new UnsupportedOperationException("Unsupported shredded value type: " + + intAnnotation); + } + } else if (annotation == null && primitiveType == INT32) { + typedConverter = new VariantIntConverter(); + } else if (annotation == null && primitiveType == INT64) { + typedConverter = new VariantLongConverter(); + } else if (primitiveType == FLOAT) { + typedConverter = new VariantFloatConverter(); + } else if (primitiveType == DOUBLE) { + typedConverter = new VariantDoubleConverter(); + } else if (annotation instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) { + LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType = + (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) annotation; + typedConverter = new VariantDecimalConverter(decimalType.getScale()); + } else if (annotation instanceof LogicalTypeAnnotation.DateLogicalTypeAnnotation) { + typedConverter = new VariantDateConverter(); + } else if (annotation instanceof LogicalTypeAnnotation.TimestampLogicalTypeAnnotation) { + LogicalTypeAnnotation.TimestampLogicalTypeAnnotation timestampType = + (LogicalTypeAnnotation.TimestampLogicalTypeAnnotation) annotation; + if (timestampType.isAdjustedToUTC()) { + switch (timestampType.getUnit()) { + case MILLIS: + throw new UnsupportedOperationException("MILLIS not supported by Variant"); + case MICROS: + typedConverter = new VariantTimestampConverter(); + break; + case NANOS: + typedConverter = new VariantTimestampNanosConverter(); + break; + } + } else { + switch (timestampType.getUnit()) { + case MILLIS: + throw new UnsupportedOperationException("MILLIS not supported by Variant"); + case MICROS: + typedConverter = new VariantTimestampNtzConverter(); + break; + case NANOS: + typedConverter = new VariantTimestampNanosNtzConverter(); + break; + } + } + } else if (annotation instanceof LogicalTypeAnnotation.TimeLogicalTypeAnnotation) { + LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeType = + (LogicalTypeAnnotation.TimeLogicalTypeAnnotation) annotation; + if (timeType.isAdjustedToUTC() || timeType.getUnit() != MICROS) { + throw new UnsupportedOperationException(timeType + " not supported by Variant"); + } else { + typedConverter = new VariantTimeConverter(); + } + } else if (annotation instanceof LogicalTypeAnnotation.UUIDLogicalTypeAnnotation) { + typedConverter = new VariantUuidConverter(); + } else if (annotation instanceof LogicalTypeAnnotation.StringLogicalTypeAnnotation) { + typedConverter = new VariantStringConverter(); + } else if (primitiveType == BINARY) { + typedConverter = new VariantBinaryConverter(); + } else { + throw new UnsupportedOperationException("Unsupported shredded value type: " + primitive); + } + return typedConverter; + } +} + +class VariantStringConverter extends VariantScalarConverter { + @Override + public void addBinary(Binary value) { + builder.builder.appendString(value.toStringUsingUTF8()); + } +} + +class VariantBinaryConverter extends VariantScalarConverter { + @Override + public void addBinary(Binary value) { + builder.builder.appendBinary(value.getBytes()); + } +} + +class VariantDecimalConverter extends VariantScalarConverter { + private int scale; + + VariantDecimalConverter(int scale) { + super(); + this.scale = scale; + } + + @Override + public void addBinary(Binary value) { + builder.builder.appendDecimal( + new BigDecimal(new BigInteger(value.getBytes()), scale)); + } + + @Override + public void addLong(long value) { + BigDecimal decimal = BigDecimal.valueOf(value, scale); + builder.builder.appendDecimal(decimal); + } + + @Override + public void addInt(int value) { + BigDecimal decimal = BigDecimal.valueOf(value, scale); + builder.builder.appendDecimal(decimal); + } +} + +class VariantUuidConverter extends VariantScalarConverter { + @Override + public void addBinary(Binary value) { + builder.builder.appendUUIDBytes(value.getBytes()); + } +} + +class VariantBooleanConverter extends VariantScalarConverter { + @Override + public void addBoolean(boolean value) { + builder.builder.appendBoolean(value); + } +} + +class VariantDoubleConverter extends VariantScalarConverter { + @Override + public void addDouble(double value) { + builder.builder.appendDouble(value); + } +} + +class VariantFloatConverter extends VariantScalarConverter { + @Override + public void addFloat(float value) { + builder.builder.appendFloat(value); + } +} + +class VariantByteConverter extends VariantScalarConverter { + @Override + public void addInt(int value) { + // TODO: Fix + builder.builder.appendLong(value); + } +} + +class VariantShortConverter extends VariantScalarConverter { + @Override + public void addInt(int value) { + // TODO: Fix + builder.builder.appendLong(value); + } +} + +class VariantIntConverter extends VariantScalarConverter { + @Override + public void addInt(int value) { + // TODO: Fix + builder.builder.appendLong(value); + } +} + +class VariantLongConverter extends VariantScalarConverter { + @Override + public void addLong(long value) { + builder.builder.appendLong(value); + } +} + +class VariantDateConverter extends VariantScalarConverter { + @Override + public void addInt(int value) { + builder.builder.appendDate(value); + } +} + +class VariantTimeConverter extends VariantScalarConverter { + @Override + public void addLong(long value) { + builder.builder.appendTime(value); + } +} + +class VariantTimestampConverter extends VariantScalarConverter { + @Override + public void addLong(long value) { + builder.builder.appendTimestamp(value); + } +} + +class VariantTimestampNtzConverter extends VariantScalarConverter { + @Override + public void addLong(long value) { + builder.builder.appendTimestampNtz(value); + } +} + +class VariantTimestampNanosConverter extends VariantScalarConverter { + @Override + public void addLong(long value) { + builder.builder.appendTimestampNanos(value); + } +} + +class VariantTimestampNanosNtzConverter extends VariantScalarConverter { + @Override + public void addLong(long value) { + builder.builder.appendTimestampNanosNtz(value); + } +} + +/** + * Converter for a LIST typed_value. + */ +class VariantArrayConverter extends GroupConverter implements VariantConverter { + private VariantBuilderHolder builder; + private VariantArrayRepeatedConverter repeatedConverter; + private ArrayList offsets; + private int startPos; + + public VariantArrayConverter(GroupType listType) { + if (listType.getFieldCount() != 1) { + throw new IllegalArgumentException("LIST must have one field"); + } + Type middleLevel = listType.getType(0); + if (!middleLevel.isRepetition(REPEATED) + || middleLevel.isPrimitive() + || middleLevel.asGroupType().getFieldCount() != 1) { + throw new IllegalArgumentException("LIST must have one repeated field"); + } + this.repeatedConverter = new VariantArrayRepeatedConverter(middleLevel.asGroupType(), this); + } + + @Override + public void init(VariantBuilderHolder builderHolder) { + builder = builderHolder; + repeatedConverter.init(builderHolder); + } + + @Override + public Converter getConverter(int fieldIndex) { + return repeatedConverter; + } + + public void addElement() { + offsets.add(builder.builder.getWritePos() - startPos); + } + + @Override + public void start() { + offsets = new ArrayList<>(); + startPos = builder.builder.getWritePos(); + } + + @Override + public void end() { + builder.builder.finishWritingArray(startPos, offsets); + } +} + +/** + * Converter for the repeated field of a LIST typed_value. + */ +class VariantArrayRepeatedConverter extends GroupConverter implements VariantConverter { + private VariantElementConverter elementConverter; + private VariantArrayConverter parentConverter; + + public VariantArrayRepeatedConverter(GroupType repeatedType, VariantArrayConverter parentaConverter) { + this.elementConverter = new VariantElementConverter(repeatedType.getType(0).asGroupType()); + this.parentConverter = parentaConverter; + } + + @Override + public void init(VariantBuilderHolder builderHolder) { + elementConverter.init(builderHolder); + } + + @Override + public Converter getConverter(int fieldIndex) { + return elementConverter; + } + + @Override + public void start() { + // Record the offset of this element in the binary. + parentConverter.addElement(); + } + + @Override + public void end() {} +} + +class VariantObjectConverter extends GroupConverter implements VariantConverter { + private VariantBuilderHolder builder; + private VariantElementConverter[] converters; + private ArrayList fieldEntries = new ArrayList<>(); + // hasValue is used to distinguish a null typed_value (which may be a missing field of another object) from + // an empty object, since both will have an empty fieldEntries list at the end, and the parent converter + // will need to know if the field is missing or empty. + private boolean hasValue = false; + // The write position in the buffer for the start of this object. + private int startWritePos; + + public VariantObjectConverter(GroupType typed_value) { + List fields = typed_value.getFields(); + converters = new VariantElementConverter[fields.size()]; + for (int i = 0; i < fields.size(); i++) { + GroupType field = fields.get(i).asGroupType(); + String name = fields.get(i).getName(); + converters[i] = new VariantElementConverter(field, name, this); + } + }; + + void addField(String fieldName, int fieldId, int fieldStartPos) { + fieldEntries.add(new VariantBuilder.FieldEntry(fieldName, fieldId, fieldStartPos - startWritePos)); + } + + // Return fieldEntries, and reset the the state for reading the next object. + // If there was no object, return null. + // It must be called after each call to end() to ensure that hasValue is reset. + ArrayList getFieldsAndReset() { + if (!hasValue) { + return null; + } + hasValue = false; + return fieldEntries; + } + + @Override + public void init(VariantBuilderHolder builderHolder) { + builder = builderHolder; + for (VariantElementConverter c: converters) { + c.init(builderHolder); + } + } + + @Override + public Converter getConverter(int fieldIndex) { + return converters[fieldIndex]; + } + + @Override + public void start() { + fieldEntries.clear(); + startWritePos = builder.builder.getWritePos(); + } + + @Override + public void end() { + // We can't finish writing the object here, because there might be residual entries in our + // parent's value column. The parent converter calls getFields to finalize the object. + // However, we need to indicate to our parent that the object is non-null. + hasValue = true; + } +} diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantDuplicateKeyException.java b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantDuplicateKeyException.java new file mode 100644 index 0000000000..12e94416c4 --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantDuplicateKeyException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.variant; + +/** + * An exception indicating that the Variant contains a duplicate key. + */ +public class VariantDuplicateKeyException extends RuntimeException { + public final String key; + + /** + * @param key the key that was duplicated + */ + public VariantDuplicateKeyException(String key) { + super("Failed to build Variant because of duplicate object key: " + key); + this.key = key; + } + + /** + * @return the key that was duplicated + */ + public String getKey() { + return key; + } +} diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantSizeLimitException.java b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantSizeLimitException.java new file mode 100644 index 0000000000..a86a41ad6e --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantSizeLimitException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.variant; + +/** + * An exception indicating that the metadata or data size of the Variant exceeds the + * configured size limit. + */ +public class VariantSizeLimitException extends RuntimeException { + public VariantSizeLimitException(long sizeLimitBytes, long estimatedSizeBytes) { + super(String.format( + "Variant size exceeds the limit of %d bytes. Estimated size: %d bytes", + sizeLimitBytes, estimatedSizeBytes)); + } +} diff --git a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java new file mode 100644 index 0000000000..ca33e62122 --- /dev/null +++ b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java @@ -0,0 +1,871 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.parquet.variant; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.HashMap; + +/** + * This class defines constants related to the Variant format and provides functions for + * manipulating Variant binaries. + * + * A Variant is made up of 2 binaries: value and metadata. A Variant value consists of a one-byte + * header and a number of content bytes (can be zero). The header byte is divided into upper 6 bits + * (called "type info") and lower 2 bits (called "basic type"). The content format is explained in + * the below constants for all possible basic type and type info values. + * + * The Variant metadata includes a version id and a dictionary of distinct strings (case-sensitive). + * Its binary format is: + * - Version: 1-byte unsigned integer. The only acceptable value is 1 currently. + * - Dictionary size: 4-byte little-endian unsigned integer. The number of keys in the + * dictionary. + * - Offsets: (size + 1) * 4-byte little-endian unsigned integers. `offsets[i]` represents the + * starting position of string i, counting starting from the address of `offsets[0]`. Strings + * must be stored contiguously, so we don’t need to store the string size, instead, we compute it + * with `offset[i + 1] - offset[i]`. + * - UTF-8 string data. + */ +public class VariantUtil { + public static final int BASIC_TYPE_BITS = 2; + public static final int BASIC_TYPE_MASK = 0b00000011; + public static final int PRIMITIVE_TYPE_MASK = 0b00111111; + /** The inclusive maximum value of the type info value. It is the size limit of `SHORT_STR`. */ + public static final int MAX_SHORT_STR_SIZE = 0b00111111; + + // The basic types + + /** + * Primitive value. + * The type info value must be one of the values in the "Primitive" section below. + */ + public static final int PRIMITIVE = 0; + /** + * Short string value. + * The type info value is the string size, which must be in `[0, MAX_SHORT_STR_SIZE]`. + * The string content bytes directly follow the header byte. + */ + public static final int SHORT_STR = 1; + /** + * Object value. + * The content contains a size, a list of field ids, a list of field offsets, and + * the actual field values. The list of field ids has `size` ids, while the list of field offsets + * has `size + 1` offsets, where the last offset represents the total size of the field values + * data. The list of fields ids must be sorted by the field name in alphabetical order. + * Duplicate field names within one object are not allowed. + * 5 bits in the type info are used to specify the integer type of the object header. It is + * 0_b4_b3b2_b1b0 (MSB is 0), where: + * - b4: the integer type of size. When it is 0/1, `size` is a little-endian 1/4-byte + * unsigned integer. + * - b3b2: the integer type of ids. When the 2 bits are 0/1/2, the id list contains + * 1/2/3-byte little-endian unsigned integers. + * - b1b0: the integer type of offset. When the 2 bits are 0/1/2, the offset list contains + * 1/2/3-byte little-endian unsigned integers. + */ + public static final int OBJECT = 2; + /** + * Array value. + * The content contains a size, a list of field offsets, and the actual element values. + * It is similar to an object without the id list. The length of the offset list + * is `size + 1`, where the last offset represent the total size of the element data. + * Its type info is: 000_b2_b1b0: + * - b2: the type of size. + * - b1b0: the integer type of offset. + */ + public static final int ARRAY = 3; + + // The primitive types + + /** JSON Null value. Empty content. */ + public static final int NULL = 0; + /** True value. Empty content. */ + public static final int TRUE = 1; + /** False value. Empty content. */ + public static final int FALSE = 2; + /** 1-byte little-endian signed integer. */ + public static final int INT8 = 3; + /** 2-byte little-endian signed integer. */ + public static final int INT16 = 4; + /** 4-byte little-endian signed integer. */ + public static final int INT32 = 5; + /** 4-byte little-endian signed integer. */ + public static final int INT64 = 6; + /** 8-byte IEEE double. */ + public static final int DOUBLE = 7; + /** 4-byte decimal. Content is 1-byte scale + 4-byte little-endian signed integer. */ + public static final int DECIMAL4 = 8; + /** 8-byte decimal. Content is 1-byte scale + 8-byte little-endian signed integer. */ + public static final int DECIMAL8 = 9; + /** 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed integer. */ + public static final int DECIMAL16 = 10; + /** + * Date value. Content is 4-byte little-endian signed integer that represents the + * number of days from the Unix epoch. + */ + public static final int DATE = 11; + /** + * Timestamp value. Content is 8-byte little-endian signed integer that represents the number of + * microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. It is displayed to users in + * their local time zones and may be displayed differently depending on the execution environment. + */ + public static final int TIMESTAMP = 12; + /** + * Timestamp_ntz value. It has the same content as `TIMESTAMP` but should always be interpreted + * as if the local time zone is UTC. + */ + public static final int TIMESTAMP_NTZ = 13; + /** 4-byte IEEE float. */ + public static final int FLOAT = 14; + /** + * Binary value. The content is (4-byte little-endian unsigned integer representing the binary + * size) + (size bytes of binary content). + */ + public static final int BINARY = 15; + /** + * Long string value. The content is (4-byte little-endian unsigned integer representing the + * string size) + (size bytes of string content). + */ + public static final int LONG_STR = 16; + /** + * Time value. Values can be from 00:00:00 to 23:59:59.999999. + * Content is 8-byte little-endian unsigned integer that represents the number of microseconds + * since midnight. + */ + public static final int TIME = 17; + /** + * Timestamp nanos value. Similar to `TIMESTAMP`, but represents the number of nanoseconds + * elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. + */ + public static final int TIMESTAMP_NANOS = 18; + /** + * Timestamp nanos (without timestamp) value. It has the same content as `TIMESTAMP_NANOS` but + * should always be interpreted as if the local time zone is UTC. + */ + public static final int TIMESTAMP_NANOS_NTZ = 19; + /** + * UUID value. The content is a 16-byte binary, encoded using big-endian. + * For example, UUID 00112233-4455-6677-8899-aabbccddeeff is encoded as the bytes + * 00 11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff. + */ + public static final int UUID = 20; + + // The metadata version. + public static final byte VERSION = 1; + // The lower 4 bits of the first metadata byte contain the version. + public static final byte VERSION_MASK = 0x0F; + + // Constants for various unsigned integer sizes. + public static final int U8_MAX = 0xFF; + public static final int U16_MAX = 0xFFFF; + public static final int U24_MAX = 0xFFFFFF; + public static final int U8_SIZE = 1; + public static final int U16_SIZE = 2; + public static final int U24_SIZE = 3; + public static final int U32_SIZE = 4; + + // Max decimal precision for each decimal type. + public static final int MAX_DECIMAL4_PRECISION = 9; + public static final int MAX_DECIMAL8_PRECISION = 18; + public static final int MAX_DECIMAL16_PRECISION = 38; + + // The size (in bytes) of a UUID. + public static final int UUID_SIZE = 16; + + // Default size limit for both variant value and variant metadata. + public static final int DEFAULT_SIZE_LIMIT = U24_MAX + 1; + + /** + * Write the least significant `numBytes` bytes in `value` into `bytes[pos, pos + numBytes)` in + * little endian. + * @param bytes The byte array to write into + * @param pos The starting index of the byte array to write into + * @param value The value to write + * @param numBytes The number of bytes to write + */ + public static void writeLong(byte[] bytes, int pos, long value, int numBytes) { + for (int i = 0; i < numBytes; ++i) { + bytes[pos + i] = (byte) ((value >>> (8 * i)) & 0xFF); + } + } + + public static byte primitiveHeader(int type) { + return (byte) (type << 2 | PRIMITIVE); + } + + public static byte shortStrHeader(int size) { + return (byte) (size << 2 | SHORT_STR); + } + + public static byte objectHeader(boolean largeSize, int idSize, int offsetSize) { + return (byte) (((largeSize ? 1 : 0) << (BASIC_TYPE_BITS + 4)) + | ((idSize - 1) << (BASIC_TYPE_BITS + 2)) + | ((offsetSize - 1) << BASIC_TYPE_BITS) + | OBJECT); + } + + public static byte arrayHeader(boolean largeSize, int offsetSize) { + return (byte) (((largeSize ? 1 : 0) << (BASIC_TYPE_BITS + 2)) | ((offsetSize - 1) << BASIC_TYPE_BITS) | ARRAY); + } + + /** + * Check the validity of an array index `pos`. + * @param pos The index to check + * @param length The length of the array + * @throws MalformedVariantException if the index is out of bound + */ + public static void checkIndex(int pos, int length) { + if (pos < 0 || pos >= length) { + throw new IllegalArgumentException( + String.format("Invalid byte-array offset (%d). length: %d", pos, length)); + } + } + + /** + * Reads a little-endian signed long value from `bytes[pos, pos + numBytes)`. + * @param bytes The byte array to read from + * @param pos The starting index of the byte array to read from + * @param numBytes The number of bytes to read + * @return The long value + */ + static long readLong(byte[] bytes, int pos, int numBytes) { + checkIndex(pos, bytes.length); + checkIndex(pos + numBytes - 1, bytes.length); + long result = 0; + // All bytes except the most significant byte should be unsigned-extended and shifted + // (so we need & 0xFF`). The most significant byte should be sign-extended and is handled + // after the loop. + for (int i = 0; i < numBytes - 1; ++i) { + long unsignedByteValue = bytes[pos + i] & 0xFF; + result |= unsignedByteValue << (8 * i); + } + long signedByteValue = bytes[pos + numBytes - 1]; + result |= signedByteValue << (8 * (numBytes - 1)); + return result; + } + + /** + * Read a little-endian unsigned int value from `bytes[pos, pos + numBytes)`. The value must fit + * into a non-negative int (`[0, Integer.MAX_VALUE]`). + */ + static int readUnsigned(byte[] bytes, int pos, int numBytes) { + checkIndex(pos, bytes.length); + checkIndex(pos + numBytes - 1, bytes.length); + int result = 0; + // Similar to the `readLong` loop, but all bytes should be unsigned-extended. + for (int i = 0; i < numBytes; ++i) { + int unsignedByteValue = bytes[pos + i] & 0xFF; + result |= unsignedByteValue << (8 * i); + } + if (result < 0) { + throw new MalformedVariantException(String.format("Failed to read unsigned int. numBytes: %d", numBytes)); + } + return result; + } + + /** + * The value type of Variant value. It is determined by the header byte but not a 1:1 mapping + * (for example, INT1/2/4/8 all maps to `Type.LONG`). + */ + public enum Type { + OBJECT, + ARRAY, + NULL, + BOOLEAN, + BYTE, + SHORT, + INT, + LONG, + STRING, + DOUBLE, + DECIMAL, + DATE, + TIMESTAMP, + TIMESTAMP_NTZ, + FLOAT, + BINARY, + TIME, + TIMESTAMP_NANOS, + TIMESTAMP_NANOS_NTZ, + UUID + } + + public static int getPrimitiveTypeId(byte[] value, int pos) { + checkIndex(pos, value.length); + return (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + } + + /** + * Returns the value type of Variant value `value[pos...]`. It is only legal to call `get*` if + * `getType` returns the corresponding type. For example, it is only legal to call + * `getLong` if this method returns `Type.Long`. + * @param value The Variant value to get the type from + * @param pos The starting index of the Variant value + * @return The type of the Variant value + */ + public static Type getType(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + switch (basicType) { + case SHORT_STR: + return Type.STRING; + case OBJECT: + return Type.OBJECT; + case ARRAY: + return Type.ARRAY; + default: + switch (typeInfo) { + case NULL: + return Type.NULL; + case TRUE: + case FALSE: + return Type.BOOLEAN; + case INT8: + return Type.BYTE; + case INT16: + return Type.SHORT; + case INT32: + return Type.INT; + case INT64: + return Type.LONG; + case DOUBLE: + return Type.DOUBLE; + case DECIMAL4: + case DECIMAL8: + case DECIMAL16: + return Type.DECIMAL; + case DATE: + return Type.DATE; + case TIMESTAMP: + return Type.TIMESTAMP; + case TIMESTAMP_NTZ: + return Type.TIMESTAMP_NTZ; + case FLOAT: + return Type.FLOAT; + case BINARY: + return Type.BINARY; + case LONG_STR: + return Type.STRING; + case TIME: + return Type.TIME; + case TIMESTAMP_NANOS: + return Type.TIMESTAMP_NANOS; + case TIMESTAMP_NANOS_NTZ: + return Type.TIMESTAMP_NANOS_NTZ; + case UUID: + return Type.UUID; + default: + throw new UnknownVariantTypeException(typeInfo); + } + } + } + + /** + * Computes the actual size (in bytes) of the Variant value at `value[pos...]`. + * `value.length - pos` is an upper bound of the size, but the actual size may be smaller. + * @param value The Variant value + * @param pos The starting index of the Variant value + * @return The actual size of the Variant value + * @throws MalformedVariantException if the Variant is malformed + */ + public static int valueSize(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + switch (basicType) { + case SHORT_STR: + return 1 + typeInfo; + case OBJECT: + return handleObject( + value, + pos, + (info) -> info.dataStart + - pos + + readUnsigned( + value, info.offsetStart + info.numElements * info.offsetSize, info.offsetSize)); + case ARRAY: + return handleArray( + value, + pos, + (info) -> info.dataStart + - pos + + readUnsigned( + value, info.offsetStart + info.numElements * info.offsetSize, info.offsetSize)); + default: + switch (typeInfo) { + case NULL: + case TRUE: + case FALSE: + return 1; + case INT8: + return 2; + case INT16: + return 3; + case INT32: + case DATE: + case FLOAT: + return 5; + case INT64: + case DOUBLE: + case TIMESTAMP: + case TIMESTAMP_NTZ: + case TIME: + case TIMESTAMP_NANOS: + case TIMESTAMP_NANOS_NTZ: + return 9; + case DECIMAL4: + return 6; + case DECIMAL8: + return 10; + case DECIMAL16: + return 18; + case BINARY: + case LONG_STR: + return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE); + case UUID: + return 1 + UUID_SIZE; + default: + throw new UnknownVariantTypeException(typeInfo); + } + } + } + + private static MalformedVariantException unexpectedType(Type type) { + return new MalformedVariantException("Expected type to be " + type); + } + + public static boolean getBoolean(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != PRIMITIVE || (typeInfo != TRUE && typeInfo != FALSE)) { + throw unexpectedType(Type.BOOLEAN); + } + return typeInfo == TRUE; + } + + /** + * Returns a long value from Variant value `value[pos...]`. + * It is only legal to call it if `getType` returns one of Type.BYTE, SHORT, INT, LONG, + * DATE, TIMESTAMP, TIMESTAMP_NTZ, TIME, TIMESTAMP_NANOS, TIMESTAMP_NANOS_NTZ. + * If the type is `DATE`, the return value is guaranteed to fit into an int and + * represents the number of days from the Unix epoch. + * If the type is `TIMESTAMP/TIMESTAMP_NTZ`, the return value represents the number of + * microseconds from the Unix epoch. + * If the type is `TIME`, the return value represents the number of microseconds since midnight. + * If the type is `TIMESTAMP_NANOS/TIMESTAMP_NANOS_NTZ`, the return value represents the number of + * nanoseconds from the Unix epoch. + * @param value The Variant value + * @param pos The starting index of the Variant value + * @return The long value + */ + public static long getLong(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + String exceptionMessage = + "Expect type to be one of: BYTE, SHORT, INT, LONG, TIMESTAMP, TIMESTAMP_NTZ, TIME, TIMESTAMP_NANOS, TIMESTAMP_NANOS_NTZ"; + if (basicType != PRIMITIVE) { + throw new IllegalStateException(exceptionMessage); + } + switch (typeInfo) { + case INT8: + return readLong(value, pos + 1, 1); + case INT16: + return readLong(value, pos + 1, 2); + case INT32: + case DATE: + return readLong(value, pos + 1, 4); + case INT64: + case TIMESTAMP: + case TIMESTAMP_NTZ: + case TIME: + case TIMESTAMP_NANOS: + case TIMESTAMP_NANOS_NTZ: + return readLong(value, pos + 1, 8); + default: + throw new IllegalStateException(exceptionMessage); + } + } + + public static double getDouble(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != PRIMITIVE || typeInfo != DOUBLE) { + throw unexpectedType(Type.DOUBLE); + } + return Double.longBitsToDouble(readLong(value, pos + 1, 8)); + } + + /** + * Checks whether the precision and scale of the decimal are within the limit. + * @param d The decimal value to check + * @param maxPrecision The maximum precision allowed + * @throws MalformedVariantException if the decimal is malformed + */ + private static void checkDecimal(BigDecimal d, int maxPrecision) { + if (d.precision() > maxPrecision || d.scale() > maxPrecision) { + throw new MalformedVariantException(String.format( + "Decimal (precision: %d, scale: %d) exceeds max precision %d", + d.precision(), d.scale(), maxPrecision)); + } + } + + public static BigDecimal getDecimalWithOriginalScale(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != PRIMITIVE) { + throw unexpectedType(Type.DECIMAL); + } + // Interpret the scale byte as unsigned. If it is a negative byte, the unsigned value must be + // greater than `MAX_DECIMAL16_PRECISION` and will trigger an error in `checkDecimal`. + int scale = value[pos + 1] & 0xFF; + BigDecimal result; + switch (typeInfo) { + case DECIMAL4: + result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale); + checkDecimal(result, MAX_DECIMAL4_PRECISION); + break; + case DECIMAL8: + result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale); + checkDecimal(result, MAX_DECIMAL8_PRECISION); + break; + case DECIMAL16: + checkIndex(pos + 17, value.length); + byte[] bytes = new byte[16]; + // Copy the bytes reversely because the `BigInteger` constructor expects a big-endian + // representation. + for (int i = 0; i < 16; ++i) { + bytes[i] = value[pos + 17 - i]; + } + result = new BigDecimal(new BigInteger(bytes), scale); + checkDecimal(result, MAX_DECIMAL16_PRECISION); + break; + default: + throw unexpectedType(Type.DECIMAL); + } + return result; + } + + public static BigDecimal getDecimal(byte[] value, int pos) { + return getDecimalWithOriginalScale(value, pos); + } + + public static float getFloat(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != PRIMITIVE || typeInfo != FLOAT) { + throw unexpectedType(Type.FLOAT); + } + return Float.intBitsToFloat((int) readLong(value, pos + 1, 4)); + } + + public static byte[] getBinary(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != PRIMITIVE || typeInfo != BINARY) { + throw unexpectedType(Type.BINARY); + } + int start = pos + 1 + U32_SIZE; + int length = readUnsigned(value, pos + 1, U32_SIZE); + checkIndex(start + length - 1, value.length); + return Arrays.copyOfRange(value, start, start + length); + } + + public static String getString(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType == SHORT_STR || (basicType == PRIMITIVE && typeInfo == LONG_STR)) { + int start; + int length; + if (basicType == SHORT_STR) { + start = pos + 1; + length = typeInfo; + } else { + start = pos + 1 + U32_SIZE; + length = readUnsigned(value, pos + 1, U32_SIZE); + } + checkIndex(start + length - 1, value.length); + return new String(value, start, length); + } + throw unexpectedType(Type.STRING); + } + + public static java.util.UUID getUUID(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != PRIMITIVE || typeInfo != UUID) { + throw unexpectedType(Type.UUID); + } + int start = pos + 1; + checkIndex(start + UUID_SIZE - 1, value.length); + ByteBuffer bb = ByteBuffer.wrap(value, start, UUID_SIZE).order(ByteOrder.BIG_ENDIAN); + return new java.util.UUID(bb.getLong(), bb.getLong()); + } + + /** + * A helper class representing the details of a Variant object, used for `ObjectHandler`. + */ + public static class ObjectInfo { + /** Number of object fields. */ + public final int numElements; + /** The integer size of the field id list. */ + public final int idSize; + /** The integer size of the offset list. */ + public final int offsetSize; + /** The starting index of the field id list in the variant value array. */ + public final int idStart; + /** The starting index of the offset list in the variant value array. */ + public final int offsetStart; + /** The starting index of field data in the variant value array. */ + public final int dataStart; + + public ObjectInfo(int numElements, int idSize, int offsetSize, int idStart, int offsetStart, int dataStart) { + this.numElements = numElements; + this.idSize = idSize; + this.offsetSize = offsetSize; + this.idStart = idStart; + this.offsetStart = offsetStart; + this.dataStart = dataStart; + } + } + + /** + * An interface for the Variant object handler. + * @param The return type of the handler + */ + public interface ObjectHandler { + /** + * @param objectInfo The details of the Variant object + */ + T apply(ObjectInfo objectInfo); + } + + /** + * An interface for the Variant object handler. + * @param The return type of the handler + */ + public interface ObjectHandlerException { + /** + * @param objectInfo The details of the Variant object + */ + T apply(ObjectInfo objectInfo) throws IOException; + } + + /** + * A helper function to access a Variant object, at `value[pos...]`. + * @param value The Variant value + * @param pos The starting index of the Variant value + * @param handler The handler to process the object + * @return The result of the handler + * @param The return type of the handler + */ + public static T handleObject(byte[] value, int pos, ObjectHandler handler) { + ObjectInfo info = parseObject(value, pos); + return handler.apply(info); + } + + /** + * Same as `handleObject` but handler can throw IOException. + */ + public static T handleObjectException(byte[] value, int pos, ObjectHandlerException handler) + throws IOException { + ObjectInfo info = parseObject(value, pos); + return handler.apply(info); + } + + /** + * Parses the object at `value[pos...]`, and returns the object details. + */ + private static ObjectInfo parseObject(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != OBJECT) { + throw unexpectedType(Type.OBJECT); + } + // Refer to the comment of the `OBJECT` constant for the details of the object header encoding. + // Suppose `typeInfo` has a bit representation of 0_b4_b3b2_b1b0, the following line extracts + // b4 to determine whether the object uses a 1/4-byte size. + boolean largeSize = ((typeInfo >> 4) & 0x1) != 0; + int sizeBytes = (largeSize ? U32_SIZE : 1); + int numElements = readUnsigned(value, pos + 1, sizeBytes); + // Extracts b3b2 to determine the integer size of the field id list. + int idSize = ((typeInfo >> 2) & 0x3) + 1; + // Extracts b1b0 to determine the integer size of the offset list. + int offsetSize = (typeInfo & 0x3) + 1; + int idStart = pos + 1 + sizeBytes; + int offsetStart = idStart + numElements * idSize; + int dataStart = offsetStart + (numElements + 1) * offsetSize; + return new ObjectInfo(numElements, idSize, offsetSize, idStart, offsetStart, dataStart); + } + + /** + * A helper class representing the details of a Variant array, used for `ArrayHandler`. + */ + public static class ArrayInfo { + /** Number of object fields. */ + public final int numElements; + /** The integer size of the offset list. */ + public final int offsetSize; + /** The starting index of the offset list in the variant value array. */ + public final int offsetStart; + /** The starting index of field data in the variant value array. */ + public final int dataStart; + + public ArrayInfo(int numElements, int offsetSize, int offsetStart, int dataStart) { + this.numElements = numElements; + this.offsetSize = offsetSize; + this.offsetStart = offsetStart; + this.dataStart = dataStart; + } + } + + /** + * An interface for the Variant array handler. + * @param The return type of the handler + */ + public interface ArrayHandler { + /** + * @param arrayInfo The details of the Variant array + */ + T apply(ArrayInfo arrayInfo); + } + + /** + * An interface for the Variant array handler. + * @param The return type of the handler + */ + public interface ArrayHandlerException { + /** + * @param arrayInfo The details of the Variant array + */ + T apply(ArrayInfo arrayInfo) throws IOException; + } + + /** + * A helper function to access a Variant array, at `value[pos...]`. + * @param value The Variant value + * @param pos The starting index of the Variant value + * @param handler The handler to process the array + * @return The result of the handler + * @param The return type of the handler + */ + public static T handleArray(byte[] value, int pos, ArrayHandler handler) { + ArrayInfo info = parseArray(value, pos); + return handler.apply(info); + } + + /** + * Same as `handleArray` but handler can throw IOException. + */ + public static T handleArrayException(byte[] value, int pos, ArrayHandlerException handler) + throws IOException { + ArrayInfo info = parseArray(value, pos); + return handler.apply(info); + } + + /** + * Parses the array at `value[pos...]`, and returns the array details. + */ + private static ArrayInfo parseArray(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK; + if (basicType != ARRAY) { + throw unexpectedType(Type.ARRAY); + } + // Refer to the comment of the `ARRAY` constant for the details of the object header encoding. + // Suppose `typeInfo` has a bit representation of 000_b2_b1b0, the following line extracts + // b2 to determine whether the object uses a 1/4-byte size. + boolean largeSize = ((typeInfo >> 2) & 0x1) != 0; + int sizeBytes = (largeSize ? U32_SIZE : 1); + int numElements = readUnsigned(value, pos + 1, sizeBytes); + // Extracts b1b0 to determine the integer size of the offset list. + int offsetSize = (typeInfo & 0x3) + 1; + int offsetStart = pos + 1 + sizeBytes; + int dataStart = offsetStart + (numElements + 1) * offsetSize; + return new ArrayInfo(numElements, offsetSize, offsetStart, dataStart); + } + + /** + * Returns a key at `id` in the Variant metadata. + * @param metadata The Variant metadata + * @param id The key id + * @return The key + * @throws MalformedVariantException if the Variant is malformed or if the id is out of bounds + */ + public static String getMetadataKey(byte[] metadata, int id) { + checkIndex(0, metadata.length); + // Extracts the highest 2 bits in the metadata header to determine the integer size of the + // offset list. + int offsetSize = ((metadata[0] >> 6) & 0x3) + 1; + int dictSize = readUnsigned(metadata, 1, offsetSize); + if (id >= dictSize) { + throw new MalformedVariantException( + String.format("Invalid dictionary id: %d. dictionary size: %d", id, dictSize)); + } + // There are a header byte, a `dictSize` with `offsetSize` bytes, and `(dictSize + 1)` offsets + // before the string data. + int stringStart = 1 + (dictSize + 2) * offsetSize; + int offset = readUnsigned(metadata, 1 + (id + 1) * offsetSize, offsetSize); + int nextOffset = readUnsigned(metadata, 1 + (id + 2) * offsetSize, offsetSize); + if (offset > nextOffset) { + throw new MalformedVariantException( + String.format("Invalid offset: %d. next offset: %d", offset, nextOffset)); + } + checkIndex(stringStart + nextOffset - 1, metadata.length); + return new String(metadata, stringStart + offset, nextOffset - offset); + } + + /** + * Returns a map from each string to its ID in the Variant metadata. + * @param metadata The Variant metadata + * @return A map from metadata key to its position. + */ + public static HashMap getMetadataMap(byte[] metadata) { + checkIndex(0, metadata.length); + // Extracts the highest 2 bits in the metadata header to determine the integer size of the + // offset list. + int offsetSize = ((metadata[0] >> 6) & 0x3) + 1; + int dictSize = readUnsigned(metadata, 1, offsetSize); + HashMap result = new HashMap<>(); + int offset = readUnsigned(metadata, 1 + offsetSize, offsetSize); + for (int id = 0; id < dictSize; id++) { + int stringStart = 1 + (dictSize + 2) * offsetSize; + int nextOffset = readUnsigned(metadata, 1 + (id + 2) * offsetSize, offsetSize); + if (offset > nextOffset) { + throw new MalformedVariantException( + String.format("Invalid offset: %d. next offset: %d", offset, nextOffset)); + } + checkIndex(stringStart + nextOffset - 1, metadata.length); + result.put(new String(metadata, stringStart + offset, nextOffset - offset), id); + offset = nextOffset; + } + return result; + } +} diff --git a/parquet-variant/src/test/java/org/apache/parquet/variant/TestVariantEncoding.java b/parquet-variant/src/test/java/org/apache/parquet/variant/TestVariantEncoding.java new file mode 100644 index 0000000000..9735f7afc0 --- /dev/null +++ b/parquet-variant/src/test/java/org/apache/parquet/variant/TestVariantEncoding.java @@ -0,0 +1,717 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.variant; + +import com.fasterxml.jackson.core.*; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.SecureRandom; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; +import java.util.Arrays; +import java.util.Base64; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import org.junit.Assert; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TestVariantEncoding { + private static final Logger LOG = LoggerFactory.getLogger(TestVariantEncoding.class); + private static final String RANDOM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + private static final List SAMPLE_JSON_VALUES = Arrays.asList( + "null", + "true", + "false", + "12", + "-9876543210", + "4.5678E123", + "8.765E-2", + "\"string value\"", + "-9876.543", + "234.456789", + "{\"a\": 1, \"b\": {\"e\": -4, \"f\": 5.5}, \"c\": true}", + "[1, -2, 4.5, -6.7, \"str\", true]"); + + /** Random number generator for generating random strings */ + private static SecureRandom random = new SecureRandom(); + /** Object mapper for comparing json values */ + private final ObjectMapper mapper = new ObjectMapper(); + + private void checkJson(String expected, String actual) { + try { + StreamReadConstraints.overrideDefaultStreamReadConstraints( + StreamReadConstraints.builder().maxNestingDepth(100000).build()); + Assert.assertEquals(mapper.readTree(expected), mapper.readTree(actual)); + } catch (IOException e) { + Assert.fail("Failed to parse json: " + e); + } + } + + private void checkJson(String jsonValue) { + try { + StreamReadConstraints.overrideDefaultStreamReadConstraints( + StreamReadConstraints.builder().maxNestingDepth(100000).build()); + Variant v = VariantBuilder.parseJson(jsonValue); + checkJson(jsonValue, v.toJson()); + } catch (IOException e) { + Assert.fail("Failed to parse json: " + jsonValue + " " + e); + } + } + + private void checkType(Variant v, int expectedBasicType, int expectedPrimitiveTypeId) { + Assert.assertEquals(expectedBasicType, v.value[v.pos] & VariantUtil.BASIC_TYPE_MASK); + Assert.assertEquals(expectedPrimitiveTypeId, v.getPrimitiveTypeId()); + } + + private long microsSinceEpoch(Instant instant) { + return TimeUnit.SECONDS.toMicros(instant.getEpochSecond()) + instant.getNano() / 1000; + } + + private long nanosSinceEpoch(Instant instant) { + return TimeUnit.SECONDS.toNanos(instant.getEpochSecond()) + instant.getNano(); + } + + private String randomString(int len) { + StringBuilder sb = new StringBuilder(len); + for (int i = 0; i < len; i++) { + sb.append(RANDOM_CHARS.charAt(random.nextInt(RANDOM_CHARS.length()))); + } + return sb.toString(); + } + + @Test + public void testNullJson() { + checkJson("null"); + } + + @Test + public void testBooleanJson() { + Arrays.asList("true", "false").forEach(this::checkJson); + } + + @Test + public void testIntegerJson() { + Arrays.asList( + "0", + Byte.toString(Byte.MIN_VALUE), + Byte.toString(Byte.MAX_VALUE), + Short.toString(Short.MIN_VALUE), + Short.toString(Short.MAX_VALUE), + Integer.toString(Integer.MIN_VALUE), + Integer.toString(Integer.MAX_VALUE), + Long.toString(Long.MIN_VALUE), + Long.toString(Long.MAX_VALUE)) + .forEach(this::checkJson); + } + + @Test + public void testFloatJson() { + Arrays.asList( + Float.toString(Float.MIN_VALUE), Float.toString(Float.MAX_VALUE), + Double.toString(Double.MIN_VALUE), Double.toString(Double.MAX_VALUE)) + .forEach(this::checkJson); + } + + @Test + public void testStringJson() { + Arrays.asList("\"short string\"", "\"long string: " + new String(new char[1000]).replace("\0", "x") + "\"") + .forEach(this::checkJson); + } + + @Test + public void testDecimalJson() { + Arrays.asList( + "12.34", "-43.21", + "10.2147483647", "-1021474836.47", + "109223372036854775.807", "-109.223372036854775807") + .forEach(this::checkJson); + } + + @Test + public void testNullBuilder() { + VariantBuilder vb = new VariantBuilder(false); + vb.appendNull(); + checkType(vb.result(), VariantUtil.NULL, 0); + } + + @Test + public void testBooleanBuilder() { + Arrays.asList(true, false).forEach(b -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendBoolean(b); + checkType(vb2.result(), VariantUtil.PRIMITIVE, b ? VariantUtil.TRUE : VariantUtil.FALSE); + }); + } + + @Test + public void testIntegerBuilder() { + Arrays.asList( + 0L, + (long) Byte.MIN_VALUE, + (long) Byte.MAX_VALUE, + (long) Short.MIN_VALUE, + (long) Short.MAX_VALUE, + (long) Integer.MIN_VALUE, + (long) Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE) + .forEach(l -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendLong(l); + Variant v = vb2.result(); + if (Byte.MIN_VALUE <= l && l <= Byte.MAX_VALUE) { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT8); + } else if (Short.MIN_VALUE <= l && l <= Short.MAX_VALUE) { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT16); + } else if (Integer.MIN_VALUE <= l && l <= Integer.MAX_VALUE) { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT32); + } else { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT64); + } + Assert.assertEquals((long) l, v.getLong()); + }); + + Arrays.asList( + 0, + (int) Byte.MIN_VALUE, + (int) Byte.MAX_VALUE, + (int) Short.MIN_VALUE, + (int) Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE) + .forEach(i -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendLong((long) i); + Variant v = vb2.result(); + if (Byte.MIN_VALUE <= i && i <= Byte.MAX_VALUE) { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT8); + } else if (Short.MIN_VALUE <= i && i <= Short.MAX_VALUE) { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT16); + } else { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT32); + } + Assert.assertEquals((int) i, v.getInt()); + }); + + Arrays.asList((short) 0, (short) Byte.MIN_VALUE, (short) Byte.MAX_VALUE, Short.MIN_VALUE, Short.MAX_VALUE) + .forEach(s -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendLong(s); + Variant v = vb2.result(); + if (Byte.MIN_VALUE <= s && s <= Byte.MAX_VALUE) { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT8); + } else { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT16); + } + Assert.assertEquals((short) s, v.getShort()); + }); + + Arrays.asList((byte) 0, Byte.MIN_VALUE, Byte.MAX_VALUE).forEach(b -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendLong(b); + Variant v = vb2.result(); + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.INT8); + Assert.assertEquals((byte) b, v.getByte()); + }); + } + + @Test + public void testFloatBuilder() { + Arrays.asList(Float.MIN_VALUE, Float.MAX_VALUE).forEach(f -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendFloat(f); + Variant v = vb2.result(); + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.FLOAT); + Assert.assertEquals(f, v.getFloat(), 0.000001); + }); + } + + @Test + public void testDoubleBuilder() { + Arrays.asList(Double.MIN_VALUE, Double.MAX_VALUE).forEach(d -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendDouble(d); + Variant v = vb2.result(); + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.DOUBLE); + Assert.assertEquals(d, v.getDouble(), 0.000001); + }); + } + + @Test + public void testStringBuilder() { + IntStream.range(VariantUtil.MAX_SHORT_STR_SIZE - 3, VariantUtil.MAX_SHORT_STR_SIZE + 3) + .forEach(len -> { + VariantBuilder vb2 = new VariantBuilder(false); + String s = randomString(len); + vb2.appendString(s); + Variant v = vb2.result(); + if (len <= VariantUtil.MAX_SHORT_STR_SIZE) { + checkType(v, VariantUtil.SHORT_STR, len); + } else { + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.LONG_STR); + } + Assert.assertEquals(s, v.getString()); + }); + } + + @Test + public void testDecimalBuilder() { + // decimal4 + Arrays.asList(new BigDecimal("123.456"), new BigDecimal("-987.654")).forEach(d -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendDecimal(d); + Variant v = vb2.result(); + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.DECIMAL4); + Assert.assertEquals(d, v.getDecimal()); + }); + + // decimal8 + Arrays.asList(new BigDecimal("10.2147483647"), new BigDecimal("-1021474836.47")) + .forEach(d -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendDecimal(d); + Variant v = vb2.result(); + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.DECIMAL8); + Assert.assertEquals(d, v.getDecimal()); + }); + + // decimal16 + Arrays.asList(new BigDecimal("109223372036854775.807"), new BigDecimal("-109.223372036854775807")) + .forEach(d -> { + VariantBuilder vb2 = new VariantBuilder(false); + vb2.appendDecimal(d); + Variant v = vb2.result(); + checkType(v, VariantUtil.PRIMITIVE, VariantUtil.DECIMAL16); + Assert.assertEquals(d, v.getDecimal()); + }); + } + + @Test + public void testDate() { + VariantBuilder vb = new VariantBuilder(false); + int days = Math.toIntExact(LocalDate.of(2024, 12, 16).toEpochDay()); + vb.appendDate(days); + Assert.assertEquals("\"2024-12-16\"", vb.result().toJson()); + Assert.assertEquals(days, vb.result().getInt()); + } + + @Test + public void testTimestamp() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + VariantBuilder vb = new VariantBuilder(false); + long micros = microsSinceEpoch(Instant.from(dtf.parse("2024-12-16T10:23:45.321456-08:00"))); + vb.appendTimestamp(micros); + Assert.assertEquals("\"2024-12-16T18:23:45.321456+00:00\"", vb.result().toJson()); + Assert.assertEquals("\"2024-12-16T10:23:45.321456-08:00\"", vb.result().toJson(ZoneId.of("-08:00"))); + Assert.assertEquals("\"2024-12-16T19:23:45.321456+01:00\"", vb.result().toJson(ZoneId.of("+01:00"))); + Assert.assertEquals(micros, vb.result().getLong()); + } + + @Test + public void testTimestampNtz() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + VariantBuilder vb = new VariantBuilder(false); + long micros = microsSinceEpoch(Instant.from(dtf.parse("2024-01-01T23:00:00.000001Z"))); + vb.appendTimestampNtz(micros); + Assert.assertEquals("\"2024-01-01T23:00:00.000001\"", vb.result().toJson()); + Assert.assertEquals("\"2024-01-01T23:00:00.000001\"", vb.result().toJson(ZoneId.of("-08:00"))); + Assert.assertEquals(vb.result().toJson(ZoneId.of("-08:00")), vb.result().toJson(ZoneId.of("+02:00"))); + Assert.assertEquals(micros, vb.result().getLong()); + } + + @Test + public void testTime() { + for (String timeStr : Arrays.asList( + "00:00:00.000000", "00:00:00.000120", "12:00:00.000000", "12:00:00.002300", "23:59:59.999999")) { + VariantBuilder vb = new VariantBuilder(false); + long micros = LocalTime.parse(timeStr).toNanoOfDay() / 1_000; + vb.appendTime(micros); + Assert.assertEquals(String.format("\"%s\"", timeStr), vb.result().toJson()); + Assert.assertEquals(micros, vb.result().getLong()); + } + } + + @Test + public void testTimestampNanos() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + VariantBuilder vb = new VariantBuilder(false); + long nanos = nanosSinceEpoch(Instant.from(dtf.parse("2024-12-16T10:23:45.321456987-08:00"))); + vb.appendTimestampNanos(nanos); + Assert.assertEquals( + "\"2024-12-16T18:23:45.321456987+00:00\"", vb.result().toJson()); + Assert.assertEquals( + "\"2024-12-16T10:23:45.321456987-08:00\"", vb.result().toJson(ZoneId.of("-08:00"))); + Assert.assertEquals( + "\"2024-12-16T19:23:45.321456987+01:00\"", vb.result().toJson(ZoneId.of("+01:00"))); + Assert.assertEquals(nanos, vb.result().getLong()); + } + + @Test + public void testTimestampNanosNtz() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + VariantBuilder vb = new VariantBuilder(false); + long nanos = nanosSinceEpoch(Instant.from(dtf.parse("2024-01-01T23:00:00.839280983Z"))); + vb.appendTimestampNanosNtz(nanos); + Assert.assertEquals("\"2024-01-01T23:00:00.839280983\"", vb.result().toJson()); + Assert.assertEquals("\"2024-01-01T23:00:00.839280983\"", vb.result().toJson(ZoneId.of("-08:00"))); + Assert.assertEquals(vb.result().toJson(ZoneId.of("-08:00")), vb.result().toJson(ZoneId.of("+02:00"))); + Assert.assertEquals(nanos, vb.result().getLong()); + } + + @Test + public void testBinary() { + VariantBuilder vb = new VariantBuilder(false); + byte[] binary = new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + vb.appendBinary(binary); + Assert.assertEquals( + "\"" + Base64.getEncoder().encodeToString(binary) + "\"", + vb.result().toJson()); + Assert.assertArrayEquals(binary, vb.result().getBinary()); + } + + @Test + public void testUUID() { + VariantBuilder vb = new VariantBuilder(false); + byte[] uuid = new byte[] {0, 17, 34, 51, 68, 85, 102, 119, -120, -103, -86, -69, -52, -35, -18, -1}; + long msb = ByteBuffer.wrap(uuid, 0, 8).order(ByteOrder.BIG_ENDIAN).getLong(); + long lsb = ByteBuffer.wrap(uuid, 8, 8).order(ByteOrder.BIG_ENDIAN).getLong(); + UUID expected = new UUID(msb, lsb); + + vb.appendUUID(expected); + Assert.assertEquals( + "\"00112233-4455-6677-8899-aabbccddeeff\"", vb.result().toJson()); + Assert.assertEquals(expected, vb.result().getUUID()); + } + + @Test + public void testObject() { + // simple object + StringBuilder sb = new StringBuilder(); + sb.append("{"); + for (int i = 0; i < SAMPLE_JSON_VALUES.size(); i++) { + if (i > 0) sb.append(", "); + sb.append("\"field" + i + "\": ").append(SAMPLE_JSON_VALUES.get(i)); + } + sb.append("}"); + checkJson(sb.toString()); + + // wide object + sb = new StringBuilder(); + sb.append("{"); + for (int i = 0; i < 50000; i++) { + if (i > 0) sb.append(", "); + sb.append("\"field" + i + "\": ").append(SAMPLE_JSON_VALUES.get(i % SAMPLE_JSON_VALUES.size())); + } + sb.append("}"); + checkJson(sb.toString()); + + // deep object + sb = new StringBuilder(); + // Jackson object mapper hit a stack overflow if json is too deep + for (int i = 0; i < 500; i++) { + sb.append("{").append("\"field" + i + "\": "); + } + sb.append("{"); + for (int i = 0; i < SAMPLE_JSON_VALUES.size(); i++) { + if (i > 0) sb.append(", "); + sb.append("\"field" + i + "\": ").append(SAMPLE_JSON_VALUES.get(i)); + } + sb.append("}"); + for (int i = 0; i < 500; i++) { + sb.append("}"); + } + checkJson(sb.toString()); + } + + @Test + public void testGetObjectFields() throws IOException { + // Create small object for linear search + StringBuilder sb = new StringBuilder(); + sb.append("{"); + for (int i = 0; i < Variant.BINARY_SEARCH_THRESHOLD / 2; i++) { + if (i > 0) sb.append(", "); + sb.append("\"field" + i + "\": ").append(i); + } + sb.append("}"); + Variant v = VariantBuilder.parseJson(sb.toString()); + Assert.assertEquals(Variant.BINARY_SEARCH_THRESHOLD / 2, v.numObjectElements()); + for (int i = 0; i < Variant.BINARY_SEARCH_THRESHOLD / 2; i++) { + String actual = v.getFieldByKey("field" + i).toJson(); + Assert.assertEquals(String.valueOf(i), actual); + // check by index + Variant.ObjectField field = v.getFieldAtIndex(i); + Assert.assertTrue(field.key.startsWith("field")); + Assert.assertEquals(field.key.substring("field".length()), field.value.toJson()); + } + + // Create larger object for binary search + sb = new StringBuilder(); + sb.append("{"); + for (int i = 0; i < 2 * Variant.BINARY_SEARCH_THRESHOLD; i++) { + if (i > 0) sb.append(", "); + sb.append("\"field" + i + "\": ").append(i); + } + sb.append("}"); + v = VariantBuilder.parseJson(sb.toString()); + Assert.assertEquals(2 * Variant.BINARY_SEARCH_THRESHOLD, v.numObjectElements()); + for (int i = 0; i < 2 * Variant.BINARY_SEARCH_THRESHOLD; i++) { + String actual = v.getFieldByKey("field" + i).toJson(); + Assert.assertEquals(String.valueOf(i), actual); + // check by index + Variant.ObjectField field = v.getFieldAtIndex(i); + Assert.assertTrue(field.key.startsWith("field")); + Assert.assertEquals(field.key.substring("field".length()), field.value.toJson()); + } + } + + @Test + public void testArray() throws IOException { + // simple array + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < SAMPLE_JSON_VALUES.size(); i++) { + if (i > 0) sb.append(", "); + sb.append(SAMPLE_JSON_VALUES.get(i)); + } + sb.append("]"); + checkJson(sb.toString()); + // Check array elements + Variant v = VariantBuilder.parseJson(sb.toString()); + Assert.assertEquals(SAMPLE_JSON_VALUES.size(), v.numArrayElements()); + for (int i = 0; i < SAMPLE_JSON_VALUES.size(); i++) { + String actual = v.getElementAtIndex(i).toJson(); + checkJson(SAMPLE_JSON_VALUES.get(i), actual); + } + + // large array + sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < 50000; i++) { + if (i > 0) sb.append(", "); + sb.append(SAMPLE_JSON_VALUES.get(i % SAMPLE_JSON_VALUES.size())); + } + sb.append("]"); + checkJson(sb.toString()); + // Check array elements + v = VariantBuilder.parseJson(sb.toString()); + Assert.assertEquals(50000, v.numArrayElements()); + for (int i = 0; i < 50000; i++) { + String actual = v.getElementAtIndex(i).toJson(); + checkJson(SAMPLE_JSON_VALUES.get(i % SAMPLE_JSON_VALUES.size()), actual); + } + } + + @Test + public void testSizeLimit() { + // large metadata size + try { + VariantBuilder.parseJson( + "{\"12345678901234567890\": 1, \"123456789012345678901\": 2}", new VariantBuilder(false, 20)); + Assert.fail("Expected VariantSizeLimitException with large metadata"); + } catch (IOException e) { + Assert.fail("Expected VariantSizeLimitException with large metadata"); + } catch (VariantSizeLimitException e) { + // Expected + } + + // large data size + try { + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < 100; i++) { + if (i > 0) sb.append(", "); + sb.append("{\"a\":1}"); + } + sb.append("]"); + VariantBuilder.parseJson(sb.toString(), new VariantBuilder(false, 100)); + Assert.fail("Expected VariantSizeLimitException with large data"); + } catch (IOException e) { + Assert.fail("Expected VariantSizeLimitException with large data"); + } catch (VariantSizeLimitException e) { + // Expected + } + } + + @Test + public void testAllowDuplicateKeys() { + // disallow duplicate keys + try { + VariantBuilder.parseJson("{\"a\": 1, \"a\": 2}"); + Assert.fail("Expected VariantDuplicateKeyException with duplicate keys"); + } catch (IOException e) { + Assert.fail("Expected VariantDuplicateKeyException with duplicate keys"); + } catch (VariantDuplicateKeyException e) { + // Expected + } + + // allow duplicate keys + try { + Variant v = VariantBuilder.parseJson( + "{\"a\": 1, \"a\": 2}", new VariantBuilder(true, VariantUtil.DEFAULT_SIZE_LIMIT)); + Assert.assertEquals(1, v.numObjectElements()); + Assert.assertEquals(VariantUtil.Type.BYTE, v.getFieldByKey("a").getType()); + Assert.assertEquals(2, v.getFieldByKey("a").getLong()); + } catch (Exception e) { + Assert.fail("Unexpected exception: " + e); + } + } + + @Test + public void testTruncateTrailingZeroDecimal() { + for (String[] strings : Arrays.asList( + // decimal4 + // truncate all trailing zeros + new String[] {"1234.0000", "1234"}, + // truncate some trailing zeros + new String[] {"1234.5600", "1234.56"}, + // truncate no trailing zeros + new String[] {"1234.5678", "1234.5678"}, + // decimal8 + // truncate all trailing zeros + new String[] {"-10.0000000000", "-10"}, + // truncate some trailing zeros + new String[] {"-10.2147000000", "-10.2147"}, + // truncate no trailing zeros + new String[] {"-10.2147483647", "-10.2147483647"}, + // decimal16 + // truncate all trailing zeros + new String[] {"1092233720368547.00000", "1092233720368547"}, + // truncate some trailing zeros + new String[] {"1092233720368547.75800", "1092233720368547.758"}, + // truncate no trailing zeros + new String[] {"1092233720368547.75807", "1092233720368547.75807"})) { + VariantBuilder vb = new VariantBuilder(false); + BigDecimal d = new BigDecimal(strings[0]); + vb.appendDecimal(d); + Variant v = vb.result(); + Assert.assertEquals(strings[0], v.toJson()); + Assert.assertEquals(strings[1], v.toJson(ZoneId.of("UTC"), true)); + } + } + + @Test + public void testTruncateTrailingZeroTimestamp() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + for (String[] strings : Arrays.asList( + // truncate all trailing zeros + new String[] {"2024-12-16T10:23:45.000000-08:00", "2024-12-16T10:23:45-08:00"}, + // truncate all trailing zeros + new String[] {"2024-12-16T10:23:45.123000-08:00", "2024-12-16T10:23:45.123-08:00"}, + // truncate no trailing zeros + new String[] {"2024-12-16T10:23:45.123456-08:00", "2024-12-16T10:23:45.123456-08:00"})) { + VariantBuilder vb = new VariantBuilder(false); + long micros = microsSinceEpoch(Instant.from(dtf.parse(strings[0]))); + vb.appendTimestamp(micros); + Variant v = vb.result(); + Assert.assertEquals(String.format("\"%s\"", strings[0]), v.toJson(ZoneId.of("-08:00"))); + Assert.assertEquals(String.format("\"%s\"", strings[1]), v.toJson(ZoneId.of("-08:00"), true)); + } + } + + @Test + public void testTruncateTrailingZeroTimestampNtz() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + for (String[] strings : Arrays.asList( + // truncate all trailing zeros + new String[] {"2024-12-16T10:23:45.000000", "2024-12-16T10:23:45"}, + // truncate all trailing zeros + new String[] {"2024-12-16T10:23:45.123000", "2024-12-16T10:23:45.123"}, + // truncate no trailing zeros + new String[] {"2024-12-16T10:23:45.123456", "2024-12-16T10:23:45.123456"})) { + VariantBuilder vb = new VariantBuilder(false); + + long micros = microsSinceEpoch(Instant.from(dtf.parse(String.format("%sZ", strings[0])))); + vb.appendTimestampNtz(micros); + Variant v = vb.result(); + Assert.assertEquals(String.format("\"%s\"", strings[0]), v.toJson(ZoneId.of("-08:00"))); + Assert.assertEquals(String.format("\"%s\"", strings[1]), v.toJson(ZoneId.of("-08:00"), true)); + Assert.assertEquals(micros, vb.result().getLong()); + } + } + + @Test + public void testTruncateTrailingZeroTime() { + for (String[] strings : Arrays.asList( + // truncate all trailing zeros + new String[] {"10:23:45.000000", "10:23:45"}, + // truncate some trailing zeros + new String[] {"10:23:45.123000", "10:23:45.123"}, + // truncate no trailing zeros + new String[] {"10:23:45.123456", "10:23:45.123456"})) { + VariantBuilder vb = new VariantBuilder(false); + + long micros = LocalTime.parse(strings[0]).toNanoOfDay() / 1_000; + vb.appendTime(micros); + Variant v = vb.result(); + Assert.assertEquals(String.format("\"%s\"", strings[0]), v.toJson(ZoneId.of("-08:00"))); + Assert.assertEquals(String.format("\"%s\"", strings[1]), v.toJson(ZoneId.of("-08:00"), true)); + Assert.assertEquals(micros, vb.result().getLong()); + } + } + + @Test + public void testTruncateTrailingZeroTimestampNanos() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + for (String[] strings : Arrays.asList( + // truncate all trailing zeros + new String[] {"2024-12-16T10:23:45.000000000-08:00", "2024-12-16T10:23:45-08:00"}, + // truncate some trailing zeros + new String[] {"2024-12-16T10:23:45.123450000-08:00", "2024-12-16T10:23:45.12345-08:00"}, + // truncate no trailing zeros + new String[] {"2024-12-16T10:23:45.123456789-08:00", "2024-12-16T10:23:45.123456789-08:00"})) { + VariantBuilder vb = new VariantBuilder(false); + long nanos = nanosSinceEpoch(Instant.from(dtf.parse(strings[0]))); + vb.appendTimestampNanos(nanos); + Variant v = vb.result(); + Assert.assertEquals(String.format("\"%s\"", strings[0]), v.toJson(ZoneId.of("-08:00"))); + Assert.assertEquals(String.format("\"%s\"", strings[1]), v.toJson(ZoneId.of("-08:00"), true)); + } + } + + @Test + public void testTruncateTrailingZeroTimestampNanosNtz() { + DateTimeFormatter dtf = DateTimeFormatter.ISO_DATE_TIME; + for (String[] strings : Arrays.asList( + // truncate all trailing zeros + new String[] {"2024-12-16T10:23:45.000000000", "2024-12-16T10:23:45"}, + // truncate some trailing zeros + new String[] {"2024-12-16T10:23:45.123450000", "2024-12-16T10:23:45.12345"}, + // truncate no trailing zeros + new String[] {"2024-12-16T10:23:45.123456789", "2024-12-16T10:23:45.123456789"})) { + VariantBuilder vb = new VariantBuilder(false); + + long nanos = nanosSinceEpoch(Instant.from(dtf.parse(String.format("%sZ", strings[0])))); + vb.appendTimestampNanosNtz(nanos); + Variant v = vb.result(); + Assert.assertEquals(String.format("\"%s\"", strings[0]), v.toJson(ZoneId.of("-08:00"))); + Assert.assertEquals(String.format("\"%s\"", strings[1]), v.toJson(ZoneId.of("-08:00"), true)); + Assert.assertEquals(nanos, vb.result().getLong()); + } + } +} diff --git a/pom.xml b/pom.xml index c81f6f9af5..22436729fa 100644 --- a/pom.xml +++ b/pom.xml @@ -69,15 +69,6 @@ - - - jitpack.io - https://jitpack.io - Jitpack.io repository - - - - 1.8 1.8 @@ -94,7 +85,7 @@ shaded.parquet 3.3.0 - 2.10.0 + 2.11.0 1.15.1 thrift ${thrift.executable} @@ -163,6 +154,7 @@ parquet-protobuf parquet-thrift parquet-hadoop-bundle + parquet-variant