diff --git a/core/src/main/java/org/apache/iceberg/TableProperties.java b/core/src/main/java/org/apache/iceberg/TableProperties.java index 71991f633d97..2a936521ddaf 100644 --- a/core/src/main/java/org/apache/iceberg/TableProperties.java +++ b/core/src/main/java/org/apache/iceberg/TableProperties.java @@ -154,6 +154,12 @@ private TableProperties() {} "write.delete.parquet.compression-level"; public static final String PARQUET_COMPRESSION_LEVEL_DEFAULT = null; + public static final String PARQUET_VARIANT_SHRED = "write.parquet.variant.shred"; + public static final boolean PARQUET_VARIANT_SHRED_DEFAULT = false; + public static final String PARQUET_VARIANT_BUFFER_SIZE = + "write.parquet.variant.inference.buffer-size"; + public static final int PARQUET_VARIANT_BUFFER_SIZE_DEFAULT = 100; + public static final String PARQUET_ROW_GROUP_CHECK_MIN_RECORD_COUNT = "write.parquet.row-group-check-min-record-count"; public static final String DELETE_PARQUET_ROW_GROUP_CHECK_MIN_RECORD_COUNT = diff --git a/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java new file mode 100644 index 000000000000..bcd45f9d30de --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.io; + +import java.io.IOException; +import java.util.List; +import java.util.function.Function; +import java.util.function.UnaryOperator; +import org.apache.iceberg.Metrics; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; + +/** + * A FileAppender that buffers the first N rows, then creates a delegate appender via a factory. + * + *

The factory receives the buffered rows and is responsible for creating the real appender. Row + * replay is handled internally. All subsequent {@link #add} calls delegate directly to the real + * appender. + * + *

If fewer than {@code bufferSize} rows are written before close, the factory is called with + * whatever rows were buffered. If no rows were written, the factory is not called and no file is + * created on disk. In this case, {@link #metrics()} returns {@code new Metrics(0L)} and {@link + * #length()} returns {@code 0L}. + * + * @param the row type + */ +public class BufferedFileAppender implements FileAppender { + private final int bufferRowCount; + private final Function, FileAppender> appenderFactory; + private final UnaryOperator copyFunc; + private List buffer; + private FileAppender delegate; + private boolean closed = false; + + /** + * @param bufferRowCount number of rows to buffer before creating the delegate appender + * @param appenderFactory given the buffered rows, creates the delegate appender + */ + public BufferedFileAppender( + int bufferRowCount, Function, FileAppender> appenderFactory) { + this(bufferRowCount, appenderFactory, UnaryOperator.identity()); + } + + /** + * @param bufferRowCount number of rows to buffer before creating the delegate appender + * @param appenderFactory given the buffered rows, creates the delegate appender + * @param copyFunc copies a row before buffering (needed when row objects are reused, e.g. Spark + * InternalRow) + */ + public BufferedFileAppender( + int bufferRowCount, + Function, FileAppender> appenderFactory, + UnaryOperator copyFunc) { + Preconditions.checkArgument( + bufferRowCount > 0, "bufferRowCount must be > 0, got %s", bufferRowCount); + Preconditions.checkNotNull(appenderFactory, "appenderFactory must not be null"); + Preconditions.checkNotNull(copyFunc, "copyFunc must not be null"); + this.bufferRowCount = bufferRowCount; + this.appenderFactory = appenderFactory; + this.copyFunc = copyFunc; + this.buffer = Lists.newArrayListWithCapacity(bufferRowCount); + } + + @Override + public void add(D datum) { + Preconditions.checkState(!closed, "Cannot add to a closed appender"); + if (delegate != null) { + delegate.add(datum); + } else { + buffer.add(copyFunc.apply(datum)); + if (buffer.size() >= bufferRowCount) { + initialize(); + } + } + } + + @Override + public Metrics metrics() { + Preconditions.checkState(closed, "Cannot return metrics for unclosed appender"); + if (delegate == null) { + return new Metrics(0L); + } + + return delegate.metrics(); + } + + @Override + public long length() { + if (delegate != null) { + return delegate.length(); + } + + // No bytes written to disk yet; data is buffered in memory + return 0L; + } + + @Override + public List splitOffsets() { + if (delegate != null) { + return delegate.splitOffsets(); + } + + return null; + } + + @Override + public void close() throws IOException { + if (!closed) { + if (delegate == null && buffer != null && !buffer.isEmpty()) { + initialize(); + } + + if (delegate != null) { + delegate.close(); + } + + this.closed = true; + this.buffer = null; + } + } + + private void initialize() { + delegate = appenderFactory.apply(buffer); + Preconditions.checkState(delegate != null, "appenderFactory must not return null"); + try { + buffer.forEach(delegate::add); + } finally { + buffer = null; + } + } +} diff --git a/data/src/test/java/org/apache/iceberg/data/DataTestHelpers.java b/core/src/test/java/org/apache/iceberg/data/DataTestHelpers.java similarity index 100% rename from data/src/test/java/org/apache/iceberg/data/DataTestHelpers.java rename to core/src/test/java/org/apache/iceberg/data/DataTestHelpers.java diff --git a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java new file mode 100644 index 000000000000..9bbc0f9f8c71 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.io; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.util.List; +import java.util.function.Function; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.data.DataTestHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.avro.DataWriter; +import org.apache.iceberg.data.avro.PlannedDataReader; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.inmemory.InMemoryOutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TestBufferedFileAppender { + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + private InMemoryOutputFile outputFile; + private GenericRecord record; + + @BeforeEach + public void before() { + this.outputFile = new InMemoryOutputFile(); + this.record = GenericRecord.create(SCHEMA); + } + + private Function, FileAppender> avroFactory(OutputFile out) { + return bufferedRows -> { + try { + return Avro.write(out) + .createWriterFunc(DataWriter::create) + .schema(SCHEMA) + .overwrite() + .build(); + } catch (IOException e) { + throw new RuntimeIOException(e); + } + }; + } + + private BufferedFileAppender createAppender(int bufferSize) { + return new BufferedFileAppender<>(bufferSize, avroFactory(outputFile), Record::copy); + } + + private Record createRecord(long id, String data) { + return record.copy(ImmutableMap.of("id", id, "data", data)); + } + + private List readBack() throws IOException { + try (AvroIterable reader = + Avro.read(outputFile.toInputFile()) + .project(SCHEMA) + .createResolvingReader(PlannedDataReader::create) + .build()) { + return Lists.newArrayList(reader); + } + } + + @Test + public void testBufferFlushesOnThreshold() throws IOException { + BufferedFileAppender appender = createAppender(3); + + appender.add(createRecord(1L, "a")); + appender.add(createRecord(2L, "b")); + + // delegate not yet created, length should be 0 + assertThat(appender.length()).isEqualTo(0L); + + appender.add(createRecord(3L, "c")); + + // delegate created after 3rd row, length should be > 0 + assertThat(appender.length()).isGreaterThan(0L); + + appender.add(createRecord(4L, "d")); + appender.add(createRecord(5L, "e")); + appender.close(); + + List expected = + Lists.newArrayList( + createRecord(1L, "a"), + createRecord(2L, "b"), + createRecord(3L, "c"), + createRecord(4L, "d"), + createRecord(5L, "e")); + DataTestHelpers.assertEquals(SCHEMA.asStruct(), expected, readBack()); + } + + @Test + public void testCloseWithPartialBuffer() throws IOException { + BufferedFileAppender appender = createAppender(10); + + appender.add(createRecord(1L, "a")); + appender.add(createRecord(2L, "b")); + appender.add(createRecord(3L, "c")); + + // buffer not full yet + assertThat(appender.length()).isEqualTo(0L); + + // close flushes partial buffer through factory + appender.close(); + + List expected = + Lists.newArrayList(createRecord(1L, "a"), createRecord(2L, "b"), createRecord(3L, "c")); + DataTestHelpers.assertEquals(SCHEMA.asStruct(), expected, readBack()); + } + + @Test + public void testCopyFuncIsApplied() throws IOException { + BufferedFileAppender appender = createAppender(3); + + // use a single mutable record, relying on copyFunc to snapshot it + record.set(0, 1L); + record.set(1, "first"); + appender.add(record); + + record.set(0, 2L); + record.set(1, "second"); + appender.add(record); + + record.set(0, 3L); + record.set(1, "third"); + appender.add(record); + + appender.close(); + + List expected = + Lists.newArrayList( + createRecord(1L, "first"), createRecord(2L, "second"), createRecord(3L, "third")); + DataTestHelpers.assertEquals(SCHEMA.asStruct(), expected, readBack()); + } + + @Test + public void testMetricsAfterClose() throws IOException { + BufferedFileAppender appender = createAppender(2); + + appender.add(createRecord(1L, "a")); + appender.add(createRecord(2L, "b")); + appender.add(createRecord(3L, "c")); + appender.close(); + + assertThat(appender.metrics()).isNotNull(); + assertThat(appender.metrics().recordCount()).isEqualTo(3L); + assertThat(appender.length()).isGreaterThan(0L); + } + + @Test + public void testMetricsBeforeCloseThrows() throws IOException { + try (BufferedFileAppender appender = createAppender(10)) { + assertThatThrownBy(appender::metrics) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Cannot return metrics for unclosed appender"); + } + } + + @Test + public void testAddAfterCloseThrows() throws IOException { + try (BufferedFileAppender appender = createAppender(10)) { + appender.add(createRecord(1L, "a")); + appender.close(); + + assertThatThrownBy(() -> appender.add(createRecord(2L, "b"))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Cannot add to a closed appender"); + } + } + + @Test + public void testAddAllSpanningBuffer() throws IOException { + BufferedFileAppender appender = createAppender(2); + + List records = + Lists.newArrayList( + createRecord(1L, "a"), + createRecord(2L, "b"), + createRecord(3L, "c"), + createRecord(4L, "d")); + + appender.addAll(records); + appender.close(); + + DataTestHelpers.assertEquals(SCHEMA.asStruct(), records, readBack()); + } + + @Test + public void testCloseWithNoData() throws IOException { + BufferedFileAppender appender = createAppender(10); + // close immediately with no data written + appender.close(); + // delegate was never created + assertThat(appender.length()).isEqualTo(0L); + assertThat(appender.metrics()).isNotNull(); + assertThat(appender.metrics().recordCount()).isEqualTo(0L); + assertThat(appender.splitOffsets()).isNull(); + } +} diff --git a/docs/docs/configuration.md b/docs/docs/configuration.md index f12bcea6afd5..94423af55496 100644 --- a/docs/docs/configuration.md +++ b/docs/docs/configuration.md @@ -49,6 +49,8 @@ Iceberg tables support table properties to configure table behavior, like the de | write.parquet.dict-size-bytes | 2097152 (2 MB) | Parquet dictionary page size | | write.parquet.compression-codec | zstd | Parquet compression codec: zstd, brotli, lz4, gzip, snappy, uncompressed | | write.parquet.compression-level | null | Parquet compression level | +| write.parquet.variant.shred | false | When true, variant columns are written with shredded Parquet encoding for improved query performance | +| write.parquet.variant.inference.buffer-size | 100 | Number of rows to buffer for schema inference when variant shredding is enabled | | write.parquet.bloom-filter-enabled.column.col1 | (not set) | Hint to parquet to write a bloom filter for the column: 'col1' | | write.parquet.bloom-filter-max-bytes | 1048576 (1 MB) | The maximum number of bytes for a bloom filter bitset | | write.parquet.bloom-filter-fpp.column.col1 | 0.01 | The false positive probability for a bloom filter applied to 'col1' (must > 0.0 and < 1.0) | diff --git a/docs/docs/spark-configuration.md b/docs/docs/spark-configuration.md index e8e4f7e3c8c1..613fe5b66554 100644 --- a/docs/docs/spark-configuration.md +++ b/docs/docs/spark-configuration.md @@ -191,6 +191,8 @@ val spark = SparkSession.builder() | spark.sql.iceberg.distribution-mode | See [Spark Writes](spark-writes.md#writing-distribution-modes) | Controls distribution strategy during writes | | spark.wap.id | null | [Write-Audit-Publish](branching.md#audit-branch) snapshot staging ID | | spark.wap.branch | null | WAP branch name for snapshot commit | +| spark.sql.iceberg.shred-variants | Table default | When true, variant columns are written with shredded Parquet encoding for improved query performance | +| spark.sql.iceberg.variant.inference.buffer-size | Table default | Number of rows to buffer for schema inference when variant shredding is enabled | | spark.sql.iceberg.compression-codec | Table default | Write compression codec (e.g., `zstd`, `snappy`) | | spark.sql.iceberg.compression-level | Table default | Compression level for Parquet/Avro | | spark.sql.iceberg.compression-strategy | Table default | Compression strategy for ORC | @@ -262,6 +264,7 @@ df.writeTo("catalog.db.table") | compression-strategy | Table write.orc.compression-strategy | Overrides this table's compression strategy for ORC tables for this write | | distribution-mode | See [Spark Writes](spark-writes.md#writing-distribution-modes) for defaults | Override this table's distribution mode for this write | | delete-granularity | file | Override this table's delete granularity for this write | +| shred-variants | false | Overrides this table's write.parquet.variant.shred for this write | CommitMetadata provides an interface to add custom metadata to a snapshot summary during a SQL execution, which can be beneficial for purposes such as auditing or change tracking. If properties start with `snapshot-property.`, then that prefix will be removed from each property. Here is an example: diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java index fbd7a6e97fe2..0845c0727897 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java @@ -19,13 +19,16 @@ package org.apache.iceberg.parquet; import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.util.Map; import java.util.function.Function; +import java.util.function.UnaryOperator; import org.apache.iceberg.FileContent; import org.apache.iceberg.FileFormat; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.Schema; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.deletes.PositionDelete; import org.apache.iceberg.encryption.EncryptedOutputFile; @@ -33,6 +36,7 @@ import org.apache.iceberg.formats.BaseFormatModel; import org.apache.iceberg.formats.ModelWriteBuilder; import org.apache.iceberg.formats.ReadBuilder; +import org.apache.iceberg.io.BufferedFileAppender; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.DeleteSchemaUtil; import org.apache.iceberg.io.FileAppender; @@ -42,14 +46,21 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; public class ParquetFormatModel extends BaseFormatModel, R, MessageType> { public static final String WRITER_VERSION_KEY = "parquet.writer.version"; + public static final String SHRED_VARIANTS_KEY = TableProperties.PARQUET_VARIANT_SHRED; + public static final String VARIANT_BUFFER_SIZE_KEY = TableProperties.PARQUET_VARIANT_BUFFER_SIZE; + public static final int DEFAULT_BUFFER_SIZE = TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT; private final boolean isBatchReader; + private final VariantShreddingAnalyzer variantAnalyzer; + private final UnaryOperator copyFunc; public static ParquetFormatModel, Void, Object> forPositionDeletes() { - return new ParquetFormatModel<>(PositionDelete.deleteClass(), Void.class, null, null, false); + return new ParquetFormatModel<>( + PositionDelete.deleteClass(), Void.class, null, null, false, null, null); } public static ParquetFormatModel> create( @@ -57,14 +68,26 @@ public static ParquetFormatModel> create( Class schemaType, WriterFunction, S, MessageType> writerFunction, ReaderFunction, S, MessageType> readerFunction) { - return new ParquetFormatModel<>(type, schemaType, writerFunction, readerFunction, false); + return new ParquetFormatModel<>( + type, schemaType, writerFunction, readerFunction, false, null, null); + } + + public static ParquetFormatModel> create( + Class type, + Class schemaType, + WriterFunction, S, MessageType> writerFunction, + ReaderFunction, S, MessageType> readerFunction, + VariantShreddingAnalyzer variantAnalyzer, + UnaryOperator copyFunc) { + return new ParquetFormatModel<>( + type, schemaType, writerFunction, readerFunction, false, variantAnalyzer, copyFunc); } public static ParquetFormatModel> create( Class type, Class schemaType, ReaderFunction, S, MessageType> batchReaderFunction) { - return new ParquetFormatModel<>(type, schemaType, null, batchReaderFunction, true); + return new ParquetFormatModel<>(type, schemaType, null, batchReaderFunction, true, null, null); } private ParquetFormatModel( @@ -72,9 +95,13 @@ private ParquetFormatModel( Class schemaType, WriterFunction, S, MessageType> writerFunction, ReaderFunction readerFunction, - boolean isBatchReader) { + boolean isBatchReader, + VariantShreddingAnalyzer variantAnalyzer, + UnaryOperator copyFunc) { super(type, schemaType, writerFunction, readerFunction); this.isBatchReader = isBatchReader; + this.variantAnalyzer = variantAnalyzer; + this.copyFunc = copyFunc; } @Override @@ -84,7 +111,7 @@ public FileFormat format() { @Override public ModelWriteBuilder writeBuilder(EncryptedOutputFile outputFile) { - return new WriteBuilderWrapper<>(outputFile, writerFunction()); + return new WriteBuilderWrapper<>(outputFile, writerFunction(), variantAnalyzer, copyFunc); } @Override @@ -95,15 +122,23 @@ public ReadBuilder readBuilder(InputFile inputFile) { private static class WriteBuilderWrapper implements ModelWriteBuilder { private final Parquet.WriteBuilder internal; private final WriterFunction, S, MessageType> writerFunction; + private final VariantShreddingAnalyzer variantAnalyzer; + private final UnaryOperator copyFunc; private Schema schema; private S engineSchema; private FileContent content; + private boolean shreddingEnabled = false; + private int bufferSize = DEFAULT_BUFFER_SIZE; private WriteBuilderWrapper( EncryptedOutputFile outputFile, - WriterFunction, S, MessageType> writerFunction) { + WriterFunction, S, MessageType> writerFunction, + VariantShreddingAnalyzer variantAnalyzer, + UnaryOperator copyFunc) { this.internal = Parquet.write(outputFile); this.writerFunction = writerFunction; + this.variantAnalyzer = variantAnalyzer; + this.copyFunc = copyFunc; } @Override @@ -125,13 +160,21 @@ public ModelWriteBuilder set(String property, String value) { internal.writerVersion(ParquetProperties.WriterVersion.valueOf(value)); } + if (SHRED_VARIANTS_KEY.equals(property)) { + shreddingEnabled = Boolean.parseBoolean(value); + } + + if (VARIANT_BUFFER_SIZE_KEY.equals(property)) { + bufferSize = Integer.parseInt(value); + } + internal.set(property, value); return this; } @Override public ModelWriteBuilder setAll(Map properties) { - internal.setAll(properties); + properties.forEach(this::set); return this; } @@ -179,12 +222,14 @@ public ModelWriteBuilder withAADPrefix(ByteBuffer aadPrefix) { @Override public FileAppender build() throws IOException { + boolean shredVariants = false; switch (content) { case DATA: internal.createContextFunc(Parquet.WriteBuilder.Context::dataContext); internal.createWriterFunc( (icebergSchema, messageType) -> writerFunction.write(icebergSchema, messageType, engineSchema)); + shredVariants = shreddingEnabled && variantAnalyzer != null && hasVariantColumns(schema); break; case EQUALITY_DELETES: internal.createContextFunc(Parquet.WriteBuilder.Context::deleteContext); @@ -215,8 +260,45 @@ public FileAppender build() throws IOException { throw new IllegalArgumentException("Unknown file content: " + content); } + if (shredVariants) { + return buildShreddedAppender(); + } + return internal.build(); } + + /** + * Creates a {@link BufferedFileAppender} that buffers the first N rows, runs variant shredding + * analysis on them, then creates the real Parquet appender with a shredded schema. + * + *

Only top-level variant columns are shredded. Nested variants (inside structs/lists/maps) + * fall through to unshredded 2-field layout because column index resolution only applies to + * top-level fields. + */ + private FileAppender buildShreddedAppender() { + return new BufferedFileAppender<>( + bufferSize, + bufferedRows -> { + Map shreddedTypes = + variantAnalyzer.analyzeVariantColumns(bufferedRows, schema, engineSchema); + + if (!shreddedTypes.isEmpty()) { + internal.variantShreddingFunc((fieldId, name) -> shreddedTypes.get(fieldId)); + } + + try { + return internal.build(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to create shredded variant writer", e); + } + }, + copyFunc); + } + + private static boolean hasVariantColumns(Schema schema) { + return schema != null + && schema.columns().stream().anyMatch(field -> field.type().isVariantType()); + } } private static class ReadBuilderWrapper implements ReadBuilder { diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java index 9e94b1bbd6cd..08016667bdab 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java @@ -275,8 +275,14 @@ private ShreddedVariantWriter( @Override public void write(int repetitionLevel, VariantValue value) { if (typedWriter.types().contains(value.type())) { - typedWriter.write(repetitionLevel, value); - writeNull(valueWriter, repetitionLevel, valueDefinitionLevel); + try { + typedWriter.write(repetitionLevel, value); + writeNull(valueWriter, repetitionLevel, valueDefinitionLevel); + } catch (IllegalArgumentException e) { + // Fall back to value field if typed write fails (e.g., decimal scale mismatch) + valueWriter.write(repetitionLevel, value); + writeNull(typedWriter, repetitionLevel, typedDefinitionLevel); + } } else { valueWriter.write(repetitionLevel, value); writeNull(typedWriter, repetitionLevel, typedDefinitionLevel); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java new file mode 100644 index 000000000000..024635939c5d --- /dev/null +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java @@ -0,0 +1,528 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.variants.PhysicalType; +import org.apache.iceberg.variants.VariantArray; +import org.apache.iceberg.variants.VariantObject; +import org.apache.iceberg.variants.VariantPrimitive; +import org.apache.iceberg.variants.VariantValue; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; + +/** + * Analyzes variant data across buffered rows to determine an optimal shredding schema. + * + *

Determinism contract: for a given set of variant values (regardless of row arrival order), + * this analyzer produces the same shredded schema. + * + *

+ * + *

This contract holds within a single batch. Different batches with different distributions may + * produce different layouts; cross-batch stability requires schema pinning (not yet implemented). + * + *

Subclasses implement {@link #extractVariantValues} to convert engine-specific row types into + * {@link VariantValue} instances. + * + * @param the engine-specific row type (e.g., Spark InternalRow, Flink RowData) + * @param the engine-specific schema type (e.g., Spark StructType, Flink RowType) + */ +public abstract class VariantShreddingAnalyzer { + private static final String TYPED_VALUE = "typed_value"; + private static final String VALUE = "value"; + private static final String ELEMENT = "element"; + private static final double MIN_FIELD_FREQUENCY = 0.10; + private static final int MAX_SHREDDED_FIELDS = 300; + private static final int MAX_SHREDDING_DEPTH = 50; + private static final int MAX_INTERMEDIATE_FIELDS = 1000; + + protected VariantShreddingAnalyzer() {} + + /** + * Analyzes buffered variant values to determine the optimal shredding schema. + * + * @param bufferedRows the buffered rows to analyze + * @param variantFieldIndex the index of the variant field in the rows + * @return the shredded schema type, or null if no shredding should be performed + */ + public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) { + List variantValues = extractVariantValues(bufferedRows, variantFieldIndex); + if (variantValues.isEmpty()) { + return null; + } + + PathNode root = buildPathTree(variantValues); + PhysicalType rootType = root.info.getMostCommonType(); + if (rootType == null) { + return null; + } + + pruneInfrequentFields(root, root.info.observationCount); + + return buildTypedValue(root, rootType); + } + + protected abstract List extractVariantValues( + List bufferedRows, int variantFieldIndex); + + /** + * Resolves a column name to its index in the engine-specific schema. Returns -1 if the column is + * not found. + */ + protected abstract int resolveColumnIndex(S engineSchema, String columnName); + + /** + * Analyzes all variant columns in the schema, resolving column indices via the engine-specific + * {@link #resolveColumnIndex} method. + * + * @param bufferedRows the buffered rows to analyze + * @param icebergSchema the Iceberg table schema + * @param engineSchema the engine-specific schema used to resolve column indices + * @return a map from Iceberg field ID to the shredded Parquet type for each variant column + */ + public Map analyzeVariantColumns( + List bufferedRows, Schema icebergSchema, S engineSchema) { + Map shreddedTypes = Maps.newHashMap(); + for (NestedField col : icebergSchema.columns()) { + if (col.type().isVariantType()) { + int rowIndex = resolveColumnIndex(engineSchema, col.name()); + if (rowIndex >= 0) { + Type typed = analyzeAndCreateSchema(bufferedRows, rowIndex); + if (typed != null) { + shreddedTypes.put(col.fieldId(), typed); + } + } + } + } + + return shreddedTypes; + } + + private static PathNode buildPathTree(List variantValues) { + PathNode root = new PathNode(null); + root.info = new FieldInfo(); + + for (VariantValue value : variantValues) { + traverse(root, value, 0); + } + + return root; + } + + private static void pruneInfrequentFields(PathNode node, int totalRows) { + if (node.objectChildren.isEmpty() && node.arrayElement == null) { + return; + } + + // Remove fields below frequency threshold + node.objectChildren + .entrySet() + .removeIf( + entry -> { + FieldInfo info = entry.getValue().info; + return info != null + && ((double) info.observationCount / totalRows) < MIN_FIELD_FREQUENCY; + }); + + // Cap at MAX_SHREDDED_FIELDS, keep the most frequently observed + if (node.objectChildren.size() > MAX_SHREDDED_FIELDS) { + List> sorted = Lists.newArrayList(node.objectChildren.entrySet()); + sorted.sort( + (a, b) -> { + int cmp = + Integer.compare( + b.getValue().info.observationCount, a.getValue().info.observationCount); + return cmp != 0 ? cmp : a.getKey().compareTo(b.getKey()); + }); + Set keep = Sets.newHashSet(); + for (int i = 0; i < MAX_SHREDDED_FIELDS; i++) { + keep.add(sorted.get(i).getKey()); + } + node.objectChildren.entrySet().removeIf(entry -> !keep.contains(entry.getKey())); + } + + // Recurse into remaining object children + for (PathNode child : node.objectChildren.values()) { + pruneInfrequentFields(child, totalRows); + } + + // Recurse into array elements (arrays of objects need pruning too) + if (node.arrayElement != null) { + pruneInfrequentFields(node.arrayElement, totalRows); + } + } + + private static void traverse(PathNode node, VariantValue value, int depth) { + if (value == null || value.type() == PhysicalType.NULL) { + return; + } + + node.info.observe(value); + + if (value.type() == PhysicalType.OBJECT && depth < MAX_SHREDDING_DEPTH) { + traverseObject(node, value.asObject(), depth); + } else if (value.type() == PhysicalType.ARRAY && depth < MAX_SHREDDING_DEPTH) { + traverseArray(node, value.asArray(), depth); + } + } + + private static void traverseObject(PathNode node, VariantObject obj, int depth) { + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + PathNode childNode = node.objectChildren.get(fieldName); + if (childNode == null) { + if (node.objectChildren.size() >= MAX_INTERMEDIATE_FIELDS) { + continue; + } + childNode = new PathNode(fieldName); + childNode.info = new FieldInfo(); + node.objectChildren.put(fieldName, childNode); + } + traverse(childNode, fieldValue, depth + 1); + } + } + } + + private static void traverseArray(PathNode node, VariantArray array, int depth) { + int numElements = array.numElements(); + if (node.arrayElement == null) { + node.arrayElement = new PathNode(null); + node.arrayElement.info = new FieldInfo(); + } + for (int i = 0; i < numElements; i++) { + VariantValue element = array.get(i); + if (element != null) { + traverse(node.arrayElement, element, depth + 1); + } + } + } + + private static Type buildFieldGroup(PathNode node) { + PhysicalType commonType = node.info.getMostCommonType(); + if (commonType == null) { + return null; + } + + Type typedValue = buildTypedValue(node, commonType); + if (typedValue == null) { + return null; + } + + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(typedValue) + .named(node.fieldName); + } + + private static Type buildTypedValue(PathNode node, PhysicalType physicalType) { + return switch (physicalType) { + case ARRAY -> createArrayTypedValue(node); + case OBJECT -> createObjectTypedValue(node); + default -> createPrimitiveTypedValue(node.info, physicalType); + }; + } + + private static Type createObjectTypedValue(PathNode node) { + if (node.objectChildren.isEmpty()) { + return null; + } + + Types.GroupBuilder builder = Types.buildGroup(Type.Repetition.OPTIONAL); + boolean hasFields = false; + for (PathNode child : node.objectChildren.values()) { + Type fieldType = buildFieldGroup(child); + if (fieldType != null) { + builder.addField(fieldType); + hasFields = true; + } + } + + return hasFields ? builder.named(TYPED_VALUE) : null; + } + + private static Type createArrayTypedValue(PathNode node) { + PathNode elementNode = node.arrayElement; + if (elementNode == null) { + return null; + } + PhysicalType elementType = elementNode.info.getMostCommonType(); + if (elementType == null) { + return null; + } + Type elementTypedValue = buildTypedValue(elementNode, elementType); + if (elementTypedValue == null) { + return null; + } + + GroupType elementGroup = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(elementTypedValue) + .named(ELEMENT); + + return Types.optionalList().element(elementGroup).named(TYPED_VALUE); + } + + private static class PathNode { + private final String fieldName; + private final Map objectChildren = Maps.newTreeMap(); + private PathNode arrayElement = null; + private FieldInfo info = null; + + private PathNode(String fieldName) { + this.fieldName = fieldName; + } + } + + /** Use DECIMAL with maximum precision and scale as the shredding type */ + private static Type createDecimalTypedValue(FieldInfo info) { + int maxPrecision = Math.min(info.maxDecimalIntegerDigits + info.maxDecimalScale, 38); + int maxScale = Math.min(info.maxDecimalScale, Math.max(0, 38 - info.maxDecimalIntegerDigits)); + + if (maxPrecision <= 9) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else if (maxPrecision <= 18) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else { + return Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } + } + + private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primitiveType) { + return switch (primitiveType) { + case BOOLEAN_TRUE, BOOLEAN_FALSE -> + Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); + case INT8 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(8, true)) + .named(TYPED_VALUE); + case INT16 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(16, true)) + .named(TYPED_VALUE); + case INT32 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(32, true)) + .named(TYPED_VALUE); + case INT64 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.intType(64, true)) + .named(TYPED_VALUE); + case FLOAT -> Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); + case DOUBLE -> Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); + case STRING -> + Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named(TYPED_VALUE); + case BINARY -> Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); + case TIME -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timeType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case DATE -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.dateType()) + .named(TYPED_VALUE); + case TIMESTAMPTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case DECIMAL4, DECIMAL8, DECIMAL16 -> createDecimalTypedValue(info); + case UUID -> + Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.uuidType()) + .named(TYPED_VALUE); + default -> + throw new UnsupportedOperationException( + "Unknown primitive physical type: " + primitiveType); + }; + } + + /** Tracks occurrence count and types for a single field. */ + private static class FieldInfo { + private final Map typeCounts = Maps.newHashMap(); + private int maxDecimalScale = 0; + private int maxDecimalIntegerDigits = 0; + private int observationCount = 0; + + private static final Map INTEGER_PRIORITY = + ImmutableMap.of( + PhysicalType.INT8, 0, + PhysicalType.INT16, 1, + PhysicalType.INT32, 2, + PhysicalType.INT64, 3); + + private static final Map DECIMAL_PRIORITY = + ImmutableMap.of( + PhysicalType.DECIMAL4, 0, + PhysicalType.DECIMAL8, 1, + PhysicalType.DECIMAL16, 2); + + private static final Map TIE_BREAK_PRIORITY = + ImmutableMap.builder() + .put(PhysicalType.BOOLEAN_TRUE, 0) + .put(PhysicalType.INT8, 1) + .put(PhysicalType.INT16, 2) + .put(PhysicalType.INT32, 3) + .put(PhysicalType.INT64, 4) + .put(PhysicalType.FLOAT, 5) + .put(PhysicalType.DOUBLE, 6) + .put(PhysicalType.DECIMAL4, 7) + .put(PhysicalType.DECIMAL8, 8) + .put(PhysicalType.DECIMAL16, 9) + .put(PhysicalType.DATE, 10) + .put(PhysicalType.TIME, 11) + .put(PhysicalType.TIMESTAMPTZ, 12) + .put(PhysicalType.TIMESTAMPNTZ, 13) + .put(PhysicalType.BINARY, 14) + .put(PhysicalType.STRING, 15) + .put(PhysicalType.TIMESTAMPTZ_NANOS, 16) + .put(PhysicalType.TIMESTAMPNTZ_NANOS, 17) + .put(PhysicalType.UUID, 18) + .buildOrThrow(); + + void observe(VariantValue value) { + observationCount++; + // Use BOOLEAN_TRUE for both TRUE/FALSE values + PhysicalType type = + value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); + + typeCounts.compute(type, (k, v) -> (v == null) ? 1 : v + 1); + + // Track max precision and scale for decimal types + if (type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16) { + VariantPrimitive primitive = value.asPrimitive(); + Object decimalValue = primitive.get(); + if (decimalValue instanceof BigDecimal bd) { + maxDecimalIntegerDigits = Math.max(maxDecimalIntegerDigits, bd.precision() - bd.scale()); + maxDecimalScale = Math.max(maxDecimalScale, bd.scale()); + } + } + } + + PhysicalType getMostCommonType() { + Map combinedCounts = Maps.newHashMap(); + + int integerTotalCount = 0; + PhysicalType mostCapableInteger = null; + + int decimalTotalCount = 0; + PhysicalType mostCapableDecimal = null; + + for (Map.Entry entry : typeCounts.entrySet()) { + PhysicalType type = entry.getKey(); + int count = entry.getValue(); + + if (isIntegerType(type)) { + integerTotalCount += count; + if (mostCapableInteger == null + || INTEGER_PRIORITY.get(type) > INTEGER_PRIORITY.get(mostCapableInteger)) { + mostCapableInteger = type; + } + } else if (isDecimalType(type)) { + decimalTotalCount += count; + if (mostCapableDecimal == null + || DECIMAL_PRIORITY.get(type) > DECIMAL_PRIORITY.get(mostCapableDecimal)) { + mostCapableDecimal = type; + } + } else { + combinedCounts.put(type, count); + } + } + + if (mostCapableInteger != null) { + combinedCounts.put(mostCapableInteger, integerTotalCount); + } + + if (mostCapableDecimal != null) { + combinedCounts.put(mostCapableDecimal, decimalTotalCount); + } + + // Pick the most common type with tie-breaking + return combinedCounts.entrySet().stream() + .max( + Map.Entry.comparingByValue() + .thenComparingInt(entry -> TIE_BREAK_PRIORITY.getOrDefault(entry.getKey(), -1))) + .map(Map.Entry::getKey) + .orElse(null); + } + + private static boolean isIntegerType(PhysicalType type) { + return type == PhysicalType.INT8 + || type == PhysicalType.INT16 + || type == PhysicalType.INT32 + || type == PhysicalType.INT64; + } + + private static boolean isDecimalType(PhysicalType type) { + return type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16; + } + } +} diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java index 3918fdc63084..4893bece5a7d 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java @@ -42,8 +42,11 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.data.parquet.GenericParquetReaders; import org.apache.iceberg.data.parquet.GenericParquetWriter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.io.BufferedFileAppender; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; @@ -52,6 +55,7 @@ import org.apache.iceberg.variants.Variant; import org.apache.iceberg.variants.VariantMetadata; import org.apache.iceberg.variants.VariantTestUtil; +import org.apache.iceberg.variants.VariantValue; import org.apache.iceberg.variants.Variants; import org.apache.parquet.hadoop.ParquetFileReader; import org.apache.parquet.schema.GroupType; @@ -331,4 +335,186 @@ public void testDataWriterWithVariantShredding() throws IOException { testDataWriter( variantSchema, (id, name) -> ParquetVariantUtil.toParquetSchema(variant.value())); } + + @Test + public void testShreddingWriteReturnsBufferedAppender() throws IOException { + Schema variantSchema = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "v", Types.VariantType.get())); + + VariantShreddingAnalyzer testAnalyzer = + new VariantShreddingAnalyzer() { + @Override + protected List extractVariantValues(List rows, int idx) { + return java.util.Collections.emptyList(); + } + + @Override + protected int resolveColumnIndex(Void engineSchema, String columnName) { + return -1; + } + }; + + OutputFile outputFile = Files.localOutput(createTempFile(temp)); + + ParquetFormatModel> model = + ParquetFormatModel.create( + Record.class, + Void.class, + (icebergSchema, messageType, engineSchema) -> + GenericParquetWriter.create(icebergSchema, messageType), + (icebergSchema, fileSchema, engineSchema, idToConstant) -> + GenericParquetReaders.buildReader(icebergSchema, fileSchema), + testAnalyzer, + record -> record); + + try (FileAppender appender = + model + .writeBuilder(EncryptedFiles.plainAsEncryptedOutput(outputFile)) + .schema(variantSchema) + .setAll(ImmutableMap.of(ParquetFormatModel.SHRED_VARIANTS_KEY, "true")) + .content(FileContent.DATA) + .build()) { + assertThat(appender).isInstanceOf(BufferedFileAppender.class); + } + } + + @Test + public void testWriteBuilderReturnsDirectAppenderWithNullAnalyzer() throws IOException { + Schema variantSchema = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "v", Types.VariantType.get())); + + OutputFile outputFile = Files.localOutput(createTempFile(temp)); + + ParquetFormatModel> model = + ParquetFormatModel.create( + Record.class, + Void.class, + (icebergSchema, messageType, engineSchema) -> + GenericParquetWriter.create(icebergSchema, messageType), + (icebergSchema, fileSchema, engineSchema, idToConstant) -> + GenericParquetReaders.buildReader(icebergSchema, fileSchema), + null, + null); + + try (FileAppender appender = + model + .writeBuilder(EncryptedFiles.plainAsEncryptedOutput(outputFile)) + .schema(variantSchema) + .setAll(ImmutableMap.of(ParquetFormatModel.SHRED_VARIANTS_KEY, "true")) + .content(FileContent.DATA) + .build()) { + // Even with shredding property set, null variantAnalyzer means no BufferedFileAppender + assertThat(appender).isNotInstanceOf(BufferedFileAppender.class); + } + } + + @Test + public void testFormatModelVariantShreddingRoundTrip() throws IOException { + Schema variantSchema = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "v", Types.VariantType.get())); + + VariantShreddingAnalyzer analyzer = + new VariantShreddingAnalyzer() { + @Override + protected List extractVariantValues(List rows, int idx) { + List values = Lists.newArrayList(); + for (Record row : rows) { + Object obj = row.get(idx); + if (obj instanceof Variant) { + values.add(((Variant) obj).value()); + } + } + return values; + } + + @Override + protected int resolveColumnIndex(Void engineSchema, String columnName) { + // GenericRecord uses schema column order + return variantSchema.columns().indexOf(variantSchema.findField(columnName)); + } + }; + + ByteBuffer metadataBuffer = VariantTestUtil.createMetadata(ImmutableList.of("a", "b"), true); + VariantMetadata metadata = Variants.metadata(metadataBuffer); + ByteBuffer objectBuffer = + VariantTestUtil.createObject( + metadataBuffer, + ImmutableMap.of( + "a", Variants.of(42), + "b", Variants.of("hello"))); + Variant variant = Variant.of(metadata, Variants.value(metadata, objectBuffer)); + + GenericRecord record = GenericRecord.create(variantSchema); + List variantRecords = + ImmutableList.of( + record.copy(ImmutableMap.of("id", 1L, "v", variant)), + record.copy(ImmutableMap.of("id", 2L, "v", variant)), + record.copy(ImmutableMap.of("id", 3L, "v", variant))); + + OutputFile outputFile = Files.localOutput(createTempFile(temp)); + + ParquetFormatModel> model = + ParquetFormatModel.create( + Record.class, + Void.class, + (icebergSchema, messageType, engineSchema) -> + GenericParquetWriter.create(icebergSchema, messageType), + (icebergSchema, fileSchema, engineSchema, idToConstant) -> + GenericParquetReaders.buildReader(icebergSchema, fileSchema), + analyzer, + record1 -> record1); + + try (FileAppender appender = + model + .writeBuilder(EncryptedFiles.plainAsEncryptedOutput(outputFile)) + .schema(variantSchema) + .setAll( + ImmutableMap.of( + ParquetFormatModel.SHRED_VARIANTS_KEY, "true", + ParquetFormatModel.VARIANT_BUFFER_SIZE_KEY, "2")) + .content(FileContent.DATA) + .build()) { + assertThat(appender).isInstanceOf(BufferedFileAppender.class); + for (Record rec : variantRecords) { + appender.add(rec); + } + } + + // Verify shredded Parquet schema + try (ParquetFileReader reader = + ParquetFileReader.open(ParquetIO.file(outputFile.toInputFile()))) { + MessageType parquetSchema = reader.getFooter().getFileMetaData().getSchema(); + GroupType variantGroup = parquetSchema.getType("v").asGroupType(); + assertThat(variantGroup.containsField("metadata")).isTrue(); + assertThat(variantGroup.containsField("value")).isTrue(); + assertThat(variantGroup.containsField("typed_value")).isTrue(); + + GroupType typedValue = variantGroup.getType("typed_value").asGroupType(); + assertThat(typedValue.containsField("a")).isTrue(); + assertThat(typedValue.containsField("b")).isTrue(); + } + + // Verify data round-trips + List writtenRecords; + try (CloseableIterable reader = + Parquet.read(outputFile.toInputFile()) + .project(variantSchema) + .createReaderFunc( + fileSchema -> GenericParquetReaders.buildReader(variantSchema, fileSchema)) + .build()) { + writtenRecords = Lists.newArrayList(reader); + } + + assertThat(writtenRecords).hasSameSizeAs(variantRecords); + for (int i = 0; i < variantRecords.size(); i++) { + InternalTestHelpers.assertEquals( + variantSchema.asStruct(), variantRecords.get(i), writtenRecords.get(i)); + } + } } diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java new file mode 100644 index 000000000000..797b011c7e52 --- /dev/null +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Locale; +import java.util.function.Function; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.variants.ShreddedObject; +import org.apache.iceberg.variants.ValueArray; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; +import org.apache.iceberg.variants.Variants; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.junit.jupiter.api.Test; + +public class TestVariantShreddingAnalyzer { + + private static class DirectAnalyzer extends VariantShreddingAnalyzer { + @Override + protected List extractVariantValues(List rows, int idx) { + return rows; + } + + @Override + protected int resolveColumnIndex(Void engineSchema, String columnName) { + throw new UnsupportedOperationException("Not used in direct tests"); + } + } + + @Test + public void testDepthLimitStopsObjectRecursion() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // Each level has {"a": , "x": 1} so objects always have a shreddable primitive + VariantMetadata meta = Variants.metadata("a", "x"); + ShreddedObject innermost = Variants.object(meta); + innermost.put("a", Variants.of(42)); + innermost.put("x", Variants.of(1)); + + for (int i = 0; i < 54; i++) { + ShreddedObject wrapper = Variants.object(meta); + wrapper.put("a", innermost); + wrapper.put("x", Variants.of(1)); + innermost = wrapper; + } + + Type schema = analyzer.analyzeAndCreateSchema(List.of(innermost), 0); + assertThat(schema).isNotNull(); + assertThat(schema.getName()).isEqualTo("typed_value"); + + int shreddedDepth = countObjectDepth(schema); + assertThat(shreddedDepth).isLessThanOrEqualTo(50).isGreaterThan(0); + } + + @Test + public void testDepthLimitStopsArrayRecursion() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // 55-level nested arrays with a primitive only at the very bottom. + // Depth limit (50) prevents reaching the leaf, so schema is null (graceful degradation). + VariantValue innermost = Variants.of(42); + for (int i = 0; i < 55; i++) { + ValueArray wrapper = Variants.array(); + wrapper.add(innermost); + innermost = wrapper; + } + + Type schema = analyzer.analyzeAndCreateSchema(List.of(innermost), 0); + assertThat(schema).isNull(); + } + + @Test + public void testArrayWithinDepthLimit() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // 5-level nested arrays + VariantValue innermost = Variants.of(42); + for (int i = 0; i < 5; i++) { + ValueArray wrapper = Variants.array(); + wrapper.add(innermost); + innermost = wrapper; + } + + Type schema = analyzer.analyzeAndCreateSchema(List.of(innermost), 0); + assertThat(schema).isNotNull(); + assertThat(schema.getName()).isEqualTo("typed_value"); + + int arrayDepth = countArrayDepth(schema); + assertThat(arrayDepth).isEqualTo(5); + } + + @Test + public void testIntermediateFieldCapLimitsTrackedFields() { + int numFields = 1500; + String[] fieldNames = new String[numFields]; + for (int i = 0; i < numFields; i++) { + fieldNames[i] = String.format(Locale.ROOT, "field_%04d", i); + } + + VariantMetadata meta = Variants.metadata(fieldNames); + ShreddedObject obj = Variants.object(meta); + for (String name : fieldNames) { + obj.put(name, Variants.of(42)); + } + + DirectAnalyzer analyzer = new DirectAnalyzer(); + Type schema = analyzer.analyzeAndCreateSchema(List.of(obj), 0); + + assertThat(schema).isNotNull(); + assertThat(schema).isInstanceOf(GroupType.class); + GroupType typedValue = (GroupType) schema; + assertThat(typedValue.getFieldCount()).isLessThanOrEqualTo(300).isGreaterThan(0); + } + + @Test + public void testFieldCapAllowsExistingFieldUpdates() { + int numFields = 1500; + String[] fieldNames = new String[numFields]; + for (int i = 0; i < numFields; i++) { + fieldNames[i] = String.format(Locale.ROOT, "field_%04d", i); + } + + VariantMetadata meta = Variants.metadata(fieldNames); + + ShreddedObject row1 = Variants.object(meta); + for (String name : fieldNames) { + row1.put(name, Variants.of(42)); + } + + ShreddedObject row2 = Variants.object(meta); + for (int i = 0; i < 10; i++) { + row2.put(fieldNames[i], Variants.of("text")); + } + + ShreddedObject row3 = Variants.object(meta); + for (int i = 0; i < 10; i++) { + row3.put(fieldNames[i], Variants.of(99)); + } + + DirectAnalyzer analyzer = new DirectAnalyzer(); + Type schema = analyzer.analyzeAndCreateSchema(List.of(row1, row2, row3), 0); + + assertThat(schema).isNotNull(); + assertThat(schema).isInstanceOf(GroupType.class); + GroupType typedValue = (GroupType) schema; + assertThat(typedValue.getFieldCount()).isGreaterThan(0).isLessThanOrEqualTo(300); + } + + @Test + public void testNestedObjectsWithinDepthLimit() { + VariantMetadata cityMeta = Variants.metadata("city"); + ShreddedObject city = Variants.object(cityMeta); + city.put("city", Variants.of("NYC")); + + VariantMetadata addrMeta = Variants.metadata("address"); + ShreddedObject addr = Variants.object(addrMeta); + addr.put("address", city); + + VariantMetadata rootMeta = Variants.metadata("user"); + ShreddedObject root = Variants.object(rootMeta); + root.put("user", addr); + + DirectAnalyzer analyzer = new DirectAnalyzer(); + Type schema = analyzer.analyzeAndCreateSchema(List.of(root), 0); + + assertThat(schema).isNotNull(); + GroupType rootTv = schema.asGroupType(); + assertThat(rootTv.getName()).isEqualTo("typed_value"); + + // user -> typed_value -> address -> typed_value -> city -> typed_value (STRING) + GroupType userGroup = rootTv.getType("user").asGroupType(); + assertThat(userGroup.containsField("value")).isTrue(); + assertThat(userGroup.containsField("typed_value")).isTrue(); + + GroupType addrTv = userGroup.getType("typed_value").asGroupType(); + GroupType addrGroup = addrTv.getType("address").asGroupType(); + assertThat(addrGroup.containsField("typed_value")).isTrue(); + + GroupType cityTv = addrGroup.getType("typed_value").asGroupType(); + GroupType cityGroup = cityTv.getType("city").asGroupType(); + assertThat(cityGroup.containsField("typed_value")).isTrue(); + + PrimitiveType cityPrimitive = cityGroup.getType("typed_value").asPrimitiveType(); + assertThat(cityPrimitive.getPrimitiveTypeName()) + .isEqualTo(PrimitiveType.PrimitiveTypeName.BINARY); + assertThat(cityPrimitive.getLogicalTypeAnnotation()) + .isEqualTo(LogicalTypeAnnotation.stringType()); + } + + @Test + public void testDecimalForExceedingPrecision() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + // Value 1: 30 integer digits, 0 fractional -> precision=30, scale=0, intDigits=30 + // Value 2: 1 integer digit, 20 fractional -> precision=21, scale=20, intDigits=1 + // Combined: maxIntDigits=30, maxScale=20, raw sum=50 -> capped to precision=38, + // scale=min(20, 38-30)=8 (integer digits get priority) + VariantMetadata meta = Variants.metadata("val"); + ShreddedObject row1 = Variants.object(meta); + row1.put("val", Variants.of(new java.math.BigDecimal("123456789012345678901234567890"))); + + ShreddedObject row2 = Variants.object(meta); + row2.put("val", Variants.of(new java.math.BigDecimal("1.23456789012345678901"))); + + Type schema = analyzer.analyzeAndCreateSchema(List.of(row1, row2), 0); + assertThat(schema).isNotNull(); + + GroupType typedValue = schema.asGroupType(); + GroupType valGroup = typedValue.getType("val").asGroupType(); + PrimitiveType valPrimitive = valGroup.getType("typed_value").asPrimitiveType(); + + LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimal = + (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) + valPrimitive.getLogicalTypeAnnotation(); + assertThat(decimal).isNotNull(); + assertThat(decimal.getPrecision()).isEqualTo(38); + // With 30 integer digits, scale is capped to 38 - 30 = 8 (integer digits get priority) + assertThat(decimal.getScale()).isEqualTo(8); + assertThat(decimal.getScale()).isLessThanOrEqualTo(decimal.getPrecision()); + + // Physical type should be FIXED_LEN_BYTE_ARRAY since precision > 18 + assertThat(valPrimitive.getPrimitiveTypeName()) + .isEqualTo(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY); + } + + @Test + public void testDecimalForExactPrecision() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // Value with exactly precision=38: 20 integer digits + 18 scale = 38 + VariantMetadata meta = Variants.metadata("val"); + ShreddedObject row = Variants.object(meta); + row.put( + "val", Variants.of(new java.math.BigDecimal("12345678901234567890.123456789012345678"))); + + Type schema = analyzer.analyzeAndCreateSchema(List.of(row), 0); + assertThat(schema).isNotNull(); + + GroupType typedValue = schema.asGroupType(); + GroupType valGroup = typedValue.getType("val").asGroupType(); + PrimitiveType valPrimitive = valGroup.getType("typed_value").asPrimitiveType(); + + LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimal = + (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) + valPrimitive.getLogicalTypeAnnotation(); + assertThat(decimal.getPrecision()).isEqualTo(38); + assertThat(decimal.getScale()).isEqualTo(18); + } + + @Test + public void testInfrequentFieldsArePruned() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // 100 rows: "common" in all, "rare" in only 5 (below MIN_FIELD_FREQUENCY = 0.10) + List rows = buildPruningTestRows(5, obj -> obj); + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + + GroupType group = schema.asGroupType(); + assertThat(group.containsField("common")).isTrue(); + assertThat(group.containsField("rare")).isFalse(); + } + + @Test + public void testEmptyArrayReturnsNull() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // All rows are empty arrays, no element type to infer + List rows = List.of(Variants.array(), Variants.array(), Variants.array()); + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNull(); + } + + @Test + public void testRootPrimitiveProducesTypedValue() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // root type is primitive + List rows = List.of(Variants.of("hello"), Variants.of("world"), Variants.of("x")); + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + assertThat(schema.getName()).isEqualTo("typed_value"); + assertThat(schema.isPrimitive()).isTrue(); + assertThat(schema.asPrimitiveType().getLogicalTypeAnnotation()) + .isEqualTo(LogicalTypeAnnotation.stringType()); + } + + @Test + public void testRootArrayOfObjectsPrunesInfrequentFields() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // 100 arrays: "common" in all, "rare" in only 3 (below MIN_FIELD_FREQUENCY = 0.10) + List rows = + buildPruningTestRows( + 3, + obj -> { + ValueArray arr = Variants.array(); + arr.add(obj); + return arr; + }); + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + + GroupType listType = schema.asGroupType(); + assertThat(listType.getLogicalTypeAnnotation()) + .isInstanceOf(LogicalTypeAnnotation.ListLogicalTypeAnnotation.class); + GroupType repeatedGroup = listType.getType(0).asGroupType(); + GroupType elementGroup = repeatedGroup.getType(0).asGroupType(); + assertThat(elementGroup.containsField("typed_value")).isTrue(); + GroupType objectFields = elementGroup.getType("typed_value").asGroupType(); + assertThat(objectFields.containsField("common")).isTrue(); + assertThat(objectFields.containsField("rare")).isFalse(); + } + + @Test + public void testObjectWithArrayChildPrunesNestedFields() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + VariantMetadata itemMeta = Variants.metadata("name", "rare"); + VariantMetadata rootMeta = Variants.metadata("items"); + + // 100 rows, "rare" appears in only 3 rows (below MIN_FIELD_FREQUENCY = 0.10) + List rows = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + ShreddedObject item = Variants.object(itemMeta); + item.put("name", Variants.of("item_" + i)); + if (i < 3) { + item.put("rare", Variants.of(1)); + } + ValueArray arr = Variants.array(); + arr.add(item); + ShreddedObject root = Variants.object(rootMeta); + root.put("items", arr); + rows.add(root); + } + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + + GroupType rootTv = schema.asGroupType(); + GroupType itemsGroup = rootTv.getType("items").asGroupType(); + assertThat(itemsGroup.containsField("typed_value")).isTrue(); + GroupType listType = itemsGroup.getType("typed_value").asGroupType(); + GroupType repeatedGroup = listType.getType(0).asGroupType(); + GroupType elementGroup = repeatedGroup.getType(0).asGroupType(); + assertThat(elementGroup.containsField("typed_value")).isTrue(); + GroupType elementFields = elementGroup.getType("typed_value").asGroupType(); + assertThat(elementFields.containsField("name")).isTrue(); + assertThat(elementFields.containsField("rare")).isFalse(); + } + + /** + * Builds 100 variant rows where "common" appears in every row and "rare" appears in only {@code + * rareCount} rows (below MIN_FIELD_FREQUENCY = 0.10 when rareCount < 10). + */ + private static List buildPruningTestRows( + int rareCount, Function wrap) { + VariantMetadata meta = Variants.metadata("common", "rare"); + List rows = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + ShreddedObject obj = Variants.object(meta); + obj.put("common", Variants.of(i)); + if (i < rareCount) { + obj.put("rare", Variants.of("text")); + } + rows.add(wrap.apply(obj)); + } + return rows; + } + + /** Count typed_value group nesting depth along field "a". */ + private static int countObjectDepth(Type type) { + int depth = 0; + Type current = type; + while (current != null && "typed_value".equals(current.getName()) && !current.isPrimitive()) { + depth++; + GroupType group = current.asGroupType(); + if (group.containsField("a")) { + GroupType fieldGroup = group.getType("a").asGroupType(); + if (fieldGroup.containsField("typed_value")) { + current = fieldGroup.getType("typed_value"); + } else { + break; + } + } else { + break; + } + } + return depth; + } + + /** Count nested array (LIST) levels in the schema. */ + private static int countArrayDepth(Type type) { + int depth = 0; + Type current = type; + while (current != null && !current.isPrimitive()) { + if (!"typed_value".equals(current.getName())) { + break; + } + GroupType group = current.asGroupType(); + if (!(group.getLogicalTypeAnnotation() + instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation)) { + break; + } + depth++; + GroupType listGroup = group.getType(0).asGroupType(); + GroupType elementGroup = listGroup.getType(0).asGroupType(); + if (elementGroup.containsField("typed_value")) { + current = elementGroup.getType("typed_value"); + } else { + break; + } + } + return depth; + } +} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index 161f09d53e2c..d5e36c86edad 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -114,4 +114,12 @@ private SparkSQLProperties() {} public static final String ASYNC_MICRO_BATCH_PLANNING_ENABLED = "spark.sql.iceberg.async-micro-batch-planning-enabled"; public static final boolean ASYNC_MICRO_BATCH_PLANNING_ENABLED_DEFAULT = false; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; + + // Controls the buffer size for variant schema inference during writes + // This determines how many rows are buffered before inferring shredded schema + public static final String VARIANT_INFERENCE_BUFFER_SIZE = + "spark.sql.iceberg.variant.inference.buffer-size"; } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 2296c076f0c4..afc40bdcccc0 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -33,6 +33,8 @@ import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_BUFFER_SIZE; +import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_SHRED; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import java.util.Locale; @@ -504,6 +506,14 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } + boolean shouldShredVariants = shredVariants(); + writeProperties.put(PARQUET_VARIANT_SHRED, String.valueOf(shouldShredVariants)); + + // Add variant shredding configuration properties + if (shouldShredVariants) { + writeProperties.put( + PARQUET_VARIANT_BUFFER_SIZE, String.valueOf(variantInferenceBufferSize())); + } break; case AVRO: @@ -724,4 +734,23 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } + + public boolean shredVariants() { + return confParser + .booleanConf() + .option(SparkWriteOptions.SHRED_VARIANTS) + .sessionConf(SparkSQLProperties.SHRED_VARIANTS) + .tableProperty(TableProperties.PARQUET_VARIANT_SHRED) + .defaultValue(TableProperties.PARQUET_VARIANT_SHRED_DEFAULT) + .parse(); + } + + public int variantInferenceBufferSize() { + return confParser + .intConf() + .sessionConf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE) + .tableProperty(TableProperties.PARQUET_VARIANT_BUFFER_SIZE) + .defaultValue(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT) + .parse(); + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 2b88d2bb1e44..c754bb2a6fc6 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -86,4 +86,7 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "shred-variants"; } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java index 23fbe54a4be3..5b7862116aea 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java @@ -51,7 +51,9 @@ public static void register() { StructType.class, SparkParquetWriters::buildWriter, (icebergSchema, fileSchema, engineSchema, idToConstant) -> - SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant))); + SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant), + new SparkVariantShreddingAnalyzer(), + InternalRow::copy)); FormatModelRegistry.register( ParquetFormatModel.create( diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java new file mode 100644 index 000000000000..2c08c662c9da --- /dev/null +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import org.apache.iceberg.parquet.VariantShreddingAnalyzer; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.VariantVal; + +/** + * Spark-specific implementation that extracts variant values from {@link InternalRow} instances. + */ +class SparkVariantShreddingAnalyzer extends VariantShreddingAnalyzer { + + SparkVariantShreddingAnalyzer() {} + + @Override + protected int resolveColumnIndex(StructType sparkSchema, String columnName) { + try { + return sparkSchema.fieldIndex(columnName); + } catch (IllegalArgumentException e) { + return -1; + } + } + + @Override + protected List extractVariantValues( + List bufferedRows, int variantFieldIndex) { + List values = Lists.newArrayList(); + + for (InternalRow row : bufferedRows) { + if (!row.isNullAt(variantFieldIndex)) { + VariantVal variantVal = row.getVariant(variantFieldIndex); + if (variantVal != null) { + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + values.add(variantValue); + } + } + } + + return values; + } +} diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index 383a21087d7f..3b73b7555d5b 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -34,6 +34,7 @@ import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_SHRED; import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; @@ -345,6 +346,8 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( + PARQUET_VARIANT_SHRED, + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -467,6 +470,8 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( + PARQUET_VARIANT_SHRED, + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -538,6 +543,8 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( + PARQUET_VARIANT_SHRED, + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -698,4 +705,81 @@ private void checkMode(DistributionMode expectedMode, SparkWriteConf writeConf) assertThat(writeConf.copyOnWriteDistributionMode(MERGE)).isEqualTo(expectedMode); assertThat(writeConf.positionDeltaDistributionMode(MERGE)).isEqualTo(expectedMode); } + + @TestTemplate + public void testShredVariantsDefault() { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.shredVariants()).isFalse(); + } + + @TestTemplate + public void testVariantInferenceBufferSizeDefault() { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.variantInferenceBufferSize()) + .isEqualTo(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT); + } + + @TestTemplate + public void testVariantInferenceBufferSizeTableProperty() { + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "500").commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(500); + } + + @TestTemplate + public void testShredVariantsSessionOverridesTableProperty() { + Table table = validationCatalog.loadTable(tableIdent); + table.updateProperties().set(TableProperties.PARQUET_VARIANT_SHRED, "false").commit(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "true"), + () -> { + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.shredVariants()).isTrue(); + }); + } + + @TestTemplate + public void testShredVariantsWriteOptionOverridesSessionConf() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "false"), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = + new SparkWriteConf( + spark, + table, + new CaseInsensitiveStringMap( + ImmutableMap.of(SparkWriteOptions.SHRED_VARIANTS, "true"))); + assertThat(writeConf.shredVariants()).isTrue(); + }); + } + + @TestTemplate + public void testVariantInferenceBufferSizeSessionConf() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "250"), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(250); + }); + } + + @TestTemplate + public void testWritePropertiesIncludeVariantShredding() { + Table table = validationCatalog.loadTable(tableIdent); + table.updateProperties().set(TableProperties.PARQUET_VARIANT_SHRED, "true").commit(); + table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "200").commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + Map writeProperties = writeConf.writeProperties(); + assertThat(writeProperties).containsEntry(PARQUET_VARIANT_SHRED, "true"); + assertThat(writeProperties).containsEntry(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "200"); + } } diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java new file mode 100644 index 000000000000..5b2b6103c683 --- /dev/null +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -0,0 +1,1069 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.variant; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; +import static org.apache.parquet.schema.Types.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.variants.Variant; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestVariantShredding extends CatalogTestBase { + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get())); + + private static final Schema SCHEMA2 = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get()), + Types.NestedField.optional(3, "metadata", Types.VariantType.get())); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + }; + } + + @BeforeAll + public static void startMetastoreAndSpark() { + // First call parent to initialize metastore and spark with local[2] + CatalogTestBase.startMetastoreAndSpark(); + + // Now stop and recreate spark with local[1] to write all rows to a single file + if (spark != null) { + spark.stop(); + } + + spark = + SparkSession.builder() + .master("local[1]") // Use one thread to write the rows to a single parquet file + .config("spark.driver.host", InetAddress.getLoopbackAddress().getHostAddress()) + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .config(DISABLE_UI) + .enableHiveSupport() + .getOrCreate(); + + sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @BeforeEach + public void before() { + super.before(); + validationCatalog.createTable( + tableIdent, SCHEMA, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + } + + @AfterEach + public void after() { + spark.conf().unset(SparkSQLProperties.SHRED_VARIANTS); + spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); + validationCatalog.dropTable(tableIdent, true); + } + + @TestTemplate + public void testVariantShreddingDisabled() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); + + String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testExcludingNullValue() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30, "dummy": null}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testInconsistentType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"age": "25"}')),\ + (2, parse_json('{"age": 30}')),\ + (3, parse_json('{"age": "35"}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + List rows = + sql("SELECT variant_get(address, '$.age', 'int') FROM %s WHERE id = 2", tableName); + assertThat(rows).hasSize(1); + assertThat(rows.get(0)[0]).isEqualTo(30); + } + + @TestTemplate + public void testPrimitiveType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = "(1, parse_json('123')), (2, parse_json('456')), (3, parse_json('789'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testPrimitiveDecimalType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('123.56')), (2, parse_json('\"abc\"')), (3, parse_json('12.56'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testBooleanType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"active": true}')),\ + (2, parse_json('{"active": false}')),\ + (3, parse_json('{"active": true}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalTypeWithInconsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"price": 123.456789}')),\ + (2, parse_json('{"price": 678.90}')),\ + (3, parse_json('{"price": 999.99}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType price = + field( + "price", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(6, 9))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalTypeWithConsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"price": 123.45}')),\ + (2, parse_json('{"price": 678.90}')),\ + (3, parse_json('{"price": 999.99}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType price = + field( + "price", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testArrayType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('["java", "scala", "python"]')),\ + (2, parse_json('["rust", "go"]')),\ + (3, parse_json('["javascript"]'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType arr = + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType()))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, arr); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testNestedArrayType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"tags": ["java", "scala", "python"]}')),\ + (2, parse_json('{"tags": ["rust", "go"]}')),\ + (3, parse_json('{"tags": ["javascript"]}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(tags)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testNestedObjectType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"location": {"city": "Seattle", "zip": 98101}, "tags": ["java", "scala", "python"]}')),\ + (2, parse_json('{"location": {"city": "Portland", "zip": 97201}}')),\ + (3, parse_json('{"location": {"city": "NYC", "zip": 10001}}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType zip = + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); + GroupType location = field("location", objectFields(city, zip)); + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(location, tags)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testLazyInitializationWithBufferedRows() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "5"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}')),\ + (4, parse_json('{"name": "David", "age": 28}')),\ + (5, parse_json('{"name": "Eve", "age": 32}')),\ + (6, parse_json('{"name": "Frank", "age": 40}')),\ + (7, parse_json('{"name": "Grace", "age": 27}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(7); + } + + @TestTemplate + public void testMultipleRowGroups() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int numRows = 1000; + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= numRows; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + valuesBuilder.append( + String.format("(%d, parse_json('{\"name\": \"User%d\", \"age\": %d}'))", i, i, 20 + i)); + } + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 1024); + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(numRows); + } + + @TestTemplate + public void testColumnIndexTruncateLength() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int customTruncateLength = 10; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, "parquet.columnindex.truncate.length", customTruncateLength); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 10; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + String longValue = "A".repeat(20); + valuesBuilder.append( + String.format( + "(%d, parse_json('{\"description\": \"%s\", \"id\": %d}'))", i, longValue, i)); + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType description = + field( + "description", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType id = + field( + "id", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(description, id)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(10); + } + + @TestTemplate + public void testIntegerFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Mix of INT8, INT16, INT32, INT64 - should promote to INT64 + String values = + """ + (1, parse_json('{"value": 10}')),\ + (2, parse_json('{"value": 1000}')),\ + (3, parse_json('{"value": 100000}')),\ + (4, parse_json('{"value": 10000000000}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT64, LogicalTypeAnnotation.intType(64, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Test that they get promoted to the most capable decimal type observed + String values = + """ + (1, parse_json('{"value": 1.5}')),\ + (2, parse_json('{"value": 123.456789}')),\ + (3, parse_json('{"value": 123456789123456.789}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(6, 21)) + .named("typed_value")); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDataRoundTripWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify that we can read the data back correctly + List rows = + sql( + "SELECT id, variant_get(address, '$.name', 'string')," + + " variant_get(address, '$.age', 'int')" + + " FROM %s ORDER BY id", + tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[0]).isEqualTo(1); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(0)[2]).isEqualTo(30); + assertThat(rows.get(1)[0]).isEqualTo(2); + assertThat(rows.get(1)[1]).isEqualTo("Bob"); + assertThat(rows.get(1)[2]).isEqualTo(25); + assertThat(rows.get(2)[0]).isEqualTo(3); + assertThat(rows.get(2)[1]).isEqualTo("Charlie"); + assertThat(rows.get(2)[2]).isEqualTo(35); + } + + @TestTemplate + public void testMultipleVariantsWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Recreate table with SCHEMA2 (address + metadata variant columns) + validationCatalog.dropTable(tableIdent, true); + validationCatalog.createTable( + tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + + String values = + """ + (1, parse_json('{"city": "NYC"}'), parse_json('{"source": "web"}')),\ + (2, parse_json('{"city": "LA"}'), parse_json('{"source": "app"}')),\ + (3, parse_json('{"city": "SF"}'), parse_json('{"source": "api"}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(city)); + + GroupType source = + field( + "source", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType metadata = variant("metadata", 3, Type.Repetition.REQUIRED, objectFields(source)); + MessageType expectedSchema = parquetSchema(address, metadata); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantWithNullValues() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('null')),\ + (2, parse_json('null')),\ + (3, parse_json('null'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testArrayOfNullElementsWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + sql( + "INSERT INTO %s VALUES (1, parse_json('[null, null, null]')), " + + "(2, parse_json('[null]'))", + tableName); + + // Array elements are all null, element type is null, falls back to unshredded + GroupType address = variant("address", 2, Type.Repetition.REQUIRED); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMixedNullAndNonNullVariantValues() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, null),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(3); + } + + @TestTemplate + public void testWriteOptionOverridesSessionConfig() throws IOException, NoSuchTableException { + // Disable shredding at session level + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); + + // Enable shredding via per-write option + String query = + "SELECT 1 as id, parse_json('{\"name\": \"Alice\", \"age\": 30}') as address" + + " UNION ALL SELECT 2, parse_json('{\"name\": \"Bob\", \"age\": 25}')" + + " UNION ALL SELECT 3, parse_json('{\"name\": \"Charlie\", \"age\": 35}')"; + spark.sql(query).writeTo(tableName).option("shred-variants", "true").append(); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testInfrequentFieldPruning() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "11"); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 11; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + if (i == 1) { + // Only the first row has rare_field + valuesBuilder.append( + String.format( + "(%d, parse_json('{\"name\": \"User%d\", \"rare_field\": \"rare\"}'))", i, i)); + } else { + valuesBuilder.append(String.format("(%d, parse_json('{\"name\": \"User%d\"}'))", i, i)); + } + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + // rare_field appears in 1/11 rows, should be pruned + // name appears in 11/11 rows and should be kept + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMixedTypeTieBreaking() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "10"); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 10; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + if (i <= 5) { + valuesBuilder.append(String.format("(%d, parse_json('{\"val\": %d}'))", i, i)); + } else { + valuesBuilder.append(String.format("(%d, parse_json('{\"val\": \"text%d\"}'))", i, i)); + } + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + // 5 ints + 5 strings is a tie so STRING wins (higher TIE_BREAK_PRIORITY) + GroupType val = + field( + "val", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(val)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify data round-trips correctly + List rows = + sql("SELECT id, variant_get(address, '$.val', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(10); + assertThat(rows.get(0)[1]).isEqualTo("1"); + assertThat(rows.get(5)[1]).isEqualTo("text6"); + } + + @TestTemplate + public void testFieldOnlyAfterBuffer() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + String values = + """ + (1, parse_json('{"name": "Alice"}')),\ + (2, parse_json('{"name": "Bob"}')),\ + (3, parse_json('{"name": "Charlie"}')),\ + (4, parse_json('{"name": "David", "score": 95}')),\ + (5, parse_json('{"name": "Eve", "score": 88}')),\ + (6, parse_json('{"name": "Frank", "score": 72}')),\ + (7, parse_json('{"name": "Grace", "score": 91}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Schema is determined from buffer (rows 1-3) which only has "name". + // "score" is not shredded + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify all data round-trips despite "score" not being shredded + List rows = + sql( + "SELECT id, variant_get(address, '$.name', 'string')," + + " variant_get(address, '$.score', 'int')" + + " FROM %s ORDER BY id", + tableName); + assertThat(rows).hasSize(7); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(0)[2]).isNull(); + assertThat(rows.get(3)[1]).isEqualTo("David"); + assertThat(rows.get(3)[2]).isEqualTo(95); + assertThat(rows.get(6)[1]).isEqualTo("Grace"); + assertThat(rows.get(6)[2]).isEqualTo(91); + } + + @TestTemplate + public void testCrossFileDifferentShreddedType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + // File 1: "score" is always integer → shredded as INT8 + String batch1 = + """ + (1, parse_json('{"score": 95}')),\ + (2, parse_json('{"score": 88}')),\ + (3, parse_json('{"score": 72}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, batch1); + + // Verify file 1 schema: score shredded as INT8 + Table table = validationCatalog.loadTable(tableIdent); + GroupType scoreInt = + field( + "score", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + MessageType expectedSchema1 = + parquetSchema(variant("address", 2, Type.Repetition.REQUIRED, objectFields(scoreInt))); + verifyParquetSchema(table, expectedSchema1); + + // File 2: "score" is always string → shredded as STRING + String batch2 = + """ + (4, parse_json('{"score": "high"}')),\ + (5, parse_json('{"score": "medium"}')),\ + (6, parse_json('{"score": "low"}'))\ + """; + sql("INSERT INTO %s VALUES %s", tableName, batch2); + + // Query across both files, reader must handle different shredded types + List rows = + sql("SELECT id, variant_get(address, '$.score', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(6); + assertThat(rows.get(0)[1]).isEqualTo("95"); + assertThat(rows.get(1)[1]).isEqualTo("88"); + assertThat(rows.get(3)[1]).isEqualTo("high"); + assertThat(rows.get(5)[1]).isEqualTo("low"); + } + + @TestTemplate + public void testAllNullVariantColumn() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + sql("INSERT INTO %s VALUES (1, null), (2, null), (3, null)", tableName); + + // All variant values are SQL NULL, so no shredding should occur + Table table = validationCatalog.loadTable(tableIdent); + MessageType expectedSchema = parquetSchema(variant("address", 2, Type.Repetition.OPTIONAL)); + verifyParquetSchema(table, expectedSchema); + + List rows = sql("SELECT id, address FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[1]).isNull(); + assertThat(rows.get(1)[1]).isNull(); + assertThat(rows.get(2)[1]).isNull(); + } + + @TestTemplate + public void testBufferSizeOne() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "1"); + + sql( + """ + INSERT INTO %s VALUES + (1, parse_json('{"name": "Alice", "age": 30}')), + (2, parse_json('{"name": "Bob", "age": 25}')), + (3, parse_json('{"name": "Charlie", "age": 35}')) + """, + tableName); + + // Schema inferred from first row only, should still shred name and age + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + List rows = + sql("SELECT id, variant_get(address, '$.name', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(2)[1]).isEqualTo("Charlie"); + } + + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { + try (CloseableIterable tasks = table.newScan().planFiles()) { + assertThat(tasks).isNotEmpty(); + + for (FileScanTask task : tasks) { + String path = task.file().location(); + + HadoopInputFile inputFile = + HadoopInputFile.fromPath(new org.apache.hadoop.fs.Path(path), new Configuration()); + + try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) { + MessageType actualSchema = reader.getFileMetaData().getSchema(); + assertThat(actualSchema).isEqualTo(expectedSchema); + } + } + } + } + + private static MessageType parquetSchema(Type... variantTypes) { + return org.apache.parquet.schema.Types.buildMessage() + .required(PrimitiveType.PrimitiveTypeName.INT32) + .id(1) + .named("id") + .addFields(variantTypes) + .named("table"); + } + + private static GroupType variant(String name, int fieldId, Type.Repetition repetition) { + return org.apache.parquet.schema.Types.buildGroup(repetition) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + + private static GroupType variant( + String name, int fieldId, Type.Repetition repetition, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(repetition) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) { + return optional(primitive).named("typed_value"); + } + + private static Type shreddedPrimitive( + PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { + return optional(primitive).as(annotation).named("typed_value"); + } + + private static GroupType objectFields(GroupType... fields) { + for (GroupType fieldType : fields) { + checkField(fieldType); + } + + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .addFields(fields) + .named("typed_value"); + } + + private static GroupType field(String name, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static GroupType element(Type shreddedType) { + return field("element", shreddedType); + } + + private static GroupType list(GroupType elementType) { + return org.apache.parquet.schema.Types.optionalList().element(elementType).named("typed_value"); + } + + private static void checkShreddedType(Type shreddedType) { + Preconditions.checkArgument( + shreddedType.getName().equals("typed_value"), + "Invalid shredded type name: %s should be typed_value", + shreddedType.getName()); + Preconditions.checkArgument( + shreddedType.isRepetition(Type.Repetition.OPTIONAL), + "Invalid shredded type repetition: %s should be OPTIONAL", + shreddedType.getRepetition()); + } + + private static void checkField(GroupType fieldType) { + Preconditions.checkArgument( + fieldType.isRepetition(Type.Repetition.REQUIRED), + "Invalid field type repetition: %s should be REQUIRED", + fieldType.getRepetition()); + } +}