From d3728a1b78e4839779306bde610c76b80f3bb632 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Sun, 24 Aug 2025 16:06:00 -0700 Subject: [PATCH 01/17] Spark shredded variant implementation --- .../iceberg/parquet/ParquetVariantUtil.java | 4 +- .../apache/iceberg/parquet/ParquetWriter.java | 41 +- .../parquet/WriterLazyInitializable.java | 87 +++++ .../iceberg/spark/SparkSQLProperties.java | 10 + .../apache/iceberg/spark/SparkWriteConf.java | 10 + .../iceberg/spark/SparkWriteOptions.java | 3 + .../spark/source/SchemaInferenceVisitor.java | 198 ++++++++++ ...parkParquetWriterWithVariantShredding.java | 181 +++++++++ .../iceberg/spark/TestSparkWriteConf.java | 7 + .../spark/variant/TestVariantShredding.java | 363 ++++++++++++++++++ 10 files changed, 901 insertions(+), 3 deletions(-) create mode 100644 parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java create mode 100644 spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java create mode 100644 spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java create mode 100644 spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java index ac418a1127bd..d94760773e51 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java @@ -57,7 +57,7 @@ import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; -class ParquetVariantUtil { +public class ParquetVariantUtil { private ParquetVariantUtil() {} /** @@ -212,7 +212,7 @@ static int scale(PrimitiveType primitive) { * @param value a variant value * @return a Parquet schema that can fully shred the value */ - static Type toParquetSchema(VariantValue value) { + public static Type toParquetSchema(VariantValue value) { return VariantVisitor.visit(value, new ParquetSchemaProducer()); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java index 2334e75532be..88dbad6fb6e8 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java @@ -51,7 +51,7 @@ class ParquetWriter implements FileAppender, Closeable { private final Map metadata; private final ParquetProperties props; private final CompressionCodecFactory.BytesInputCompressor compressor; - private final MessageType parquetSchema; + private MessageType parquetSchema; private final ParquetValueWriter model; private final MetricsConfig metricsConfig; private final int columnIndexTruncateLength; @@ -134,6 +134,30 @@ private void ensureWriterInitialized() { @Override public void add(T value) { + if (model instanceof WriterLazyInitializable) { + WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (lazy.needsInitialization()) { + model.write(0, value); + recordCount += 1; + + if (!lazy.needsInitialization()) { + WriterLazyInitializable.InitializationResult result = + lazy.initialize(props, compressor, rowGroupOrdinal); + this.parquetSchema = result.getSchema(); + this.pageStore = result.getPageStore(); + this.writeStore = result.getWriteStore(); + + // Re-initialize the file writer with the new schema + ensureWriterInitialized(); + + // Buffered rows were already written with endRecord() calls + // in the lazy writer's initialization, so we don't call endRecord() here + checkSize(); + } + return; + } + } + recordCount += 1; model.write(0, value); writeStore.endRecord(); @@ -255,6 +279,21 @@ private void startRowGroup() { public void close() throws IOException { if (!closed) { this.closed = true; + + // Force initialization if lazy writer still has buffered data + if (model instanceof WriterLazyInitializable) { + WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (lazy.needsInitialization()) { + WriterLazyInitializable.InitializationResult result = + lazy.initialize(props, compressor, rowGroupOrdinal); + this.parquetSchema = result.getSchema(); + this.pageStore = result.getPageStore(); + this.writeStore = result.getWriteStore(); + + ensureWriterInitialized(); + } + } + flushRowGroup(true); writeStore.close(); if (writer != null) { diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java new file mode 100644 index 000000000000..9c5913d7bd9b --- /dev/null +++ b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import org.apache.parquet.column.ColumnWriteStore; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.compression.CompressionCodecFactory; +import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; +import org.apache.parquet.schema.MessageType; + +/** + * Interface for ParquetValueWriters that need to defer initialization until they can analyze the + * data. This is useful for scenarios like variant shredding where the schema needs to be inferred + * from the actual data before creating the writer structures. + * + *

Writers implementing this interface can buffer initial rows and perform schema inference + * before committing to a final Parquet schema. + */ +public interface WriterLazyInitializable { + /** + * Result returned by lazy initialization of a ParquetValueWriter required by ParquetWriter. + * Contains the finalized schema and write stores after schema inference or other initialization + * logic. + */ + class InitializationResult { + private final MessageType schema; + private final ColumnChunkPageWriteStore pageStore; + private final ColumnWriteStore writeStore; + + public InitializationResult( + MessageType schema, ColumnChunkPageWriteStore pageStore, ColumnWriteStore writeStore) { + this.schema = schema; + this.pageStore = pageStore; + this.writeStore = writeStore; + } + + public MessageType getSchema() { + return schema; + } + + public ColumnChunkPageWriteStore getPageStore() { + return pageStore; + } + + public ColumnWriteStore getWriteStore() { + return writeStore; + } + } + + /** + * Checks if this writer still needs initialization. This will return true until the writer has + * buffered enough data to perform initialization (e.g., schema inference). + * + * @return true if initialization is still needed, false if already initialized + */ + boolean needsInitialization(); + + /** + * Performs initialization and returns the result containing updated schema and write stores. This + * method should only be called when {@link #needsInitialization()} returns true. + * + * @param props Parquet properties needed for creating write stores + * @param compressor Bytes compressor for compression + * @param rowGroupOrdinal The ordinal number of the current row group + * @return InitializationResult containing the finalized schema and write stores + */ + InitializationResult initialize( + ParquetProperties props, + CompressionCodecFactory.BytesInputCompressor compressor, + int rowGroupOrdinal); +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index 81139969f746..b12606d23948 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -109,4 +109,14 @@ private SparkSQLProperties() {} // Prefix for custom snapshot properties public static final String SNAPSHOT_PROPERTY_PREFIX = "spark.sql.iceberg.snapshot-property."; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; + public static final boolean SHRED_VARIANTS_DEFAULT = true; + + // Controls the buffer size for variant schema inference during writes + // This determines how many rows are buffered before inferring shredded schema + public static final String VARIANT_INFERENCE_BUFFER_SIZE = + "spark.sql.iceberg.variant.inference.buffer-size"; + public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 96131e0e56dd..4baf5585b220 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -509,6 +509,7 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } + writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); break; case AVRO: @@ -729,4 +730,13 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } + + public boolean shredVariants() { + return confParser + .booleanConf() + .option(SparkWriteOptions.SHRED_VARIANTS) + .sessionConf(SparkSQLProperties.SHRED_VARIANTS) + .defaultValue(SparkSQLProperties.SHRED_VARIANTS_DEFAULT) + .parse(); + } } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 33db70bae587..f8fb41696f76 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -85,4 +85,7 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "shred-variants"; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java new file mode 100644 index 000000000000..0eed88a8eb66 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import org.apache.iceberg.parquet.ParquetVariantUtil; +import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; +import org.apache.iceberg.variants.Variant; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; +import org.apache.parquet.schema.Types.MessageTypeBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VariantType; +import org.apache.spark.unsafe.types.VariantVal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A visitor that infers variant shredding schemas by analyzing buffered rows of data. This visitor + * can be plugged into ParquetWithSparkSchemaVisitor.visit() to create a shredded MessageType based + * on actual variant data content. + * + *

The visitor uses the field names tracked during traversal to look up the correct field index + * in the Spark schema, allowing it to access the corresponding value in the rows for schema + * inference. It searches through all buffered rows to find the first non-null variant value for + * schema inference. + */ +public class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { + private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); + + private final List bufferedRows; + private final StructType sparkSchema; + + public SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { + this.bufferedRows = bufferedRows; + this.sparkSchema = sparkSchema; + } + + @Override + public Type message(StructType sStruct, MessageType message, List fields) { + MessageTypeBuilder builder = Types.buildMessage(); + + for (Type field : fields) { + if (field != null) { + builder.addField(field); + } + } + + return builder.named(message.getName()); + } + + @Override + public Type struct(StructType sStruct, GroupType struct, List fields) { + Types.GroupBuilder builder = Types.buildGroup(struct.getRepetition()); + + if (struct.getId() != null) { + builder = builder.id(struct.getId().intValue()); + } + + for (Type field : fields) { + if (field != null) { + builder = builder.addField(field); + } + } + + return builder.named(struct.getName()); + } + + @Override + public Type primitive(DataType sPrimitive, PrimitiveType primitive) { + return primitive; + } + + @Override + public Type list(ArrayType sArray, GroupType array, Type element) { + Types.GroupBuilder builder = + Types.buildGroup(array.getRepetition()).as(LogicalTypeAnnotation.listType()); + + if (array.getId() != null) { + builder = builder.id(array.getId().intValue()); + } + + if (element != null) { + builder = builder.addField(element); + } + + return builder.named(array.getName()); + } + + @Override + public Type map(MapType sMap, GroupType map, Type key, Type value) { + Types.GroupBuilder builder = + Types.buildGroup(map.getRepetition()).as(LogicalTypeAnnotation.mapType()); + + if (map.getId() != null) { + builder = builder.id(map.getId().intValue()); + } + + if (key != null) { + builder = builder.addField(key); + } + if (value != null) { + builder = builder.addField(value); + } + + return builder.named(map.getName()); + } + + @Override + public Type variant(VariantType sVariant, GroupType variant) { + int variantFieldIndex = getFieldIndex(currentPath()); + + // Find the first non-null variant value from buffered rows for schema inference + // This ensures we can infer a schema even if the first rows has null variant values + if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { + for (InternalRow row : bufferedRows) { + if (!row.isNullAt(variantFieldIndex)) { + VariantVal variantVal = row.getVariant(variantFieldIndex); + if (variantVal != null) { + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + + Type shreddedType = ParquetVariantUtil.toParquetSchema(variantValue); + if (shreddedType != null) { + return Types.buildGroup(variant.getRepetition()) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .id(variant.getId().intValue()) + .required(BINARY) + .named("metadata") + .optional(BINARY) + .named("value") + .addField(shreddedType) + .named(variant.getName()); + } + } + } + } + } + + return variant; + } + + private int getFieldIndex(String[] path) { + if (path == null || path.length == 0) { + return -1; + } + + // TODO: For now, we only support top-level variant fields. To support nested variants, we would + // need to navigate the struct hierarchy + if (path.length == 1) { + String fieldName = path[0]; + for (int i = 0; i < sparkSchema.fields().length; i++) { + if (sparkSchema.fields()[i].name().equals(fieldName)) { + return i; + } + } + } else { + LOG.warn( + "Nested variant fields are not yet supported for schema inference. Path: {}", + String.join(".", path)); + } + + return -1; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java new file mode 100644 index 000000000000..8f1a61d60c6f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.Schema; +import org.apache.iceberg.parquet.ParquetValueWriter; +import org.apache.iceberg.parquet.TripleWriter; +import org.apache.iceberg.parquet.WriterLazyInitializable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnWriteStore; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.compression.CompressionCodecFactory; +import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; +import org.apache.parquet.schema.MessageType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +/** + * A Parquet output writer that performs variant shredding with schema inference. This is similar to + * Spark's ParquetOutputWriterWithVariantShredding but adapted for Iceberg. + * + *

The writer works in two phases: 1. Schema inference phase: Buffers initial rows and analyzes + * variant data to infer schemas 2. Writing phase: Creates the actual Parquet writer with inferred + * schemas and writes all data + */ +public class SparkParquetWriterWithVariantShredding + implements ParquetValueWriter, WriterLazyInitializable { + private final StructType sparkSchema; + private final MessageType parquetType; + + private final List bufferedRows; + private ParquetValueWriter actualWriter; + private boolean writerInitialized = false; + private final int bufferSize; + + private static class BufferedRow { + private final int repetitionLevel; + private final InternalRow row; + + BufferedRow(int repetitionLevel, InternalRow row) { + this.repetitionLevel = repetitionLevel; + this.row = row; + } + } + + public SparkParquetWriterWithVariantShredding( + StructType sparkSchema, MessageType parquetType, Map properties) { + this.sparkSchema = sparkSchema; + this.parquetType = parquetType; + + this.bufferSize = + Integer.parseInt( + properties.getOrDefault( + SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, + String.valueOf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT))); + this.bufferedRows = Lists.newArrayList(); + } + + @Override + public void write(int repetitionLevel, InternalRow row) { + if (!writerInitialized) { + bufferedRows.add( + new BufferedRow( + repetitionLevel, row.copy())); /* Make a copy of the object since row gets reused */ + + if (bufferedRows.size() >= bufferSize) { + writerInitialized = true; + } + } else { + actualWriter.write(repetitionLevel, row); + } + } + + @Override + public List> columns() { + if (actualWriter != null) { + return actualWriter.columns(); + } + return Collections.emptyList(); + } + + @Override + public void setColumnStore(ColumnWriteStore columnStore) { + // Ignored for lazy initialization - will be set on actualWriter after initialization + } + + @Override + public Stream> metrics() { + if (actualWriter != null) { + return actualWriter.metrics(); + } + return Stream.empty(); + } + + @Override + public boolean needsInitialization() { + return !writerInitialized; + } + + @Override + public InitializationResult initialize( + ParquetProperties props, + CompressionCodecFactory.BytesInputCompressor compressor, + int rowGroupOrdinal) { + if (bufferedRows.isEmpty()) { + throw new IllegalStateException("No buffered rows available for schema inference"); + } + + List rows = Lists.newLinkedList(); + for (BufferedRow bufferedRow : bufferedRows) { + rows.add(bufferedRow.row); + } + + MessageType shreddedSchema = + (MessageType) + ParquetWithSparkSchemaVisitor.visit( + sparkSchema, parquetType, new SchemaInferenceVisitor(rows, sparkSchema)); + + actualWriter = SparkParquetWriters.buildWriter(sparkSchema, shreddedSchema); + + ColumnChunkPageWriteStore pageStore = + new ColumnChunkPageWriteStore( + compressor, + shreddedSchema, + props.getAllocator(), + 64, + ParquetProperties.DEFAULT_PAGE_WRITE_CHECKSUM_ENABLED, + null, + rowGroupOrdinal); + + ColumnWriteStore columnStore = props.newColumnWriteStore(shreddedSchema, pageStore, pageStore); + + actualWriter.setColumnStore(columnStore); + + for (BufferedRow bufferedRow : bufferedRows) { + actualWriter.write(bufferedRow.repetitionLevel, bufferedRow.row); + columnStore.endRecord(); + } + + bufferedRows.clear(); + writerInitialized = true; + + return new InitializationResult(shreddedSchema, pageStore, columnStore); + } + + public static boolean shouldUseVariantShredding(Map properties, Schema schema) { + boolean shreddingEnabled = + properties.containsKey(SparkSQLProperties.SHRED_VARIANTS) + && Boolean.parseBoolean(properties.get(SparkSQLProperties.SHRED_VARIANTS)); + + boolean hasVariantFields = + schema.columns().stream().anyMatch(field -> field.type() instanceof Types.VariantType); + + return shreddingEnabled && hasVariantFields; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index 61aacfa4589d..d97579f29e86 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -41,6 +41,7 @@ import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; +import static org.apache.iceberg.spark.SparkSQLProperties.SHRED_VARIANTS; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; @@ -339,6 +340,8 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -460,6 +463,8 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -531,6 +536,8 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java new file mode 100644 index 000000000000..d82a241ba148 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.variant; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.variants.Variant; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestVariantShredding extends CatalogTestBase { + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get())); + + private static final Schema SCHEMA2 = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get()), + Types.NestedField.optional(3, "metadata", Types.VariantType.get())); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + }; + } + + @BeforeAll + public static void startMetastoreAndSpark() { + // First call parent to initialize metastore and spark with local[2] + CatalogTestBase.startMetastoreAndSpark(); + + // Now stop and recreate spark with local[1] to write all rows to a single file + if (spark != null) { + spark.stop(); + } + + spark = + SparkSession.builder() + .master("local[1]") // Use one thread to write the rows to a single parquet file + .config("spark.driver.host", InetAddress.getLoopbackAddress().getHostAddress()) + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .enableHiveSupport() + .getOrCreate(); + + sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @BeforeEach + public void before() { + super.before(); + validationCatalog.createTable( + tableIdent, SCHEMA, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + } + + @AfterEach + public void after() { + validationCatalog.dropTable(tableIdent, true); + } + + @TestTemplate + public void testVariantShreddingWrite() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + String values = + "(1, parse_json('{\"name\": \"Joe\", \"streets\": [\"Apt #3\", \"1234 Ave\"], \"zip\": 10001}')), (2, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType streets = + field( + "streets", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + GroupType zip = + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16))); + GroupType address = variant("address", 2, objectFields(name, streets, zip)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingWithNullFirstRow() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = "(1, null), (2, parse_json('{\"city\": \"Seattle\", \"state\": \"WA\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType state = + field( + "state", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, objectFields(city, state)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingWithTwoVariantColumns() throws IOException { + validationCatalog.dropTable(tableIdent, true); + validationCatalog.createTable( + tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}'), parse_json('{\"type\": \"home\", \"verified\": true}')), " + + "(2, null, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType zip = + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); + GroupType address = variant("address", 2, objectFields(city, zip)); + + GroupType type = + field( + "type", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType verified = + field("verified", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType metadata = variant("metadata", 3, objectFields(type, verified)); + + MessageType expectedSchema = parquetSchema(address, metadata); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingWithTwoVariantColumnsOneNull() throws IOException { + validationCatalog.dropTable(tableIdent, true); + validationCatalog.createTable( + tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // First row: address is null, metadata has value + // Second row: address has value, metadata is null + String values = + "(1, null, parse_json('{\"label\": \"primary\"}'))," + + " (2, parse_json('{\"street\": \"Main St\"}'), null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType street = + field( + "street", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, objectFields(street)); + + GroupType label = + field( + "label", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType metadata = variant("metadata", 3, objectFields(label)); + + MessageType expectedSchema = parquetSchema(address, metadata); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingDisabled() throws IOException { + // Test with shredding explicitly disabled + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); + + String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = variant("address", 2); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { + try (CloseableIterable tasks = table.newScan().planFiles()) { + assertThat(tasks).isNotEmpty(); + + FileScanTask task = tasks.iterator().next(); + String path = task.file().location(); + + HadoopInputFile inputFile = + HadoopInputFile.fromPath(new org.apache.hadoop.fs.Path(path), new Configuration()); + + try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) { + MessageType actualSchema = reader.getFileMetaData().getSchema(); + assertThat(actualSchema).isEqualTo(expectedSchema); + } + } + } + + private static MessageType parquetSchema(Type... variantTypes) { + return org.apache.parquet.schema.Types.buildMessage() + .required(PrimitiveType.PrimitiveTypeName.INT32) + .id(1) + .named("id") + .addFields(variantTypes) + .named("table"); + } + + private static GroupType variant(String name, int fieldId) { + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + + private static GroupType variant(String name, int fieldId, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) { + return org.apache.parquet.schema.Types.optional(primitive).named("typed_value"); + } + + private static Type shreddedPrimitive( + PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { + return org.apache.parquet.schema.Types.optional(primitive).as(annotation).named("typed_value"); + } + + private static GroupType objectFields(GroupType... fields) { + for (GroupType fieldType : fields) { + checkField(fieldType); + } + + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .addFields(fields) + .named("typed_value"); + } + + private static GroupType field(String name, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static GroupType element(Type shreddedType) { + return field("element", shreddedType); + } + + private static GroupType list(GroupType elementType) { + return org.apache.parquet.schema.Types.optionalList().element(elementType).named("typed_value"); + } + + private static void checkShreddedType(Type shreddedType) { + Preconditions.checkArgument( + shreddedType.getName().equals("typed_value"), + "Invalid shredded type name: %s should be typed_value", + shreddedType.getName()); + Preconditions.checkArgument( + shreddedType.isRepetition(Type.Repetition.OPTIONAL), + "Invalid shredded type repetition: %s should be OPTIONAL", + shreddedType.getRepetition()); + } + + private static void checkField(GroupType fieldType) { + Preconditions.checkArgument( + fieldType.isRepetition(Type.Repetition.REQUIRED), + "Invalid field type repetition: %s should be REQUIRED", + fieldType.getRepetition()); + } +} From edf772fbe53b9762b3964dbadffef0b0bb064943 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Sat, 1 Nov 2025 21:26:35 -0700 Subject: [PATCH 02/17] Add heuristics to determine the shredding schema --- .../parquet/ParquetVariantWriters.java | 58 +- .../iceberg/parquet/VariantWriterBuilder.java | 16 +- .../iceberg/spark/SparkSQLProperties.java | 12 + .../apache/iceberg/spark/SparkWriteConf.java | 18 + .../spark/source/SchemaInferenceVisitor.java | 83 +-- ...parkParquetWriterWithVariantShredding.java | 9 +- .../source/VariantShreddingAnalyzer.java | 545 ++++++++++++++++++ .../spark/variant/TestVariantShredding.java | 396 ++++++++++++- 8 files changed, 1077 insertions(+), 60 deletions(-) create mode 100644 spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java index 9e94b1bbd6cd..42cdee7a1a5c 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java @@ -65,6 +65,11 @@ static ParquetValueWriter primitive( return new PrimitiveWriter<>(writer, Sets.immutableEnumSet(Arrays.asList(types))); } + static ParquetValueWriter decimal( + ParquetValueWriter writer, int expectedScale, PhysicalType... types) { + return new DecimalWriter(writer, expectedScale, Sets.immutableEnumSet(Arrays.asList(types))); + } + @SuppressWarnings("unchecked") static ParquetValueWriter shredded( int valueDefinitionLevel, @@ -253,6 +258,49 @@ public void setColumnStore(ColumnWriteStore columnStore) { } } + /** + * A TypedWriter for decimals that validates scale before writing. + * If the scale doesn't match, it returns false from canWrite() to trigger fallback to value field. + */ + private static class DecimalWriter implements TypedWriter { + private final Set types; + private final ParquetValueWriter writer; + private final int expectedScale; + + private DecimalWriter( + ParquetValueWriter writer, int expectedScale, Set types) { + this.types = types; + this.writer = (ParquetValueWriter) writer; + this.expectedScale = expectedScale; + } + + @Override + public Set types() { + return types; + } + + @Override + public void write(int repetitionLevel, VariantValue value) { + java.math.BigDecimal decimal = (java.math.BigDecimal) value.asPrimitive().get(); + // Validate scale matches before writing + if (decimal.scale() != expectedScale) { + throw new IllegalArgumentException( + "Cannot write decimal with scale " + decimal.scale() + " to schema expecting scale " + expectedScale); + } + writer.write(repetitionLevel, decimal); + } + + @Override + public List> columns() { + return writer.columns(); + } + + @Override + public void setColumnStore(ColumnWriteStore columnStore) { + writer.setColumnStore(columnStore); + } + } + private static class ShreddedVariantWriter implements ParquetValueWriter { private final int valueDefinitionLevel; private final ParquetValueWriter valueWriter; @@ -275,8 +323,14 @@ private ShreddedVariantWriter( @Override public void write(int repetitionLevel, VariantValue value) { if (typedWriter.types().contains(value.type())) { - typedWriter.write(repetitionLevel, value); - writeNull(valueWriter, repetitionLevel, valueDefinitionLevel); + try { + typedWriter.write(repetitionLevel, value); + writeNull(valueWriter, repetitionLevel, valueDefinitionLevel); + } catch (IllegalArgumentException e) { + // Fall back to value field if typed write fails (e.g., decimal scale mismatch) + valueWriter.write(repetitionLevel, value); + writeNull(typedWriter, repetitionLevel, typedDefinitionLevel); + } } else { valueWriter.write(repetitionLevel, value); writeNull(typedWriter, repetitionLevel, typedDefinitionLevel); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java index a447a102690a..53cf5d9933d6 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java @@ -198,27 +198,31 @@ public Optional> visit(StringLogicalTypeAnnotation ignored @Override public Optional> visit(DecimalLogicalTypeAnnotation decimal) { ParquetValueWriter writer; + int scale = decimal.getScale(); switch (desc.getPrimitiveType().getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: case BINARY: writer = - ParquetVariantWriters.primitive( + ParquetVariantWriters.decimal( ParquetValueWriters.decimalAsFixed( - desc, decimal.getPrecision(), decimal.getScale()), + desc, decimal.getPrecision(), scale), + scale, PhysicalType.DECIMAL16); return Optional.of(writer); case INT64: writer = - ParquetVariantWriters.primitive( + ParquetVariantWriters.decimal( ParquetValueWriters.decimalAsLong( - desc, decimal.getPrecision(), decimal.getScale()), + desc, decimal.getPrecision(), scale), + scale, PhysicalType.DECIMAL8); return Optional.of(writer); case INT32: writer = - ParquetVariantWriters.primitive( + ParquetVariantWriters.decimal( ParquetValueWriters.decimalAsInteger( - desc, decimal.getPrecision(), decimal.getScale()), + desc, decimal.getPrecision(), scale), + scale, PhysicalType.DECIMAL4); return Optional.of(writer); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index b12606d23948..e111becad89e 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -119,4 +119,16 @@ private SparkSQLProperties() {} public static final String VARIANT_INFERENCE_BUFFER_SIZE = "spark.sql.iceberg.variant.inference.buffer-size"; public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; + + // Controls the minimum occurrence threshold for variant fields during shredding + // Fields that appear in fewer than this percentage of rows will be dropped + public static final String VARIANT_MIN_OCCURRENCE_THRESHOLD = + "spark.sql.iceberg.variant.min-occurrence-threshold"; + public static final double VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT = 0.1; // 10% + + // Controls the maximum number of fields to shred in a variant column + // This prevents creating overly wide Parquet schemas + public static final String VARIANT_MAX_SHREDDED_FIELDS = + "spark.sql.iceberg.variant.max-shredded-fields"; + public static final int VARIANT_MAX_SHREDDED_FIELDS_DEFAULT = 300; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 4baf5585b220..34fcd2f1e467 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -510,6 +510,24 @@ private Map dataWriteProperties() { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); + + // Add variant shredding configuration properties + if (shredVariants()) { + String variantMaxFields = sessionConf.get(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, null); + if (variantMaxFields != null) { + writeProperties.put(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, variantMaxFields); + } + + String variantMinOccurrence = sessionConf.get(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, null); + if (variantMinOccurrence != null) { + writeProperties.put(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, variantMinOccurrence); + } + + String variantBufferSize = sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); + if (variantBufferSize != null) { + writeProperties.put(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); + } + } break; case AVRO: diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index 0eed88a8eb66..c03fc74f00e3 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -23,7 +23,9 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.List; +import java.util.Map; import org.apache.iceberg.parquet.ParquetVariantUtil; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; import org.apache.iceberg.variants.Variant; import org.apache.iceberg.variants.VariantMetadata; @@ -46,24 +48,33 @@ import org.slf4j.LoggerFactory; /** - * A visitor that infers variant shredding schemas by analyzing buffered rows of data. This visitor - * can be plugged into ParquetWithSparkSchemaVisitor.visit() to create a shredded MessageType based - * on actual variant data content. - * - *

The visitor uses the field names tracked during traversal to look up the correct field index - * in the Spark schema, allowing it to access the corresponding value in the rows for schema - * inference. It searches through all buffered rows to find the first non-null variant value for - * schema inference. + * A visitor that infers variant shredding schemas by analyzing buffered rows of data. */ public class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); private final List bufferedRows; private final StructType sparkSchema; + private final VariantShreddingAnalyzer analyzer; - public SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { + public SchemaInferenceVisitor( + List bufferedRows, StructType sparkSchema, Map properties) { this.bufferedRows = bufferedRows; this.sparkSchema = sparkSchema; + + double minOccurrenceThreshold = + Double.parseDouble( + properties.getOrDefault( + SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, + String.valueOf(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT))); + + int maxFields = + Integer.parseInt( + properties.getOrDefault( + SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, + String.valueOf(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS_DEFAULT))); + + this.analyzer = new VariantShreddingAnalyzer(minOccurrenceThreshold, maxFields); } @Override @@ -140,33 +151,22 @@ public Type map(MapType sMap, GroupType map, Type key, Type value) { public Type variant(VariantType sVariant, GroupType variant) { int variantFieldIndex = getFieldIndex(currentPath()); - // Find the first non-null variant value from buffered rows for schema inference - // This ensures we can infer a schema even if the first rows has null variant values + // Apply heuristics to determine the shredding schema: + // - Fields must appear in at least the configured percentage of rows + // - Type consistency determines if typed_value is created + // - Maximum field count to avoid overly wide schemas if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { - for (InternalRow row : bufferedRows) { - if (!row.isNullAt(variantFieldIndex)) { - VariantVal variantVal = row.getVariant(variantFieldIndex); - if (variantVal != null) { - VariantValue variantValue = - VariantValue.from( - VariantMetadata.from( - ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), - ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); - - Type shreddedType = ParquetVariantUtil.toParquetSchema(variantValue); - if (shreddedType != null) { - return Types.buildGroup(variant.getRepetition()) - .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) - .id(variant.getId().intValue()) - .required(BINARY) - .named("metadata") - .optional(BINARY) - .named("value") - .addField(shreddedType) - .named(variant.getName()); - } - } - } + Type shreddedType = analyzer.analyzeAndCreateSchema(bufferedRows, variantFieldIndex); + if (shreddedType != null) { + return Types.buildGroup(variant.getRepetition()) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .id(variant.getId().intValue()) + .required(BINARY) + .named("metadata") + .optional(BINARY) + .named("value") + .addField(shreddedType) + .named(variant.getName()); } } @@ -178,9 +178,9 @@ private int getFieldIndex(String[] path) { return -1; } - // TODO: For now, we only support top-level variant fields. To support nested variants, we would - // need to navigate the struct hierarchy + // Support nested variant fields by navigating the struct hierarchy if (path.length == 1) { + // Top-level field - direct lookup String fieldName = path[0]; for (int i = 0; i < sparkSchema.fields().length; i++) { if (sparkSchema.fields()[i].name().equals(fieldName)) { @@ -188,8 +188,15 @@ private int getFieldIndex(String[] path) { } } } else { + // Nested field - navigate through struct hierarchy + // For now, we only support direct struct nesting (not arrays/maps) + LOG.debug( + "Attempting to resolve nested variant field path: {}", String.join(".", path)); + // TODO: Implement full nested field resolution when needed + // This would require tracking the current struct context during traversal + // and maintaining a stack of field indices LOG.warn( - "Nested variant fields are not yet supported for schema inference. Path: {}", + "Multi-level nested variant fields require struct context tracking. Path: {}", String.join(".", path)); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java index 8f1a61d60c6f..6a2ed1e85324 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java @@ -41,8 +41,7 @@ import org.apache.spark.sql.types.StructType; /** - * A Parquet output writer that performs variant shredding with schema inference. This is similar to - * Spark's ParquetOutputWriterWithVariantShredding but adapted for Iceberg. + * A Parquet output writer that performs variant shredding with schema inference. * *

The writer works in two phases: 1. Schema inference phase: Buffers initial rows and analyzes * variant data to infer schemas 2. Writing phase: Creates the actual Parquet writer with inferred @@ -52,6 +51,7 @@ public class SparkParquetWriterWithVariantShredding implements ParquetValueWriter, WriterLazyInitializable { private final StructType sparkSchema; private final MessageType parquetType; + private final Map properties; private final List bufferedRows; private ParquetValueWriter actualWriter; @@ -72,6 +72,7 @@ public SparkParquetWriterWithVariantShredding( StructType sparkSchema, MessageType parquetType, Map properties) { this.sparkSchema = sparkSchema; this.parquetType = parquetType; + this.properties = properties; this.bufferSize = Integer.parseInt( @@ -139,7 +140,9 @@ public InitializationResult initialize( MessageType shreddedSchema = (MessageType) ParquetWithSparkSchemaVisitor.visit( - sparkSchema, parquetType, new SchemaInferenceVisitor(rows, sparkSchema)); + sparkSchema, + parquetType, + new SchemaInferenceVisitor(rows, sparkSchema, properties)); actualWriter = SparkParquetWriters.buildWriter(sparkSchema, shreddedSchema); diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java new file mode 100644 index 000000000000..581043cd802e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -0,0 +1,545 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.parquet.ParquetVariantUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.variants.PhysicalType; +import org.apache.iceberg.variants.VariantArray; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantObject; +import org.apache.iceberg.variants.VariantPrimitive; +import org.apache.iceberg.variants.VariantValue; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.unsafe.types.VariantVal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Analyzes variant data across buffered rows to determine an optimal shredding schema. + ** + *

    + *
  • If a field appears consistently with a consistent type → create both {@code value} and + * {@code typed_value} + *
  • If a field appears with inconsistent types → only create {@code value} + *
  • Drop fields that occur in less than the configured threshold of sampled rows + *
  • Cap the maximum fields to shred + *
+ */ +public class VariantShreddingAnalyzer { + private static final Logger LOG = LoggerFactory.getLogger(VariantShreddingAnalyzer.class); + + private final double minOccurrenceThreshold; + private final int maxFields; + + /** + * Creates a new analyzer with the specified configuration. + * + * @param minOccurrenceThreshold minimum occurrence threshold (e.g., 0.1 for 10%) + * @param maxFields maximum number of fields to shred + */ + public VariantShreddingAnalyzer(double minOccurrenceThreshold, int maxFields) { + this.minOccurrenceThreshold = minOccurrenceThreshold; + this.maxFields = maxFields; + } + + /** + * Analyzes buffered variant values to determine the optimal shredding schema. + * + * @param bufferedRows the buffered rows to analyze + * @param variantFieldIndex the index of the variant field in the rows + * @return the shredded schema type, or null if no shredding should be performed + */ + public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) { + if (bufferedRows.isEmpty()) { + return null; + } + + List variantValues = extractVariantValues(bufferedRows, variantFieldIndex); + if (variantValues.isEmpty()) { + return null; + } + + FieldStats stats = analyzeFields(variantValues); + return buildShreddedSchema(stats, variantValues.size()); + } + + private static List extractVariantValues( + List bufferedRows, int variantFieldIndex) { + List values = new java.util.ArrayList<>(); + + for (InternalRow row : bufferedRows) { + if (!row.isNullAt(variantFieldIndex)) { + VariantVal variantVal = row.getVariant(variantFieldIndex); + if (variantVal != null) { + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + values.add(variantValue); + } + } + } + + return values; + } + + private static FieldStats analyzeFields(List variantValues) { + FieldStats stats = new FieldStats(); + + for (VariantValue value : variantValues) { + if (value.type() == PhysicalType.OBJECT) { + VariantObject obj = value.asObject(); + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + stats.recordField(fieldName, fieldValue); + } + } + } + } + + return stats; + } + + private Type buildShreddedSchema(FieldStats stats, int totalRows) { + int minOccurrences = (int) Math.ceil(totalRows * minOccurrenceThreshold); + + // Get fields that meet the occurrence threshold + Set candidateFields = Sets.newTreeSet(); + for (Map.Entry entry : stats.fieldInfoMap.entrySet()) { + String fieldName = entry.getKey(); + FieldInfo info = entry.getValue(); + + if (info.occurrenceCount >= minOccurrences) { + candidateFields.add(fieldName); + } else { + LOG.debug( + "Field '{}' appears only {} times out of {} (< {}%), dropping", + fieldName, + info.occurrenceCount, + totalRows, + (int) (minOccurrenceThreshold * 100)); + } + } + + if (candidateFields.isEmpty()) { + return null; + } + + // Build the typed_value struct with field count limit + Types.GroupBuilder objectBuilder = Types.buildGroup(Type.Repetition.OPTIONAL); + int fieldCount = 0; + + for (String fieldName : candidateFields) { + FieldInfo info = stats.fieldInfoMap.get(fieldName); + + if (info.hasConsistentType()) { + Type shreddedFieldType = createShreddedFieldType(fieldName, info); + if (shreddedFieldType != null) { + if (fieldCount + 2 > maxFields) { + LOG.debug( + "Reached maximum field limit ({}) while processing field '{}', stopping", + maxFields, + fieldName); + break; + } + objectBuilder.addField(shreddedFieldType); + fieldCount += 2; + } + } else { + Type valueOnlyField = createValueOnlyField(fieldName); + if (fieldCount + 1 > maxFields) { + LOG.debug( + "Reached maximum field limit ({}) while processing field '{}', stopping", + maxFields, + fieldName); + break; + } + objectBuilder.addField(valueOnlyField); + fieldCount += 1; + LOG.debug( + "Field '{}' has inconsistent types ({}), creating value-only field", + fieldName, + info.observedTypes); + } + } + + if (fieldCount == 0) { + return null; + } + + LOG.info("Created shredded schema with {} fields for {} candidate fields", fieldCount, candidateFields.size()); + return objectBuilder.named("typed_value"); + } + + private static Type createShreddedFieldType(String fieldName, FieldInfo info) { + PhysicalType physicalType = info.getConsistentType(); + if (physicalType == null) { + return null; + } + + // For array types, analyze the first value to determine element type + Type typedValue; + if (physicalType == PhysicalType.ARRAY) { + typedValue = createArrayTypedValue(info); + } else if (physicalType == PhysicalType.DECIMAL4 + || physicalType == PhysicalType.DECIMAL8 + || physicalType == PhysicalType.DECIMAL16) { + // For decimals, infer precision and scale from actual values + typedValue = createDecimalTypedValue(info, physicalType); + } else if (physicalType == PhysicalType.OBJECT) { + // For nested objects, attempt recursive shredding + typedValue = createNestedObjectTypedValue(info); + } else { + // Convert the physical type to a Parquet type for typed_value + typedValue = convertPhysicalTypeToParquet(physicalType); + } + + if (typedValue == null) { + // If we can't create a typed_value (e.g., inconsistent decimal scales), + // create a value-only field instead of skipping the field entirely + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(fieldName); + } + + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(typedValue) + .named(fieldName); + } + + private static Type createDecimalTypedValue(FieldInfo info, PhysicalType decimalType) { + // Analyze decimal values to determine precision and scale + // All values must have the same scale to be considered consistent + Integer consistentScale = null; + int maxPrecision = 0; + + for (VariantValue value : info.observedValues) { + if (value.type() == decimalType) { + try { + VariantPrimitive primitive = value.asPrimitive(); + Object decimalValue = primitive.get(); + if (decimalValue instanceof BigDecimal) { + BigDecimal bd = (BigDecimal) decimalValue; + int precision = bd.precision(); + int scale = bd.scale(); + + // Check scale consistency + if (consistentScale == null) { + consistentScale = scale; + } else if (consistentScale != scale) { + // Different scales mean inconsistent types - no typed_value + LOG.debug( + "Decimal values have inconsistent scales ({} vs {}), skipping typed_value", + consistentScale, + scale); + return null; + } + + maxPrecision = Math.max(maxPrecision, precision); + } + } catch (Exception e) { + LOG.debug("Failed to analyze decimal value", e); + } + } + } + + if (maxPrecision == 0 || consistentScale == null) { + LOG.debug("Could not determine decimal precision/scale, skipping typed_value"); + return null; + } + + // Determine the appropriate Parquet type based on precision + PrimitiveType.PrimitiveTypeName primitiveType; + if (maxPrecision <= 9) { + primitiveType = PrimitiveType.PrimitiveTypeName.INT32; + } else if (maxPrecision <= 18) { + primitiveType = PrimitiveType.PrimitiveTypeName.INT64; + } else { + primitiveType = PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY; + } + + return Types.optional(primitiveType) + .as(LogicalTypeAnnotation.decimalType(consistentScale, maxPrecision)) + .named("typed_value"); + } + + private static Type createNestedObjectTypedValue(FieldInfo info) { + // For nested objects, we can recursively analyze their fields + // For now, we'll create a simpler representation + // A full implementation would recursively build the object structure + + // Get a sample object to analyze its fields + for (VariantValue value : info.observedValues) { + if (value.type() == PhysicalType.OBJECT) { + try { + VariantObject obj = value.asObject(); + int numFields = obj.numFields(); + + // Only shred simple nested objects (not too many fields) + if (numFields > 0 && numFields <= 20) { + // Analyze fields in the nested object + Map> nestedFieldTypes = Maps.newHashMap(); + + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + nestedFieldTypes + .computeIfAbsent(fieldName, k -> Sets.newHashSet()) + .add(fieldValue.type()); + } + } + + // Build nested struct with fields that have consistent types + Types.GroupBuilder nestedBuilder = + Types.buildGroup(Type.Repetition.OPTIONAL); + int fieldCount = 0; + + for (Map.Entry> entry : nestedFieldTypes.entrySet()) { + String fieldName = entry.getKey(); + Set types = entry.getValue(); + + // Only include fields with consistent types + if (types.size() == 1) { + PhysicalType fieldType = types.iterator().next(); + Type fieldParquetType = convertPhysicalTypeToParquet(fieldType); + if (fieldParquetType != null) { + GroupType nestedField = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(fieldParquetType) + .named(fieldName); + nestedBuilder.addField(nestedField); + fieldCount++; + } + } + } + + if (fieldCount > 0) { + return nestedBuilder.named("typed_value"); + } + } + } catch (Exception e) { + LOG.debug("Failed to analyze nested object", e); + } + break; + } + } + + LOG.debug("Skipping nested object - complex structure or analysis failed"); + return null; + } + + private static Type createArrayTypedValue(FieldInfo info) { + // Get a sample array value to analyze element types + for (VariantValue value : info.observedValues) { + if (value.type() == PhysicalType.ARRAY) { + try { + VariantArray array = value.asArray(); + int numElements = array.numElements(); + if (numElements > 0) { + // Analyze elements to determine if they have consistent type + Set elementTypes = Sets.newHashSet(); + for (int i = 0; i < numElements; i++) { + elementTypes.add(array.get(i).type()); + } + + // If all elements have consistent type, create typed array + if (elementTypes.size() == 1 + || (elementTypes.size() == 2 + && elementTypes.contains(PhysicalType.BOOLEAN_TRUE) + && elementTypes.contains(PhysicalType.BOOLEAN_FALSE))) { + PhysicalType elementType = elementTypes.iterator().next(); + if (elementType == PhysicalType.BOOLEAN_FALSE + || elementType == PhysicalType.BOOLEAN_TRUE) { + elementType = PhysicalType.BOOLEAN_TRUE; + } + Type elementParquetType = convertPhysicalTypeToParquet(elementType); + if (elementParquetType != null) { + // Create list with typed element + GroupType element = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(elementParquetType) + .named("element"); + return Types.optionalList().element(element).named("typed_value"); + } + } + } + } catch (Exception e) { + LOG.debug("Failed to analyze array elements", e); + } + break; + } + } + return null; + } + + private static Type createValueOnlyField(String fieldName) { + // Create a field with only the value field (no typed_value) + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(fieldName); + } + + private static Type convertPhysicalTypeToParquet(PhysicalType physicalType) { + switch (physicalType) { + case BOOLEAN_TRUE: + case BOOLEAN_FALSE: + return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named("typed_value"); + + case INT8: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(8, true)) + .named("typed_value"); + + case INT16: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(16, true)) + .named("typed_value"); + + case INT32: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(32, true)) + .named("typed_value"); + + case INT64: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named("typed_value"); + + case FLOAT: + return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named("typed_value"); + + case DOUBLE: + return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named("typed_value"); + + case STRING: + return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("typed_value"); + + case BINARY: + return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named("typed_value"); + + case DATE: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.dateType()) + .named("typed_value"); + + case TIMESTAMPTZ: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named("typed_value"); + + case TIMESTAMPNTZ: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named("typed_value"); + + case DECIMAL4: + case DECIMAL8: + case DECIMAL16: + // Decimals are now handled in createDecimalTypedValue() + // This case should not be reached for consistent decimal types + LOG.debug("Decimal type {} should be handled by createDecimalTypedValue()", physicalType); + return null; + + case ARRAY: + // Arrays are now handled in createArrayTypedValue() + LOG.debug("Array type should be handled by createArrayTypedValue()"); + return null; + + case OBJECT: + // Nested objects are now handled in createNestedObjectTypedValue() + LOG.debug("Object type should be handled by createNestedObjectTypedValue()"); + return null; + + default: + LOG.debug("Unknown physical type: {}", physicalType); + return null; + } + } + + /** Tracks statistics about fields across multiple variant values. */ + private static class FieldStats { + private final Map fieldInfoMap = Maps.newHashMap(); + + void recordField(String fieldName, VariantValue value) { + FieldInfo info = fieldInfoMap.computeIfAbsent(fieldName, k -> new FieldInfo()); + info.observe(value); + } + } + + /** Tracks occurrence count and type consistency for a single field. */ + private static class FieldInfo { + private int occurrenceCount = 0; + private final Set observedTypes = Sets.newHashSet(); + private final List observedValues = new java.util.ArrayList<>(); + + void observe(VariantValue value) { + occurrenceCount++; + observedTypes.add(value.type()); + observedValues.add(value); + } + + boolean hasConsistentType() { + // Handle boolean types specially - both TRUE and FALSE map to BOOLEAN + if (observedTypes.size() == 2 + && observedTypes.contains(PhysicalType.BOOLEAN_TRUE) + && observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { + return true; + } + return observedTypes.size() == 1; + } + + PhysicalType getConsistentType() { + if (!hasConsistentType()) { + return null; + } + + // Handle boolean types + if (observedTypes.contains(PhysicalType.BOOLEAN_TRUE) + || observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { + return PhysicalType.BOOLEAN_TRUE; // Use TRUE as canonical boolean type + } + + return observedTypes.iterator().next(); + } + } +} + diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index d82a241ba148..083242c6b743 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.net.InetAddress; +import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.FileScanTask; @@ -136,7 +137,7 @@ public void testVariantShreddingWrite() throws IOException { "zip", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16))); - GroupType address = variant("address", 2, objectFields(name, streets, zip)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(name, streets, zip)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); @@ -160,7 +161,7 @@ public void testVariantShreddingWithNullFirstRow() throws IOException { "state", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, objectFields(city, state)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, state)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); @@ -190,7 +191,7 @@ public void testVariantShreddingWithTwoVariantColumns() throws IOException { "zip", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); - GroupType address = variant("address", 2, objectFields(city, zip)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, zip)); GroupType type = field( @@ -199,7 +200,7 @@ public void testVariantShreddingWithTwoVariantColumns() throws IOException { PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType verified = field("verified", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - GroupType metadata = variant("metadata", 3, objectFields(type, verified)); + GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(type, verified)); MessageType expectedSchema = parquetSchema(address, metadata); @@ -227,14 +228,14 @@ public void testVariantShreddingWithTwoVariantColumnsOneNull() throws IOExceptio "street", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, objectFields(street)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(street)); GroupType label = field( "label", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType metadata = variant("metadata", 3, objectFields(label)); + GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(label)); MessageType expectedSchema = parquetSchema(address, metadata); @@ -250,13 +251,385 @@ public void testVariantShreddingDisabled() throws IOException { String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; sql("INSERT INTO %s VALUES %s", tableName, values); - GroupType address = variant("address", 2); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testConsistentTypeCreatesTypedValue() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field("age", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 2: Inconsistent Type → Value Only + * + *

When a field appears with different types across rows, only the "value" field should be + * created (no "typed_value"). + */ + @TestTemplate + public void testInconsistentTypeCreatesValueOnly() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // "age" appears as both string and int - inconsistent type + String values = + "(1, parse_json('{\"age\": \"25\"}'))," + + " (2, parse_json('{\"age\": 30}'))," + + " (3, parse_json('{\"age\": \"35\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "age" should have only "value" field, no "typed_value" + GroupType age = valueOnlyField("age"); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 3: Rare Fields Are Dropped + * + *

Fields that appear in less than the configured threshold percentage of rows should be + * dropped from the shredded schema. + */ + @TestTemplate + public void testRareFieldIsDropped() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + // Set threshold to 20% (0.2) + spark.conf().set(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, "0.2"); + + // "common" appears in all 10 rows (100%), "rare" appears in 1 row (10%) + String values = + "(1, parse_json('{\"common\": 1, \"rare\": 100}'))," + + " (2, parse_json('{\"common\": 2}'))," + + " (3, parse_json('{\"common\": 3}'))," + + " (4, parse_json('{\"common\": 4}'))," + + " (5, parse_json('{\"common\": 5}'))," + + " (6, parse_json('{\"common\": 6}'))," + + " (7, parse_json('{\"common\": 7}'))," + + " (8, parse_json('{\"common\": 8}'))," + + " (9, parse_json('{\"common\": 9}'))," + + " (10, parse_json('{\"common\": 10}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Only "common" should be present (appears in 100% of rows) + // "rare" should be dropped (appears in only 10% of rows, below 20% threshold) + GroupType common = + field("common", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(common)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Reset threshold to default to avoid interference with other tests + spark.conf().unset(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD); + } + + /** + * Test Heuristic 4: Boolean Type Handling + * + *

Both "true" and "false" values should be treated as the same consistent boolean type, and a + * typed_value field should be created. + */ + @TestTemplate + public void testBooleanTypeHandling() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // "active" field has both true and false values - should be treated as consistent boolean + String values = + "(1, parse_json('{\"active\": true}'))," + + " (2, parse_json('{\"active\": false}'))," + + " (3, parse_json('{\"active\": true}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "active" should have typed_value with boolean type + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Reset field limit to default to avoid interference from previous tests + spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); + } + + /** + * Test Heuristic 5: Mixed Fields (Consistent and Inconsistent) + * + *

Tests a realistic scenario with multiple fields where some have consistent types and others + * don't. + */ + @TestTemplate + public void testMixedFieldsConsistentAndInconsistent() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // "name": always string (consistent) + // "age": mixed int/string (inconsistent) + // "active": boolean (consistent) + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30, \"active\": true}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": \"25\", \"active\": false}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": \"35\", \"active\": true}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "name" should have typed_value (consistent string) + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + + // "age" should NOT have typed_value (inconsistent types) + GroupType age = valueOnlyField("age"); + + // "active" should have typed_value (consistent boolean) + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active, age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 6: Field Limit Enforcement + * + *

Verify that the analyzer respects the maximum field limit and stops adding fields once the + * limit is reached. + */ + @TestTemplate + public void testMaxFieldLimitEnforcement() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + // Set very low field limit + spark.conf().set(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, "4"); + + // Create rows with many fields (a, b, c, d, e, f) + String values = + "(1, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))," + + " (2, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // With limit 4: field "a" (2 fields: value + typed_value) + field "b" (2 fields) = 4 total + // Fields are added alphabetically, so only "a" and "b" should be present + GroupType a = + field("a", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType b = + field("b", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(a, b)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Reset field limit to default to avoid interference from previous tests + spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); + } + + /** + * Test Heuristic 7: Decimal Type Handling - Inconsistent Scales + * + *

Verify that decimal fields with different scales are treated as inconsistent types + * and only get a value field (no typed_value). + */ + @TestTemplate + public void testDecimalTypeHandlingInconsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Decimal values with different scales: scale 6, 2, 2 + // 123.456789 → precision 9, scale 6 + // 678.90 → precision 5, scale 2 + // 999.99 → precision 5, scale 2 + // These are treated as inconsistent types due to different scales + String values = + "(1, parse_json('{\"price\": 123.456789}'))," + + " (2, parse_json('{\"price\": 678.90}'))," + + " (3, parse_json('{\"price\": 999.99}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "price" has inconsistent scales, so only "value" field (no typed_value) + GroupType price = valueOnlyField("price"); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 7b: Decimal Type Handling - Consistent Scales + * + *

Verify that decimal fields with the same scale get proper typed_value with inferred + * precision/scale. + */ + @TestTemplate + public void testDecimalTypeHandlingConsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Decimal values with consistent scale (all 2 decimal places) + String values = + "(1, parse_json('{\"price\": 123.45}'))," + + " (2, parse_json('{\"price\": 678.90}'))," + + " (3, parse_json('{\"price\": 999.99}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "price" should have typed_value with inferred DECIMAL(5,2) type + GroupType price = + field( + "price", + org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(2, 5)) + .named("typed_value")); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 7c: Decimal Type Handling - Inconsistent After Buffering + * + *

Verify that when buffered rows have consistent decimal scales but subsequent unbuffered rows + * have inconsistent scales, the inconsistent values are written to the value field only. + * The schema is inferred from buffered rows and should include typed_value for the consistent type. + */ + @TestTemplate + public void testDecimalTypeHandlingInconsistentAfterBuffering() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + // Set a small buffer size to test the scenario + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + // First 3 rows (buffered): consistent scale (2 decimal places) + // 4th row onwards (unbuffered): different scale (6 decimal places) + // Schema should be inferred from buffered rows with DECIMAL(5,2) + // The unbuffered row with different scale should still write successfully to value field + String values = + "(1, parse_json('{\"price\": 123.45}'))," + + " (2, parse_json('{\"price\": 678.90}'))," + + " (3, parse_json('{\"price\": 999.99}'))," + + " (4, parse_json('{\"price\": 111.111111}'))"; // Different scale - should write to value only + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Schema should have typed_value with DECIMAL(5,2) based on buffered rows + GroupType price = + field( + "price", + org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(2, 5)) + .named("typed_value")); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify all rows were written successfully + List result = sql("SELECT id, address FROM %s ORDER BY id", tableName); + assertThat(result).hasSize(4); + + // Reset buffer size to default + spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); + } + + /** + * Test Heuristic 8: Array Type Handling + * + *

Verify that array fields with consistent element types get proper typed_value. + */ + @TestTemplate + public void testArrayTypeHandling() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Arrays with consistent element types (all strings) + String values = + "(1, parse_json('{\"tags\": [\"java\", \"scala\", \"python\"]}'))," + + " (2, parse_json('{\"tags\": [\"rust\", \"go\"]}'))," + + " (3, parse_json('{\"tags\": [\"javascript\"]}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "tags" should have typed_value with list of strings + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(tags)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 9: Nested Object Handling + * + *

Verify that simple nested objects are recursively shredded. + */ + @TestTemplate + public void testNestedObjectHandling() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Nested objects with consistent structure + String values = + "(1, parse_json('{\"location\": {\"city\": \"Seattle\", \"zip\": 98101}}'))," + + " (2, parse_json('{\"location\": {\"city\": \"Portland\", \"zip\": 97201}}'))," + + " (3, parse_json('{\"location\": {\"city\": \"NYC\", \"zip\": 10001}}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Nested "location" object should be shredded with its fields + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType zip = + field("zip", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); + GroupType location = field("location", objectFields(zip, city)); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(location)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** Helper method to create a value-only field (no typed_value) for inconsistent types. */ + private static GroupType valueOnlyField(String name) { + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); @@ -283,8 +656,8 @@ private static MessageType parquetSchema(Type... variantTypes) { .named("table"); } - private static GroupType variant(String name, int fieldId) { - return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + private static GroupType variant(String name, int fieldId, Type.Repetition repetition) { + return org.apache.parquet.schema.Types.buildGroup(repetition) .id(fieldId) .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) .required(PrimitiveType.PrimitiveTypeName.BINARY) @@ -294,9 +667,10 @@ private static GroupType variant(String name, int fieldId) { .named(name); } - private static GroupType variant(String name, int fieldId, Type shreddedType) { + private static GroupType variant( + String name, int fieldId, Type.Repetition repetition, Type shreddedType) { checkShreddedType(shreddedType); - return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + return org.apache.parquet.schema.Types.buildGroup(repetition) .id(fieldId) .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) .required(PrimitiveType.PrimitiveTypeName.BINARY) From 07b17224589a7a9c8f76aae063b0f65ee452d1c6 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Fri, 9 Jan 2026 12:20:07 -0800 Subject: [PATCH 03/17] Simplify heuristics to most common type --- .../iceberg/parquet/ParquetVariantUtil.java | 4 +- .../parquet/ParquetVariantWriters.java | 48 -- .../iceberg/parquet/VariantWriterBuilder.java | 16 +- .../iceberg/spark/SparkSQLProperties.java | 12 - .../apache/iceberg/spark/SparkWriteConf.java | 18 +- .../spark/source/SchemaInferenceVisitor.java | 33 +- .../source/VariantShreddingAnalyzer.java | 513 +++++------------- .../spark/variant/TestVariantShredding.java | 418 +++----------- 8 files changed, 238 insertions(+), 824 deletions(-) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java index d94760773e51..ac418a1127bd 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java @@ -57,7 +57,7 @@ import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; -public class ParquetVariantUtil { +class ParquetVariantUtil { private ParquetVariantUtil() {} /** @@ -212,7 +212,7 @@ static int scale(PrimitiveType primitive) { * @param value a variant value * @return a Parquet schema that can fully shred the value */ - public static Type toParquetSchema(VariantValue value) { + static Type toParquetSchema(VariantValue value) { return VariantVisitor.visit(value, new ParquetSchemaProducer()); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java index 42cdee7a1a5c..08016667bdab 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java @@ -65,11 +65,6 @@ static ParquetValueWriter primitive( return new PrimitiveWriter<>(writer, Sets.immutableEnumSet(Arrays.asList(types))); } - static ParquetValueWriter decimal( - ParquetValueWriter writer, int expectedScale, PhysicalType... types) { - return new DecimalWriter(writer, expectedScale, Sets.immutableEnumSet(Arrays.asList(types))); - } - @SuppressWarnings("unchecked") static ParquetValueWriter shredded( int valueDefinitionLevel, @@ -258,49 +253,6 @@ public void setColumnStore(ColumnWriteStore columnStore) { } } - /** - * A TypedWriter for decimals that validates scale before writing. - * If the scale doesn't match, it returns false from canWrite() to trigger fallback to value field. - */ - private static class DecimalWriter implements TypedWriter { - private final Set types; - private final ParquetValueWriter writer; - private final int expectedScale; - - private DecimalWriter( - ParquetValueWriter writer, int expectedScale, Set types) { - this.types = types; - this.writer = (ParquetValueWriter) writer; - this.expectedScale = expectedScale; - } - - @Override - public Set types() { - return types; - } - - @Override - public void write(int repetitionLevel, VariantValue value) { - java.math.BigDecimal decimal = (java.math.BigDecimal) value.asPrimitive().get(); - // Validate scale matches before writing - if (decimal.scale() != expectedScale) { - throw new IllegalArgumentException( - "Cannot write decimal with scale " + decimal.scale() + " to schema expecting scale " + expectedScale); - } - writer.write(repetitionLevel, decimal); - } - - @Override - public List> columns() { - return writer.columns(); - } - - @Override - public void setColumnStore(ColumnWriteStore columnStore) { - writer.setColumnStore(columnStore); - } - } - private static class ShreddedVariantWriter implements ParquetValueWriter { private final int valueDefinitionLevel; private final ParquetValueWriter valueWriter; diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java index 53cf5d9933d6..a447a102690a 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java @@ -198,31 +198,27 @@ public Optional> visit(StringLogicalTypeAnnotation ignored @Override public Optional> visit(DecimalLogicalTypeAnnotation decimal) { ParquetValueWriter writer; - int scale = decimal.getScale(); switch (desc.getPrimitiveType().getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: case BINARY: writer = - ParquetVariantWriters.decimal( + ParquetVariantWriters.primitive( ParquetValueWriters.decimalAsFixed( - desc, decimal.getPrecision(), scale), - scale, + desc, decimal.getPrecision(), decimal.getScale()), PhysicalType.DECIMAL16); return Optional.of(writer); case INT64: writer = - ParquetVariantWriters.decimal( + ParquetVariantWriters.primitive( ParquetValueWriters.decimalAsLong( - desc, decimal.getPrecision(), scale), - scale, + desc, decimal.getPrecision(), decimal.getScale()), PhysicalType.DECIMAL8); return Optional.of(writer); case INT32: writer = - ParquetVariantWriters.decimal( + ParquetVariantWriters.primitive( ParquetValueWriters.decimalAsInteger( - desc, decimal.getPrecision(), scale), - scale, + desc, decimal.getPrecision(), decimal.getScale()), PhysicalType.DECIMAL4); return Optional.of(writer); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index e111becad89e..b12606d23948 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -119,16 +119,4 @@ private SparkSQLProperties() {} public static final String VARIANT_INFERENCE_BUFFER_SIZE = "spark.sql.iceberg.variant.inference.buffer-size"; public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; - - // Controls the minimum occurrence threshold for variant fields during shredding - // Fields that appear in fewer than this percentage of rows will be dropped - public static final String VARIANT_MIN_OCCURRENCE_THRESHOLD = - "spark.sql.iceberg.variant.min-occurrence-threshold"; - public static final double VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT = 0.1; // 10% - - // Controls the maximum number of fields to shred in a variant column - // This prevents creating overly wide Parquet schemas - public static final String VARIANT_MAX_SHREDDED_FIELDS = - "spark.sql.iceberg.variant.max-shredded-fields"; - public static final int VARIANT_MAX_SHREDDED_FIELDS_DEFAULT = 300; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 34fcd2f1e467..80d245712e6b 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -510,22 +510,14 @@ private Map dataWriteProperties() { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); - + // Add variant shredding configuration properties if (shredVariants()) { - String variantMaxFields = sessionConf.get(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, null); - if (variantMaxFields != null) { - writeProperties.put(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, variantMaxFields); - } - - String variantMinOccurrence = sessionConf.get(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, null); - if (variantMinOccurrence != null) { - writeProperties.put(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, variantMinOccurrence); - } - - String variantBufferSize = sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); + String variantBufferSize = + sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); if (variantBufferSize != null) { - writeProperties.put(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); + writeProperties.put( + SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); } } break; diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index c03fc74f00e3..6903f1f03353 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -20,16 +20,10 @@ import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.List; import java.util.Map; -import org.apache.iceberg.parquet.ParquetVariantUtil; -import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; import org.apache.iceberg.variants.Variant; -import org.apache.iceberg.variants.VariantMetadata; -import org.apache.iceberg.variants.VariantValue; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; @@ -43,13 +37,10 @@ import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.VariantType; -import org.apache.spark.unsafe.types.VariantVal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * A visitor that infers variant shredding schemas by analyzing buffered rows of data. - */ +/** A visitor that infers variant shredding schemas by analyzing buffered rows of data. */ public class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); @@ -61,20 +52,7 @@ public SchemaInferenceVisitor( List bufferedRows, StructType sparkSchema, Map properties) { this.bufferedRows = bufferedRows; this.sparkSchema = sparkSchema; - - double minOccurrenceThreshold = - Double.parseDouble( - properties.getOrDefault( - SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, - String.valueOf(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT))); - - int maxFields = - Integer.parseInt( - properties.getOrDefault( - SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, - String.valueOf(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS_DEFAULT))); - - this.analyzer = new VariantShreddingAnalyzer(minOccurrenceThreshold, maxFields); + this.analyzer = new VariantShreddingAnalyzer(); } @Override @@ -151,10 +129,6 @@ public Type map(MapType sMap, GroupType map, Type key, Type value) { public Type variant(VariantType sVariant, GroupType variant) { int variantFieldIndex = getFieldIndex(currentPath()); - // Apply heuristics to determine the shredding schema: - // - Fields must appear in at least the configured percentage of rows - // - Type consistency determines if typed_value is created - // - Maximum field count to avoid overly wide schemas if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { Type shreddedType = analyzer.analyzeAndCreateSchema(bufferedRows, variantFieldIndex); if (shreddedType != null) { @@ -190,8 +164,7 @@ private int getFieldIndex(String[] path) { } else { // Nested field - navigate through struct hierarchy // For now, we only support direct struct nesting (not arrays/maps) - LOG.debug( - "Attempting to resolve nested variant field path: {}", String.join(".", path)); + LOG.debug("Attempting to resolve nested variant field path: {}", String.join(".", path)); // TODO: Implement full nested field resolution when needed // This would require tracking the current struct context during traversal // and maintaining a stack of field indices diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java index 581043cd802e..27b526134737 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.iceberg.parquet.ParquetVariantUtil; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.variants.PhysicalType; @@ -40,36 +39,20 @@ import org.apache.parquet.schema.Types; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.unsafe.types.VariantVal; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Analyzes variant data across buffered rows to determine an optimal shredding schema. - ** + * *

    - *
  • If a field appears consistently with a consistent type → create both {@code value} and - * {@code typed_value} - *
  • If a field appears with inconsistent types → only create {@code value} - *
  • Drop fields that occur in less than the configured threshold of sampled rows - *
  • Cap the maximum fields to shred + *
  • shred to the most common type *
*/ public class VariantShreddingAnalyzer { - private static final Logger LOG = LoggerFactory.getLogger(VariantShreddingAnalyzer.class); + private static final String TYPED_VALUE = "typed_value"; + private static final String VALUE = "value"; + private static final String ELEMENT = "element"; - private final double minOccurrenceThreshold; - private final int maxFields; - - /** - * Creates a new analyzer with the specified configuration. - * - * @param minOccurrenceThreshold minimum occurrence threshold (e.g., 0.1 for 10%) - * @param maxFields maximum number of fields to shred - */ - public VariantShreddingAnalyzer(double minOccurrenceThreshold, int maxFields) { - this.minOccurrenceThreshold = minOccurrenceThreshold; - this.maxFields = maxFields; - } + public VariantShreddingAnalyzer() {} /** * Analyzes buffered variant values to determine the optimal shredding schema. @@ -79,17 +62,13 @@ public VariantShreddingAnalyzer(double minOccurrenceThreshold, int maxFields) { * @return the shredded schema type, or null if no shredding should be performed */ public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) { - if (bufferedRows.isEmpty()) { - return null; - } - List variantValues = extractVariantValues(bufferedRows, variantFieldIndex); if (variantValues.isEmpty()) { return null; } - FieldStats stats = analyzeFields(variantValues); - return buildShreddedSchema(stats, variantValues.size()); + PathNode root = buildPathTree(variantValues); + return buildTypedValue(root, root.info.getMostCommonType()); } private static List extractVariantValues( @@ -100,12 +79,12 @@ private static List extractVariantValues( if (!row.isNullAt(variantFieldIndex)) { VariantVal variantVal = row.getVariant(variantFieldIndex); if (variantVal != null) { - VariantValue variantValue = - VariantValue.from( - VariantMetadata.from( - ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), - ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); - values.add(variantValue); + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + values.add(variantValue); } } } @@ -113,433 +92,231 @@ private static List extractVariantValues( return values; } - private static FieldStats analyzeFields(List variantValues) { - FieldStats stats = new FieldStats(); + private static PathNode buildPathTree(List variantValues) { + PathNode root = new PathNode(null); + root.info = new FieldInfo(); for (VariantValue value : variantValues) { - if (value.type() == PhysicalType.OBJECT) { - VariantObject obj = value.asObject(); - for (String fieldName : obj.fieldNames()) { - VariantValue fieldValue = obj.get(fieldName); - if (fieldValue != null) { - stats.recordField(fieldName, fieldValue); - } - } - } + traverse(root, value); } - return stats; + return root; } - private Type buildShreddedSchema(FieldStats stats, int totalRows) { - int minOccurrences = (int) Math.ceil(totalRows * minOccurrenceThreshold); - - // Get fields that meet the occurrence threshold - Set candidateFields = Sets.newTreeSet(); - for (Map.Entry entry : stats.fieldInfoMap.entrySet()) { - String fieldName = entry.getKey(); - FieldInfo info = entry.getValue(); - - if (info.occurrenceCount >= minOccurrences) { - candidateFields.add(fieldName); - } else { - LOG.debug( - "Field '{}' appears only {} times out of {} (< {}%), dropping", - fieldName, - info.occurrenceCount, - totalRows, - (int) (minOccurrenceThreshold * 100)); - } + private static void traverse(PathNode node, VariantValue value) { + if (value == null) { + return; } - if (candidateFields.isEmpty()) { - return null; - } + node.info.observe(value); - // Build the typed_value struct with field count limit - Types.GroupBuilder objectBuilder = Types.buildGroup(Type.Repetition.OPTIONAL); - int fieldCount = 0; - - for (String fieldName : candidateFields) { - FieldInfo info = stats.fieldInfoMap.get(fieldName); - - if (info.hasConsistentType()) { - Type shreddedFieldType = createShreddedFieldType(fieldName, info); - if (shreddedFieldType != null) { - if (fieldCount + 2 > maxFields) { - LOG.debug( - "Reached maximum field limit ({}) while processing field '{}', stopping", - maxFields, - fieldName); - break; + if (value.type() == PhysicalType.OBJECT) { + VariantObject obj = value.asObject(); + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + PathNode childNode = node.objectChildren.computeIfAbsent(fieldName, PathNode::new); + if (childNode.info == null) { + childNode.info = new FieldInfo(); } - objectBuilder.addField(shreddedFieldType); - fieldCount += 2; + traverse(childNode, fieldValue); } - } else { - Type valueOnlyField = createValueOnlyField(fieldName); - if (fieldCount + 1 > maxFields) { - LOG.debug( - "Reached maximum field limit ({}) while processing field '{}', stopping", - maxFields, - fieldName); - break; + } + } else if (value.type() == PhysicalType.ARRAY) { + VariantArray array = value.asArray(); + int numElements = array.numElements(); + if (node.arrayElement == null) { + node.arrayElement = new PathNode(null); + node.arrayElement.info = new FieldInfo(); + } + for (int i = 0; i < numElements; i++) { + VariantValue element = array.get(i); + if (element != null) { + traverse(node.arrayElement, element); } - objectBuilder.addField(valueOnlyField); - fieldCount += 1; - LOG.debug( - "Field '{}' has inconsistent types ({}), creating value-only field", - fieldName, - info.observedTypes); } } - - if (fieldCount == 0) { - return null; - } - - LOG.info("Created shredded schema with {} fields for {} candidate fields", fieldCount, candidateFields.size()); - return objectBuilder.named("typed_value"); } - private static Type createShreddedFieldType(String fieldName, FieldInfo info) { - PhysicalType physicalType = info.getConsistentType(); - if (physicalType == null) { - return null; - } + private static Type buildFieldGroup(PathNode node) { + Type typedValue = buildTypedValue(node, node.info.getMostCommonType()); + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(typedValue) + .named(node.fieldName); + } - // For array types, analyze the first value to determine element type + private static Type buildTypedValue(PathNode node, PhysicalType physicalType) { Type typedValue; if (physicalType == PhysicalType.ARRAY) { - typedValue = createArrayTypedValue(info); - } else if (physicalType == PhysicalType.DECIMAL4 - || physicalType == PhysicalType.DECIMAL8 - || physicalType == PhysicalType.DECIMAL16) { - // For decimals, infer precision and scale from actual values - typedValue = createDecimalTypedValue(info, physicalType); + typedValue = createArrayTypedValue(node); } else if (physicalType == PhysicalType.OBJECT) { - // For nested objects, attempt recursive shredding - typedValue = createNestedObjectTypedValue(info); + typedValue = createObjectTypedValue(node); } else { - // Convert the physical type to a Parquet type for typed_value - typedValue = convertPhysicalTypeToParquet(physicalType); + typedValue = createPrimitiveTypedValue(node.info, physicalType); } - if (typedValue == null) { - // If we can't create a typed_value (e.g., inconsistent decimal scales), - // create a value-only field instead of skipping the field entirely - return Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .named(fieldName); - } - - return Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .addField(typedValue) - .named(fieldName); + return typedValue; } - private static Type createDecimalTypedValue(FieldInfo info, PhysicalType decimalType) { - // Analyze decimal values to determine precision and scale - // All values must have the same scale to be considered consistent - Integer consistentScale = null; - int maxPrecision = 0; - - for (VariantValue value : info.observedValues) { - if (value.type() == decimalType) { - try { - VariantPrimitive primitive = value.asPrimitive(); - Object decimalValue = primitive.get(); - if (decimalValue instanceof BigDecimal) { - BigDecimal bd = (BigDecimal) decimalValue; - int precision = bd.precision(); - int scale = bd.scale(); - - // Check scale consistency - if (consistentScale == null) { - consistentScale = scale; - } else if (consistentScale != scale) { - // Different scales mean inconsistent types - no typed_value - LOG.debug( - "Decimal values have inconsistent scales ({} vs {}), skipping typed_value", - consistentScale, - scale); - return null; - } - - maxPrecision = Math.max(maxPrecision, precision); - } - } catch (Exception e) { - LOG.debug("Failed to analyze decimal value", e); - } - } - } - - if (maxPrecision == 0 || consistentScale == null) { - LOG.debug("Could not determine decimal precision/scale, skipping typed_value"); + private static Type createObjectTypedValue(PathNode node) { + if (node.objectChildren.isEmpty()) { return null; } - // Determine the appropriate Parquet type based on precision - PrimitiveType.PrimitiveTypeName primitiveType; - if (maxPrecision <= 9) { - primitiveType = PrimitiveType.PrimitiveTypeName.INT32; - } else if (maxPrecision <= 18) { - primitiveType = PrimitiveType.PrimitiveTypeName.INT64; - } else { - primitiveType = PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY; + Types.GroupBuilder builder = Types.buildGroup(Type.Repetition.OPTIONAL); + for (PathNode child : node.objectChildren.values()) { + builder.addField(buildFieldGroup(child)); } - return Types.optional(primitiveType) - .as(LogicalTypeAnnotation.decimalType(consistentScale, maxPrecision)) - .named("typed_value"); + return builder.named(TYPED_VALUE); } - private static Type createNestedObjectTypedValue(FieldInfo info) { - // For nested objects, we can recursively analyze their fields - // For now, we'll create a simpler representation - // A full implementation would recursively build the object structure - - // Get a sample object to analyze its fields - for (VariantValue value : info.observedValues) { - if (value.type() == PhysicalType.OBJECT) { - try { - VariantObject obj = value.asObject(); - int numFields = obj.numFields(); - - // Only shred simple nested objects (not too many fields) - if (numFields > 0 && numFields <= 20) { - // Analyze fields in the nested object - Map> nestedFieldTypes = Maps.newHashMap(); - - for (String fieldName : obj.fieldNames()) { - VariantValue fieldValue = obj.get(fieldName); - if (fieldValue != null) { - nestedFieldTypes - .computeIfAbsent(fieldName, k -> Sets.newHashSet()) - .add(fieldValue.type()); - } - } - - // Build nested struct with fields that have consistent types - Types.GroupBuilder nestedBuilder = - Types.buildGroup(Type.Repetition.OPTIONAL); - int fieldCount = 0; - - for (Map.Entry> entry : nestedFieldTypes.entrySet()) { - String fieldName = entry.getKey(); - Set types = entry.getValue(); - - // Only include fields with consistent types - if (types.size() == 1) { - PhysicalType fieldType = types.iterator().next(); - Type fieldParquetType = convertPhysicalTypeToParquet(fieldType); - if (fieldParquetType != null) { - GroupType nestedField = - Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .addField(fieldParquetType) - .named(fieldName); - nestedBuilder.addField(nestedField); - fieldCount++; - } - } - } - - if (fieldCount > 0) { - return nestedBuilder.named("typed_value"); - } - } - } catch (Exception e) { - LOG.debug("Failed to analyze nested object", e); - } - break; - } - } + private static Type createArrayTypedValue(PathNode node) { + PathNode elementNode = node.arrayElement; + PhysicalType elementType = elementNode.info.getMostCommonType(); + Type elementTypedValue = buildTypedValue(elementNode, elementType); - LOG.debug("Skipping nested object - complex structure or analysis failed"); - return null; + GroupType elementGroup = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(elementTypedValue) + .named(ELEMENT); + + return Types.optionalList().element(elementGroup).named(TYPED_VALUE); } - private static Type createArrayTypedValue(FieldInfo info) { - // Get a sample array value to analyze element types - for (VariantValue value : info.observedValues) { - if (value.type() == PhysicalType.ARRAY) { - try { - VariantArray array = value.asArray(); - int numElements = array.numElements(); - if (numElements > 0) { - // Analyze elements to determine if they have consistent type - Set elementTypes = Sets.newHashSet(); - for (int i = 0; i < numElements; i++) { - elementTypes.add(array.get(i).type()); - } - - // If all elements have consistent type, create typed array - if (elementTypes.size() == 1 - || (elementTypes.size() == 2 - && elementTypes.contains(PhysicalType.BOOLEAN_TRUE) - && elementTypes.contains(PhysicalType.BOOLEAN_FALSE))) { - PhysicalType elementType = elementTypes.iterator().next(); - if (elementType == PhysicalType.BOOLEAN_FALSE - || elementType == PhysicalType.BOOLEAN_TRUE) { - elementType = PhysicalType.BOOLEAN_TRUE; - } - Type elementParquetType = convertPhysicalTypeToParquet(elementType); - if (elementParquetType != null) { - // Create list with typed element - GroupType element = - Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .addField(elementParquetType) - .named("element"); - return Types.optionalList().element(element).named("typed_value"); - } - } - } - } catch (Exception e) { - LOG.debug("Failed to analyze array elements", e); - } - break; - } + private static class PathNode { + private final String fieldName; + private final Map objectChildren = Maps.newTreeMap(); + private PathNode arrayElement = null; + private FieldInfo info = null; + + private PathNode(String fieldName) { + this.fieldName = fieldName; } - return null; } - private static Type createValueOnlyField(String fieldName) { - // Create a field with only the value field (no typed_value) - return Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .named(fieldName); + /** Use DECIMAL with maximum precision and scale as the shredding type */ + private static Type createDecimalTypedValue(FieldInfo info) { + int maxPrecision = info.maxDecimalPrecision; + int maxScale = info.maxDecimalScale; + + if (maxPrecision <= 9) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else if (maxPrecision <= 18) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else { + return Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } } - private static Type convertPhysicalTypeToParquet(PhysicalType physicalType) { - switch (physicalType) { + private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primitiveType) { + switch (primitiveType) { case BOOLEAN_TRUE: case BOOLEAN_FALSE: - return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); case INT8: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(8, true)) - .named("typed_value"); + .named(TYPED_VALUE); case INT16: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(16, true)) - .named("typed_value"); + .named(TYPED_VALUE); case INT32: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(32, true)) - .named("typed_value"); + .named(TYPED_VALUE); case INT64: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); case FLOAT: - return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); case DOUBLE: - return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); case STRING: return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) .as(LogicalTypeAnnotation.stringType()) - .named("typed_value"); + .named(TYPED_VALUE); case BINARY: - return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); case DATE: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.dateType()) - .named("typed_value"); + .named(TYPED_VALUE); case TIMESTAMPTZ: return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named("typed_value"); + .named(TYPED_VALUE); case TIMESTAMPNTZ: return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named("typed_value"); + .named(TYPED_VALUE); case DECIMAL4: case DECIMAL8: case DECIMAL16: - // Decimals are now handled in createDecimalTypedValue() - // This case should not be reached for consistent decimal types - LOG.debug("Decimal type {} should be handled by createDecimalTypedValue()", physicalType); - return null; - - case ARRAY: - // Arrays are now handled in createArrayTypedValue() - LOG.debug("Array type should be handled by createArrayTypedValue()"); - return null; - - case OBJECT: - // Nested objects are now handled in createNestedObjectTypedValue() - LOG.debug("Object type should be handled by createNestedObjectTypedValue()"); - return null; + return createDecimalTypedValue(info); default: - LOG.debug("Unknown physical type: {}", physicalType); - return null; - } - } - - /** Tracks statistics about fields across multiple variant values. */ - private static class FieldStats { - private final Map fieldInfoMap = Maps.newHashMap(); - - void recordField(String fieldName, VariantValue value) { - FieldInfo info = fieldInfoMap.computeIfAbsent(fieldName, k -> new FieldInfo()); - info.observe(value); + throw new UnsupportedOperationException( + "Unknown primitive physical type: " + primitiveType); } } - /** Tracks occurrence count and type consistency for a single field. */ + /** Tracks occurrence count and types for a single field. */ private static class FieldInfo { - private int occurrenceCount = 0; private final Set observedTypes = Sets.newHashSet(); - private final List observedValues = new java.util.ArrayList<>(); + private final Map typeCounts = Maps.newHashMap(); + private int maxDecimalPrecision = 0; + private int maxDecimalScale = 0; void observe(VariantValue value) { - occurrenceCount++; - observedTypes.add(value.type()); - observedValues.add(value); - } - - boolean hasConsistentType() { - // Handle boolean types specially - both TRUE and FALSE map to BOOLEAN - if (observedTypes.size() == 2 - && observedTypes.contains(PhysicalType.BOOLEAN_TRUE) - && observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { - return true; + // Use BOOLEAN_TRUE for both TRUE/FALSE values + PhysicalType type = + value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); + observedTypes.add(type); + typeCounts.compute(type, (k, v) -> (v == null) ? 1 : v + 1); + + // Track max precision and scale for decimal types + if (type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16) { + VariantPrimitive primitive = value.asPrimitive(); + Object decimalValue = primitive.get(); + if (decimalValue instanceof BigDecimal) { + BigDecimal bd = (BigDecimal) decimalValue; + maxDecimalPrecision = Math.max(maxDecimalPrecision, bd.precision()); + maxDecimalScale = Math.max(maxDecimalScale, bd.scale()); + } } - return observedTypes.size() == 1; } - PhysicalType getConsistentType() { - if (!hasConsistentType()) { - return null; - } - - // Handle boolean types - if (observedTypes.contains(PhysicalType.BOOLEAN_TRUE) - || observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { - return PhysicalType.BOOLEAN_TRUE; // Use TRUE as canonical boolean type - } - - return observedTypes.iterator().next(); + PhysicalType getMostCommonType() { + return typeCounts.entrySet().stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .orElse(null); } } } - diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index 083242c6b743..5f4eb2a2732f 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -19,11 +19,11 @@ package org.apache.iceberg.spark.variant; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.parquet.schema.Types.optional; import static org.assertj.core.api.Assertions.assertThat; import java.io.IOException; import java.net.InetAddress; -import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.FileScanTask; @@ -112,140 +112,8 @@ public void after() { validationCatalog.dropTable(tableIdent, true); } - @TestTemplate - public void testVariantShreddingWrite() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - String values = - "(1, parse_json('{\"name\": \"Joe\", \"streets\": [\"Apt #3\", \"1234 Ave\"], \"zip\": 10001}')), (2, null)"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType name = - field( - "name", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType streets = - field( - "streets", - list( - element( - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, - LogicalTypeAnnotation.stringType())))); - GroupType zip = - field( - "zip", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16))); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(name, streets, zip)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - - @TestTemplate - public void testVariantShreddingWithNullFirstRow() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - String values = "(1, null), (2, parse_json('{\"city\": \"Seattle\", \"state\": \"WA\"}'))"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType city = - field( - "city", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType state = - field( - "state", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, state)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - - @TestTemplate - public void testVariantShreddingWithTwoVariantColumns() throws IOException { - validationCatalog.dropTable(tableIdent, true); - validationCatalog.createTable( - tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); - - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - String values = - "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}'), parse_json('{\"type\": \"home\", \"verified\": true}')), " - + "(2, null, null)"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType city = - field( - "city", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType zip = - field( - "zip", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, zip)); - - GroupType type = - field( - "type", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType verified = - field("verified", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(type, verified)); - - MessageType expectedSchema = parquetSchema(address, metadata); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - - @TestTemplate - public void testVariantShreddingWithTwoVariantColumnsOneNull() throws IOException { - validationCatalog.dropTable(tableIdent, true); - validationCatalog.createTable( - tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); - - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - // First row: address is null, metadata has value - // Second row: address has value, metadata is null - String values = - "(1, null, parse_json('{\"label\": \"primary\"}'))," - + " (2, parse_json('{\"street\": \"Main St\"}'), null)"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType street = - field( - "street", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(street)); - - GroupType label = - field( - "label", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(label)); - - MessageType expectedSchema = parquetSchema(address, metadata); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - @TestTemplate public void testVariantShreddingDisabled() throws IOException { - // Test with shredding explicitly disabled spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; @@ -259,7 +127,7 @@ public void testVariantShreddingDisabled() throws IOException { } @TestTemplate - public void testConsistentTypeCreatesTypedValue() throws IOException { + public void testConsistentType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = @@ -274,7 +142,10 @@ public void testConsistentTypeCreatesTypedValue() throws IOException { shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType age = - field("age", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); MessageType expectedSchema = parquetSchema(address); @@ -282,25 +153,21 @@ public void testConsistentTypeCreatesTypedValue() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 2: Inconsistent Type → Value Only - * - *

When a field appears with different types across rows, only the "value" field should be - * created (no "typed_value"). - */ @TestTemplate - public void testInconsistentTypeCreatesValueOnly() throws IOException { + public void testInconsistentType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // "age" appears as both string and int - inconsistent type String values = "(1, parse_json('{\"age\": \"25\"}'))," + " (2, parse_json('{\"age\": 30}'))," + " (3, parse_json('{\"age\": \"35\"}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "age" should have only "value" field, no "typed_value" - GroupType age = valueOnlyField("age"); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age)); MessageType expectedSchema = parquetSchema(address); @@ -308,172 +175,80 @@ public void testInconsistentTypeCreatesValueOnly() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 3: Rare Fields Are Dropped - * - *

Fields that appear in less than the configured threshold percentage of rows should be - * dropped from the shredded schema. - */ - @TestTemplate - public void testRareFieldIsDropped() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Set threshold to 20% (0.2) - spark.conf().set(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, "0.2"); - - // "common" appears in all 10 rows (100%), "rare" appears in 1 row (10%) - String values = - "(1, parse_json('{\"common\": 1, \"rare\": 100}'))," - + " (2, parse_json('{\"common\": 2}'))," - + " (3, parse_json('{\"common\": 3}'))," - + " (4, parse_json('{\"common\": 4}'))," - + " (5, parse_json('{\"common\": 5}'))," - + " (6, parse_json('{\"common\": 6}'))," - + " (7, parse_json('{\"common\": 7}'))," - + " (8, parse_json('{\"common\": 8}'))," - + " (9, parse_json('{\"common\": 9}'))," - + " (10, parse_json('{\"common\": 10}'))"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - // Only "common" should be present (appears in 100% of rows) - // "rare" should be dropped (appears in only 10% of rows, below 20% threshold) - GroupType common = - field("common", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(common)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - - // Reset threshold to default to avoid interference with other tests - spark.conf().unset(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD); - } - - /** - * Test Heuristic 4: Boolean Type Handling - * - *

Both "true" and "false" values should be treated as the same consistent boolean type, and a - * typed_value field should be created. - */ @TestTemplate - public void testBooleanTypeHandling() throws IOException { + public void testPrimitiveType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // "active" field has both true and false values - should be treated as consistent boolean - String values = - "(1, parse_json('{\"active\": true}'))," - + " (2, parse_json('{\"active\": false}'))," - + " (3, parse_json('{\"active\": true}'))"; + String values = "(1, parse_json('123')), (2, parse_json('\"abc\"')), (3, parse_json('12'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "active" should have typed_value with boolean type - GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); - - // Reset field limit to default to avoid interference from previous tests - spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); } - /** - * Test Heuristic 5: Mixed Fields (Consistent and Inconsistent) - * - *

Tests a realistic scenario with multiple fields where some have consistent types and others - * don't. - */ @TestTemplate - public void testMixedFieldsConsistentAndInconsistent() throws IOException { + public void testPrimitiveDecimalType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // "name": always string (consistent) - // "age": mixed int/string (inconsistent) - // "active": boolean (consistent) String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30, \"active\": true}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": \"25\", \"active\": false}'))," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": \"35\", \"active\": true}'))"; + "(1, parse_json('123.56')), (2, parse_json('\"abc\"')), (3, parse_json('12.56'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "name" should have typed_value (consistent string) - GroupType name = - field( - "name", + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - - // "age" should NOT have typed_value (inconsistent types) - GroupType age = valueOnlyField("age"); - - // "active" should have typed_value (consistent boolean) - GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active, age, name)); + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 6: Field Limit Enforcement - * - *

Verify that the analyzer respects the maximum field limit and stops adding fields once the - * limit is reached. - */ @TestTemplate - public void testMaxFieldLimitEnforcement() throws IOException { + public void testBooleanType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Set very low field limit - spark.conf().set(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, "4"); - // Create rows with many fields (a, b, c, d, e, f) String values = - "(1, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))," - + " (2, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))"; + "(1, parse_json('{\"active\": true}'))," + + " (2, parse_json('{\"active\": false}'))," + + " (3, parse_json('{\"active\": true}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // With limit 4: field "a" (2 fields: value + typed_value) + field "b" (2 fields) = 4 total - // Fields are added alphabetically, so only "a" and "b" should be present - GroupType a = - field("a", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - GroupType b = - field("b", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(a, b)); + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); - - // Reset field limit to default to avoid interference from previous tests - spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); } - /** - * Test Heuristic 7: Decimal Type Handling - Inconsistent Scales - * - *

Verify that decimal fields with different scales are treated as inconsistent types - * and only get a value field (no typed_value). - */ @TestTemplate - public void testDecimalTypeHandlingInconsistentScales() throws IOException { + public void testDecimalTypeWithInconsistentScales() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Decimal values with different scales: scale 6, 2, 2 - // 123.456789 → precision 9, scale 6 - // 678.90 → precision 5, scale 2 - // 999.99 → precision 5, scale 2 - // These are treated as inconsistent types due to different scales String values = "(1, parse_json('{\"price\": 123.456789}'))," + " (2, parse_json('{\"price\": 678.90}'))," + " (3, parse_json('{\"price\": 999.99}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "price" has inconsistent scales, so only "value" field (no typed_value) - GroupType price = valueOnlyField("price"); + GroupType price = + field( + "price", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(6, 9))); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); MessageType expectedSchema = parquetSchema(address); @@ -481,30 +256,21 @@ public void testDecimalTypeHandlingInconsistentScales() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 7b: Decimal Type Handling - Consistent Scales - * - *

Verify that decimal fields with the same scale get proper typed_value with inferred - * precision/scale. - */ @TestTemplate - public void testDecimalTypeHandlingConsistentScales() throws IOException { + public void testDecimalTypeWithConsistentScales() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Decimal values with consistent scale (all 2 decimal places) String values = "(1, parse_json('{\"price\": 123.45}'))," + " (2, parse_json('{\"price\": 678.90}'))," + " (3, parse_json('{\"price\": 999.99}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "price" should have typed_value with inferred DECIMAL(5,2) type GroupType price = field( "price", - org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.decimalType(2, 5)) - .named("typed_value")); + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); MessageType expectedSchema = parquetSchema(address); @@ -512,68 +278,38 @@ public void testDecimalTypeHandlingConsistentScales() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 7c: Decimal Type Handling - Inconsistent After Buffering - * - *

Verify that when buffered rows have consistent decimal scales but subsequent unbuffered rows - * have inconsistent scales, the inconsistent values are written to the value field only. - * The schema is inferred from buffered rows and should include typed_value for the consistent type. - */ @TestTemplate - public void testDecimalTypeHandlingInconsistentAfterBuffering() throws IOException { + public void testArrayType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Set a small buffer size to test the scenario - spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); - // First 3 rows (buffered): consistent scale (2 decimal places) - // 4th row onwards (unbuffered): different scale (6 decimal places) - // Schema should be inferred from buffered rows with DECIMAL(5,2) - // The unbuffered row with different scale should still write successfully to value field String values = - "(1, parse_json('{\"price\": 123.45}'))," - + " (2, parse_json('{\"price\": 678.90}'))," - + " (3, parse_json('{\"price\": 999.99}'))," - + " (4, parse_json('{\"price\": 111.111111}'))"; // Different scale - should write to value only + "(1, parse_json('[\"java\", \"scala\", \"python\"]'))," + + " (2, parse_json('[\"rust\", \"go\"]'))," + + " (3, parse_json('[\"javascript\"]'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // Schema should have typed_value with DECIMAL(5,2) based on buffered rows - GroupType price = - field( - "price", - org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.decimalType(2, 5)) - .named("typed_value")); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + GroupType arr = + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType()))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, arr); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); - - // Verify all rows were written successfully - List result = sql("SELECT id, address FROM %s ORDER BY id", tableName); - assertThat(result).hasSize(4); - - // Reset buffer size to default - spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); } - /** - * Test Heuristic 8: Array Type Handling - * - *

Verify that array fields with consistent element types get proper typed_value. - */ @TestTemplate - public void testArrayTypeHandling() throws IOException { + public void testNestedArrayType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Arrays with consistent element types (all strings) String values = "(1, parse_json('{\"tags\": [\"java\", \"scala\", \"python\"]}'))," + " (2, parse_json('{\"tags\": [\"rust\", \"go\"]}'))," + " (3, parse_json('{\"tags\": [\"javascript\"]}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "tags" should have typed_value with list of strings GroupType tags = field( "tags", @@ -589,47 +325,44 @@ public void testArrayTypeHandling() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 9: Nested Object Handling - * - *

Verify that simple nested objects are recursively shredded. - */ @TestTemplate - public void testNestedObjectHandling() throws IOException { + public void testNestedObjectType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Nested objects with consistent structure String values = - "(1, parse_json('{\"location\": {\"city\": \"Seattle\", \"zip\": 98101}}'))," + "(1, parse_json('{\"location\": {\"city\": \"Seattle\", \"zip\": 98101}, \"tags\": [\"java\", \"scala\", \"python\"]}'))," + " (2, parse_json('{\"location\": {\"city\": \"Portland\", \"zip\": 97201}}'))," + " (3, parse_json('{\"location\": {\"city\": \"NYC\", \"zip\": 10001}}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // Nested "location" object should be shredded with its fields GroupType city = field( "city", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType zip = - field("zip", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); - GroupType location = field("location", objectFields(zip, city)); + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); + GroupType location = field("location", objectFields(city, zip)); + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(location)); + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(location, tags)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); } - /** Helper method to create a value-only field (no typed_value) for inconsistent types. */ - private static GroupType valueOnlyField(String name) { - return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .named(name); - } - private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); @@ -644,6 +377,9 @@ private void verifyParquetSchema(Table table, MessageType expectedSchema) throws MessageType actualSchema = reader.getFileMetaData().getSchema(); assertThat(actualSchema).isEqualTo(expectedSchema); } + + // Print the result + spark.read().format("iceberg").load(tableName).orderBy("id").show(false); } } @@ -682,12 +418,12 @@ private static GroupType variant( } private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) { - return org.apache.parquet.schema.Types.optional(primitive).named("typed_value"); + return optional(primitive).named("typed_value"); } private static Type shreddedPrimitive( PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { - return org.apache.parquet.schema.Types.optional(primitive).as(annotation).named("typed_value"); + return optional(primitive).as(annotation).named("typed_value"); } private static GroupType objectFields(GroupType... fields) { From bb55257ba6f6b585609b610de4b55614ef0e2323 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Thu, 15 Jan 2026 11:23:33 -0800 Subject: [PATCH 04/17] Add to 4.1 --- .../apache/iceberg/parquet/ParquetWriter.java | 10 +- .../iceberg/spark/SparkSQLProperties.java | 10 -- .../apache/iceberg/spark/SparkWriteConf.java | 20 --- .../iceberg/spark/SparkWriteOptions.java | 3 - .../iceberg/spark/TestSparkWriteConf.java | 7 - .../iceberg/spark/SparkSQLProperties.java | 10 ++ .../apache/iceberg/spark/SparkWriteConf.java | 20 +++ .../iceberg/spark/SparkWriteOptions.java | 3 + .../spark/source/SchemaInferenceVisitor.java | 12 +- ...parkParquetWriterWithVariantShredding.java | 0 .../source/VariantShreddingAnalyzer.java | 129 +++++++++--------- .../iceberg/spark/TestSparkWriteConf.java | 7 + .../spark/variant/TestVariantShredding.java | 62 +++++++++ 13 files changed, 175 insertions(+), 118 deletions(-) rename spark/{v4.0 => v4.1}/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java (89%) rename spark/{v4.0 => v4.1}/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java (100%) rename spark/{v4.0 => v4.1}/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java (74%) rename spark/{v4.0 => v4.1}/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java (87%) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java index 88dbad6fb6e8..6a42ae440ff6 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java @@ -134,8 +134,7 @@ private void ensureWriterInitialized() { @Override public void add(T value) { - if (model instanceof WriterLazyInitializable) { - WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (model instanceof WriterLazyInitializable lazy) { if (lazy.needsInitialization()) { model.write(0, value); recordCount += 1; @@ -144,7 +143,9 @@ public void add(T value) { WriterLazyInitializable.InitializationResult result = lazy.initialize(props, compressor, rowGroupOrdinal); this.parquetSchema = result.getSchema(); + this.pageStore.close(); this.pageStore = result.getPageStore(); + this.writeStore.close(); this.writeStore = result.getWriteStore(); // Re-initialize the file writer with the new schema @@ -281,13 +282,14 @@ public void close() throws IOException { this.closed = true; // Force initialization if lazy writer still has buffered data - if (model instanceof WriterLazyInitializable) { - WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (model instanceof WriterLazyInitializable lazy) { if (lazy.needsInitialization()) { WriterLazyInitializable.InitializationResult result = lazy.initialize(props, compressor, rowGroupOrdinal); this.parquetSchema = result.getSchema(); + this.pageStore.close(); this.pageStore = result.getPageStore(); + this.writeStore.close(); this.writeStore = result.getWriteStore(); ensureWriterInitialized(); diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index b12606d23948..81139969f746 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -109,14 +109,4 @@ private SparkSQLProperties() {} // Prefix for custom snapshot properties public static final String SNAPSHOT_PROPERTY_PREFIX = "spark.sql.iceberg.snapshot-property."; - - // Controls whether to shred variant columns during write operations - public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; - public static final boolean SHRED_VARIANTS_DEFAULT = true; - - // Controls the buffer size for variant schema inference during writes - // This determines how many rows are buffered before inferring shredded schema - public static final String VARIANT_INFERENCE_BUFFER_SIZE = - "spark.sql.iceberg.variant.inference.buffer-size"; - public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 80d245712e6b..96131e0e56dd 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -509,17 +509,6 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } - writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); - - // Add variant shredding configuration properties - if (shredVariants()) { - String variantBufferSize = - sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); - if (variantBufferSize != null) { - writeProperties.put( - SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); - } - } break; case AVRO: @@ -740,13 +729,4 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } - - public boolean shredVariants() { - return confParser - .booleanConf() - .option(SparkWriteOptions.SHRED_VARIANTS) - .sessionConf(SparkSQLProperties.SHRED_VARIANTS) - .defaultValue(SparkSQLProperties.SHRED_VARIANTS_DEFAULT) - .parse(); - } } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index f8fb41696f76..33db70bae587 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -85,7 +85,4 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; - - // Controls whether to shred variant columns during write operations - public static final String SHRED_VARIANTS = "shred-variants"; } diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index d97579f29e86..61aacfa4589d 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -41,7 +41,6 @@ import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; -import static org.apache.iceberg.spark.SparkSQLProperties.SHRED_VARIANTS; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; @@ -340,8 +339,6 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( - SHRED_VARIANTS, - "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -463,8 +460,6 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( - SHRED_VARIANTS, - "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -536,8 +531,6 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( - SHRED_VARIANTS, - "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index 81139969f746..b12606d23948 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -109,4 +109,14 @@ private SparkSQLProperties() {} // Prefix for custom snapshot properties public static final String SNAPSHOT_PROPERTY_PREFIX = "spark.sql.iceberg.snapshot-property."; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; + public static final boolean SHRED_VARIANTS_DEFAULT = true; + + // Controls the buffer size for variant schema inference during writes + // This determines how many rows are buffered before inferring shredded schema + public static final String VARIANT_INFERENCE_BUFFER_SIZE = + "spark.sql.iceberg.variant.inference.buffer-size"; + public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 1e937863c3c0..d531fa063bb2 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -485,6 +485,17 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } + writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); + + // Add variant shredding configuration properties + if (shredVariants()) { + String variantBufferSize = + sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); + if (variantBufferSize != null) { + writeProperties.put( + SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); + } + } break; case AVRO: @@ -705,4 +716,13 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } + + public boolean shredVariants() { + return confParser + .booleanConf() + .option(SparkWriteOptions.SHRED_VARIANTS) + .sessionConf(SparkSQLProperties.SHRED_VARIANTS) + .defaultValue(SparkSQLProperties.SHRED_VARIANTS_DEFAULT) + .parse(); + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 86c27acd88cf..9bf6ba843aa9 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -85,4 +85,7 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "shred-variants"; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java similarity index 89% rename from spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java rename to spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index 6903f1f03353..06a79b8dcef0 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -152,7 +152,6 @@ private int getFieldIndex(String[] path) { return -1; } - // Support nested variant fields by navigating the struct hierarchy if (path.length == 1) { // Top-level field - direct lookup String fieldName = path[0]; @@ -162,15 +161,8 @@ private int getFieldIndex(String[] path) { } } } else { - // Nested field - navigate through struct hierarchy - // For now, we only support direct struct nesting (not arrays/maps) - LOG.debug("Attempting to resolve nested variant field path: {}", String.join(".", path)); - // TODO: Implement full nested field resolution when needed - // This would require tracking the current struct context during traversal - // and maintaining a stack of field indices - LOG.warn( - "Multi-level nested variant fields require struct context tracking. Path: {}", - String.join(".", path)); + // TODO: Implement full nested field resolution + LOG.warn("Nested variant shredding is not supported. Path: {}", String.join(".", path)); } return -1; diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java similarity index 100% rename from spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java rename to spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java similarity index 74% rename from spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java rename to spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java index 27b526134737..9487c2dc0141 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -104,7 +104,7 @@ private static PathNode buildPathTree(List variantValues) { } private static void traverse(PathNode node, VariantValue value) { - if (value == null) { + if (value == null || value.type() == PhysicalType.NULL) { return; } @@ -139,7 +139,12 @@ private static void traverse(PathNode node, VariantValue value) { } private static Type buildFieldGroup(PathNode node) { - Type typedValue = buildTypedValue(node, node.info.getMostCommonType()); + PhysicalType commonType = node.info.getMostCommonType(); + if (commonType == null) { + return null; + } + + Type typedValue = buildTypedValue(node, commonType); return Types.buildGroup(Type.Repetition.REQUIRED) .optional(PrimitiveType.PrimitiveTypeName.BINARY) .named(VALUE) @@ -167,7 +172,12 @@ private static Type createObjectTypedValue(PathNode node) { Types.GroupBuilder builder = Types.buildGroup(Type.Repetition.OPTIONAL); for (PathNode child : node.objectChildren.values()) { - builder.addField(buildFieldGroup(child)); + Type fieldType = buildFieldGroup(child); + if (fieldType == null) { + continue; + } + + builder.addField(fieldType); } return builder.named(TYPED_VALUE); @@ -221,67 +231,58 @@ private static Type createDecimalTypedValue(FieldInfo info) { } private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primitiveType) { - switch (primitiveType) { - case BOOLEAN_TRUE: - case BOOLEAN_FALSE: - return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); - - case INT8: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(8, true)) - .named(TYPED_VALUE); - - case INT16: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(16, true)) - .named(TYPED_VALUE); - - case INT32: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(32, true)) - .named(TYPED_VALUE); - - case INT64: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); - - case FLOAT: - return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); - - case DOUBLE: - return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); - - case STRING: - return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.stringType()) - .named(TYPED_VALUE); - - case BINARY: - return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); - - case DATE: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.dateType()) - .named(TYPED_VALUE); - - case TIMESTAMPTZ: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named(TYPED_VALUE); - - case TIMESTAMPNTZ: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named(TYPED_VALUE); - - case DECIMAL4: - case DECIMAL8: - case DECIMAL16: - return createDecimalTypedValue(info); - - default: - throw new UnsupportedOperationException( - "Unknown primitive physical type: " + primitiveType); - } + return switch (primitiveType) { + case BOOLEAN_TRUE, BOOLEAN_FALSE -> + Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); + case INT8 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(8, true)) + .named(TYPED_VALUE); + case INT16 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(16, true)) + .named(TYPED_VALUE); + case INT32 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(32, true)) + .named(TYPED_VALUE); + case INT64 -> Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); + case FLOAT -> Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); + case DOUBLE -> Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); + case STRING -> + Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named(TYPED_VALUE); + case BINARY -> Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); + case TIME -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timeType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case DATE -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.dateType()) + .named(TYPED_VALUE); + case TIMESTAMPTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case DECIMAL4, DECIMAL8, DECIMAL16 -> createDecimalTypedValue(info); + default -> + throw new UnsupportedOperationException( + "Unknown primitive physical type: " + primitiveType); + }; } /** Tracks occurrence count and types for a single field. */ diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index 227b93dfa478..fbd04fae1c98 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -41,6 +41,7 @@ import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; +import static org.apache.iceberg.spark.SparkSQLProperties.SHRED_VARIANTS; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; @@ -344,6 +345,8 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -466,6 +469,8 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -537,6 +542,8 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java similarity index 87% rename from spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java rename to spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index 5f4eb2a2732f..df239a674ba7 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -153,6 +153,33 @@ public void testConsistentType() throws IOException { verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testExcludingNullValue() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30, \"dummy\": null}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + @TestTemplate public void testInconsistentType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); @@ -363,6 +390,41 @@ public void testNestedObjectType() throws IOException { verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testLazyInitializationWithBufferedRows() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "5"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))," + + " (4, parse_json('{\"name\": \"David\", \"age\": 28}'))," + + " (5, parse_json('{\"name\": \"Eve\", \"age\": 32}'))," + + " (6, parse_json('{\"name\": \"Frank\", \"age\": 40}'))," + + " (7, parse_json('{\"name\": \"Grace\", \"age\": 27}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(7); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); From 110c8021e3d1e98e5b6754f93c870205ebdad4a7 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Mon, 26 Jan 2026 22:40:09 -0800 Subject: [PATCH 05/17] Add tie break and INT/DECIMAL promotion --- .../apache/iceberg/parquet/ParquetWriter.java | 11 +- .../parquet/WriterLazyInitializable.java | 8 +- .../apache/iceberg/spark/SparkWriteConf.java | 5 +- ...parkParquetWriterWithVariantShredding.java | 11 +- .../source/VariantShreddingAnalyzer.java | 61 ++++++- .../spark/variant/TestVariantShredding.java | 159 +++++++++++++++++- 6 files changed, 239 insertions(+), 16 deletions(-) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java index 6a42ae440ff6..bdcfca7b2f94 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java @@ -51,7 +51,6 @@ class ParquetWriter implements FileAppender, Closeable { private final Map metadata; private final ParquetProperties props; private final CompressionCodecFactory.BytesInputCompressor compressor; - private MessageType parquetSchema; private final ParquetValueWriter model; private final MetricsConfig metricsConfig; private final int columnIndexTruncateLength; @@ -60,6 +59,7 @@ class ParquetWriter implements FileAppender, Closeable { private final Configuration conf; private final InternalFileEncryptor fileEncryptor; + private MessageType parquetSchema; private ColumnChunkPageWriteStore pageStore = null; private ColumnWriteStore writeStore; private long recordCount = 0; @@ -141,7 +141,8 @@ public void add(T value) { if (!lazy.needsInitialization()) { WriterLazyInitializable.InitializationResult result = - lazy.initialize(props, compressor, rowGroupOrdinal); + lazy.initialize( + props, compressor, rowGroupOrdinal, columnIndexTruncateLength, fileEncryptor); this.parquetSchema = result.getSchema(); this.pageStore.close(); this.pageStore = result.getPageStore(); @@ -281,11 +282,13 @@ public void close() throws IOException { if (!closed) { this.closed = true; - // Force initialization if lazy writer still has buffered data if (model instanceof WriterLazyInitializable lazy) { + // If initialization is not triggered with few data, lazy writer needs to initialize and + // process remaining buffered data if (lazy.needsInitialization()) { WriterLazyInitializable.InitializationResult result = - lazy.initialize(props, compressor, rowGroupOrdinal); + lazy.initialize( + props, compressor, rowGroupOrdinal, columnIndexTruncateLength, fileEncryptor); this.parquetSchema = result.getSchema(); this.pageStore.close(); this.pageStore = result.getPageStore(); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java index 9c5913d7bd9b..f7b6c591fa49 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java @@ -21,6 +21,7 @@ import org.apache.parquet.column.ColumnWriteStore; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.compression.CompressionCodecFactory; +import org.apache.parquet.crypto.InternalFileEncryptor; import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; import org.apache.parquet.schema.MessageType; @@ -78,10 +79,15 @@ public ColumnWriteStore getWriteStore() { * @param props Parquet properties needed for creating write stores * @param compressor Bytes compressor for compression * @param rowGroupOrdinal The ordinal number of the current row group + * @param columnIndexTruncateLength The column index truncate length from ParquetWriter config + * @param fileEncryptor The file encryptor from ParquetWriter, may be null if encryption is + * disabled * @return InitializationResult containing the finalized schema and write stores */ InitializationResult initialize( ParquetProperties props, CompressionCodecFactory.BytesInputCompressor compressor, - int rowGroupOrdinal); + int rowGroupOrdinal, + int columnIndexTruncateLength, + InternalFileEncryptor fileEncryptor); } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index d531fa063bb2..d0ae46541c3d 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -485,10 +485,11 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } - writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); + boolean shouldShredVariants = shredVariants(); + writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shouldShredVariants)); // Add variant shredding configuration properties - if (shredVariants()) { + if (shouldShredVariants) { String variantBufferSize = sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); if (variantBufferSize != null) { diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java index 6a2ed1e85324..5b9c10ff548f 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java @@ -108,6 +108,9 @@ public List> columns() { @Override public void setColumnStore(ColumnWriteStore columnStore) { // Ignored for lazy initialization - will be set on actualWriter after initialization + if (actualWriter != null) { + actualWriter.setColumnStore(columnStore); + } } @Override @@ -127,7 +130,9 @@ public boolean needsInitialization() { public InitializationResult initialize( ParquetProperties props, CompressionCodecFactory.BytesInputCompressor compressor, - int rowGroupOrdinal) { + int rowGroupOrdinal, + int columnIndexTruncateLength, + org.apache.parquet.crypto.InternalFileEncryptor fileEncryptor) { if (bufferedRows.isEmpty()) { throw new IllegalStateException("No buffered rows available for schema inference"); } @@ -151,9 +156,9 @@ public InitializationResult initialize( compressor, shreddedSchema, props.getAllocator(), - 64, + columnIndexTruncateLength, ParquetProperties.DEFAULT_PAGE_WRITE_CHECKSUM_ENABLED, - null, + fileEncryptor, rowGroupOrdinal); ColumnWriteStore columnStore = props.newColumnWriteStore(shreddedSchema, pageStore, pageStore); diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java index 9487c2dc0141..fba3d258995b 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -246,7 +246,10 @@ private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primi Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(32, true)) .named(TYPED_VALUE); - case INT64 -> Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); + case INT64 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.intType(64, true)) + .named(TYPED_VALUE); case FLOAT -> Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); case DOUBLE -> Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); case STRING -> @@ -296,6 +299,7 @@ void observe(VariantValue value) { // Use BOOLEAN_TRUE for both TRUE/FALSE values PhysicalType type = value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); + observedTypes.add(type); typeCounts.compute(type, (k, v) -> (v == null) ? 1 : v + 1); @@ -314,10 +318,61 @@ void observe(VariantValue value) { } PhysicalType getMostCommonType() { - return typeCounts.entrySet().stream() - .max(Map.Entry.comparingByValue()) + Map combinedCounts = Maps.newHashMap(); + + int integerTotalCount = 0; + PhysicalType mostCapableInteger = null; + + int decimalTotalCount = 0; + PhysicalType mostCapableDecimal = null; + + for (Map.Entry entry : typeCounts.entrySet()) { + PhysicalType type = entry.getKey(); + int count = entry.getValue(); + + if (isIntegerType(type)) { + integerTotalCount += count; + if (mostCapableInteger == null || type.ordinal() > mostCapableInteger.ordinal()) { + mostCapableInteger = type; + } + } else if (isDecimalType(type)) { + decimalTotalCount += count; + if (mostCapableDecimal == null || type.ordinal() > mostCapableDecimal.ordinal()) { + mostCapableDecimal = type; + } + } else { + combinedCounts.put(type, count); + } + } + + if (mostCapableInteger != null) { + combinedCounts.put(mostCapableInteger, integerTotalCount); + } + + if (mostCapableDecimal != null) { + combinedCounts.put(mostCapableDecimal, decimalTotalCount); + } + + // Pick the most common type with tie-breaking + return combinedCounts.entrySet().stream() + .max( + Map.Entry.comparingByValue() + .thenComparingInt(entry -> entry.getKey().ordinal())) .map(Map.Entry::getKey) .orElse(null); } + + private boolean isIntegerType(PhysicalType type) { + return type == PhysicalType.INT8 + || type == PhysicalType.INT16 + || type == PhysicalType.INT32 + || type == PhysicalType.INT64; + } + + private boolean isDecimalType(PhysicalType type) { + return type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16; + } } } diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index df239a674ba7..ec668c2043f8 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -19,6 +19,7 @@ package org.apache.iceberg.spark.variant; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; import static org.apache.parquet.schema.Types.optional; import static org.assertj.core.api.Assertions.assertThat; @@ -109,6 +110,8 @@ public void before() { @AfterEach public void after() { + spark.conf().unset(SparkSQLProperties.SHRED_VARIANTS); + spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); validationCatalog.dropTable(tableIdent, true); } @@ -425,6 +428,159 @@ public void testLazyInitializationWithBufferedRows() throws IOException { assertThat(rowCount).isEqualTo(7); } + @TestTemplate + public void testTieBreakingWithEqualCounts() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"value\": 10}'))," + + " (2, parse_json('{\"value\": 20}'))," + + " (3, parse_json('{\"value\": \"hello\"}'))," + + " (4, parse_json('{\"value\": \"world\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // When counts are tied, sort the types in order and choose the last one + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMultipleRowGroups() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int numRows = 1000; + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= numRows; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + valuesBuilder.append( + String.format("(%d, parse_json('{\"name\": \"User%d\", \"age\": %d}'))", i, i, 20 + i)); + } + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 1024); + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(numRows); + } + + @TestTemplate + public void testColumnIndexTruncateLength() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int customTruncateLength = 10; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, "parquet.columnindex.truncate.length", customTruncateLength); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 10; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + String longValue = "A".repeat(20); + valuesBuilder.append( + String.format( + "(%d, parse_json('{\"description\": \"%s\", \"id\": %d}'))", i, longValue, i)); + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType description = + field( + "description", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType id = + field( + "id", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(description, id)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(10); + } + + @TestTemplate + public void testIntegerFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Mix of INT8, INT16, INT32, INT64 - should promote to INT64 + String values = + "(1, parse_json('{\"value\": 10}'))," + + " (2, parse_json('{\"value\": 1000}'))," + + " (3, parse_json('{\"value\": 100000}'))," + + " (4, parse_json('{\"value\": 10000000000}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT64, LogicalTypeAnnotation.intType(64, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Test that they get promoted to the most capable decimal type observed + String values = + "(1, parse_json('{\"value\": 1.5}'))," + + " (2, parse_json('{\"value\": 123.456789}'))," + + " (3, parse_json('{\"value\": 123456789123456.789}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT64, LogicalTypeAnnotation.decimalType(6, 18))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); @@ -439,9 +595,6 @@ private void verifyParquetSchema(Table table, MessageType expectedSchema) throws MessageType actualSchema = reader.getFileMetaData().getSchema(); assertThat(actualSchema).isEqualTo(expectedSchema); } - - // Print the result - spark.read().format("iceberg").load(tableName).orderBy("id").show(false); } } From 4f00355a0ee16d01ab124f974ab193a16d380ae4 Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Fri, 13 Mar 2026 10:39:15 -0700 Subject: [PATCH 06/17] Wire shredding writer through WriterFunction API --- .../iceberg/formats/BaseFormatModel.java | 18 ++++++++++ .../iceberg/parquet/ParquetFormatModel.java | 10 ++++-- .../spark/source/SparkFormatModels.java | 35 ++++++++++++++++++- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java b/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java index 7cba465670d9..9ddfefc2c9d9 100644 --- a/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java +++ b/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java @@ -105,6 +105,24 @@ public interface WriterFunction { * @return a writer configured for the given schemas */ W write(Schema icebergSchema, F fileSchema, S engineSchema); + + /** + * Creates a writer for the given schemas and write properties. Implementations can use + * properties to customize writer behavior, such as enabling variant shredding. + * + *

The default implementation ignores properties and delegates to {@link #write(Schema, + * Object, Object)}. + * + * @param icebergSchema the Iceberg schema defining the table structure + * @param fileSchema the file format specific target schema for the output files + * @param engineSchema the engine-specific schema for the input data (optional) + * @param writeProperties writer configuration properties + * @return a writer configured for the given schemas + */ + default W write( + Schema icebergSchema, F fileSchema, S engineSchema, Map writeProperties) { + return write(icebergSchema, fileSchema, engineSchema); + } } /** diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java index 90d6e3ef41ac..66913b66b128 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java @@ -40,6 +40,7 @@ import org.apache.iceberg.mapping.NameMapping; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.schema.MessageType; @@ -98,6 +99,7 @@ private static class WriteBuilderWrapper implements ModelWriteBuilder collectedProperties = Maps.newHashMap(); private WriteBuilderWrapper( EncryptedOutputFile outputFile, @@ -121,6 +123,7 @@ public ModelWriteBuilder engineSchema(S newSchema) { @Override public ModelWriteBuilder set(String property, String value) { + collectedProperties.put(property, value); if (WRITER_VERSION_KEY.equals(property)) { internal.writerVersion(ParquetProperties.WriterVersion.valueOf(value)); } @@ -131,6 +134,7 @@ public ModelWriteBuilder set(String property, String value) { @Override public ModelWriteBuilder setAll(Map properties) { + collectedProperties.putAll(properties); internal.setAll(properties); return this; } @@ -184,13 +188,15 @@ public FileAppender build() throws IOException { internal.createContextFunc(Parquet.WriteBuilder.Context::dataContext); internal.createWriterFunc( (icebergSchema, messageType) -> - writerFunction.write(icebergSchema, messageType, engineSchema)); + writerFunction.write( + icebergSchema, messageType, engineSchema, collectedProperties)); break; case EQUALITY_DELETES: internal.createContextFunc(Parquet.WriteBuilder.Context::deleteContext); internal.createWriterFunc( (icebergSchema, messageType) -> - writerFunction.write(icebergSchema, messageType, engineSchema)); + writerFunction.write( + icebergSchema, messageType, engineSchema, collectedProperties)); break; case POSITION_DELETES: Preconditions.checkState( diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java index 677f2e950b44..199a7ae40f8e 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java @@ -18,10 +18,14 @@ */ package org.apache.iceberg.spark.source; +import java.util.Map; +import org.apache.iceberg.Schema; import org.apache.iceberg.avro.AvroFormatModel; +import org.apache.iceberg.formats.BaseFormatModel; import org.apache.iceberg.formats.FormatModelRegistry; import org.apache.iceberg.orc.ORCFormatModel; import org.apache.iceberg.parquet.ParquetFormatModel; +import org.apache.iceberg.parquet.ParquetValueWriter; import org.apache.iceberg.spark.data.SparkAvroWriter; import org.apache.iceberg.spark.data.SparkOrcReader; import org.apache.iceberg.spark.data.SparkOrcWriter; @@ -30,6 +34,7 @@ import org.apache.iceberg.spark.data.SparkPlannedAvroReader; import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.parquet.schema.MessageType; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -48,7 +53,7 @@ public static void register() { ParquetFormatModel.create( InternalRow.class, StructType.class, - SparkParquetWriters::buildWriter, + new SparkParquetWriterFunction(), (icebergSchema, fileSchema, engineSchema, idToConstant) -> SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant))); @@ -86,4 +91,32 @@ public static void register() { } private SparkFormatModels() {} + + /** + * Writer function that checks for variant shredding conditions and returns a writer that performs + * variant shredding if needed. + */ + private static class SparkParquetWriterFunction + implements BaseFormatModel.WriterFunction, StructType, MessageType> { + + @Override + public ParquetValueWriter write( + Schema icebergSchema, MessageType fileSchema, StructType engineSchema) { + return SparkParquetWriters.buildWriter(icebergSchema, fileSchema, engineSchema); + } + + @Override + public ParquetValueWriter write( + Schema icebergSchema, + MessageType fileSchema, + StructType engineSchema, + Map writeProperties) { + if (SparkParquetWriterWithVariantShredding.shouldUseVariantShredding( + writeProperties, icebergSchema)) { + return new SparkParquetWriterWithVariantShredding( + engineSchema, fileSchema, writeProperties); + } + return SparkParquetWriters.buildWriter(icebergSchema, fileSchema, engineSchema); + } + } } From 12e2e8855dc6574c0bc226fc712138e2f95b781e Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Fri, 13 Mar 2026 10:39:52 -0700 Subject: [PATCH 07/17] Fix decimal issue, null handling, heuristics and adding more tests --- .../source/VariantShreddingAnalyzer.java | 121 +++++++- .../spark/variant/TestVariantShredding.java | 269 +++++++++++++++++- 2 files changed, 380 insertions(+), 10 deletions(-) diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java index fba3d258995b..bea8a6318e2f 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -43,14 +43,28 @@ /** * Analyzes variant data across buffered rows to determine an optimal shredding schema. * + *

Determinism contract: for a given set of variant values (regardless of row arrival order), + * this analyzer produces the same shredded schema. + * *

    - *
  • shred to the most common type + *
  • Object fields use a TreeMap, so field ordering is alphabetical and deterministic. + *
  • Type selection picks the most common type with explicit tie-break priority (see + * TIE_BREAK_PRIORITY), not enum ordinal. + *
  • Integer types (INT8/16/32/64) and decimal types (DECIMAL4/8/16) are each promoted to the + * widest observed before competing with other types. + *
  • Fields below MIN_FIELD_FREQUENCY frequency are pruned. Above MAX_SHREDDED_FIELDS fields, + * the most frequent are kept with alphabetical tie-breaking. *
+ * + *

This contract holds within a single batch. Different batches with different distributions may + * produce different layouts; cross-batch stability requires schema pinning (not yet implemented). */ public class VariantShreddingAnalyzer { private static final String TYPED_VALUE = "typed_value"; private static final String VALUE = "value"; private static final String ELEMENT = "element"; + private static final double MIN_FIELD_FREQUENCY = 0.10; + private static final int MAX_SHREDDED_FIELDS = 300; public VariantShreddingAnalyzer() {} @@ -68,7 +82,16 @@ public Type analyzeAndCreateSchema(List bufferedRows, int variantFi } PathNode root = buildPathTree(variantValues); - return buildTypedValue(root, root.info.getMostCommonType()); + PhysicalType rootType = root.info.getMostCommonType(); + if (rootType == null) { + return null; + } + + if (rootType == PhysicalType.OBJECT) { + pruneInfrequentFields(root, variantValues.size()); + } + + return buildTypedValue(root, rootType); } private static List extractVariantValues( @@ -103,6 +126,45 @@ private static PathNode buildPathTree(List variantValues) { return root; } + private static void pruneInfrequentFields(PathNode node, int totalRows) { + if (node.objectChildren.isEmpty()) { + return; + } + + // Remove fields below frequency threshold + node.objectChildren + .entrySet() + .removeIf( + entry -> { + FieldInfo info = entry.getValue().info; + return info != null + && ((double) info.observationCount / totalRows) < MIN_FIELD_FREQUENCY; + }); + + // Cap at MAX_SHREDDED_FIELDS, keep the most frequently observed + if (node.objectChildren.size() > MAX_SHREDDED_FIELDS) { + List> sorted = + new java.util.ArrayList<>(node.objectChildren.entrySet()); + sorted.sort( + (a, b) -> { + int cmp = + Integer.compare( + b.getValue().info.observationCount, a.getValue().info.observationCount); + return cmp != 0 ? cmp : a.getKey().compareTo(b.getKey()); + }); + Set keep = Sets.newHashSet(); + for (int i = 0; i < MAX_SHREDDED_FIELDS; i++) { + keep.add(sorted.get(i).getKey()); + } + node.objectChildren.entrySet().removeIf(entry -> !keep.contains(entry.getKey())); + } + + // Recurse into remaining children + for (PathNode child : node.objectChildren.values()) { + pruneInfrequentFields(child, totalRows); + } + } + private static void traverse(PathNode node, VariantValue value) { if (value == null || value.type() == PhysicalType.NULL) { return; @@ -145,6 +207,10 @@ private static Type buildFieldGroup(PathNode node) { } Type typedValue = buildTypedValue(node, commonType); + if (typedValue == null) { + return null; + } + return Types.buildGroup(Type.Repetition.REQUIRED) .optional(PrimitiveType.PrimitiveTypeName.BINARY) .named(VALUE) @@ -186,6 +252,9 @@ private static Type createObjectTypedValue(PathNode node) { private static Type createArrayTypedValue(PathNode node) { PathNode elementNode = node.arrayElement; PhysicalType elementType = elementNode.info.getMostCommonType(); + if (elementType == null) { + return null; + } Type elementTypedValue = buildTypedValue(elementNode, elementType); GroupType elementGroup = @@ -211,7 +280,7 @@ private PathNode(String fieldName) { /** Use DECIMAL with maximum precision and scale as the shredding type */ private static Type createDecimalTypedValue(FieldInfo info) { - int maxPrecision = info.maxDecimalPrecision; + int maxPrecision = info.maxDecimalIntegerDigits + info.maxDecimalScale; int maxScale = info.maxDecimalScale; if (maxPrecision <= 9) { @@ -292,10 +361,44 @@ private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primi private static class FieldInfo { private final Set observedTypes = Sets.newHashSet(); private final Map typeCounts = Maps.newHashMap(); - private int maxDecimalPrecision = 0; private int maxDecimalScale = 0; + private int maxDecimalIntegerDigits = 0; + private int observationCount = 0; + + private static final Map INTEGER_PRIORITY = + Map.of( + PhysicalType.INT8, 0, + PhysicalType.INT16, 1, + PhysicalType.INT32, 2, + PhysicalType.INT64, 3); + + private static final Map DECIMAL_PRIORITY = + Map.of( + PhysicalType.DECIMAL4, 0, + PhysicalType.DECIMAL8, 1, + PhysicalType.DECIMAL16, 2); + + private static final Map TIE_BREAK_PRIORITY = + Map.ofEntries( + Map.entry(PhysicalType.BOOLEAN_TRUE, 0), + Map.entry(PhysicalType.INT8, 1), + Map.entry(PhysicalType.INT16, 2), + Map.entry(PhysicalType.INT32, 3), + Map.entry(PhysicalType.INT64, 4), + Map.entry(PhysicalType.FLOAT, 5), + Map.entry(PhysicalType.DOUBLE, 6), + Map.entry(PhysicalType.DECIMAL4, 7), + Map.entry(PhysicalType.DECIMAL8, 8), + Map.entry(PhysicalType.DECIMAL16, 9), + Map.entry(PhysicalType.DATE, 10), + Map.entry(PhysicalType.TIME, 11), + Map.entry(PhysicalType.TIMESTAMPTZ, 12), + Map.entry(PhysicalType.TIMESTAMPNTZ, 13), + Map.entry(PhysicalType.BINARY, 14), + Map.entry(PhysicalType.STRING, 15)); void observe(VariantValue value) { + observationCount++; // Use BOOLEAN_TRUE for both TRUE/FALSE values PhysicalType type = value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); @@ -311,7 +414,7 @@ void observe(VariantValue value) { Object decimalValue = primitive.get(); if (decimalValue instanceof BigDecimal) { BigDecimal bd = (BigDecimal) decimalValue; - maxDecimalPrecision = Math.max(maxDecimalPrecision, bd.precision()); + maxDecimalIntegerDigits = Math.max(maxDecimalIntegerDigits, bd.precision() - bd.scale()); maxDecimalScale = Math.max(maxDecimalScale, bd.scale()); } } @@ -332,12 +435,14 @@ PhysicalType getMostCommonType() { if (isIntegerType(type)) { integerTotalCount += count; - if (mostCapableInteger == null || type.ordinal() > mostCapableInteger.ordinal()) { + if (mostCapableInteger == null + || INTEGER_PRIORITY.get(type) > INTEGER_PRIORITY.get(mostCapableInteger)) { mostCapableInteger = type; } } else if (isDecimalType(type)) { decimalTotalCount += count; - if (mostCapableDecimal == null || type.ordinal() > mostCapableDecimal.ordinal()) { + if (mostCapableDecimal == null + || DECIMAL_PRIORITY.get(type) > DECIMAL_PRIORITY.get(mostCapableDecimal)) { mostCapableDecimal = type; } } else { @@ -357,7 +462,7 @@ PhysicalType getMostCommonType() { return combinedCounts.entrySet().stream() .max( Map.Entry.comparingByValue() - .thenComparingInt(entry -> entry.getKey().ordinal())) + .thenComparingInt(entry -> TIE_BREAK_PRIORITY.getOrDefault(entry.getKey(), -1))) .map(Map.Entry::getKey) .orElse(null); } diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index ec668c2043f8..e63630cfe3ad 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.net.InetAddress; +import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.FileScanTask; @@ -48,6 +49,7 @@ import org.apache.parquet.schema.Type; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.internal.SQLConf; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -572,8 +574,10 @@ public void testDecimalFamilyPromotion() throws IOException { GroupType value = field( "value", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT64, LogicalTypeAnnotation.decimalType(6, 18))); + optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(6, 21)) + .named("typed_value")); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); MessageType expectedSchema = parquetSchema(address); @@ -581,6 +585,267 @@ public void testDecimalFamilyPromotion() throws IOException { verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testDataRoundTripWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify that we can read the data back correctly + List rows = + sql( + "SELECT id, variant_get(address, '$.name', 'string')," + + " variant_get(address, '$.age', 'int')" + + " FROM %s ORDER BY id", + tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[0]).isEqualTo(1); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(0)[2]).isEqualTo(30); + assertThat(rows.get(1)[0]).isEqualTo(2); + assertThat(rows.get(1)[1]).isEqualTo("Bob"); + assertThat(rows.get(1)[2]).isEqualTo(25); + assertThat(rows.get(2)[0]).isEqualTo(3); + assertThat(rows.get(2)[1]).isEqualTo("Charlie"); + assertThat(rows.get(2)[2]).isEqualTo(35); + } + + @TestTemplate + public void testMultipleVariantsWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Recreate table with SCHEMA2 (address + metadata variant columns) + validationCatalog.dropTable(tableIdent, true); + validationCatalog.createTable( + tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + + String values = + "(1, parse_json('{\"city\": \"NYC\"}'), parse_json('{\"source\": \"web\"}'))," + + " (2, parse_json('{\"city\": \"LA\"}'), parse_json('{\"source\": \"app\"}'))," + + " (3, parse_json('{\"city\": \"SF\"}'), parse_json('{\"source\": \"api\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(city)); + + GroupType source = + field( + "source", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType metadata = variant("metadata", 3, Type.Repetition.REQUIRED, objectFields(source)); + MessageType expectedSchema = parquetSchema(address, metadata); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantWithNullValues() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('null'))," + " (2, parse_json('null'))," + " (3, parse_json('null'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testArrayOfNullElementsWithShredding() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + sql( + "INSERT INTO %s VALUES (1, parse_json('[null, null, null]')), " + + "(2, parse_json('[null]'))", + tableName); + + // Array elements are all null, element type is null, falls back to unshredded + GroupType address = variant("address", 2, Type.Repetition.REQUIRED); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMixedNullAndNonNullVariantValues() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, null)," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(3); + } + + @TestTemplate + public void testWriteOptionOverridesSessionConfig() throws IOException, NoSuchTableException { + // Disable shredding at session level + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); + + // Enable shredding via per-write option + String query = + "SELECT 1 as id, parse_json('{\"name\": \"Alice\", \"age\": 30}') as address" + + " UNION ALL SELECT 2, parse_json('{\"name\": \"Bob\", \"age\": 25}')" + + " UNION ALL SELECT 3, parse_json('{\"name\": \"Charlie\", \"age\": 35}')"; + spark.sql(query).writeTo(tableName).option("shred-variants", "true").append(); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDefaultShreddingEnabled() throws IOException { + // Not setting SHRED_VARIANTS - default (true) should activate shredding + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testInfrequentFieldPruning() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "11"); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 11; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + if (i == 1) { + // Only the first row has rare_field + valuesBuilder.append( + String.format( + "(%d, parse_json('{\"name\": \"User%d\", \"rare_field\": \"rare\"}'))", i, i)); + } else { + valuesBuilder.append(String.format("(%d, parse_json('{\"name\": \"User%d\"}'))", i, i)); + } + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + // rare_field appears in 1/11 rows, should be pruned + // name appears in 11/11 rows and should be kept + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMixedTypeTieBreaking() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "10"); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 10; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + if (i <= 5) { + valuesBuilder.append(String.format("(%d, parse_json('{\"val\": %d}'))", i, i)); + } else { + valuesBuilder.append(String.format("(%d, parse_json('{\"val\": \"text%d\"}'))", i, i)); + } + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + // 5 ints + 5 strings is a tie so STRING wins (higher TIE_BREAK_PRIORITY) + GroupType val = + field( + "val", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(val)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); From 24a1b0c751275c933c1a5130aac2a1bc1765fe41 Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Mon, 23 Mar 2026 08:21:32 -0700 Subject: [PATCH 08/17] Adding BufferedFileAppender for deferred writer init --- .../iceberg/io/BufferedFileAppender.java | 132 +++++++++++ .../iceberg/io/TestBufferedFileAppender.java | 217 ++++++++++++++++++ .../parquet/WriterLazyInitializable.java | 93 -------- ...parkParquetWriterWithVariantShredding.java | 189 --------------- 4 files changed, 349 insertions(+), 282 deletions(-) create mode 100644 core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java create mode 100644 core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java delete mode 100644 parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java delete mode 100644 spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java diff --git a/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java new file mode 100644 index 000000000000..a798da8007ac --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.io; + +import java.io.IOException; +import java.util.List; +import java.util.function.Function; +import java.util.function.UnaryOperator; +import org.apache.iceberg.Metrics; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; + +/** + * A FileAppender that buffers the first N rows, then creates a delegate appender via a factory. + * + *

The factory receives the buffered rows, is responsible for creating the real appender and + * writing the buffered rows into it before returning. All subsequent {@link #add} calls delegate + * directly to the real appender. + * + *

If fewer than N rows are written before {@link #close}, the factory is called at close time. + * + * @param the row type + */ +public class BufferedFileAppender implements FileAppender { + private final int bufferRowCount; + private final Function, FileAppender> appenderFactory; + private final UnaryOperator copyFunc; + private List buffer; + private FileAppender delegate; + private boolean closed = false; + + /** + * @param bufferRowCount number of rows to buffer before creating the delegate appender + * @param appenderFactory given the buffered rows, creates the delegate appender and replays them + * @param copyFunc copies a row before buffering (needed when row objects are reused, e.g. Spark + * InternalRow) + */ + public BufferedFileAppender( + int bufferRowCount, + Function, FileAppender> appenderFactory, + UnaryOperator copyFunc) { + Preconditions.checkArgument( + bufferRowCount > 0, "bufferRowCount must be > 0, got %s", bufferRowCount); + Preconditions.checkNotNull(appenderFactory, "appenderFactory must not be null"); + Preconditions.checkNotNull(copyFunc, "copyFunc must not be null"); + this.bufferRowCount = bufferRowCount; + this.appenderFactory = appenderFactory; + this.copyFunc = copyFunc; + this.buffer = Lists.newArrayList(); + } + + @Override + public void add(D datum) { + Preconditions.checkState(!closed, "Cannot add to a closed appender"); + if (delegate != null) { + delegate.add(datum); + } else { + buffer.add(copyFunc.apply(datum)); + if (buffer.size() >= bufferRowCount) { + initialize(); + } + } + } + + @Override + public Metrics metrics() { + Preconditions.checkState(closed, "Cannot return metrics for unclosed appender"); + Preconditions.checkState(delegate != null, "Delegate appender was never created"); + return delegate.metrics(); + } + + @Override + public long length() { + if (delegate != null) { + return delegate.length(); + } + return 0L; + } + + @Override + public List splitOffsets() { + if (delegate != null) { + return delegate.splitOffsets(); + } + return null; + } + + @Override + public void close() throws IOException { + if (!closed) { + this.closed = true; + try { + if (delegate == null) { + initialize(); + } + } catch (RuntimeException e) { + // If initialize fails, attempt to close the delegate if it was partially created + closeDelegate(); + throw e; + } + closeDelegate(); + } + } + + private void closeDelegate() throws IOException { + if (delegate != null) { + delegate.close(); + } + } + + private void initialize() { + delegate = appenderFactory.apply(buffer); + Preconditions.checkState(delegate != null, "appenderFactory must not return null"); + buffer = null; + } +} diff --git a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java new file mode 100644 index 000000000000..7c0f8c401d86 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.io; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.IOException; +import java.util.List; +import java.util.function.Function; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.avro.DataWriter; +import org.apache.iceberg.data.avro.PlannedDataReader; +import org.apache.iceberg.inmemory.InMemoryOutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TestBufferedFileAppender { + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "data", Types.StringType.get())); + + private InMemoryOutputFile outputFile; + private GenericRecord record; + + @BeforeEach + public void before() { + this.outputFile = new InMemoryOutputFile(); + this.record = GenericRecord.create(SCHEMA); + } + + private Function, FileAppender> avroFactory(OutputFile out) { + return bufferedRows -> { + try { + FileAppender appender = + Avro.write(out).createWriterFunc(DataWriter::create).schema(SCHEMA).overwrite().build(); + for (Record row : bufferedRows) { + appender.add(row); + } + return appender; + } catch (IOException e) { + throw new org.apache.iceberg.exceptions.RuntimeIOException(e); + } + }; + } + + private BufferedFileAppender createAppender(int bufferSize) { + return new BufferedFileAppender<>(bufferSize, avroFactory(outputFile), Record::copy); + } + + private Record createRecord(long id, String data) { + return record.copy(ImmutableMap.of("id", id, "data", data)); + } + + private List readBack() throws IOException { + try (AvroIterable reader = + Avro.read(outputFile.toInputFile()) + .project(SCHEMA) + .createResolvingReader(PlannedDataReader::create) + .build()) { + return Lists.newArrayList(reader); + } + } + + @Test + public void testBufferFlushesOnThreshold() throws IOException { + BufferedFileAppender appender = createAppender(3); + + appender.add(createRecord(1L, "a")); + appender.add(createRecord(2L, "b")); + + // delegate not yet created, length should be 0 + assertThat(appender.length()).isEqualTo(0L); + + appender.add(createRecord(3L, "c")); + + // delegate created after 3rd row, length should be > 0 + assertThat(appender.length()).isGreaterThan(0L); + + appender.add(createRecord(4L, "d")); + appender.add(createRecord(5L, "e")); + appender.close(); + + List actual = readBack(); + assertThat(actual).hasSize(5); + assertThat(actual.get(0).getField("id")).isEqualTo(1L); + assertThat(actual.get(4).getField("id")).isEqualTo(5L); + } + + @Test + public void testCloseWithPartialBuffer() throws IOException { + BufferedFileAppender appender = createAppender(10); + + appender.add(createRecord(1L, "a")); + appender.add(createRecord(2L, "b")); + appender.add(createRecord(3L, "c")); + + // buffer not full yet + assertThat(appender.length()).isEqualTo(0L); + + // close flushes partial buffer through factory + appender.close(); + + List actual = readBack(); + assertThat(actual).hasSize(3); + assertThat(actual.get(0).getField("data")).isEqualTo("a"); + assertThat(actual.get(2).getField("data")).isEqualTo("c"); + } + + @Test + public void testCopyFuncIsApplied() throws IOException { + BufferedFileAppender appender = createAppender(3); + + // use a single mutable record, relying on copyFunc to snapshot it + record.set(0, 1L); + record.set(1, "first"); + appender.add(record); + + record.set(0, 2L); + record.set(1, "second"); + appender.add(record); + + record.set(0, 3L); + record.set(1, "third"); + appender.add(record); + + appender.close(); + + List actual = readBack(); + assertThat(actual).hasSize(3); + // without copyFunc, all 3 rows would have the last values (3, "third") + assertThat(actual.get(0).getField("id")).isEqualTo(1L); + assertThat(actual.get(0).getField("data")).isEqualTo("first"); + assertThat(actual.get(1).getField("id")).isEqualTo(2L); + assertThat(actual.get(1).getField("data")).isEqualTo("second"); + } + + @Test + public void testMetricsAfterClose() throws IOException { + BufferedFileAppender appender = createAppender(2); + + appender.add(createRecord(1L, "a")); + appender.add(createRecord(2L, "b")); + appender.add(createRecord(3L, "c")); + appender.close(); + + assertThat(appender.metrics()).isNotNull(); + assertThat(appender.metrics().recordCount()).isEqualTo(3L); + assertThat(appender.length()).isGreaterThan(0L); + } + + @Test + public void testMetricsBeforeCloseThrows() throws IOException { + try (BufferedFileAppender appender = createAppender(10)) { + assertThatThrownBy(appender::metrics) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Cannot return metrics for unclosed appender"); + } + } + + @Test + public void testAddAfterCloseThrows() throws IOException { + try (BufferedFileAppender appender = createAppender(10)) { + appender.add(createRecord(1L, "a")); + appender.close(); + + assertThatThrownBy(() -> appender.add(createRecord(2L, "b"))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Cannot add to a closed appender"); + } + } + + @Test + public void testAddAllSpanningBuffer() throws IOException { + BufferedFileAppender appender = createAppender(2); + + List records = + Lists.newArrayList( + createRecord(1L, "a"), + createRecord(2L, "b"), + createRecord(3L, "c"), + createRecord(4L, "d")); + + appender.addAll(records); + appender.close(); + + List actual = readBack(); + assertThat(actual).hasSize(4); + assertThat(actual.get(0).getField("id")).isEqualTo(1L); + assertThat(actual.get(3).getField("id")).isEqualTo(4L); + } +} diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java deleted file mode 100644 index f7b6c591fa49..000000000000 --- a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.parquet; - -import org.apache.parquet.column.ColumnWriteStore; -import org.apache.parquet.column.ParquetProperties; -import org.apache.parquet.compression.CompressionCodecFactory; -import org.apache.parquet.crypto.InternalFileEncryptor; -import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; -import org.apache.parquet.schema.MessageType; - -/** - * Interface for ParquetValueWriters that need to defer initialization until they can analyze the - * data. This is useful for scenarios like variant shredding where the schema needs to be inferred - * from the actual data before creating the writer structures. - * - *

Writers implementing this interface can buffer initial rows and perform schema inference - * before committing to a final Parquet schema. - */ -public interface WriterLazyInitializable { - /** - * Result returned by lazy initialization of a ParquetValueWriter required by ParquetWriter. - * Contains the finalized schema and write stores after schema inference or other initialization - * logic. - */ - class InitializationResult { - private final MessageType schema; - private final ColumnChunkPageWriteStore pageStore; - private final ColumnWriteStore writeStore; - - public InitializationResult( - MessageType schema, ColumnChunkPageWriteStore pageStore, ColumnWriteStore writeStore) { - this.schema = schema; - this.pageStore = pageStore; - this.writeStore = writeStore; - } - - public MessageType getSchema() { - return schema; - } - - public ColumnChunkPageWriteStore getPageStore() { - return pageStore; - } - - public ColumnWriteStore getWriteStore() { - return writeStore; - } - } - - /** - * Checks if this writer still needs initialization. This will return true until the writer has - * buffered enough data to perform initialization (e.g., schema inference). - * - * @return true if initialization is still needed, false if already initialized - */ - boolean needsInitialization(); - - /** - * Performs initialization and returns the result containing updated schema and write stores. This - * method should only be called when {@link #needsInitialization()} returns true. - * - * @param props Parquet properties needed for creating write stores - * @param compressor Bytes compressor for compression - * @param rowGroupOrdinal The ordinal number of the current row group - * @param columnIndexTruncateLength The column index truncate length from ParquetWriter config - * @param fileEncryptor The file encryptor from ParquetWriter, may be null if encryption is - * disabled - * @return InitializationResult containing the finalized schema and write stores - */ - InitializationResult initialize( - ParquetProperties props, - CompressionCodecFactory.BytesInputCompressor compressor, - int rowGroupOrdinal, - int columnIndexTruncateLength, - InternalFileEncryptor fileEncryptor); -} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java deleted file mode 100644 index 5b9c10ff548f..000000000000 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.spark.source; - -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; -import org.apache.iceberg.FieldMetrics; -import org.apache.iceberg.Schema; -import org.apache.iceberg.parquet.ParquetValueWriter; -import org.apache.iceberg.parquet.TripleWriter; -import org.apache.iceberg.parquet.WriterLazyInitializable; -import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.spark.SparkSQLProperties; -import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; -import org.apache.iceberg.spark.data.SparkParquetWriters; -import org.apache.iceberg.types.Types; -import org.apache.parquet.column.ColumnWriteStore; -import org.apache.parquet.column.ParquetProperties; -import org.apache.parquet.compression.CompressionCodecFactory; -import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; -import org.apache.parquet.schema.MessageType; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructType; - -/** - * A Parquet output writer that performs variant shredding with schema inference. - * - *

The writer works in two phases: 1. Schema inference phase: Buffers initial rows and analyzes - * variant data to infer schemas 2. Writing phase: Creates the actual Parquet writer with inferred - * schemas and writes all data - */ -public class SparkParquetWriterWithVariantShredding - implements ParquetValueWriter, WriterLazyInitializable { - private final StructType sparkSchema; - private final MessageType parquetType; - private final Map properties; - - private final List bufferedRows; - private ParquetValueWriter actualWriter; - private boolean writerInitialized = false; - private final int bufferSize; - - private static class BufferedRow { - private final int repetitionLevel; - private final InternalRow row; - - BufferedRow(int repetitionLevel, InternalRow row) { - this.repetitionLevel = repetitionLevel; - this.row = row; - } - } - - public SparkParquetWriterWithVariantShredding( - StructType sparkSchema, MessageType parquetType, Map properties) { - this.sparkSchema = sparkSchema; - this.parquetType = parquetType; - this.properties = properties; - - this.bufferSize = - Integer.parseInt( - properties.getOrDefault( - SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, - String.valueOf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT))); - this.bufferedRows = Lists.newArrayList(); - } - - @Override - public void write(int repetitionLevel, InternalRow row) { - if (!writerInitialized) { - bufferedRows.add( - new BufferedRow( - repetitionLevel, row.copy())); /* Make a copy of the object since row gets reused */ - - if (bufferedRows.size() >= bufferSize) { - writerInitialized = true; - } - } else { - actualWriter.write(repetitionLevel, row); - } - } - - @Override - public List> columns() { - if (actualWriter != null) { - return actualWriter.columns(); - } - return Collections.emptyList(); - } - - @Override - public void setColumnStore(ColumnWriteStore columnStore) { - // Ignored for lazy initialization - will be set on actualWriter after initialization - if (actualWriter != null) { - actualWriter.setColumnStore(columnStore); - } - } - - @Override - public Stream> metrics() { - if (actualWriter != null) { - return actualWriter.metrics(); - } - return Stream.empty(); - } - - @Override - public boolean needsInitialization() { - return !writerInitialized; - } - - @Override - public InitializationResult initialize( - ParquetProperties props, - CompressionCodecFactory.BytesInputCompressor compressor, - int rowGroupOrdinal, - int columnIndexTruncateLength, - org.apache.parquet.crypto.InternalFileEncryptor fileEncryptor) { - if (bufferedRows.isEmpty()) { - throw new IllegalStateException("No buffered rows available for schema inference"); - } - - List rows = Lists.newLinkedList(); - for (BufferedRow bufferedRow : bufferedRows) { - rows.add(bufferedRow.row); - } - - MessageType shreddedSchema = - (MessageType) - ParquetWithSparkSchemaVisitor.visit( - sparkSchema, - parquetType, - new SchemaInferenceVisitor(rows, sparkSchema, properties)); - - actualWriter = SparkParquetWriters.buildWriter(sparkSchema, shreddedSchema); - - ColumnChunkPageWriteStore pageStore = - new ColumnChunkPageWriteStore( - compressor, - shreddedSchema, - props.getAllocator(), - columnIndexTruncateLength, - ParquetProperties.DEFAULT_PAGE_WRITE_CHECKSUM_ENABLED, - fileEncryptor, - rowGroupOrdinal); - - ColumnWriteStore columnStore = props.newColumnWriteStore(shreddedSchema, pageStore, pageStore); - - actualWriter.setColumnStore(columnStore); - - for (BufferedRow bufferedRow : bufferedRows) { - actualWriter.write(bufferedRow.repetitionLevel, bufferedRow.row); - columnStore.endRecord(); - } - - bufferedRows.clear(); - writerInitialized = true; - - return new InitializationResult(shreddedSchema, pageStore, columnStore); - } - - public static boolean shouldUseVariantShredding(Map properties, Schema schema) { - boolean shreddingEnabled = - properties.containsKey(SparkSQLProperties.SHRED_VARIANTS) - && Boolean.parseBoolean(properties.get(SparkSQLProperties.SHRED_VARIANTS)); - - boolean hasVariantFields = - schema.columns().stream().anyMatch(field -> field.type() instanceof Types.VariantType); - - return shreddingEnabled && hasVariantFields; - } -} From 53c6125c718231546e16992276d311738befc649 Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Mon, 23 Mar 2026 08:23:49 -0700 Subject: [PATCH 09/17] Adding VariantShreddingAnalyzer and withFileSchema support --- .../iceberg/formats/BaseFormatModel.java | 18 - .../org/apache/iceberg/parquet/Parquet.java | 29 +- .../iceberg/parquet/ParquetFormatModel.java | 10 +- .../apache/iceberg/parquet/ParquetWriter.java | 46 +- .../parquet/VariantShreddingAnalyzer.java | 487 ++++++++++++++++++ .../parquet/TestParquetDataWriter.java | 86 ++++ .../parquet/TestVariantShreddingAnalyzer.java | 309 +++++++++++ 7 files changed, 913 insertions(+), 72 deletions(-) create mode 100644 parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java create mode 100644 parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java diff --git a/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java b/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java index 9ddfefc2c9d9..7cba465670d9 100644 --- a/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java +++ b/core/src/main/java/org/apache/iceberg/formats/BaseFormatModel.java @@ -105,24 +105,6 @@ public interface WriterFunction { * @return a writer configured for the given schemas */ W write(Schema icebergSchema, F fileSchema, S engineSchema); - - /** - * Creates a writer for the given schemas and write properties. Implementations can use - * properties to customize writer behavior, such as enabling variant shredding. - * - *

The default implementation ignores properties and delegates to {@link #write(Schema, - * Object, Object)}. - * - * @param icebergSchema the Iceberg schema defining the table structure - * @param fileSchema the file format specific target schema for the output files - * @param engineSchema the engine-specific schema for the input data (optional) - * @param writeProperties writer configuration properties - * @return a writer configured for the given schemas - */ - default W write( - Schema icebergSchema, F fileSchema, S engineSchema, Map writeProperties) { - return write(icebergSchema, fileSchema, engineSchema); - } } /** diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java b/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java index 2387d52edf2f..5d725213fa82 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java @@ -163,6 +163,7 @@ public static class WriteBuilder implements InternalData.WriteBuilder { private final Map config = Maps.newLinkedHashMap(); private Schema schema = null; private VariantShreddingFunction variantShreddingFunc = null; + private MessageType fileSchema = null; private String name = "table"; private WriteSupport writeSupport = null; private BiFunction> createWriterFunc = null; @@ -208,6 +209,21 @@ public WriteBuilder variantShreddingFunc(VariantShreddingFunction func) { return this; } + /** + * Set a pre-computed Parquet {@link MessageType} to use as the file schema, bypassing the + * default conversion from the Iceberg schema. + * + *

The provided schema must have Parquet field IDs that match the Iceberg schema's field IDs. + * This method is mutually exclusive with {@link #variantShreddingFunc}. + * + * @param newFileSchema the Parquet message type to write + * @return this for method chaining + */ + public WriteBuilder withFileSchema(MessageType newFileSchema) { + this.fileSchema = newFileSchema; + return this; + } + @Override public WriteBuilder named(String newName) { this.name = newName; @@ -395,7 +411,13 @@ public FileAppender build() throws IOException { } set("parquet.avro.write-old-list-structure", "false"); - MessageType type = ParquetSchemaUtil.convert(schema, name, variantShreddingFunc); + Preconditions.checkArgument( + fileSchema == null || variantShreddingFunc == null, + "Cannot set both withFileSchema and variantShreddingFunc"); + MessageType type = + fileSchema != null + ? fileSchema + : ParquetSchemaUtil.convert(schema, name, variantShreddingFunc); FileEncryptionProperties fileEncryptionProperties = null; if (fileEncryptionKey != null) { @@ -851,6 +873,11 @@ public DataWriteBuilder variantShreddingFunc(VariantShreddingFunction func) { return this; } + public DataWriteBuilder withFileSchema(MessageType newFileSchema) { + appenderBuilder.withFileSchema(newFileSchema); + return this; + } + public DataWriteBuilder withSpec(PartitionSpec newSpec) { this.spec = newSpec; return this; diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java index 66913b66b128..90d6e3ef41ac 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java @@ -40,7 +40,6 @@ import org.apache.iceberg.mapping.NameMapping; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.schema.MessageType; @@ -99,7 +98,6 @@ private static class WriteBuilderWrapper implements ModelWriteBuilder collectedProperties = Maps.newHashMap(); private WriteBuilderWrapper( EncryptedOutputFile outputFile, @@ -123,7 +121,6 @@ public ModelWriteBuilder engineSchema(S newSchema) { @Override public ModelWriteBuilder set(String property, String value) { - collectedProperties.put(property, value); if (WRITER_VERSION_KEY.equals(property)) { internal.writerVersion(ParquetProperties.WriterVersion.valueOf(value)); } @@ -134,7 +131,6 @@ public ModelWriteBuilder set(String property, String value) { @Override public ModelWriteBuilder setAll(Map properties) { - collectedProperties.putAll(properties); internal.setAll(properties); return this; } @@ -188,15 +184,13 @@ public FileAppender build() throws IOException { internal.createContextFunc(Parquet.WriteBuilder.Context::dataContext); internal.createWriterFunc( (icebergSchema, messageType) -> - writerFunction.write( - icebergSchema, messageType, engineSchema, collectedProperties)); + writerFunction.write(icebergSchema, messageType, engineSchema)); break; case EQUALITY_DELETES: internal.createContextFunc(Parquet.WriteBuilder.Context::deleteContext); internal.createWriterFunc( (icebergSchema, messageType) -> - writerFunction.write( - icebergSchema, messageType, engineSchema, collectedProperties)); + writerFunction.write(icebergSchema, messageType, engineSchema)); break; case POSITION_DELETES: Preconditions.checkState( diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java index bdcfca7b2f94..2334e75532be 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java @@ -51,6 +51,7 @@ class ParquetWriter implements FileAppender, Closeable { private final Map metadata; private final ParquetProperties props; private final CompressionCodecFactory.BytesInputCompressor compressor; + private final MessageType parquetSchema; private final ParquetValueWriter model; private final MetricsConfig metricsConfig; private final int columnIndexTruncateLength; @@ -59,7 +60,6 @@ class ParquetWriter implements FileAppender, Closeable { private final Configuration conf; private final InternalFileEncryptor fileEncryptor; - private MessageType parquetSchema; private ColumnChunkPageWriteStore pageStore = null; private ColumnWriteStore writeStore; private long recordCount = 0; @@ -134,32 +134,6 @@ private void ensureWriterInitialized() { @Override public void add(T value) { - if (model instanceof WriterLazyInitializable lazy) { - if (lazy.needsInitialization()) { - model.write(0, value); - recordCount += 1; - - if (!lazy.needsInitialization()) { - WriterLazyInitializable.InitializationResult result = - lazy.initialize( - props, compressor, rowGroupOrdinal, columnIndexTruncateLength, fileEncryptor); - this.parquetSchema = result.getSchema(); - this.pageStore.close(); - this.pageStore = result.getPageStore(); - this.writeStore.close(); - this.writeStore = result.getWriteStore(); - - // Re-initialize the file writer with the new schema - ensureWriterInitialized(); - - // Buffered rows were already written with endRecord() calls - // in the lazy writer's initialization, so we don't call endRecord() here - checkSize(); - } - return; - } - } - recordCount += 1; model.write(0, value); writeStore.endRecord(); @@ -281,24 +255,6 @@ private void startRowGroup() { public void close() throws IOException { if (!closed) { this.closed = true; - - if (model instanceof WriterLazyInitializable lazy) { - // If initialization is not triggered with few data, lazy writer needs to initialize and - // process remaining buffered data - if (lazy.needsInitialization()) { - WriterLazyInitializable.InitializationResult result = - lazy.initialize( - props, compressor, rowGroupOrdinal, columnIndexTruncateLength, fileEncryptor); - this.parquetSchema = result.getSchema(); - this.pageStore.close(); - this.pageStore = result.getPageStore(); - this.writeStore.close(); - this.writeStore = result.getWriteStore(); - - ensureWriterInitialized(); - } - } - flushRowGroup(true); writeStore.close(); if (writer != null) { diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java new file mode 100644 index 000000000000..2659c5f2aaee --- /dev/null +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.variants.PhysicalType; +import org.apache.iceberg.variants.VariantArray; +import org.apache.iceberg.variants.VariantObject; +import org.apache.iceberg.variants.VariantPrimitive; +import org.apache.iceberg.variants.VariantValue; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; + +/** + * Analyzes variant data across buffered rows to determine an optimal shredding schema. + * + *

Determinism contract: for a given set of variant values (regardless of row arrival order), + * this analyzer produces the same shredded schema. + * + *

    + *
  • Object fields use a TreeMap, so field ordering is alphabetical and deterministic. + *
  • Type selection picks the most common type with explicit tie-break priority (see + * TIE_BREAK_PRIORITY), not enum ordinal. + *
  • Integer types (INT8/16/32/64) and decimal types (DECIMAL4/8/16) are each promoted to the + * widest observed before competing with other types. + *
  • Fields below {@code MIN_FIELD_FREQUENCY} are pruned. Above {@code MAX_SHREDDED_FIELDS}, the + * most frequent are kept with alphabetical tie-breaking. + *
  • Recursion into nested objects/arrays stops at {@code MAX_SHREDDING_DEPTH} (default 50). + *
  • New struct fields are not tracked once a node reaches {@code MAX_INTERMEDIATE_FIELDS} + * (default 1000) to bound memory during inference. + *
+ * + *

This contract holds within a single batch. Different batches with different distributions may + * produce different layouts; cross-batch stability requires schema pinning (not yet implemented). + * + *

Subclasses implement {@link #extractVariantValues} to convert engine-specific row types into + * {@link VariantValue} instances. + * + * @param the engine-specific row type (e.g., Spark InternalRow, Flink RowData) + */ +public abstract class VariantShreddingAnalyzer { + private static final String TYPED_VALUE = "typed_value"; + private static final String VALUE = "value"; + private static final String ELEMENT = "element"; + private static final double MIN_FIELD_FREQUENCY = 0.10; + private static final int MAX_SHREDDED_FIELDS = 300; + private static final int MAX_SHREDDING_DEPTH = 50; + private static final int MAX_INTERMEDIATE_FIELDS = 1000; + + protected VariantShreddingAnalyzer() {} + + /** + * Analyzes buffered variant values to determine the optimal shredding schema. + * + * @param bufferedRows the buffered rows to analyze + * @param variantFieldIndex the index of the variant field in the rows + * @return the shredded schema type, or null if no shredding should be performed + */ + public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) { + List variantValues = extractVariantValues(bufferedRows, variantFieldIndex); + if (variantValues.isEmpty()) { + return null; + } + + PathNode root = buildPathTree(variantValues); + PhysicalType rootType = root.info.getMostCommonType(); + if (rootType == null) { + return null; + } + + if (rootType == PhysicalType.OBJECT) { + pruneInfrequentFields(root, root.info.observationCount); + } + + return buildTypedValue(root, rootType); + } + + protected abstract List extractVariantValues( + List bufferedRows, int variantFieldIndex); + + private static PathNode buildPathTree(List variantValues) { + PathNode root = new PathNode(null); + root.info = new FieldInfo(); + + for (VariantValue value : variantValues) { + traverse(root, value, 0); + } + + return root; + } + + private static void pruneInfrequentFields(PathNode node, int totalRows) { + if (node.objectChildren.isEmpty()) { + return; + } + + // Remove fields below frequency threshold + node.objectChildren + .entrySet() + .removeIf( + entry -> { + FieldInfo info = entry.getValue().info; + return info != null + && ((double) info.observationCount / totalRows) < MIN_FIELD_FREQUENCY; + }); + + // Cap at MAX_SHREDDED_FIELDS, keep the most frequently observed + if (node.objectChildren.size() > MAX_SHREDDED_FIELDS) { + List> sorted = Lists.newArrayList(node.objectChildren.entrySet()); + sorted.sort( + (a, b) -> { + int cmp = + Integer.compare( + b.getValue().info.observationCount, a.getValue().info.observationCount); + return cmp != 0 ? cmp : a.getKey().compareTo(b.getKey()); + }); + Set keep = Sets.newHashSet(); + for (int i = 0; i < MAX_SHREDDED_FIELDS; i++) { + keep.add(sorted.get(i).getKey()); + } + node.objectChildren.entrySet().removeIf(entry -> !keep.contains(entry.getKey())); + } + + // Recurse into remaining children + for (PathNode child : node.objectChildren.values()) { + pruneInfrequentFields(child, totalRows); + } + } + + private static void traverse(PathNode node, VariantValue value, int depth) { + if (value == null || value.type() == PhysicalType.NULL) { + return; + } + + node.info.observe(value); + + if (value.type() == PhysicalType.OBJECT && depth < MAX_SHREDDING_DEPTH) { + traverseObject(node, value.asObject(), depth); + } else if (value.type() == PhysicalType.ARRAY && depth < MAX_SHREDDING_DEPTH) { + traverseArray(node, value.asArray(), depth); + } + } + + private static void traverseObject(PathNode node, VariantObject obj, int depth) { + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + PathNode childNode = node.objectChildren.get(fieldName); + if (childNode == null) { + if (node.objectChildren.size() >= MAX_INTERMEDIATE_FIELDS) { + continue; + } + childNode = new PathNode(fieldName); + childNode.info = new FieldInfo(); + node.objectChildren.put(fieldName, childNode); + } + traverse(childNode, fieldValue, depth + 1); + } + } + } + + private static void traverseArray(PathNode node, VariantArray array, int depth) { + int numElements = array.numElements(); + if (node.arrayElement == null) { + node.arrayElement = new PathNode(null); + node.arrayElement.info = new FieldInfo(); + } + for (int i = 0; i < numElements; i++) { + VariantValue element = array.get(i); + if (element != null) { + traverse(node.arrayElement, element, depth + 1); + } + } + } + + private static Type buildFieldGroup(PathNode node) { + PhysicalType commonType = node.info.getMostCommonType(); + if (commonType == null) { + return null; + } + + Type typedValue = buildTypedValue(node, commonType); + if (typedValue == null) { + return null; + } + + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(typedValue) + .named(node.fieldName); + } + + private static Type buildTypedValue(PathNode node, PhysicalType physicalType) { + return switch (physicalType) { + case ARRAY -> createArrayTypedValue(node); + case OBJECT -> createObjectTypedValue(node); + default -> createPrimitiveTypedValue(node.info, physicalType); + }; + } + + private static Type createObjectTypedValue(PathNode node) { + if (node.objectChildren.isEmpty()) { + return null; + } + + Types.GroupBuilder builder = Types.buildGroup(Type.Repetition.OPTIONAL); + boolean hasFields = false; + for (PathNode child : node.objectChildren.values()) { + Type fieldType = buildFieldGroup(child); + if (fieldType != null) { + builder.addField(fieldType); + hasFields = true; + } + } + + return hasFields ? builder.named(TYPED_VALUE) : null; + } + + private static Type createArrayTypedValue(PathNode node) { + PathNode elementNode = node.arrayElement; + if (elementNode == null) { + return null; + } + PhysicalType elementType = elementNode.info.getMostCommonType(); + if (elementType == null) { + return null; + } + Type elementTypedValue = buildTypedValue(elementNode, elementType); + if (elementTypedValue == null) { + return null; + } + + GroupType elementGroup = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(elementTypedValue) + .named(ELEMENT); + + return Types.optionalList().element(elementGroup).named(TYPED_VALUE); + } + + private static class PathNode { + private final String fieldName; + private final Map objectChildren = Maps.newTreeMap(); + private PathNode arrayElement = null; + private FieldInfo info = null; + + private PathNode(String fieldName) { + this.fieldName = fieldName; + } + } + + /** Use DECIMAL with maximum precision and scale as the shredding type */ + private static Type createDecimalTypedValue(FieldInfo info) { + int maxPrecision = Math.min(info.maxDecimalIntegerDigits + info.maxDecimalScale, 38); + int maxScale = Math.min(info.maxDecimalScale, maxPrecision); + + if (maxPrecision <= 9) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else if (maxPrecision <= 18) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else { + return Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } + } + + private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primitiveType) { + return switch (primitiveType) { + case BOOLEAN_TRUE, BOOLEAN_FALSE -> + Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); + case INT8 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(8, true)) + .named(TYPED_VALUE); + case INT16 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(16, true)) + .named(TYPED_VALUE); + case INT32 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(32, true)) + .named(TYPED_VALUE); + case INT64 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.intType(64, true)) + .named(TYPED_VALUE); + case FLOAT -> Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); + case DOUBLE -> Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); + case STRING -> + Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named(TYPED_VALUE); + case BINARY -> Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); + case TIME -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timeType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case DATE -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.dateType()) + .named(TYPED_VALUE); + case TIMESTAMPTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case DECIMAL4, DECIMAL8, DECIMAL16 -> createDecimalTypedValue(info); + case UUID -> + Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.uuidType()) + .named(TYPED_VALUE); + default -> + throw new UnsupportedOperationException( + "Unknown primitive physical type: " + primitiveType); + }; + } + + /** Tracks occurrence count and types for a single field. */ + private static class FieldInfo { + private final Map typeCounts = Maps.newHashMap(); + private int maxDecimalScale = 0; + private int maxDecimalIntegerDigits = 0; + private int observationCount = 0; + + private static final Map INTEGER_PRIORITY = + Map.of( + PhysicalType.INT8, 0, + PhysicalType.INT16, 1, + PhysicalType.INT32, 2, + PhysicalType.INT64, 3); + + private static final Map DECIMAL_PRIORITY = + Map.of( + PhysicalType.DECIMAL4, 0, + PhysicalType.DECIMAL8, 1, + PhysicalType.DECIMAL16, 2); + + private static final Map TIE_BREAK_PRIORITY = + Map.ofEntries( + Map.entry(PhysicalType.BOOLEAN_TRUE, 0), + Map.entry(PhysicalType.INT8, 1), + Map.entry(PhysicalType.INT16, 2), + Map.entry(PhysicalType.INT32, 3), + Map.entry(PhysicalType.INT64, 4), + Map.entry(PhysicalType.FLOAT, 5), + Map.entry(PhysicalType.DOUBLE, 6), + Map.entry(PhysicalType.DECIMAL4, 7), + Map.entry(PhysicalType.DECIMAL8, 8), + Map.entry(PhysicalType.DECIMAL16, 9), + Map.entry(PhysicalType.DATE, 10), + Map.entry(PhysicalType.TIME, 11), + Map.entry(PhysicalType.TIMESTAMPTZ, 12), + Map.entry(PhysicalType.TIMESTAMPNTZ, 13), + Map.entry(PhysicalType.BINARY, 14), + Map.entry(PhysicalType.STRING, 15), + Map.entry(PhysicalType.TIMESTAMPTZ_NANOS, 16), + Map.entry(PhysicalType.TIMESTAMPNTZ_NANOS, 17), + Map.entry(PhysicalType.UUID, 18)); + + void observe(VariantValue value) { + observationCount++; + // Use BOOLEAN_TRUE for both TRUE/FALSE values + PhysicalType type = + value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); + + typeCounts.compute(type, (k, v) -> (v == null) ? 1 : v + 1); + + // Track max precision and scale for decimal types + if (type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16) { + VariantPrimitive primitive = value.asPrimitive(); + Object decimalValue = primitive.get(); + if (decimalValue instanceof BigDecimal bd) { + maxDecimalIntegerDigits = Math.max(maxDecimalIntegerDigits, bd.precision() - bd.scale()); + maxDecimalScale = Math.max(maxDecimalScale, bd.scale()); + } + } + } + + PhysicalType getMostCommonType() { + Map combinedCounts = Maps.newHashMap(); + + int integerTotalCount = 0; + PhysicalType mostCapableInteger = null; + + int decimalTotalCount = 0; + PhysicalType mostCapableDecimal = null; + + for (Map.Entry entry : typeCounts.entrySet()) { + PhysicalType type = entry.getKey(); + int count = entry.getValue(); + + if (isIntegerType(type)) { + integerTotalCount += count; + if (mostCapableInteger == null + || INTEGER_PRIORITY.get(type) > INTEGER_PRIORITY.get(mostCapableInteger)) { + mostCapableInteger = type; + } + } else if (isDecimalType(type)) { + decimalTotalCount += count; + if (mostCapableDecimal == null + || DECIMAL_PRIORITY.get(type) > DECIMAL_PRIORITY.get(mostCapableDecimal)) { + mostCapableDecimal = type; + } + } else { + combinedCounts.put(type, count); + } + } + + if (mostCapableInteger != null) { + combinedCounts.put(mostCapableInteger, integerTotalCount); + } + + if (mostCapableDecimal != null) { + combinedCounts.put(mostCapableDecimal, decimalTotalCount); + } + + // Pick the most common type with tie-breaking + return combinedCounts.entrySet().stream() + .max( + Map.Entry.comparingByValue() + .thenComparingInt(entry -> TIE_BREAK_PRIORITY.getOrDefault(entry.getKey(), -1))) + .map(Map.Entry::getKey) + .orElse(null); + } + + private static boolean isIntegerType(PhysicalType type) { + return type == PhysicalType.INT8 + || type == PhysicalType.INT16 + || type == PhysicalType.INT32 + || type == PhysicalType.INT64; + } + + private static boolean isDecimalType(PhysicalType type) { + return type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16; + } + } +} diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java index 3918fdc63084..74e8e62fea0f 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java @@ -55,7 +55,10 @@ import org.apache.iceberg.variants.Variants; import org.apache.parquet.hadoop.ParquetFileReader; import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -331,4 +334,87 @@ public void testDataWriterWithVariantShredding() throws IOException { testDataWriter( variantSchema, (id, name) -> ParquetVariantUtil.toParquetSchema(variant.value())); } + + @Test + public void testWithFileSchemaOverride() throws IOException { + Schema variantSchema = + new Schema( + ImmutableList.builder() + .addAll(SCHEMA.columns()) + .add(Types.NestedField.optional(4, "variant", Types.VariantType.get())) + .build()); + MessageType customSchema = + org.apache.parquet.schema.Types.buildMessage() + .required(PrimitiveType.PrimitiveTypeName.INT64) + .id(1) + .named("id") + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .id(2) + .named("data") + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .id(3) + .named("binary") + .addField( + org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .id(4) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .optional(PrimitiveType.PrimitiveTypeName.INT32) + .named("typed_value") + .named("variant")) + .named("table"); + ByteBuffer metadataBuffer = VariantTestUtil.createMetadata(ImmutableList.of("a"), true); + VariantMetadata metadata = Variants.metadata(metadataBuffer); + ByteBuffer objectBuffer = + VariantTestUtil.createObject(metadataBuffer, ImmutableMap.of("a", Variants.of(42))); + Variant variant = Variant.of(metadata, Variants.value(metadata, objectBuffer)); + + GenericRecord record = GenericRecord.create(variantSchema); + List variantRecords = + ImmutableList.of( + record.copy(ImmutableMap.of("id", 1L, "variant", variant)), + record.copy(ImmutableMap.of("id", 2L, "variant", variant))); + + OutputFile file = Files.localOutput(createTempFile(temp)); + DataWriter dataWriter = + Parquet.writeData(file) + .schema(variantSchema) + .withFileSchema(customSchema) + .createWriterFunc(GenericParquetWriter::create) + .overwrite() + .withSpec(PartitionSpec.unpartitioned()) + .build(); + + try (dataWriter) { + for (Record rec : variantRecords) { + dataWriter.write(rec); + } + } + + // Verify physical schema matches the override + try (ParquetFileReader reader = ParquetFileReader.open(ParquetIO.file(file.toInputFile()))) { + MessageType actualSchema = reader.getFooter().getFileMetaData().getSchema(); + assertThat(actualSchema).isEqualTo(customSchema); + } + + // Verify data round-trips correctly + List writtenRecords; + try (CloseableIterable reader = + Parquet.read(file.toInputFile()) + .project(variantSchema) + .createReaderFunc( + fileSchema -> GenericParquetReaders.buildReader(variantSchema, fileSchema)) + .build()) { + writtenRecords = Lists.newArrayList(reader); + } + + assertThat(writtenRecords).hasSameSizeAs(variantRecords); + for (int i = 0; i < variantRecords.size(); i++) { + InternalTestHelpers.assertEquals( + variantSchema.asStruct(), variantRecords.get(i), writtenRecords.get(i)); + } + } } diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java new file mode 100644 index 000000000000..38996c31a638 --- /dev/null +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Locale; +import org.apache.iceberg.variants.ShreddedObject; +import org.apache.iceberg.variants.ValueArray; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; +import org.apache.iceberg.variants.Variants; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.junit.jupiter.api.Test; + +public class TestVariantShreddingAnalyzer { + + private static class DirectAnalyzer extends VariantShreddingAnalyzer { + @Override + protected List extractVariantValues(List rows, int idx) { + return rows; + } + } + + @Test + public void testDepthLimitStopsObjectRecursion() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // Each level has {"a": , "x": 1} so objects always have a shreddable primitive + VariantMetadata meta = Variants.metadata("a", "x"); + ShreddedObject innermost = Variants.object(meta); + innermost.put("a", Variants.of(42)); + innermost.put("x", Variants.of(1)); + + for (int i = 0; i < 54; i++) { + ShreddedObject wrapper = Variants.object(meta); + wrapper.put("a", innermost); + wrapper.put("x", Variants.of(1)); + innermost = wrapper; + } + + Type schema = analyzer.analyzeAndCreateSchema(List.of(innermost), 0); + assertThat(schema).isNotNull(); + assertThat(schema.getName()).isEqualTo("typed_value"); + + int shreddedDepth = countObjectDepth(schema); + assertThat(shreddedDepth).isLessThanOrEqualTo(50).isGreaterThan(0); + } + + @Test + public void testDepthLimitStopsArrayRecursion() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // 55-level nested arrays with a primitive only at the very bottom. + // Depth limit (50) prevents reaching the leaf, so schema is null (graceful degradation). + VariantValue innermost = Variants.of(42); + for (int i = 0; i < 55; i++) { + ValueArray wrapper = Variants.array(); + wrapper.add(innermost); + innermost = wrapper; + } + + Type schema = analyzer.analyzeAndCreateSchema(List.of(innermost), 0); + assertThat(schema).isNull(); + } + + @Test + public void testArrayWithinDepthLimit() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // 5-level nested arrays + VariantValue innermost = Variants.of(42); + for (int i = 0; i < 5; i++) { + ValueArray wrapper = Variants.array(); + wrapper.add(innermost); + innermost = wrapper; + } + + Type schema = analyzer.analyzeAndCreateSchema(List.of(innermost), 0); + assertThat(schema).isNotNull(); + assertThat(schema.getName()).isEqualTo("typed_value"); + + int arrayDepth = countArrayDepth(schema); + assertThat(arrayDepth).isEqualTo(5); + } + + @Test + public void testIntermediateFieldCapLimitsTrackedFields() { + int numFields = 1500; + String[] fieldNames = new String[numFields]; + for (int i = 0; i < numFields; i++) { + fieldNames[i] = String.format(Locale.ROOT, "field_%04d", i); + } + + VariantMetadata meta = Variants.metadata(fieldNames); + ShreddedObject obj = Variants.object(meta); + for (String name : fieldNames) { + obj.put(name, Variants.of(42)); + } + + DirectAnalyzer analyzer = new DirectAnalyzer(); + Type schema = analyzer.analyzeAndCreateSchema(List.of(obj), 0); + + assertThat(schema).isNotNull(); + assertThat(schema).isInstanceOf(GroupType.class); + GroupType typedValue = (GroupType) schema; + assertThat(typedValue.getFieldCount()).isLessThanOrEqualTo(300).isGreaterThan(0); + } + + @Test + public void testFieldCapAllowsExistingFieldUpdates() { + int numFields = 1500; + String[] fieldNames = new String[numFields]; + for (int i = 0; i < numFields; i++) { + fieldNames[i] = String.format(Locale.ROOT, "field_%04d", i); + } + + VariantMetadata meta = Variants.metadata(fieldNames); + + ShreddedObject row1 = Variants.object(meta); + for (String name : fieldNames) { + row1.put(name, Variants.of(42)); + } + + ShreddedObject row2 = Variants.object(meta); + for (int i = 0; i < 10; i++) { + row2.put(fieldNames[i], Variants.of("text")); + } + + ShreddedObject row3 = Variants.object(meta); + for (int i = 0; i < 10; i++) { + row3.put(fieldNames[i], Variants.of(99)); + } + + DirectAnalyzer analyzer = new DirectAnalyzer(); + Type schema = analyzer.analyzeAndCreateSchema(List.of(row1, row2, row3), 0); + + assertThat(schema).isNotNull(); + assertThat(schema).isInstanceOf(GroupType.class); + GroupType typedValue = (GroupType) schema; + assertThat(typedValue.getFieldCount()).isGreaterThan(0).isLessThanOrEqualTo(300); + } + + @Test + public void testNestedObjectsWithinDepthLimit() { + VariantMetadata cityMeta = Variants.metadata("city"); + ShreddedObject city = Variants.object(cityMeta); + city.put("city", Variants.of("NYC")); + + VariantMetadata addrMeta = Variants.metadata("address"); + ShreddedObject addr = Variants.object(addrMeta); + addr.put("address", city); + + VariantMetadata rootMeta = Variants.metadata("user"); + ShreddedObject root = Variants.object(rootMeta); + root.put("user", addr); + + DirectAnalyzer analyzer = new DirectAnalyzer(); + Type schema = analyzer.analyzeAndCreateSchema(List.of(root), 0); + + assertThat(schema).isNotNull(); + GroupType rootTv = schema.asGroupType(); + assertThat(rootTv.getName()).isEqualTo("typed_value"); + + // user -> typed_value -> address -> typed_value -> city -> typed_value (STRING) + GroupType userGroup = rootTv.getType("user").asGroupType(); + assertThat(userGroup.containsField("value")).isTrue(); + assertThat(userGroup.containsField("typed_value")).isTrue(); + + GroupType addrTv = userGroup.getType("typed_value").asGroupType(); + GroupType addrGroup = addrTv.getType("address").asGroupType(); + assertThat(addrGroup.containsField("typed_value")).isTrue(); + + GroupType cityTv = addrGroup.getType("typed_value").asGroupType(); + GroupType cityGroup = cityTv.getType("city").asGroupType(); + assertThat(cityGroup.containsField("typed_value")).isTrue(); + + PrimitiveType cityPrimitive = cityGroup.getType("typed_value").asPrimitiveType(); + assertThat(cityPrimitive.getPrimitiveTypeName()) + .isEqualTo(PrimitiveType.PrimitiveTypeName.BINARY); + assertThat(cityPrimitive.getLogicalTypeAnnotation()) + .isEqualTo(LogicalTypeAnnotation.stringType()); + } + + @Test + public void testDecimalForExceedingPrecision() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + // Value 1: 30 integer digits, 0 fractional -> precision=30, scale=0, intDigits=30 + // Value 2: 1 integer digit, 20 fractional -> precision=21, scale=20, intDigits=1 + // Combined: maxIntDigits=30, maxScale=20, raw sum=50 -> capped to precision=38, + // scale=min(20,38)=20 + VariantMetadata meta = Variants.metadata("val"); + ShreddedObject row1 = Variants.object(meta); + row1.put("val", Variants.of(new java.math.BigDecimal("123456789012345678901234567890"))); + + ShreddedObject row2 = Variants.object(meta); + row2.put("val", Variants.of(new java.math.BigDecimal("1.23456789012345678901"))); + + Type schema = analyzer.analyzeAndCreateSchema(List.of(row1, row2), 0); + assertThat(schema).isNotNull(); + + GroupType typedValue = schema.asGroupType(); + GroupType valGroup = typedValue.getType("val").asGroupType(); + PrimitiveType valPrimitive = valGroup.getType("typed_value").asPrimitiveType(); + + LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimal = + (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) + valPrimitive.getLogicalTypeAnnotation(); + assertThat(decimal).isNotNull(); + assertThat(decimal.getPrecision()).isEqualTo(38); + assertThat(decimal.getScale()).isLessThanOrEqualTo(38); + // Scale must not exceed precision + assertThat(decimal.getScale()).isLessThanOrEqualTo(decimal.getPrecision()); + + // Physical type should be FIXED_LEN_BYTE_ARRAY since precision > 18 + assertThat(valPrimitive.getPrimitiveTypeName()) + .isEqualTo(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY); + } + + @Test + public void testDecimalForExactPrecision() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // Value with exactly precision=38: 20 integer digits + 18 scale = 38 + VariantMetadata meta = Variants.metadata("val"); + ShreddedObject row = Variants.object(meta); + row.put( + "val", Variants.of(new java.math.BigDecimal("12345678901234567890.123456789012345678"))); + + Type schema = analyzer.analyzeAndCreateSchema(List.of(row), 0); + assertThat(schema).isNotNull(); + + GroupType typedValue = schema.asGroupType(); + GroupType valGroup = typedValue.getType("val").asGroupType(); + PrimitiveType valPrimitive = valGroup.getType("typed_value").asPrimitiveType(); + + LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimal = + (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) + valPrimitive.getLogicalTypeAnnotation(); + assertThat(decimal.getPrecision()).isEqualTo(38); + assertThat(decimal.getScale()).isEqualTo(18); + } + + /** Count typed_value group nesting depth along field "a". */ + private static int countObjectDepth(Type type) { + int depth = 0; + Type current = type; + while (current != null && "typed_value".equals(current.getName()) && !current.isPrimitive()) { + depth++; + GroupType group = current.asGroupType(); + if (group.containsField("a")) { + GroupType fieldGroup = group.getType("a").asGroupType(); + if (fieldGroup.containsField("typed_value")) { + current = fieldGroup.getType("typed_value"); + } else { + break; + } + } else { + break; + } + } + return depth; + } + + /** Count nested array (LIST) levels in the schema. */ + private static int countArrayDepth(Type type) { + int depth = 0; + Type current = type; + while (current != null && !current.isPrimitive()) { + if (!"typed_value".equals(current.getName())) { + break; + } + GroupType group = current.asGroupType(); + if (!(group.getLogicalTypeAnnotation() + instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation)) { + break; + } + depth++; + GroupType listGroup = group.getType(0).asGroupType(); + GroupType elementGroup = listGroup.getType(0).asGroupType(); + if (elementGroup.containsField("typed_value")) { + current = elementGroup.getType("typed_value"); + } else { + break; + } + } + return depth; + } +} From 50b01a476a98480d736d25c17e4002d7160baaf7 Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Mon, 23 Mar 2026 08:24:30 -0700 Subject: [PATCH 10/17] Wiring the variant shredding write path via BufferedFileAppender --- .../iceberg/spark/SparkSQLProperties.java | 4 +- .../spark/source/SchemaInferenceVisitor.java | 54 +- .../spark/source/SparkFileWriterFactory.java | 111 ++++ .../spark/source/SparkFormatModels.java | 35 +- .../source/SparkVariantShreddingAnalyzer.java | 59 +++ .../source/VariantShreddingAnalyzer.java | 483 ------------------ .../iceberg/spark/TestSparkWriteConf.java | 6 +- .../spark/variant/TestVariantShredding.java | 214 +++++--- 8 files changed, 347 insertions(+), 619 deletions(-) create mode 100644 spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java delete mode 100644 spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index b527bda5aea2..a1f299d3dc30 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -117,11 +117,11 @@ private SparkSQLProperties() {} // Controls whether to shred variant columns during write operations public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; - public static final boolean SHRED_VARIANTS_DEFAULT = true; + public static final boolean SHRED_VARIANTS_DEFAULT = false; // Controls the buffer size for variant schema inference during writes // This determines how many rows are buffered before inferring shredded schema public static final String VARIANT_INFERENCE_BUFFER_SIZE = "spark.sql.iceberg.variant.inference.buffer-size"; - public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; + public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 100; } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index 06a79b8dcef0..8b7d83972d68 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -21,7 +21,6 @@ import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; import java.util.List; -import java.util.Map; import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; import org.apache.iceberg.variants.Variant; import org.apache.parquet.schema.GroupType; @@ -41,18 +40,17 @@ import org.slf4j.LoggerFactory; /** A visitor that infers variant shredding schemas by analyzing buffered rows of data. */ -public class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { +class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); private final List bufferedRows; private final StructType sparkSchema; - private final VariantShreddingAnalyzer analyzer; + private final SparkVariantShreddingAnalyzer analyzer; - public SchemaInferenceVisitor( - List bufferedRows, StructType sparkSchema, Map properties) { + public SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { this.bufferedRows = bufferedRows; this.sparkSchema = sparkSchema; - this.analyzer = new VariantShreddingAnalyzer(); + this.analyzer = new SparkVariantShreddingAnalyzer(); } @Override @@ -92,36 +90,46 @@ public Type primitive(DataType sPrimitive, PrimitiveType primitive) { @Override public Type list(ArrayType sArray, GroupType array, Type element) { + if (element == null) { + return array; + } + + GroupType repeatedGroup = array.getType(0).asGroupType(); + Types.GroupBuilder repeatedBuilder = + Types.buildGroup(repeatedGroup.getRepetition()).addField(element); + Types.GroupBuilder builder = Types.buildGroup(array.getRepetition()).as(LogicalTypeAnnotation.listType()); - if (array.getId() != null) { builder = builder.id(array.getId().intValue()); } - - if (element != null) { - builder = builder.addField(element); - } + builder = builder.addField(repeatedBuilder.named(repeatedGroup.getName())); return builder.named(array.getName()); } @Override public Type map(MapType sMap, GroupType map, Type key, Type value) { - Types.GroupBuilder builder = - Types.buildGroup(map.getRepetition()).as(LogicalTypeAnnotation.mapType()); - - if (map.getId() != null) { - builder = builder.id(map.getId().intValue()); + if (key == null && value == null) { + return map; } + GroupType repeatedGroup = map.getType(0).asGroupType(); + Types.GroupBuilder repeatedBuilder = Types.buildGroup(repeatedGroup.getRepetition()); if (key != null) { - builder = builder.addField(key); + repeatedBuilder = repeatedBuilder.addField(key); } if (value != null) { - builder = builder.addField(value); + repeatedBuilder = repeatedBuilder.addField(value); } + Types.GroupBuilder builder = + Types.buildGroup(map.getRepetition()).as(LogicalTypeAnnotation.mapType()); + if (map.getId() != null) { + builder = builder.id(map.getId().intValue()); + } + builder = builder.addField(repeatedBuilder.named(repeatedGroup.getName())); + return builder.named(map.getName()); } @@ -132,9 +140,13 @@ public Type variant(VariantType sVariant, GroupType variant) { if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { Type shreddedType = analyzer.analyzeAndCreateSchema(bufferedRows, variantFieldIndex); if (shreddedType != null) { - return Types.buildGroup(variant.getRepetition()) - .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) - .id(variant.getId().intValue()) + Types.GroupBuilder builder = + Types.buildGroup(variant.getRepetition()) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)); + if (variant.getId() != null) { + builder = builder.id(variant.getId().intValue()); + } + return builder .required(BINARY) .named("metadata") .optional(BINARY) diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java index 39110f0b0597..055173a435ab 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -25,7 +25,9 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.util.List; import java.util.Map; +import java.util.function.Function; import org.apache.iceberg.FileFormat; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.PartitionSpec; @@ -37,15 +39,23 @@ import org.apache.iceberg.data.RegistryBasedFileWriterFactory; import org.apache.iceberg.deletes.PositionDeleteWriter; import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.io.BufferedFileAppender; +import org.apache.iceberg.io.DataWriter; import org.apache.iceberg.io.DeleteSchemaUtil; +import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.parquet.ParquetSchemaUtil; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; import org.apache.iceberg.spark.data.SparkAvroWriter; import org.apache.iceberg.spark.data.SparkOrcWriter; import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.MessageType; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; @@ -62,6 +72,10 @@ class SparkFileWriterFactory extends RegistryBasedFileWriterFactory writeProperties; + private final Schema dataSchema; + private final StructType dataSparkType; + private final FileFormat dataFileFormat; + private final SortOrder dataSortOrder; /** * @deprecated This constructor is deprecated as of version 1.11.0 and will be removed in 1.12.0. @@ -102,6 +116,10 @@ class SparkFileWriterFactory extends RegistryBasedFileWriterFactory newPositionDeleteWriter( } } + @Override + public DataWriter newDataWriter( + EncryptedOutputFile file, PartitionSpec spec, StructLike partition) { + if (!shouldUseVariantShredding()) { + return super.newDataWriter(file, spec, partition); + } + + int bufferSize = + Integer.parseInt( + writeProperties.getOrDefault( + SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, + String.valueOf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT))); + + Map tableProperties = table != null ? table.properties() : ImmutableMap.of(); + MetricsConfig metricsConfig = + table != null ? MetricsConfig.forTable(table) : MetricsConfig.getDefault(); + + Function, FileAppender> appenderFactory = + bufferedRows -> { + Preconditions.checkNotNull(bufferedRows, "bufferedRows must not be null"); + MessageType originalSchema = ParquetSchemaUtil.convert(dataSchema, "table"); + + MessageType shreddedSchema = + (MessageType) + ParquetWithSparkSchemaVisitor.visit( + dataSparkType, + originalSchema, + new SchemaInferenceVisitor(bufferedRows, dataSparkType)); + + try { + FileAppender appender = + Parquet.write(file) + .schema(dataSchema) + .withFileSchema(shreddedSchema) + .createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(dataSparkType, msgType)) + .setAll(tableProperties) + .setAll(writeProperties) + .metricsConfig(metricsConfig) + .overwrite() + .build(); + + try { + for (InternalRow row : bufferedRows) { + appender.add(row); + } + } catch (RuntimeException e) { + try { + appender.close(); + } catch (IOException suppressed) { + e.addSuppressed(suppressed); + } + throw e; + } + + return appender; + } catch (IOException e) { + throw new UncheckedIOException("Failed to create shredded variant writer", e); + } + }; + + BufferedFileAppender bufferedAppender = + new BufferedFileAppender<>(bufferSize, appenderFactory, InternalRow::copy); + + return new DataWriter<>( + bufferedAppender, + dataFileFormat, + file.encryptingOutputFile().location(), + spec, + partition, + file.keyMetadata(), + dataSortOrder); + } + static class Builder { private final Table table; private FileFormat dataFileFormat; @@ -361,4 +457,19 @@ private static StructType useOrConvert(StructType sparkType, Schema schema) { return null; } } + + private boolean shouldUseVariantShredding() { + // Variant shredding is currently only supported for Parquet files + if (dataFileFormat != FileFormat.PARQUET) { + return false; + } + + boolean shreddingEnabled = + Boolean.parseBoolean(writeProperties.get(SparkSQLProperties.SHRED_VARIANTS)); + + return shreddingEnabled + && dataSchema != null + && dataSchema.columns().stream() + .anyMatch(field -> field.type() instanceof Types.VariantType); + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java index 69f51be12158..a1e8a82b4d80 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java @@ -18,14 +18,10 @@ */ package org.apache.iceberg.spark.source; -import java.util.Map; -import org.apache.iceberg.Schema; import org.apache.iceberg.avro.AvroFormatModel; -import org.apache.iceberg.formats.BaseFormatModel; import org.apache.iceberg.formats.FormatModelRegistry; import org.apache.iceberg.orc.ORCFormatModel; import org.apache.iceberg.parquet.ParquetFormatModel; -import org.apache.iceberg.parquet.ParquetValueWriter; import org.apache.iceberg.spark.data.SparkAvroWriter; import org.apache.iceberg.spark.data.SparkOrcReader; import org.apache.iceberg.spark.data.SparkOrcWriter; @@ -34,7 +30,6 @@ import org.apache.iceberg.spark.data.SparkPlannedAvroReader; import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; -import org.apache.parquet.schema.MessageType; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -53,7 +48,7 @@ public static void register() { ParquetFormatModel.create( InternalRow.class, StructType.class, - new SparkParquetWriterFunction(), + SparkParquetWriters::buildWriter, (icebergSchema, fileSchema, engineSchema, idToConstant) -> SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant))); @@ -83,32 +78,4 @@ public static void register() { } private SparkFormatModels() {} - - /** - * Writer function that checks for variant shredding conditions and returns a writer that performs - * variant shredding if needed. - */ - private static class SparkParquetWriterFunction - implements BaseFormatModel.WriterFunction, StructType, MessageType> { - - @Override - public ParquetValueWriter write( - Schema icebergSchema, MessageType fileSchema, StructType engineSchema) { - return SparkParquetWriters.buildWriter(icebergSchema, fileSchema, engineSchema); - } - - @Override - public ParquetValueWriter write( - Schema icebergSchema, - MessageType fileSchema, - StructType engineSchema, - Map writeProperties) { - if (SparkParquetWriterWithVariantShredding.shouldUseVariantShredding( - writeProperties, icebergSchema)) { - return new SparkParquetWriterWithVariantShredding( - engineSchema, fileSchema, writeProperties); - } - return SparkParquetWriters.buildWriter(icebergSchema, fileSchema, engineSchema); - } - } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java new file mode 100644 index 000000000000..19e0237ee28d --- /dev/null +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import org.apache.iceberg.parquet.VariantShreddingAnalyzer; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.unsafe.types.VariantVal; + +/** + * Spark-specific implementation that extracts variant values from {@link InternalRow} instances. + */ +class SparkVariantShreddingAnalyzer extends VariantShreddingAnalyzer { + + SparkVariantShreddingAnalyzer() {} + + @Override + protected List extractVariantValues( + List bufferedRows, int variantFieldIndex) { + List values = Lists.newArrayList(); + + for (InternalRow row : bufferedRows) { + if (!row.isNullAt(variantFieldIndex)) { + VariantVal variantVal = row.getVariant(variantFieldIndex); + if (variantVal != null) { + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + values.add(variantValue); + } + } + } + + return values; + } +} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java deleted file mode 100644 index bea8a6318e2f..000000000000 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.spark.source; - -import java.math.BigDecimal; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.relocated.com.google.common.collect.Sets; -import org.apache.iceberg.variants.PhysicalType; -import org.apache.iceberg.variants.VariantArray; -import org.apache.iceberg.variants.VariantMetadata; -import org.apache.iceberg.variants.VariantObject; -import org.apache.iceberg.variants.VariantPrimitive; -import org.apache.iceberg.variants.VariantValue; -import org.apache.parquet.schema.GroupType; -import org.apache.parquet.schema.LogicalTypeAnnotation; -import org.apache.parquet.schema.PrimitiveType; -import org.apache.parquet.schema.Type; -import org.apache.parquet.schema.Types; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.unsafe.types.VariantVal; - -/** - * Analyzes variant data across buffered rows to determine an optimal shredding schema. - * - *

Determinism contract: for a given set of variant values (regardless of row arrival order), - * this analyzer produces the same shredded schema. - * - *

    - *
  • Object fields use a TreeMap, so field ordering is alphabetical and deterministic. - *
  • Type selection picks the most common type with explicit tie-break priority (see - * TIE_BREAK_PRIORITY), not enum ordinal. - *
  • Integer types (INT8/16/32/64) and decimal types (DECIMAL4/8/16) are each promoted to the - * widest observed before competing with other types. - *
  • Fields below MIN_FIELD_FREQUENCY frequency are pruned. Above MAX_SHREDDED_FIELDS fields, - * the most frequent are kept with alphabetical tie-breaking. - *
- * - *

This contract holds within a single batch. Different batches with different distributions may - * produce different layouts; cross-batch stability requires schema pinning (not yet implemented). - */ -public class VariantShreddingAnalyzer { - private static final String TYPED_VALUE = "typed_value"; - private static final String VALUE = "value"; - private static final String ELEMENT = "element"; - private static final double MIN_FIELD_FREQUENCY = 0.10; - private static final int MAX_SHREDDED_FIELDS = 300; - - public VariantShreddingAnalyzer() {} - - /** - * Analyzes buffered variant values to determine the optimal shredding schema. - * - * @param bufferedRows the buffered rows to analyze - * @param variantFieldIndex the index of the variant field in the rows - * @return the shredded schema type, or null if no shredding should be performed - */ - public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) { - List variantValues = extractVariantValues(bufferedRows, variantFieldIndex); - if (variantValues.isEmpty()) { - return null; - } - - PathNode root = buildPathTree(variantValues); - PhysicalType rootType = root.info.getMostCommonType(); - if (rootType == null) { - return null; - } - - if (rootType == PhysicalType.OBJECT) { - pruneInfrequentFields(root, variantValues.size()); - } - - return buildTypedValue(root, rootType); - } - - private static List extractVariantValues( - List bufferedRows, int variantFieldIndex) { - List values = new java.util.ArrayList<>(); - - for (InternalRow row : bufferedRows) { - if (!row.isNullAt(variantFieldIndex)) { - VariantVal variantVal = row.getVariant(variantFieldIndex); - if (variantVal != null) { - VariantValue variantValue = - VariantValue.from( - VariantMetadata.from( - ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), - ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); - values.add(variantValue); - } - } - } - - return values; - } - - private static PathNode buildPathTree(List variantValues) { - PathNode root = new PathNode(null); - root.info = new FieldInfo(); - - for (VariantValue value : variantValues) { - traverse(root, value); - } - - return root; - } - - private static void pruneInfrequentFields(PathNode node, int totalRows) { - if (node.objectChildren.isEmpty()) { - return; - } - - // Remove fields below frequency threshold - node.objectChildren - .entrySet() - .removeIf( - entry -> { - FieldInfo info = entry.getValue().info; - return info != null - && ((double) info.observationCount / totalRows) < MIN_FIELD_FREQUENCY; - }); - - // Cap at MAX_SHREDDED_FIELDS, keep the most frequently observed - if (node.objectChildren.size() > MAX_SHREDDED_FIELDS) { - List> sorted = - new java.util.ArrayList<>(node.objectChildren.entrySet()); - sorted.sort( - (a, b) -> { - int cmp = - Integer.compare( - b.getValue().info.observationCount, a.getValue().info.observationCount); - return cmp != 0 ? cmp : a.getKey().compareTo(b.getKey()); - }); - Set keep = Sets.newHashSet(); - for (int i = 0; i < MAX_SHREDDED_FIELDS; i++) { - keep.add(sorted.get(i).getKey()); - } - node.objectChildren.entrySet().removeIf(entry -> !keep.contains(entry.getKey())); - } - - // Recurse into remaining children - for (PathNode child : node.objectChildren.values()) { - pruneInfrequentFields(child, totalRows); - } - } - - private static void traverse(PathNode node, VariantValue value) { - if (value == null || value.type() == PhysicalType.NULL) { - return; - } - - node.info.observe(value); - - if (value.type() == PhysicalType.OBJECT) { - VariantObject obj = value.asObject(); - for (String fieldName : obj.fieldNames()) { - VariantValue fieldValue = obj.get(fieldName); - if (fieldValue != null) { - PathNode childNode = node.objectChildren.computeIfAbsent(fieldName, PathNode::new); - if (childNode.info == null) { - childNode.info = new FieldInfo(); - } - traverse(childNode, fieldValue); - } - } - } else if (value.type() == PhysicalType.ARRAY) { - VariantArray array = value.asArray(); - int numElements = array.numElements(); - if (node.arrayElement == null) { - node.arrayElement = new PathNode(null); - node.arrayElement.info = new FieldInfo(); - } - for (int i = 0; i < numElements; i++) { - VariantValue element = array.get(i); - if (element != null) { - traverse(node.arrayElement, element); - } - } - } - } - - private static Type buildFieldGroup(PathNode node) { - PhysicalType commonType = node.info.getMostCommonType(); - if (commonType == null) { - return null; - } - - Type typedValue = buildTypedValue(node, commonType); - if (typedValue == null) { - return null; - } - - return Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named(VALUE) - .addField(typedValue) - .named(node.fieldName); - } - - private static Type buildTypedValue(PathNode node, PhysicalType physicalType) { - Type typedValue; - if (physicalType == PhysicalType.ARRAY) { - typedValue = createArrayTypedValue(node); - } else if (physicalType == PhysicalType.OBJECT) { - typedValue = createObjectTypedValue(node); - } else { - typedValue = createPrimitiveTypedValue(node.info, physicalType); - } - - return typedValue; - } - - private static Type createObjectTypedValue(PathNode node) { - if (node.objectChildren.isEmpty()) { - return null; - } - - Types.GroupBuilder builder = Types.buildGroup(Type.Repetition.OPTIONAL); - for (PathNode child : node.objectChildren.values()) { - Type fieldType = buildFieldGroup(child); - if (fieldType == null) { - continue; - } - - builder.addField(fieldType); - } - - return builder.named(TYPED_VALUE); - } - - private static Type createArrayTypedValue(PathNode node) { - PathNode elementNode = node.arrayElement; - PhysicalType elementType = elementNode.info.getMostCommonType(); - if (elementType == null) { - return null; - } - Type elementTypedValue = buildTypedValue(elementNode, elementType); - - GroupType elementGroup = - Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named(VALUE) - .addField(elementTypedValue) - .named(ELEMENT); - - return Types.optionalList().element(elementGroup).named(TYPED_VALUE); - } - - private static class PathNode { - private final String fieldName; - private final Map objectChildren = Maps.newTreeMap(); - private PathNode arrayElement = null; - private FieldInfo info = null; - - private PathNode(String fieldName) { - this.fieldName = fieldName; - } - } - - /** Use DECIMAL with maximum precision and scale as the shredding type */ - private static Type createDecimalTypedValue(FieldInfo info) { - int maxPrecision = info.maxDecimalIntegerDigits + info.maxDecimalScale; - int maxScale = info.maxDecimalScale; - - if (maxPrecision <= 9) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) - .named(TYPED_VALUE); - } else if (maxPrecision <= 18) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) - .named(TYPED_VALUE); - } else { - return Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) - .length(16) - .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) - .named(TYPED_VALUE); - } - } - - private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primitiveType) { - return switch (primitiveType) { - case BOOLEAN_TRUE, BOOLEAN_FALSE -> - Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); - case INT8 -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(8, true)) - .named(TYPED_VALUE); - case INT16 -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(16, true)) - .named(TYPED_VALUE); - case INT32 -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(32, true)) - .named(TYPED_VALUE); - case INT64 -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.intType(64, true)) - .named(TYPED_VALUE); - case FLOAT -> Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); - case DOUBLE -> Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); - case STRING -> - Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.stringType()) - .named(TYPED_VALUE); - case BINARY -> Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); - case TIME -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timeType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named(TYPED_VALUE); - case DATE -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.dateType()) - .named(TYPED_VALUE); - case TIMESTAMPTZ -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named(TYPED_VALUE); - case TIMESTAMPNTZ -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named(TYPED_VALUE); - case TIMESTAMPTZ_NANOS -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS)) - .named(TYPED_VALUE); - case TIMESTAMPNTZ_NANOS -> - Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.NANOS)) - .named(TYPED_VALUE); - case DECIMAL4, DECIMAL8, DECIMAL16 -> createDecimalTypedValue(info); - default -> - throw new UnsupportedOperationException( - "Unknown primitive physical type: " + primitiveType); - }; - } - - /** Tracks occurrence count and types for a single field. */ - private static class FieldInfo { - private final Set observedTypes = Sets.newHashSet(); - private final Map typeCounts = Maps.newHashMap(); - private int maxDecimalScale = 0; - private int maxDecimalIntegerDigits = 0; - private int observationCount = 0; - - private static final Map INTEGER_PRIORITY = - Map.of( - PhysicalType.INT8, 0, - PhysicalType.INT16, 1, - PhysicalType.INT32, 2, - PhysicalType.INT64, 3); - - private static final Map DECIMAL_PRIORITY = - Map.of( - PhysicalType.DECIMAL4, 0, - PhysicalType.DECIMAL8, 1, - PhysicalType.DECIMAL16, 2); - - private static final Map TIE_BREAK_PRIORITY = - Map.ofEntries( - Map.entry(PhysicalType.BOOLEAN_TRUE, 0), - Map.entry(PhysicalType.INT8, 1), - Map.entry(PhysicalType.INT16, 2), - Map.entry(PhysicalType.INT32, 3), - Map.entry(PhysicalType.INT64, 4), - Map.entry(PhysicalType.FLOAT, 5), - Map.entry(PhysicalType.DOUBLE, 6), - Map.entry(PhysicalType.DECIMAL4, 7), - Map.entry(PhysicalType.DECIMAL8, 8), - Map.entry(PhysicalType.DECIMAL16, 9), - Map.entry(PhysicalType.DATE, 10), - Map.entry(PhysicalType.TIME, 11), - Map.entry(PhysicalType.TIMESTAMPTZ, 12), - Map.entry(PhysicalType.TIMESTAMPNTZ, 13), - Map.entry(PhysicalType.BINARY, 14), - Map.entry(PhysicalType.STRING, 15)); - - void observe(VariantValue value) { - observationCount++; - // Use BOOLEAN_TRUE for both TRUE/FALSE values - PhysicalType type = - value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); - - observedTypes.add(type); - typeCounts.compute(type, (k, v) -> (v == null) ? 1 : v + 1); - - // Track max precision and scale for decimal types - if (type == PhysicalType.DECIMAL4 - || type == PhysicalType.DECIMAL8 - || type == PhysicalType.DECIMAL16) { - VariantPrimitive primitive = value.asPrimitive(); - Object decimalValue = primitive.get(); - if (decimalValue instanceof BigDecimal) { - BigDecimal bd = (BigDecimal) decimalValue; - maxDecimalIntegerDigits = Math.max(maxDecimalIntegerDigits, bd.precision() - bd.scale()); - maxDecimalScale = Math.max(maxDecimalScale, bd.scale()); - } - } - } - - PhysicalType getMostCommonType() { - Map combinedCounts = Maps.newHashMap(); - - int integerTotalCount = 0; - PhysicalType mostCapableInteger = null; - - int decimalTotalCount = 0; - PhysicalType mostCapableDecimal = null; - - for (Map.Entry entry : typeCounts.entrySet()) { - PhysicalType type = entry.getKey(); - int count = entry.getValue(); - - if (isIntegerType(type)) { - integerTotalCount += count; - if (mostCapableInteger == null - || INTEGER_PRIORITY.get(type) > INTEGER_PRIORITY.get(mostCapableInteger)) { - mostCapableInteger = type; - } - } else if (isDecimalType(type)) { - decimalTotalCount += count; - if (mostCapableDecimal == null - || DECIMAL_PRIORITY.get(type) > DECIMAL_PRIORITY.get(mostCapableDecimal)) { - mostCapableDecimal = type; - } - } else { - combinedCounts.put(type, count); - } - } - - if (mostCapableInteger != null) { - combinedCounts.put(mostCapableInteger, integerTotalCount); - } - - if (mostCapableDecimal != null) { - combinedCounts.put(mostCapableDecimal, decimalTotalCount); - } - - // Pick the most common type with tie-breaking - return combinedCounts.entrySet().stream() - .max( - Map.Entry.comparingByValue() - .thenComparingInt(entry -> TIE_BREAK_PRIORITY.getOrDefault(entry.getKey(), -1))) - .map(Map.Entry::getKey) - .orElse(null); - } - - private boolean isIntegerType(PhysicalType type) { - return type == PhysicalType.INT8 - || type == PhysicalType.INT16 - || type == PhysicalType.INT32 - || type == PhysicalType.INT64; - } - - private boolean isDecimalType(PhysicalType type) { - return type == PhysicalType.DECIMAL4 - || type == PhysicalType.DECIMAL8 - || type == PhysicalType.DECIMAL16; - } - } -} diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index fbd04fae1c98..87f03b9fb051 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -346,7 +346,7 @@ public void testSparkConfOverride() { "snappy"), ImmutableMap.of( SHRED_VARIANTS, - "true", + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -470,7 +470,7 @@ public void testDataPropsDefaultsAsDeleteProps() { "5"), ImmutableMap.of( SHRED_VARIANTS, - "true", + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -543,7 +543,7 @@ public void testDeleteFileWriteConf() { "6"), ImmutableMap.of( SHRED_VARIANTS, - "true", + "false", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index e63630cfe3ad..65e3894ccc71 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -131,33 +131,6 @@ public void testVariantShreddingDisabled() throws IOException { verifyParquetSchema(table, expectedSchema); } - @TestTemplate - public void testConsistentType() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType name = - field( - "name", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType age = - field( - "age", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - @TestTemplate public void testExcludingNullValue() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); @@ -430,30 +403,6 @@ public void testLazyInitializationWithBufferedRows() throws IOException { assertThat(rowCount).isEqualTo(7); } - @TestTemplate - public void testTieBreakingWithEqualCounts() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - String values = - "(1, parse_json('{\"value\": 10}'))," - + " (2, parse_json('{\"value\": 20}'))," - + " (3, parse_json('{\"value\": \"hello\"}'))," - + " (4, parse_json('{\"value\": \"world\"}'))"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - // When counts are tied, sort the types in order and choose the last one - GroupType value = - field( - "value", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - @TestTemplate public void testMultipleRowGroups() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); @@ -755,31 +704,6 @@ public void testWriteOptionOverridesSessionConfig() throws IOException, NoSuchTa verifyParquetSchema(table, expectedSchema); } - @TestTemplate - public void testDefaultShreddingEnabled() throws IOException { - // Not setting SHRED_VARIANTS - default (true) should activate shredding - String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType name = - field( - "name", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType age = - field( - "age", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - @TestTemplate public void testInfrequentFieldPruning() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); @@ -846,6 +770,144 @@ public void testMixedTypeTieBreaking() throws IOException { verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testFieldOnlyAfterBuffer() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + String values = + "(1, parse_json('{\"name\": \"Alice\"}'))," + + " (2, parse_json('{\"name\": \"Bob\"}'))," + + " (3, parse_json('{\"name\": \"Charlie\"}'))," + + " (4, parse_json('{\"name\": \"David\", \"score\": 95}'))," + + " (5, parse_json('{\"name\": \"Eve\", \"score\": 88}'))," + + " (6, parse_json('{\"name\": \"Frank\", \"score\": 72}'))," + + " (7, parse_json('{\"name\": \"Grace\", \"score\": 91}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Schema is determined from buffer (rows 1-3) which only has "name". + // "score" is not shredded + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify all data round-trips despite "score" not being shredded + List rows = + sql( + "SELECT id, variant_get(address, '$.name', 'string')," + + " variant_get(address, '$.score', 'int')" + + " FROM %s ORDER BY id", + tableName); + assertThat(rows).hasSize(7); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(0)[2]).isNull(); + assertThat(rows.get(3)[1]).isEqualTo("David"); + assertThat(rows.get(3)[2]).isEqualTo(95); + assertThat(rows.get(6)[1]).isEqualTo("Grace"); + assertThat(rows.get(6)[2]).isEqualTo(91); + } + + @TestTemplate + public void testCrossFileDifferentShreddedType() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + // File 1: "score" is always integer → shredded as INT8 + String batch1 = + "(1, parse_json('{\"score\": 95}'))," + + " (2, parse_json('{\"score\": 88}'))," + + " (3, parse_json('{\"score\": 72}'))"; + sql("INSERT INTO %s VALUES %s", tableName, batch1); + + // Verify file 1 schema: score shredded as INT8 + Table table = validationCatalog.loadTable(tableIdent); + GroupType scoreInt = + field( + "score", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + MessageType expectedSchema1 = + parquetSchema(variant("address", 2, Type.Repetition.REQUIRED, objectFields(scoreInt))); + verifyParquetSchema(table, expectedSchema1); + + // File 2: "score" is always string → shredded as STRING + String batch2 = + "(4, parse_json('{\"score\": \"high\"}'))," + + " (5, parse_json('{\"score\": \"medium\"}'))," + + " (6, parse_json('{\"score\": \"low\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, batch2); + + // Query across both files, reader must handle different shredded types + List rows = + sql("SELECT id, variant_get(address, '$.score', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(6); + assertThat(rows.get(0)[1]).isEqualTo("95"); + assertThat(rows.get(1)[1]).isEqualTo("88"); + assertThat(rows.get(3)[1]).isEqualTo("high"); + assertThat(rows.get(5)[1]).isEqualTo("low"); + } + + @TestTemplate + public void testAllNullVariantColumn() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + sql("INSERT INTO %s VALUES (1, null), (2, null), (3, null)", tableName); + + // All variant values are SQL NULL, so no shredding should occur + Table table = validationCatalog.loadTable(tableIdent); + MessageType expectedSchema = parquetSchema(variant("address", 2, Type.Repetition.OPTIONAL)); + verifyParquetSchema(table, expectedSchema); + + List rows = sql("SELECT id, address FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[1]).isNull(); + assertThat(rows.get(1)[1]).isNull(); + assertThat(rows.get(2)[1]).isNull(); + } + + @TestTemplate + public void testBufferSizeOne() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "1"); + + sql( + "INSERT INTO %s VALUES " + + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))", + tableName); + + // Schema inferred from first row only, should still shred name and age + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + List rows = + sql("SELECT id, variant_get(address, '$.name', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(3); + assertThat(rows.get(0)[1]).isEqualTo("Alice"); + assertThat(rows.get(2)[1]).isEqualTo("Charlie"); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); From b598b5c77851b90b01b3616753063c8759c7019c Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Mon, 23 Mar 2026 09:59:58 -0700 Subject: [PATCH 11/17] Fix checkstyle violations in SchemaInferenceVisitor and SparkFileWriterFactory --- .../spark/source/SchemaInferenceVisitor.java | 2 +- .../spark/source/SparkFileWriterFactory.java | 13 ++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index 8b7d83972d68..3b02f282ed93 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -47,7 +47,7 @@ class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { private final StructType sparkSchema; private final SparkVariantShreddingAnalyzer analyzer; - public SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { + SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { this.bufferedRows = bufferedRows; this.sparkSchema = sparkSchema; this.analyzer = new SparkVariantShreddingAnalyzer(); diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java index 055173a435ab..769f53b21624 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -293,17 +293,8 @@ public DataWriter newDataWriter( .overwrite() .build(); - try { - for (InternalRow row : bufferedRows) { - appender.add(row); - } - } catch (RuntimeException e) { - try { - appender.close(); - } catch (IOException suppressed) { - e.addSuppressed(suppressed); - } - throw e; + for (InternalRow row : bufferedRows) { + appender.add(row); } return appender; From 9df8b4329934ca0c40ca9143a4388f3c2073016f Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Sun, 29 Mar 2026 20:06:45 -0700 Subject: [PATCH 12/17] Wire variant shredding write path through FormatModel API as per PR feedback --- .../org/apache/iceberg/TableProperties.java | 6 + .../iceberg/io/BufferedFileAppender.java | 54 +++-- .../iceberg/io/TestBufferedFileAppender.java | 23 +- docs/docs/configuration.md | 2 + docs/docs/spark-configuration.md | 2 + .../org/apache/iceberg/parquet/Parquet.java | 29 +-- .../iceberg/parquet/ParquetFormatModel.java | 98 ++++++++- .../parquet/VariantShreddingAnalyzer.java | 82 +++++-- .../parquet/TestParquetDataWriter.java | 206 +++++++++++++----- .../parquet/TestVariantShreddingAnalyzer.java | 46 +++- .../iceberg/spark/SparkSQLProperties.java | 2 - .../apache/iceberg/spark/SparkWriteConf.java | 24 +- .../spark/source/SchemaInferenceVisitor.java | 182 ---------------- .../spark/source/SparkFileWriterFactory.java | 102 --------- .../spark/source/SparkFormatModels.java | 4 +- .../source/SparkVariantShreddingAnalyzer.java | 12 +- .../iceberg/spark/TestSparkWriteConf.java | 85 +++++++- .../spark/variant/TestVariantShredding.java | 201 ++++++++++------- 18 files changed, 640 insertions(+), 520 deletions(-) delete mode 100644 spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java diff --git a/core/src/main/java/org/apache/iceberg/TableProperties.java b/core/src/main/java/org/apache/iceberg/TableProperties.java index 1f778984af17..08ff65b51825 100644 --- a/core/src/main/java/org/apache/iceberg/TableProperties.java +++ b/core/src/main/java/org/apache/iceberg/TableProperties.java @@ -154,6 +154,12 @@ private TableProperties() {} "write.delete.parquet.compression-level"; public static final String PARQUET_COMPRESSION_LEVEL_DEFAULT = null; + public static final String PARQUET_VARIANT_SHRED = "write.parquet.variant.shred"; + public static final boolean PARQUET_VARIANT_SHRED_DEFAULT = false; + public static final String PARQUET_VARIANT_BUFFER_SIZE = + "write.parquet.variant.inference.buffer-size"; + public static final int PARQUET_VARIANT_BUFFER_SIZE_DEFAULT = 100; + public static final String PARQUET_ROW_GROUP_CHECK_MIN_RECORD_COUNT = "write.parquet.row-group-check-min-record-count"; public static final String DELETE_PARQUET_ROW_GROUP_CHECK_MIN_RECORD_COUNT = diff --git a/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java index a798da8007ac..15ecf7328ae4 100644 --- a/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java +++ b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java @@ -29,11 +29,14 @@ /** * A FileAppender that buffers the first N rows, then creates a delegate appender via a factory. * - *

The factory receives the buffered rows, is responsible for creating the real appender and - * writing the buffered rows into it before returning. All subsequent {@link #add} calls delegate - * directly to the real appender. + *

The factory receives the buffered rows and is responsible for creating the real appender. Row + * replay is handled internally. All subsequent {@link #add} calls delegate directly to the real + * appender. * - *

If fewer than N rows are written before {@link #close}, the factory is called at close time. + *

If fewer than {@code bufferSize} rows are written before close, the factory is called with + * whatever rows were buffered. If no rows were written, the factory is not called and no file is + * created on disk. In this case, {@link #metrics()} returns {@code new Metrics(0L)} and {@link + * #length()} returns {@code 0L}. * * @param the row type */ @@ -47,7 +50,16 @@ public class BufferedFileAppender implements FileAppender { /** * @param bufferRowCount number of rows to buffer before creating the delegate appender - * @param appenderFactory given the buffered rows, creates the delegate appender and replays them + * @param appenderFactory given the buffered rows, creates the delegate appender + */ + public BufferedFileAppender( + int bufferRowCount, Function, FileAppender> appenderFactory) { + this(bufferRowCount, appenderFactory, UnaryOperator.identity()); + } + + /** + * @param bufferRowCount number of rows to buffer before creating the delegate appender + * @param appenderFactory given the buffered rows, creates the delegate appender * @param copyFunc copies a row before buffering (needed when row objects are reused, e.g. Spark * InternalRow) */ @@ -62,7 +74,7 @@ public BufferedFileAppender( this.bufferRowCount = bufferRowCount; this.appenderFactory = appenderFactory; this.copyFunc = copyFunc; - this.buffer = Lists.newArrayList(); + this.buffer = Lists.newArrayListWithCapacity(bufferRowCount); } @Override @@ -81,7 +93,9 @@ public void add(D datum) { @Override public Metrics metrics() { Preconditions.checkState(closed, "Cannot return metrics for unclosed appender"); - Preconditions.checkState(delegate != null, "Delegate appender was never created"); + if (delegate == null) { + return new Metrics(0L); + } return delegate.metrics(); } @@ -90,6 +104,8 @@ public long length() { if (delegate != null) { return delegate.length(); } + + // No bytes written to disk yet; data is buffered in memory return 0L; } @@ -98,6 +114,7 @@ public List splitOffsets() { if (delegate != null) { return delegate.splitOffsets(); } + return null; } @@ -105,28 +122,21 @@ public List splitOffsets() { public void close() throws IOException { if (!closed) { this.closed = true; - try { - if (delegate == null) { - initialize(); - } - } catch (RuntimeException e) { - // If initialize fails, attempt to close the delegate if it was partially created - closeDelegate(); - throw e; + if (delegate == null && buffer != null && !buffer.isEmpty()) { + initialize(); + } + if (delegate != null) { + delegate.close(); } - closeDelegate(); - } - } - - private void closeDelegate() throws IOException { - if (delegate != null) { - delegate.close(); } } private void initialize() { delegate = appenderFactory.apply(buffer); Preconditions.checkState(delegate != null, "appenderFactory must not return null"); + for (D row : buffer) { + delegate.add(row); + } buffer = null; } } diff --git a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java index 7c0f8c401d86..8157800d07f8 100644 --- a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java +++ b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java @@ -57,12 +57,11 @@ public void before() { private Function, FileAppender> avroFactory(OutputFile out) { return bufferedRows -> { try { - FileAppender appender = - Avro.write(out).createWriterFunc(DataWriter::create).schema(SCHEMA).overwrite().build(); - for (Record row : bufferedRows) { - appender.add(row); - } - return appender; + return Avro.write(out) + .createWriterFunc(DataWriter::create) + .schema(SCHEMA) + .overwrite() + .build(); } catch (IOException e) { throw new org.apache.iceberg.exceptions.RuntimeIOException(e); } @@ -214,4 +213,16 @@ public void testAddAllSpanningBuffer() throws IOException { assertThat(actual.get(0).getField("id")).isEqualTo(1L); assertThat(actual.get(3).getField("id")).isEqualTo(4L); } + + @Test + public void testCloseWithNoData() throws IOException { + BufferedFileAppender appender = createAppender(10); + // close immediately with no data written + appender.close(); + // delegate was never created + assertThat(appender.length()).isEqualTo(0L); + assertThat(appender.metrics()).isNotNull(); + assertThat(appender.metrics().recordCount()).isEqualTo(0L); + assertThat(appender.splitOffsets()).isNull(); + } } diff --git a/docs/docs/configuration.md b/docs/docs/configuration.md index f12bcea6afd5..94423af55496 100644 --- a/docs/docs/configuration.md +++ b/docs/docs/configuration.md @@ -49,6 +49,8 @@ Iceberg tables support table properties to configure table behavior, like the de | write.parquet.dict-size-bytes | 2097152 (2 MB) | Parquet dictionary page size | | write.parquet.compression-codec | zstd | Parquet compression codec: zstd, brotli, lz4, gzip, snappy, uncompressed | | write.parquet.compression-level | null | Parquet compression level | +| write.parquet.variant.shred | false | When true, variant columns are written with shredded Parquet encoding for improved query performance | +| write.parquet.variant.inference.buffer-size | 100 | Number of rows to buffer for schema inference when variant shredding is enabled | | write.parquet.bloom-filter-enabled.column.col1 | (not set) | Hint to parquet to write a bloom filter for the column: 'col1' | | write.parquet.bloom-filter-max-bytes | 1048576 (1 MB) | The maximum number of bytes for a bloom filter bitset | | write.parquet.bloom-filter-fpp.column.col1 | 0.01 | The false positive probability for a bloom filter applied to 'col1' (must > 0.0 and < 1.0) | diff --git a/docs/docs/spark-configuration.md b/docs/docs/spark-configuration.md index 01bb773680fd..02c5303c943d 100644 --- a/docs/docs/spark-configuration.md +++ b/docs/docs/spark-configuration.md @@ -181,6 +181,8 @@ val spark = SparkSession.builder() | spark.sql.iceberg.distribution-mode | See [Spark Writes](spark-writes.md#writing-distribution-modes) | Controls distribution strategy during writes | | spark.wap.id | null | [Write-Audit-Publish](branching.md#audit-branch) snapshot staging ID | | spark.wap.branch | null | WAP branch name for snapshot commit | +| spark.sql.iceberg.shred-variants | Table default | When true, variant columns are written with shredded Parquet encoding for improved query performance | +| spark.sql.iceberg.variant.inference.buffer-size | Table default | Number of rows to buffer for schema inference when variant shredding is enabled | | spark.sql.iceberg.compression-codec | Table default | Write compression codec (e.g., `zstd`, `snappy`) | | spark.sql.iceberg.compression-level | Table default | Compression level for Parquet/Avro | | spark.sql.iceberg.compression-strategy | Table default | Compression strategy for ORC | diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java b/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java index 5d725213fa82..2387d52edf2f 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java @@ -163,7 +163,6 @@ public static class WriteBuilder implements InternalData.WriteBuilder { private final Map config = Maps.newLinkedHashMap(); private Schema schema = null; private VariantShreddingFunction variantShreddingFunc = null; - private MessageType fileSchema = null; private String name = "table"; private WriteSupport writeSupport = null; private BiFunction> createWriterFunc = null; @@ -209,21 +208,6 @@ public WriteBuilder variantShreddingFunc(VariantShreddingFunction func) { return this; } - /** - * Set a pre-computed Parquet {@link MessageType} to use as the file schema, bypassing the - * default conversion from the Iceberg schema. - * - *

The provided schema must have Parquet field IDs that match the Iceberg schema's field IDs. - * This method is mutually exclusive with {@link #variantShreddingFunc}. - * - * @param newFileSchema the Parquet message type to write - * @return this for method chaining - */ - public WriteBuilder withFileSchema(MessageType newFileSchema) { - this.fileSchema = newFileSchema; - return this; - } - @Override public WriteBuilder named(String newName) { this.name = newName; @@ -411,13 +395,7 @@ public FileAppender build() throws IOException { } set("parquet.avro.write-old-list-structure", "false"); - Preconditions.checkArgument( - fileSchema == null || variantShreddingFunc == null, - "Cannot set both withFileSchema and variantShreddingFunc"); - MessageType type = - fileSchema != null - ? fileSchema - : ParquetSchemaUtil.convert(schema, name, variantShreddingFunc); + MessageType type = ParquetSchemaUtil.convert(schema, name, variantShreddingFunc); FileEncryptionProperties fileEncryptionProperties = null; if (fileEncryptionKey != null) { @@ -873,11 +851,6 @@ public DataWriteBuilder variantShreddingFunc(VariantShreddingFunction func) { return this; } - public DataWriteBuilder withFileSchema(MessageType newFileSchema) { - appenderBuilder.withFileSchema(newFileSchema); - return this; - } - public DataWriteBuilder withSpec(PartitionSpec newSpec) { this.spec = newSpec; return this; diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java index fbd7a6e97fe2..f6531b5cfe60 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java @@ -19,13 +19,16 @@ package org.apache.iceberg.parquet; import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.util.Map; import java.util.function.Function; +import java.util.function.UnaryOperator; import org.apache.iceberg.FileContent; import org.apache.iceberg.FileFormat; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.Schema; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.deletes.PositionDelete; import org.apache.iceberg.encryption.EncryptedOutputFile; @@ -33,6 +36,7 @@ import org.apache.iceberg.formats.BaseFormatModel; import org.apache.iceberg.formats.ModelWriteBuilder; import org.apache.iceberg.formats.ReadBuilder; +import org.apache.iceberg.io.BufferedFileAppender; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.DeleteSchemaUtil; import org.apache.iceberg.io.FileAppender; @@ -42,14 +46,21 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; public class ParquetFormatModel extends BaseFormatModel, R, MessageType> { public static final String WRITER_VERSION_KEY = "parquet.writer.version"; + public static final String SHRED_VARIANTS_KEY = TableProperties.PARQUET_VARIANT_SHRED; + public static final String VARIANT_BUFFER_SIZE_KEY = TableProperties.PARQUET_VARIANT_BUFFER_SIZE; + public static final int DEFAULT_BUFFER_SIZE = TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT; private final boolean isBatchReader; + private final VariantShreddingAnalyzer variantAnalyzer; + private final UnaryOperator copyFunc; public static ParquetFormatModel, Void, Object> forPositionDeletes() { - return new ParquetFormatModel<>(PositionDelete.deleteClass(), Void.class, null, null, false); + return new ParquetFormatModel<>( + PositionDelete.deleteClass(), Void.class, null, null, false, null, null); } public static ParquetFormatModel> create( @@ -57,14 +68,26 @@ public static ParquetFormatModel> create( Class schemaType, WriterFunction, S, MessageType> writerFunction, ReaderFunction, S, MessageType> readerFunction) { - return new ParquetFormatModel<>(type, schemaType, writerFunction, readerFunction, false); + return new ParquetFormatModel<>( + type, schemaType, writerFunction, readerFunction, false, null, null); + } + + public static ParquetFormatModel> create( + Class type, + Class schemaType, + WriterFunction, S, MessageType> writerFunction, + ReaderFunction, S, MessageType> readerFunction, + VariantShreddingAnalyzer variantAnalyzer, + UnaryOperator copyFunc) { + return new ParquetFormatModel<>( + type, schemaType, writerFunction, readerFunction, false, variantAnalyzer, copyFunc); } public static ParquetFormatModel> create( Class type, Class schemaType, ReaderFunction, S, MessageType> batchReaderFunction) { - return new ParquetFormatModel<>(type, schemaType, null, batchReaderFunction, true); + return new ParquetFormatModel<>(type, schemaType, null, batchReaderFunction, true, null, null); } private ParquetFormatModel( @@ -72,9 +95,13 @@ private ParquetFormatModel( Class schemaType, WriterFunction, S, MessageType> writerFunction, ReaderFunction readerFunction, - boolean isBatchReader) { + boolean isBatchReader, + VariantShreddingAnalyzer variantAnalyzer, + UnaryOperator copyFunc) { super(type, schemaType, writerFunction, readerFunction); this.isBatchReader = isBatchReader; + this.variantAnalyzer = variantAnalyzer; + this.copyFunc = copyFunc; } @Override @@ -84,7 +111,7 @@ public FileFormat format() { @Override public ModelWriteBuilder writeBuilder(EncryptedOutputFile outputFile) { - return new WriteBuilderWrapper<>(outputFile, writerFunction()); + return new WriteBuilderWrapper<>(outputFile, writerFunction(), variantAnalyzer, copyFunc); } @Override @@ -95,15 +122,23 @@ public ReadBuilder readBuilder(InputFile inputFile) { private static class WriteBuilderWrapper implements ModelWriteBuilder { private final Parquet.WriteBuilder internal; private final WriterFunction, S, MessageType> writerFunction; + private final VariantShreddingAnalyzer variantAnalyzer; + private final UnaryOperator copyFunc; private Schema schema; private S engineSchema; private FileContent content; + private boolean shreddingEnabled = false; + private int bufferSize = DEFAULT_BUFFER_SIZE; private WriteBuilderWrapper( EncryptedOutputFile outputFile, - WriterFunction, S, MessageType> writerFunction) { + WriterFunction, S, MessageType> writerFunction, + VariantShreddingAnalyzer variantAnalyzer, + UnaryOperator copyFunc) { this.internal = Parquet.write(outputFile); this.writerFunction = writerFunction; + this.variantAnalyzer = variantAnalyzer; + this.copyFunc = copyFunc; } @Override @@ -125,13 +160,15 @@ public ModelWriteBuilder set(String property, String value) { internal.writerVersion(ParquetProperties.WriterVersion.valueOf(value)); } - internal.set(property, value); - return this; - } + if (SHRED_VARIANTS_KEY.equals(property)) { + shreddingEnabled = Boolean.parseBoolean(value); + } - @Override - public ModelWriteBuilder setAll(Map properties) { - internal.setAll(properties); + if (VARIANT_BUFFER_SIZE_KEY.equals(property)) { + bufferSize = Integer.parseInt(value); + } + + internal.set(property, value); return this; } @@ -179,12 +216,16 @@ public ModelWriteBuilder withAADPrefix(ByteBuffer aadPrefix) { @Override public FileAppender build() throws IOException { + Preconditions.checkState(content != null, "File content type must be set before building"); switch (content) { case DATA: internal.createContextFunc(Parquet.WriteBuilder.Context::dataContext); internal.createWriterFunc( (icebergSchema, messageType) -> writerFunction.write(icebergSchema, messageType, engineSchema)); + if (shreddingEnabled && variantAnalyzer != null && hasVariantColumns(schema)) { + return buildShreddedAppender(); + } break; case EQUALITY_DELETES: internal.createContextFunc(Parquet.WriteBuilder.Context::deleteContext); @@ -217,6 +258,39 @@ public FileAppender build() throws IOException { return internal.build(); } + + /** + * Creates a {@link BufferedFileAppender} that buffers the first N rows, runs variant shredding + * analysis on them, then creates the real Parquet appender with a shredded schema. + * + *

Only top-level variant columns are shredded. Nested variants (inside structs/lists/maps) + * fall through to unshredded 2-field layout because column index resolution only applies to + * top-level fields. + */ + private FileAppender buildShreddedAppender() { + return new BufferedFileAppender<>( + bufferSize, + bufferedRows -> { + Map shreddedTypes = + variantAnalyzer.analyzeVariantColumns(bufferedRows, schema, engineSchema); + + if (!shreddedTypes.isEmpty()) { + internal.variantShreddingFunc((fieldId, name) -> shreddedTypes.get(fieldId)); + } + + try { + return internal.build(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to create shredded variant writer", e); + } + }, + copyFunc); + } + + private static boolean hasVariantColumns(Schema schema) { + return schema != null + && schema.columns().stream().anyMatch(field -> field.type().isVariantType()); + } } private static class ReadBuilderWrapper implements ReadBuilder { diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java index 2659c5f2aaee..2dcbf66ce283 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java @@ -22,6 +22,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; @@ -62,8 +64,9 @@ * {@link VariantValue} instances. * * @param the engine-specific row type (e.g., Spark InternalRow, Flink RowData) + * @param the engine-specific schema type (e.g., Spark StructType, Flink RowType) */ -public abstract class VariantShreddingAnalyzer { +public abstract class VariantShreddingAnalyzer { private static final String TYPED_VALUE = "typed_value"; private static final String VALUE = "value"; private static final String ELEMENT = "element"; @@ -103,6 +106,38 @@ public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) protected abstract List extractVariantValues( List bufferedRows, int variantFieldIndex); + /** + * Resolves a column name to its index in the engine-specific schema. Returns -1 if the column is + * not found. + */ + protected abstract int resolveColumnIndex(S engineSchema, String columnName); + + /** + * Analyzes all variant columns in the schema, resolving column indices via the engine-specific + * {@link #resolveColumnIndex} method. + * + * @param bufferedRows the buffered rows to analyze + * @param icebergSchema the Iceberg table schema + * @param engineSchema the engine-specific schema used to resolve column indices + * @return a map from Iceberg field ID to the shredded Parquet type for each variant column + */ + public Map analyzeVariantColumns( + List bufferedRows, Schema icebergSchema, S engineSchema) { + Map shreddedTypes = Maps.newHashMap(); + for (org.apache.iceberg.types.Types.NestedField col : icebergSchema.columns()) { + if (col.type().isVariantType()) { + int rowIndex = resolveColumnIndex(engineSchema, col.name()); + if (rowIndex >= 0) { + Type typed = analyzeAndCreateSchema(bufferedRows, rowIndex); + if (typed != null) { + shreddedTypes.put(col.fieldId(), typed); + } + } + } + } + return shreddedTypes; + } + private static PathNode buildPathTree(List variantValues) { PathNode root = new PathNode(null); root.info = new FieldInfo(); @@ -369,39 +404,40 @@ private static class FieldInfo { private int observationCount = 0; private static final Map INTEGER_PRIORITY = - Map.of( + ImmutableMap.of( PhysicalType.INT8, 0, PhysicalType.INT16, 1, PhysicalType.INT32, 2, PhysicalType.INT64, 3); private static final Map DECIMAL_PRIORITY = - Map.of( + ImmutableMap.of( PhysicalType.DECIMAL4, 0, PhysicalType.DECIMAL8, 1, PhysicalType.DECIMAL16, 2); private static final Map TIE_BREAK_PRIORITY = - Map.ofEntries( - Map.entry(PhysicalType.BOOLEAN_TRUE, 0), - Map.entry(PhysicalType.INT8, 1), - Map.entry(PhysicalType.INT16, 2), - Map.entry(PhysicalType.INT32, 3), - Map.entry(PhysicalType.INT64, 4), - Map.entry(PhysicalType.FLOAT, 5), - Map.entry(PhysicalType.DOUBLE, 6), - Map.entry(PhysicalType.DECIMAL4, 7), - Map.entry(PhysicalType.DECIMAL8, 8), - Map.entry(PhysicalType.DECIMAL16, 9), - Map.entry(PhysicalType.DATE, 10), - Map.entry(PhysicalType.TIME, 11), - Map.entry(PhysicalType.TIMESTAMPTZ, 12), - Map.entry(PhysicalType.TIMESTAMPNTZ, 13), - Map.entry(PhysicalType.BINARY, 14), - Map.entry(PhysicalType.STRING, 15), - Map.entry(PhysicalType.TIMESTAMPTZ_NANOS, 16), - Map.entry(PhysicalType.TIMESTAMPNTZ_NANOS, 17), - Map.entry(PhysicalType.UUID, 18)); + ImmutableMap.builder() + .put(PhysicalType.BOOLEAN_TRUE, 0) + .put(PhysicalType.INT8, 1) + .put(PhysicalType.INT16, 2) + .put(PhysicalType.INT32, 3) + .put(PhysicalType.INT64, 4) + .put(PhysicalType.FLOAT, 5) + .put(PhysicalType.DOUBLE, 6) + .put(PhysicalType.DECIMAL4, 7) + .put(PhysicalType.DECIMAL8, 8) + .put(PhysicalType.DECIMAL16, 9) + .put(PhysicalType.DATE, 10) + .put(PhysicalType.TIME, 11) + .put(PhysicalType.TIMESTAMPTZ, 12) + .put(PhysicalType.TIMESTAMPNTZ, 13) + .put(PhysicalType.BINARY, 14) + .put(PhysicalType.STRING, 15) + .put(PhysicalType.TIMESTAMPTZ_NANOS, 16) + .put(PhysicalType.TIMESTAMPNTZ_NANOS, 17) + .put(PhysicalType.UUID, 18) + .buildOrThrow(); void observe(VariantValue value) { observationCount++; diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java index 74e8e62fea0f..4893bece5a7d 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java @@ -42,8 +42,11 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.data.parquet.GenericParquetReaders; import org.apache.iceberg.data.parquet.GenericParquetWriter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.io.BufferedFileAppender; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.DataWriter; +import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; @@ -52,13 +55,11 @@ import org.apache.iceberg.variants.Variant; import org.apache.iceberg.variants.VariantMetadata; import org.apache.iceberg.variants.VariantTestUtil; +import org.apache.iceberg.variants.VariantValue; import org.apache.iceberg.variants.Variants; import org.apache.parquet.hadoop.ParquetFileReader; import org.apache.parquet.schema.GroupType; -import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; -import org.apache.parquet.schema.PrimitiveType; -import org.apache.parquet.schema.Type; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -336,74 +337,173 @@ public void testDataWriterWithVariantShredding() throws IOException { } @Test - public void testWithFileSchemaOverride() throws IOException { + public void testShreddingWriteReturnsBufferedAppender() throws IOException { Schema variantSchema = new Schema( - ImmutableList.builder() - .addAll(SCHEMA.columns()) - .add(Types.NestedField.optional(4, "variant", Types.VariantType.get())) - .build()); - MessageType customSchema = - org.apache.parquet.schema.Types.buildMessage() - .required(PrimitiveType.PrimitiveTypeName.INT64) - .id(1) - .named("id") - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .id(2) - .named("data") - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .id(3) - .named("binary") - .addField( - org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) - .id(4) - .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) - .required(PrimitiveType.PrimitiveTypeName.BINARY) - .named("metadata") - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .optional(PrimitiveType.PrimitiveTypeName.INT32) - .named("typed_value") - .named("variant")) - .named("table"); - ByteBuffer metadataBuffer = VariantTestUtil.createMetadata(ImmutableList.of("a"), true); + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "v", Types.VariantType.get())); + + VariantShreddingAnalyzer testAnalyzer = + new VariantShreddingAnalyzer() { + @Override + protected List extractVariantValues(List rows, int idx) { + return java.util.Collections.emptyList(); + } + + @Override + protected int resolveColumnIndex(Void engineSchema, String columnName) { + return -1; + } + }; + + OutputFile outputFile = Files.localOutput(createTempFile(temp)); + + ParquetFormatModel> model = + ParquetFormatModel.create( + Record.class, + Void.class, + (icebergSchema, messageType, engineSchema) -> + GenericParquetWriter.create(icebergSchema, messageType), + (icebergSchema, fileSchema, engineSchema, idToConstant) -> + GenericParquetReaders.buildReader(icebergSchema, fileSchema), + testAnalyzer, + record -> record); + + try (FileAppender appender = + model + .writeBuilder(EncryptedFiles.plainAsEncryptedOutput(outputFile)) + .schema(variantSchema) + .setAll(ImmutableMap.of(ParquetFormatModel.SHRED_VARIANTS_KEY, "true")) + .content(FileContent.DATA) + .build()) { + assertThat(appender).isInstanceOf(BufferedFileAppender.class); + } + } + + @Test + public void testWriteBuilderReturnsDirectAppenderWithNullAnalyzer() throws IOException { + Schema variantSchema = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "v", Types.VariantType.get())); + + OutputFile outputFile = Files.localOutput(createTempFile(temp)); + + ParquetFormatModel> model = + ParquetFormatModel.create( + Record.class, + Void.class, + (icebergSchema, messageType, engineSchema) -> + GenericParquetWriter.create(icebergSchema, messageType), + (icebergSchema, fileSchema, engineSchema, idToConstant) -> + GenericParquetReaders.buildReader(icebergSchema, fileSchema), + null, + null); + + try (FileAppender appender = + model + .writeBuilder(EncryptedFiles.plainAsEncryptedOutput(outputFile)) + .schema(variantSchema) + .setAll(ImmutableMap.of(ParquetFormatModel.SHRED_VARIANTS_KEY, "true")) + .content(FileContent.DATA) + .build()) { + // Even with shredding property set, null variantAnalyzer means no BufferedFileAppender + assertThat(appender).isNotInstanceOf(BufferedFileAppender.class); + } + } + + @Test + public void testFormatModelVariantShreddingRoundTrip() throws IOException { + Schema variantSchema = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "v", Types.VariantType.get())); + + VariantShreddingAnalyzer analyzer = + new VariantShreddingAnalyzer() { + @Override + protected List extractVariantValues(List rows, int idx) { + List values = Lists.newArrayList(); + for (Record row : rows) { + Object obj = row.get(idx); + if (obj instanceof Variant) { + values.add(((Variant) obj).value()); + } + } + return values; + } + + @Override + protected int resolveColumnIndex(Void engineSchema, String columnName) { + // GenericRecord uses schema column order + return variantSchema.columns().indexOf(variantSchema.findField(columnName)); + } + }; + + ByteBuffer metadataBuffer = VariantTestUtil.createMetadata(ImmutableList.of("a", "b"), true); VariantMetadata metadata = Variants.metadata(metadataBuffer); ByteBuffer objectBuffer = - VariantTestUtil.createObject(metadataBuffer, ImmutableMap.of("a", Variants.of(42))); + VariantTestUtil.createObject( + metadataBuffer, + ImmutableMap.of( + "a", Variants.of(42), + "b", Variants.of("hello"))); Variant variant = Variant.of(metadata, Variants.value(metadata, objectBuffer)); GenericRecord record = GenericRecord.create(variantSchema); List variantRecords = ImmutableList.of( - record.copy(ImmutableMap.of("id", 1L, "variant", variant)), - record.copy(ImmutableMap.of("id", 2L, "variant", variant))); - - OutputFile file = Files.localOutput(createTempFile(temp)); - DataWriter dataWriter = - Parquet.writeData(file) + record.copy(ImmutableMap.of("id", 1L, "v", variant)), + record.copy(ImmutableMap.of("id", 2L, "v", variant)), + record.copy(ImmutableMap.of("id", 3L, "v", variant))); + + OutputFile outputFile = Files.localOutput(createTempFile(temp)); + + ParquetFormatModel> model = + ParquetFormatModel.create( + Record.class, + Void.class, + (icebergSchema, messageType, engineSchema) -> + GenericParquetWriter.create(icebergSchema, messageType), + (icebergSchema, fileSchema, engineSchema, idToConstant) -> + GenericParquetReaders.buildReader(icebergSchema, fileSchema), + analyzer, + record1 -> record1); + + try (FileAppender appender = + model + .writeBuilder(EncryptedFiles.plainAsEncryptedOutput(outputFile)) .schema(variantSchema) - .withFileSchema(customSchema) - .createWriterFunc(GenericParquetWriter::create) - .overwrite() - .withSpec(PartitionSpec.unpartitioned()) - .build(); - - try (dataWriter) { + .setAll( + ImmutableMap.of( + ParquetFormatModel.SHRED_VARIANTS_KEY, "true", + ParquetFormatModel.VARIANT_BUFFER_SIZE_KEY, "2")) + .content(FileContent.DATA) + .build()) { + assertThat(appender).isInstanceOf(BufferedFileAppender.class); for (Record rec : variantRecords) { - dataWriter.write(rec); + appender.add(rec); } } - // Verify physical schema matches the override - try (ParquetFileReader reader = ParquetFileReader.open(ParquetIO.file(file.toInputFile()))) { - MessageType actualSchema = reader.getFooter().getFileMetaData().getSchema(); - assertThat(actualSchema).isEqualTo(customSchema); + // Verify shredded Parquet schema + try (ParquetFileReader reader = + ParquetFileReader.open(ParquetIO.file(outputFile.toInputFile()))) { + MessageType parquetSchema = reader.getFooter().getFileMetaData().getSchema(); + GroupType variantGroup = parquetSchema.getType("v").asGroupType(); + assertThat(variantGroup.containsField("metadata")).isTrue(); + assertThat(variantGroup.containsField("value")).isTrue(); + assertThat(variantGroup.containsField("typed_value")).isTrue(); + + GroupType typedValue = variantGroup.getType("typed_value").asGroupType(); + assertThat(typedValue.containsField("a")).isTrue(); + assertThat(typedValue.containsField("b")).isTrue(); } - // Verify data round-trips correctly + // Verify data round-trips List writtenRecords; try (CloseableIterable reader = - Parquet.read(file.toInputFile()) + Parquet.read(outputFile.toInputFile()) .project(variantSchema) .createReaderFunc( fileSchema -> GenericParquetReaders.buildReader(variantSchema, fileSchema)) diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java index 38996c31a638..d87799cf19e6 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Locale; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.variants.ShreddedObject; import org.apache.iceberg.variants.ValueArray; import org.apache.iceberg.variants.VariantMetadata; @@ -35,11 +36,16 @@ public class TestVariantShreddingAnalyzer { - private static class DirectAnalyzer extends VariantShreddingAnalyzer { + private static class DirectAnalyzer extends VariantShreddingAnalyzer { @Override protected List extractVariantValues(List rows, int idx) { return rows; } + + @Override + protected int resolveColumnIndex(Void engineSchema, String columnName) { + throw new UnsupportedOperationException("Not used in direct tests"); + } } @Test @@ -228,7 +234,7 @@ public void testDecimalForExceedingPrecision() { valPrimitive.getLogicalTypeAnnotation(); assertThat(decimal).isNotNull(); assertThat(decimal.getPrecision()).isEqualTo(38); - assertThat(decimal.getScale()).isLessThanOrEqualTo(38); + assertThat(decimal.getScale()).isEqualTo(20); // Scale must not exceed precision assertThat(decimal.getScale()).isLessThanOrEqualTo(decimal.getPrecision()); @@ -261,6 +267,42 @@ public void testDecimalForExactPrecision() { assertThat(decimal.getScale()).isEqualTo(18); } + @Test + public void testInfrequentFieldsArePruned() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + VariantMetadata meta = Variants.metadata("common", "rare"); + + // 100 rows: "common" in all 100, "rare" in only 5 (< 10% threshold) + List rows = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + ShreddedObject obj = Variants.object(meta); + obj.put("common", Variants.of(i)); + if (i < 5) { + obj.put("rare", Variants.of("text")); + } + rows.add(obj); + } + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + + GroupType group = schema.asGroupType(); + assertThat(group.containsField("common")).isTrue(); + assertThat(group.containsField("rare")).isFalse(); + } + + @Test + public void testEmptyArrayReturnsNull() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // All rows are empty arrays, no element type to infer + List rows = List.of(Variants.array(), Variants.array(), Variants.array()); + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNull(); + } + /** Count typed_value group nesting depth along field "a". */ private static int countObjectDepth(Type type) { int depth = 0; diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index a1f299d3dc30..d5e36c86edad 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -117,11 +117,9 @@ private SparkSQLProperties() {} // Controls whether to shred variant columns during write operations public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; - public static final boolean SHRED_VARIANTS_DEFAULT = false; // Controls the buffer size for variant schema inference during writes // This determines how many rows are buffered before inferring shredded schema public static final String VARIANT_INFERENCE_BUFFER_SIZE = "spark.sql.iceberg.variant.inference.buffer-size"; - public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 100; } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 4e47c38c14c4..b96b777c2a59 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -33,6 +33,8 @@ import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_BUFFER_SIZE; +import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_SHRED; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import java.util.Locale; @@ -485,16 +487,12 @@ private Map dataWriteProperties() { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } boolean shouldShredVariants = shredVariants(); - writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shouldShredVariants)); + writeProperties.put(PARQUET_VARIANT_SHRED, String.valueOf(shouldShredVariants)); // Add variant shredding configuration properties if (shouldShredVariants) { - String variantBufferSize = - sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); - if (variantBufferSize != null) { - writeProperties.put( - SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); - } + writeProperties.put( + PARQUET_VARIANT_BUFFER_SIZE, String.valueOf(variantInferenceBufferSize())); } break; @@ -722,7 +720,17 @@ public boolean shredVariants() { .booleanConf() .option(SparkWriteOptions.SHRED_VARIANTS) .sessionConf(SparkSQLProperties.SHRED_VARIANTS) - .defaultValue(SparkSQLProperties.SHRED_VARIANTS_DEFAULT) + .tableProperty(TableProperties.PARQUET_VARIANT_SHRED) + .defaultValue(TableProperties.PARQUET_VARIANT_SHRED_DEFAULT) + .parse(); + } + + public int variantInferenceBufferSize() { + return confParser + .intConf() + .sessionConf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE) + .tableProperty(TableProperties.PARQUET_VARIANT_BUFFER_SIZE) + .defaultValue(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT) .parse(); } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java deleted file mode 100644 index 3b02f282ed93..000000000000 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.spark.source; - -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; - -import java.util.List; -import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; -import org.apache.iceberg.variants.Variant; -import org.apache.parquet.schema.GroupType; -import org.apache.parquet.schema.LogicalTypeAnnotation; -import org.apache.parquet.schema.MessageType; -import org.apache.parquet.schema.PrimitiveType; -import org.apache.parquet.schema.Type; -import org.apache.parquet.schema.Types; -import org.apache.parquet.schema.Types.MessageTypeBuilder; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.ArrayType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.MapType; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.VariantType; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A visitor that infers variant shredding schemas by analyzing buffered rows of data. */ -class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { - private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); - - private final List bufferedRows; - private final StructType sparkSchema; - private final SparkVariantShreddingAnalyzer analyzer; - - SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { - this.bufferedRows = bufferedRows; - this.sparkSchema = sparkSchema; - this.analyzer = new SparkVariantShreddingAnalyzer(); - } - - @Override - public Type message(StructType sStruct, MessageType message, List fields) { - MessageTypeBuilder builder = Types.buildMessage(); - - for (Type field : fields) { - if (field != null) { - builder.addField(field); - } - } - - return builder.named(message.getName()); - } - - @Override - public Type struct(StructType sStruct, GroupType struct, List fields) { - Types.GroupBuilder builder = Types.buildGroup(struct.getRepetition()); - - if (struct.getId() != null) { - builder = builder.id(struct.getId().intValue()); - } - - for (Type field : fields) { - if (field != null) { - builder = builder.addField(field); - } - } - - return builder.named(struct.getName()); - } - - @Override - public Type primitive(DataType sPrimitive, PrimitiveType primitive) { - return primitive; - } - - @Override - public Type list(ArrayType sArray, GroupType array, Type element) { - if (element == null) { - return array; - } - - GroupType repeatedGroup = array.getType(0).asGroupType(); - Types.GroupBuilder repeatedBuilder = - Types.buildGroup(repeatedGroup.getRepetition()).addField(element); - - Types.GroupBuilder builder = - Types.buildGroup(array.getRepetition()).as(LogicalTypeAnnotation.listType()); - if (array.getId() != null) { - builder = builder.id(array.getId().intValue()); - } - builder = builder.addField(repeatedBuilder.named(repeatedGroup.getName())); - - return builder.named(array.getName()); - } - - @Override - public Type map(MapType sMap, GroupType map, Type key, Type value) { - if (key == null && value == null) { - return map; - } - - GroupType repeatedGroup = map.getType(0).asGroupType(); - Types.GroupBuilder repeatedBuilder = Types.buildGroup(repeatedGroup.getRepetition()); - if (key != null) { - repeatedBuilder = repeatedBuilder.addField(key); - } - if (value != null) { - repeatedBuilder = repeatedBuilder.addField(value); - } - - Types.GroupBuilder builder = - Types.buildGroup(map.getRepetition()).as(LogicalTypeAnnotation.mapType()); - if (map.getId() != null) { - builder = builder.id(map.getId().intValue()); - } - builder = builder.addField(repeatedBuilder.named(repeatedGroup.getName())); - - return builder.named(map.getName()); - } - - @Override - public Type variant(VariantType sVariant, GroupType variant) { - int variantFieldIndex = getFieldIndex(currentPath()); - - if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { - Type shreddedType = analyzer.analyzeAndCreateSchema(bufferedRows, variantFieldIndex); - if (shreddedType != null) { - Types.GroupBuilder builder = - Types.buildGroup(variant.getRepetition()) - .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)); - if (variant.getId() != null) { - builder = builder.id(variant.getId().intValue()); - } - return builder - .required(BINARY) - .named("metadata") - .optional(BINARY) - .named("value") - .addField(shreddedType) - .named(variant.getName()); - } - } - - return variant; - } - - private int getFieldIndex(String[] path) { - if (path == null || path.length == 0) { - return -1; - } - - if (path.length == 1) { - // Top-level field - direct lookup - String fieldName = path[0]; - for (int i = 0; i < sparkSchema.fields().length; i++) { - if (sparkSchema.fields()[i].name().equals(fieldName)) { - return i; - } - } - } else { - // TODO: Implement full nested field resolution - LOG.warn("Nested variant shredding is not supported. Path: {}", String.join(".", path)); - } - - return -1; - } -} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java index 769f53b21624..39110f0b0597 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -25,9 +25,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.List; import java.util.Map; -import java.util.function.Function; import org.apache.iceberg.FileFormat; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.PartitionSpec; @@ -39,23 +37,15 @@ import org.apache.iceberg.data.RegistryBasedFileWriterFactory; import org.apache.iceberg.deletes.PositionDeleteWriter; import org.apache.iceberg.encryption.EncryptedOutputFile; -import org.apache.iceberg.io.BufferedFileAppender; -import org.apache.iceberg.io.DataWriter; import org.apache.iceberg.io.DeleteSchemaUtil; -import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; -import org.apache.iceberg.parquet.ParquetSchemaUtil; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.SparkSchemaUtil; -import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; import org.apache.iceberg.spark.data.SparkAvroWriter; import org.apache.iceberg.spark.data.SparkOrcWriter; import org.apache.iceberg.spark.data.SparkParquetWriters; -import org.apache.iceberg.types.Types; -import org.apache.parquet.schema.MessageType; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; @@ -72,10 +62,6 @@ class SparkFileWriterFactory extends RegistryBasedFileWriterFactory writeProperties; - private final Schema dataSchema; - private final StructType dataSparkType; - private final FileFormat dataFileFormat; - private final SortOrder dataSortOrder; /** * @deprecated This constructor is deprecated as of version 1.11.0 and will be removed in 1.12.0. @@ -116,10 +102,6 @@ class SparkFileWriterFactory extends RegistryBasedFileWriterFactory newPositionDeleteWriter( } } - @Override - public DataWriter newDataWriter( - EncryptedOutputFile file, PartitionSpec spec, StructLike partition) { - if (!shouldUseVariantShredding()) { - return super.newDataWriter(file, spec, partition); - } - - int bufferSize = - Integer.parseInt( - writeProperties.getOrDefault( - SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, - String.valueOf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT))); - - Map tableProperties = table != null ? table.properties() : ImmutableMap.of(); - MetricsConfig metricsConfig = - table != null ? MetricsConfig.forTable(table) : MetricsConfig.getDefault(); - - Function, FileAppender> appenderFactory = - bufferedRows -> { - Preconditions.checkNotNull(bufferedRows, "bufferedRows must not be null"); - MessageType originalSchema = ParquetSchemaUtil.convert(dataSchema, "table"); - - MessageType shreddedSchema = - (MessageType) - ParquetWithSparkSchemaVisitor.visit( - dataSparkType, - originalSchema, - new SchemaInferenceVisitor(bufferedRows, dataSparkType)); - - try { - FileAppender appender = - Parquet.write(file) - .schema(dataSchema) - .withFileSchema(shreddedSchema) - .createWriterFunc( - msgType -> SparkParquetWriters.buildWriter(dataSparkType, msgType)) - .setAll(tableProperties) - .setAll(writeProperties) - .metricsConfig(metricsConfig) - .overwrite() - .build(); - - for (InternalRow row : bufferedRows) { - appender.add(row); - } - - return appender; - } catch (IOException e) { - throw new UncheckedIOException("Failed to create shredded variant writer", e); - } - }; - - BufferedFileAppender bufferedAppender = - new BufferedFileAppender<>(bufferSize, appenderFactory, InternalRow::copy); - - return new DataWriter<>( - bufferedAppender, - dataFileFormat, - file.encryptingOutputFile().location(), - spec, - partition, - file.keyMetadata(), - dataSortOrder); - } - static class Builder { private final Table table; private FileFormat dataFileFormat; @@ -448,19 +361,4 @@ private static StructType useOrConvert(StructType sparkType, Schema schema) { return null; } } - - private boolean shouldUseVariantShredding() { - // Variant shredding is currently only supported for Parquet files - if (dataFileFormat != FileFormat.PARQUET) { - return false; - } - - boolean shreddingEnabled = - Boolean.parseBoolean(writeProperties.get(SparkSQLProperties.SHRED_VARIANTS)); - - return shreddingEnabled - && dataSchema != null - && dataSchema.columns().stream() - .anyMatch(field -> field.type() instanceof Types.VariantType); - } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java index 23fbe54a4be3..5b7862116aea 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java @@ -51,7 +51,9 @@ public static void register() { StructType.class, SparkParquetWriters::buildWriter, (icebergSchema, fileSchema, engineSchema, idToConstant) -> - SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant))); + SparkParquetReaders.buildReader(icebergSchema, fileSchema, idToConstant), + new SparkVariantShreddingAnalyzer(), + InternalRow::copy)); FormatModelRegistry.register( ParquetFormatModel.create( diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java index 19e0237ee28d..2c08c662c9da 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java @@ -26,15 +26,25 @@ import org.apache.iceberg.variants.VariantMetadata; import org.apache.iceberg.variants.VariantValue; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.VariantVal; /** * Spark-specific implementation that extracts variant values from {@link InternalRow} instances. */ -class SparkVariantShreddingAnalyzer extends VariantShreddingAnalyzer { +class SparkVariantShreddingAnalyzer extends VariantShreddingAnalyzer { SparkVariantShreddingAnalyzer() {} + @Override + protected int resolveColumnIndex(StructType sparkSchema, String columnName) { + try { + return sparkSchema.fieldIndex(columnName); + } catch (IllegalArgumentException e) { + return -1; + } + } + @Override protected List extractVariantValues( List bufferedRows, int variantFieldIndex) { diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index 87f03b9fb051..232f34bd9ea0 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -34,6 +34,7 @@ import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL; +import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_SHRED; import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE; import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE; import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH; @@ -41,7 +42,6 @@ import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; -import static org.apache.iceberg.spark.SparkSQLProperties.SHRED_VARIANTS; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; @@ -345,7 +345,7 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( - SHRED_VARIANTS, + PARQUET_VARIANT_SHRED, "false", DELETE_PARQUET_COMPRESSION, "zstd", @@ -469,7 +469,7 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( - SHRED_VARIANTS, + PARQUET_VARIANT_SHRED, "false", DELETE_PARQUET_COMPRESSION, "zstd", @@ -542,7 +542,7 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( - SHRED_VARIANTS, + PARQUET_VARIANT_SHRED, "false", DELETE_PARQUET_COMPRESSION, "zstd", @@ -653,4 +653,81 @@ private void checkMode(DistributionMode expectedMode, SparkWriteConf writeConf) assertThat(writeConf.copyOnWriteDistributionMode(MERGE)).isEqualTo(expectedMode); assertThat(writeConf.positionDeltaDistributionMode(MERGE)).isEqualTo(expectedMode); } + + @TestTemplate + public void testShredVariantsDefault() { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.shredVariants()).isFalse(); + } + + @TestTemplate + public void testVariantInferenceBufferSizeDefault() { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.variantInferenceBufferSize()) + .isEqualTo(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT); + } + + @TestTemplate + public void testVariantInferenceBufferSizeTableProperty() { + Table table = validationCatalog.loadTable(tableIdent); + + table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "500").commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(500); + } + + @TestTemplate + public void testShredVariantsSessionOverridesTableProperty() { + Table table = validationCatalog.loadTable(tableIdent); + table.updateProperties().set(TableProperties.PARQUET_VARIANT_SHRED, "false").commit(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "true"), + () -> { + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.shredVariants()).isTrue(); + }); + } + + @TestTemplate + public void testShredVariantsWriteOptionOverridesSessionConf() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "false"), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = + new SparkWriteConf( + spark, + table, + new CaseInsensitiveStringMap( + ImmutableMap.of(SparkWriteOptions.SHRED_VARIANTS, "true"))); + assertThat(writeConf.shredVariants()).isTrue(); + }); + } + + @TestTemplate + public void testVariantInferenceBufferSizeSessionConf() { + withSQLConf( + ImmutableMap.of(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "250"), + () -> { + Table table = validationCatalog.loadTable(tableIdent); + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(250); + }); + } + + @TestTemplate + public void testWritePropertiesIncludeVariantShredding() { + Table table = validationCatalog.loadTable(tableIdent); + table.updateProperties().set(TableProperties.PARQUET_VARIANT_SHRED, "true").commit(); + table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "200").commit(); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table); + Map writeProperties = writeConf.writeProperties(); + assertThat(writeProperties).containsEntry(PARQUET_VARIANT_SHRED, "true"); + assertThat(writeProperties).containsEntry(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, "200"); + } } diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index 65e3894ccc71..b645a5ff5e34 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -136,9 +136,11 @@ public void testExcludingNullValue() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30, \"dummy\": null}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + """ + (1, parse_json('{"name": "Alice", "age": 30, "dummy": null}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType name = @@ -163,9 +165,11 @@ public void testInconsistentType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"age\": \"25\"}'))," - + " (2, parse_json('{\"age\": 30}'))," - + " (3, parse_json('{\"age\": \"35\"}'))"; + """ + (1, parse_json('{"age": "25"}')),\ + (2, parse_json('{"age": 30}')),\ + (3, parse_json('{"age": "35"}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType age = @@ -178,13 +182,18 @@ public void testInconsistentType() throws IOException { Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); + + List rows = + sql("SELECT variant_get(address, '$.age', 'int') FROM %s WHERE id = 2", tableName); + assertThat(rows).hasSize(1); + assertThat(rows.get(0)[0]).isEqualTo(30); } @TestTemplate public void testPrimitiveType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - String values = "(1, parse_json('123')), (2, parse_json('\"abc\"')), (3, parse_json('12'))"; + String values = "(1, parse_json('123')), (2, parse_json('456')), (3, parse_json('789'))"; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType address = @@ -193,7 +202,7 @@ public void testPrimitiveType() throws IOException { 2, Type.Repetition.REQUIRED, shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); @@ -226,9 +235,11 @@ public void testBooleanType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"active\": true}'))," - + " (2, parse_json('{\"active\": false}'))," - + " (3, parse_json('{\"active\": true}'))"; + """ + (1, parse_json('{"active": true}')),\ + (2, parse_json('{"active": false}')),\ + (3, parse_json('{"active": true}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); @@ -244,9 +255,11 @@ public void testDecimalTypeWithInconsistentScales() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"price\": 123.456789}'))," - + " (2, parse_json('{\"price\": 678.90}'))," - + " (3, parse_json('{\"price\": 999.99}'))"; + """ + (1, parse_json('{"price": 123.456789}')),\ + (2, parse_json('{"price": 678.90}')),\ + (3, parse_json('{"price": 999.99}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType price = @@ -266,9 +279,11 @@ public void testDecimalTypeWithConsistentScales() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"price\": 123.45}'))," - + " (2, parse_json('{\"price\": 678.90}'))," - + " (3, parse_json('{\"price\": 999.99}'))"; + """ + (1, parse_json('{"price": 123.45}')),\ + (2, parse_json('{"price": 678.90}')),\ + (3, parse_json('{"price": 999.99}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType price = @@ -288,9 +303,11 @@ public void testArrayType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('[\"java\", \"scala\", \"python\"]'))," - + " (2, parse_json('[\"rust\", \"go\"]'))," - + " (3, parse_json('[\"javascript\"]'))"; + """ + (1, parse_json('["java", "scala", "python"]')),\ + (2, parse_json('["rust", "go"]')),\ + (3, parse_json('["javascript"]'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType arr = @@ -310,9 +327,11 @@ public void testNestedArrayType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"tags\": [\"java\", \"scala\", \"python\"]}'))," - + " (2, parse_json('{\"tags\": [\"rust\", \"go\"]}'))," - + " (3, parse_json('{\"tags\": [\"javascript\"]}'))"; + """ + (1, parse_json('{"tags": ["java", "scala", "python"]}')),\ + (2, parse_json('{"tags": ["rust", "go"]}')),\ + (3, parse_json('{"tags": ["javascript"]}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType tags = @@ -335,9 +354,11 @@ public void testNestedObjectType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"location\": {\"city\": \"Seattle\", \"zip\": 98101}, \"tags\": [\"java\", \"scala\", \"python\"]}'))," - + " (2, parse_json('{\"location\": {\"city\": \"Portland\", \"zip\": 97201}}'))," - + " (3, parse_json('{\"location\": {\"city\": \"NYC\", \"zip\": 10001}}'))"; + """ + (1, parse_json('{"location": {"city": "Seattle", "zip": 98101}, "tags": ["java", "scala", "python"]}')),\ + (2, parse_json('{"location": {"city": "Portland", "zip": 97201}}')),\ + (3, parse_json('{"location": {"city": "NYC", "zip": 10001}}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType city = @@ -374,13 +395,15 @@ public void testLazyInitializationWithBufferedRows() throws IOException { spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "5"); String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))," - + " (4, parse_json('{\"name\": \"David\", \"age\": 28}'))," - + " (5, parse_json('{\"name\": \"Eve\", \"age\": 32}'))," - + " (6, parse_json('{\"name\": \"Frank\", \"age\": 40}'))," - + " (7, parse_json('{\"name\": \"Grace\", \"age\": 27}'))"; + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}')),\ + (4, parse_json('{"name": "David", "age": 28}')),\ + (5, parse_json('{"name": "Eve", "age": 32}')),\ + (6, parse_json('{"name": "Frank", "age": 40}')),\ + (7, parse_json('{"name": "Grace", "age": 27}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType name = @@ -491,10 +514,12 @@ public void testIntegerFamilyPromotion() throws IOException { // Mix of INT8, INT16, INT32, INT64 - should promote to INT64 String values = - "(1, parse_json('{\"value\": 10}'))," - + " (2, parse_json('{\"value\": 1000}'))," - + " (3, parse_json('{\"value\": 100000}'))," - + " (4, parse_json('{\"value\": 10000000000}'))"; + """ + (1, parse_json('{"value": 10}')),\ + (2, parse_json('{"value": 1000}')),\ + (3, parse_json('{"value": 100000}')),\ + (4, parse_json('{"value": 10000000000}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType value = @@ -515,9 +540,11 @@ public void testDecimalFamilyPromotion() throws IOException { // Test that they get promoted to the most capable decimal type observed String values = - "(1, parse_json('{\"value\": 1.5}'))," - + " (2, parse_json('{\"value\": 123.456789}'))," - + " (3, parse_json('{\"value\": 123456789123456.789}'))"; + """ + (1, parse_json('{"value": 1.5}')),\ + (2, parse_json('{"value": 123.456789}')),\ + (3, parse_json('{"value": 123456789123456.789}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType value = @@ -539,9 +566,11 @@ public void testDataRoundTripWithShredding() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, parse_json('{"name": "Bob", "age": 25}')),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType name = @@ -589,9 +618,11 @@ public void testMultipleVariantsWithShredding() throws IOException { tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); String values = - "(1, parse_json('{\"city\": \"NYC\"}'), parse_json('{\"source\": \"web\"}'))," - + " (2, parse_json('{\"city\": \"LA\"}'), parse_json('{\"source\": \"app\"}'))," - + " (3, parse_json('{\"city\": \"SF\"}'), parse_json('{\"source\": \"api\"}'))"; + """ + (1, parse_json('{"city": "NYC"}'), parse_json('{"source": "web"}')),\ + (2, parse_json('{"city": "LA"}'), parse_json('{"source": "app"}')),\ + (3, parse_json('{"city": "SF"}'), parse_json('{"source": "api"}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType city = @@ -618,7 +649,11 @@ public void testVariantWithNullValues() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('null'))," + " (2, parse_json('null'))," + " (3, parse_json('null'))"; + """ + (1, parse_json('null')),\ + (2, parse_json('null')),\ + (3, parse_json('null'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType address = variant("address", 2, Type.Repetition.REQUIRED); @@ -650,9 +685,11 @@ public void testMixedNullAndNonNullVariantValues() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," - + " (2, null)," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + """ + (1, parse_json('{"name": "Alice", "age": 30}')),\ + (2, null),\ + (3, parse_json('{"name": "Charlie", "age": 35}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); GroupType name = @@ -768,6 +805,13 @@ public void testMixedTypeTieBreaking() throws IOException { Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); + + // Verify data round-trips correctly + List rows = + sql("SELECT id, variant_get(address, '$.val', 'string') FROM %s ORDER BY id", tableName); + assertThat(rows).hasSize(10); + assertThat(rows.get(0)[1]).isEqualTo("1"); + assertThat(rows.get(5)[1]).isEqualTo("text6"); } @TestTemplate @@ -776,13 +820,15 @@ public void testFieldOnlyAfterBuffer() throws IOException { spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); String values = - "(1, parse_json('{\"name\": \"Alice\"}'))," - + " (2, parse_json('{\"name\": \"Bob\"}'))," - + " (3, parse_json('{\"name\": \"Charlie\"}'))," - + " (4, parse_json('{\"name\": \"David\", \"score\": 95}'))," - + " (5, parse_json('{\"name\": \"Eve\", \"score\": 88}'))," - + " (6, parse_json('{\"name\": \"Frank\", \"score\": 72}'))," - + " (7, parse_json('{\"name\": \"Grace\", \"score\": 91}'))"; + """ + (1, parse_json('{"name": "Alice"}')),\ + (2, parse_json('{"name": "Bob"}')),\ + (3, parse_json('{"name": "Charlie"}')),\ + (4, parse_json('{"name": "David", "score": 95}')),\ + (5, parse_json('{"name": "Eve", "score": 88}')),\ + (6, parse_json('{"name": "Frank", "score": 72}')),\ + (7, parse_json('{"name": "Grace", "score": 91}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, values); // Schema is determined from buffer (rows 1-3) which only has "name". @@ -821,9 +867,11 @@ public void testCrossFileDifferentShreddedType() throws IOException { // File 1: "score" is always integer → shredded as INT8 String batch1 = - "(1, parse_json('{\"score\": 95}'))," - + " (2, parse_json('{\"score\": 88}'))," - + " (3, parse_json('{\"score\": 72}'))"; + """ + (1, parse_json('{"score": 95}')),\ + (2, parse_json('{"score": 88}')),\ + (3, parse_json('{"score": 72}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, batch1); // Verify file 1 schema: score shredded as INT8 @@ -839,9 +887,11 @@ public void testCrossFileDifferentShreddedType() throws IOException { // File 2: "score" is always string → shredded as STRING String batch2 = - "(4, parse_json('{\"score\": \"high\"}'))," - + " (5, parse_json('{\"score\": \"medium\"}'))," - + " (6, parse_json('{\"score\": \"low\"}'))"; + """ + (4, parse_json('{"score": "high"}')),\ + (5, parse_json('{"score": "medium"}')),\ + (6, parse_json('{"score": "low"}'))\ + """; sql("INSERT INTO %s VALUES %s", tableName, batch2); // Query across both files, reader must handle different shredded types @@ -878,10 +928,12 @@ public void testBufferSizeOne() throws IOException { spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "1"); sql( - "INSERT INTO %s VALUES " - + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))", + """ + INSERT INTO %s VALUES + (1, parse_json('{"name": "Alice", "age": 30}')), + (2, parse_json('{"name": "Bob", "age": 25}')), + (3, parse_json('{"name": "Charlie", "age": 35}')) + """, tableName); // Schema inferred from first row only, should still shred name and age @@ -912,15 +964,16 @@ private void verifyParquetSchema(Table table, MessageType expectedSchema) throws try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); - FileScanTask task = tasks.iterator().next(); - String path = task.file().location(); + for (FileScanTask task : tasks) { + String path = task.file().location(); - HadoopInputFile inputFile = - HadoopInputFile.fromPath(new org.apache.hadoop.fs.Path(path), new Configuration()); + HadoopInputFile inputFile = + HadoopInputFile.fromPath(new org.apache.hadoop.fs.Path(path), new Configuration()); - try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) { - MessageType actualSchema = reader.getFileMetaData().getSchema(); - assertThat(actualSchema).isEqualTo(expectedSchema); + try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) { + MessageType actualSchema = reader.getFileMetaData().getSchema(); + assertThat(actualSchema).isEqualTo(expectedSchema); + } } } } From f76348662c1c18719ab19b9ee67db5b2cc822a17 Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Tue, 7 Apr 2026 14:30:05 -0700 Subject: [PATCH 13/17] Fix decimal overflow, array pruning, and buffer lifecycle in variant shredding --- .../iceberg/io/BufferedFileAppender.java | 13 +- .../parquet/VariantShreddingAnalyzer.java | 15 ++- .../parquet/TestVariantShreddingAnalyzer.java | 120 +++++++++++++++--- 3 files changed, 123 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java index 15ecf7328ae4..bcd45f9d30de 100644 --- a/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java +++ b/core/src/main/java/org/apache/iceberg/io/BufferedFileAppender.java @@ -96,6 +96,7 @@ public Metrics metrics() { if (delegate == null) { return new Metrics(0L); } + return delegate.metrics(); } @@ -121,22 +122,26 @@ public List splitOffsets() { @Override public void close() throws IOException { if (!closed) { - this.closed = true; if (delegate == null && buffer != null && !buffer.isEmpty()) { initialize(); } + if (delegate != null) { delegate.close(); } + + this.closed = true; + this.buffer = null; } } private void initialize() { delegate = appenderFactory.apply(buffer); Preconditions.checkState(delegate != null, "appenderFactory must not return null"); - for (D row : buffer) { - delegate.add(row); + try { + buffer.forEach(delegate::add); + } finally { + buffer = null; } - buffer = null; } } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java index 2dcbf66ce283..5442b78449a2 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java @@ -96,9 +96,7 @@ public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) return null; } - if (rootType == PhysicalType.OBJECT) { - pruneInfrequentFields(root, root.info.observationCount); - } + pruneInfrequentFields(root, root.info.observationCount); return buildTypedValue(root, rootType); } @@ -150,7 +148,7 @@ private static PathNode buildPathTree(List variantValues) { } private static void pruneInfrequentFields(PathNode node, int totalRows) { - if (node.objectChildren.isEmpty()) { + if (node.objectChildren.isEmpty() && node.arrayElement == null) { return; } @@ -181,10 +179,15 @@ private static void pruneInfrequentFields(PathNode node, int totalRows) { node.objectChildren.entrySet().removeIf(entry -> !keep.contains(entry.getKey())); } - // Recurse into remaining children + // Recurse into remaining object children for (PathNode child : node.objectChildren.values()) { pruneInfrequentFields(child, totalRows); } + + // Recurse into array elements (arrays of objects need pruning too) + if (node.arrayElement != null) { + pruneInfrequentFields(node.arrayElement, totalRows); + } } private static void traverse(PathNode node, VariantValue value, int depth) { @@ -315,7 +318,7 @@ private PathNode(String fieldName) { /** Use DECIMAL with maximum precision and scale as the shredding type */ private static Type createDecimalTypedValue(FieldInfo info) { int maxPrecision = Math.min(info.maxDecimalIntegerDigits + info.maxDecimalScale, 38); - int maxScale = Math.min(info.maxDecimalScale, maxPrecision); + int maxScale = Math.min(info.maxDecimalScale, Math.max(0, 38 - info.maxDecimalIntegerDigits)); if (maxPrecision <= 9) { return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java index d87799cf19e6..797b011c7e52 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantShreddingAnalyzer.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Locale; +import java.util.function.Function; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.variants.ShreddedObject; import org.apache.iceberg.variants.ValueArray; @@ -214,7 +215,7 @@ public void testDecimalForExceedingPrecision() { // Value 1: 30 integer digits, 0 fractional -> precision=30, scale=0, intDigits=30 // Value 2: 1 integer digit, 20 fractional -> precision=21, scale=20, intDigits=1 // Combined: maxIntDigits=30, maxScale=20, raw sum=50 -> capped to precision=38, - // scale=min(20,38)=20 + // scale=min(20, 38-30)=8 (integer digits get priority) VariantMetadata meta = Variants.metadata("val"); ShreddedObject row1 = Variants.object(meta); row1.put("val", Variants.of(new java.math.BigDecimal("123456789012345678901234567890"))); @@ -234,8 +235,8 @@ public void testDecimalForExceedingPrecision() { valPrimitive.getLogicalTypeAnnotation(); assertThat(decimal).isNotNull(); assertThat(decimal.getPrecision()).isEqualTo(38); - assertThat(decimal.getScale()).isEqualTo(20); - // Scale must not exceed precision + // With 30 integer digits, scale is capped to 38 - 30 = 8 (integer digits get priority) + assertThat(decimal.getScale()).isEqualTo(8); assertThat(decimal.getScale()).isLessThanOrEqualTo(decimal.getPrecision()); // Physical type should be FIXED_LEN_BYTE_ARRAY since precision > 18 @@ -271,18 +272,8 @@ public void testDecimalForExactPrecision() { public void testInfrequentFieldsArePruned() { DirectAnalyzer analyzer = new DirectAnalyzer(); - VariantMetadata meta = Variants.metadata("common", "rare"); - - // 100 rows: "common" in all 100, "rare" in only 5 (< 10% threshold) - List rows = Lists.newArrayList(); - for (int i = 0; i < 100; i++) { - ShreddedObject obj = Variants.object(meta); - obj.put("common", Variants.of(i)); - if (i < 5) { - obj.put("rare", Variants.of("text")); - } - rows.add(obj); - } + // 100 rows: "common" in all, "rare" in only 5 (below MIN_FIELD_FREQUENCY = 0.10) + List rows = buildPruningTestRows(5, obj -> obj); Type schema = analyzer.analyzeAndCreateSchema(rows, 0); assertThat(schema).isNotNull(); @@ -303,6 +294,105 @@ public void testEmptyArrayReturnsNull() { assertThat(schema).isNull(); } + @Test + public void testRootPrimitiveProducesTypedValue() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // root type is primitive + List rows = List.of(Variants.of("hello"), Variants.of("world"), Variants.of("x")); + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + assertThat(schema.getName()).isEqualTo("typed_value"); + assertThat(schema.isPrimitive()).isTrue(); + assertThat(schema.asPrimitiveType().getLogicalTypeAnnotation()) + .isEqualTo(LogicalTypeAnnotation.stringType()); + } + + @Test + public void testRootArrayOfObjectsPrunesInfrequentFields() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + // 100 arrays: "common" in all, "rare" in only 3 (below MIN_FIELD_FREQUENCY = 0.10) + List rows = + buildPruningTestRows( + 3, + obj -> { + ValueArray arr = Variants.array(); + arr.add(obj); + return arr; + }); + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + + GroupType listType = schema.asGroupType(); + assertThat(listType.getLogicalTypeAnnotation()) + .isInstanceOf(LogicalTypeAnnotation.ListLogicalTypeAnnotation.class); + GroupType repeatedGroup = listType.getType(0).asGroupType(); + GroupType elementGroup = repeatedGroup.getType(0).asGroupType(); + assertThat(elementGroup.containsField("typed_value")).isTrue(); + GroupType objectFields = elementGroup.getType("typed_value").asGroupType(); + assertThat(objectFields.containsField("common")).isTrue(); + assertThat(objectFields.containsField("rare")).isFalse(); + } + + @Test + public void testObjectWithArrayChildPrunesNestedFields() { + DirectAnalyzer analyzer = new DirectAnalyzer(); + + VariantMetadata itemMeta = Variants.metadata("name", "rare"); + VariantMetadata rootMeta = Variants.metadata("items"); + + // 100 rows, "rare" appears in only 3 rows (below MIN_FIELD_FREQUENCY = 0.10) + List rows = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + ShreddedObject item = Variants.object(itemMeta); + item.put("name", Variants.of("item_" + i)); + if (i < 3) { + item.put("rare", Variants.of(1)); + } + ValueArray arr = Variants.array(); + arr.add(item); + ShreddedObject root = Variants.object(rootMeta); + root.put("items", arr); + rows.add(root); + } + + Type schema = analyzer.analyzeAndCreateSchema(rows, 0); + assertThat(schema).isNotNull(); + + GroupType rootTv = schema.asGroupType(); + GroupType itemsGroup = rootTv.getType("items").asGroupType(); + assertThat(itemsGroup.containsField("typed_value")).isTrue(); + GroupType listType = itemsGroup.getType("typed_value").asGroupType(); + GroupType repeatedGroup = listType.getType(0).asGroupType(); + GroupType elementGroup = repeatedGroup.getType(0).asGroupType(); + assertThat(elementGroup.containsField("typed_value")).isTrue(); + GroupType elementFields = elementGroup.getType("typed_value").asGroupType(); + assertThat(elementFields.containsField("name")).isTrue(); + assertThat(elementFields.containsField("rare")).isFalse(); + } + + /** + * Builds 100 variant rows where "common" appears in every row and "rare" appears in only {@code + * rareCount} rows (below MIN_FIELD_FREQUENCY = 0.10 when rareCount < 10). + */ + private static List buildPruningTestRows( + int rareCount, Function wrap) { + VariantMetadata meta = Variants.metadata("common", "rare"); + List rows = Lists.newArrayList(); + for (int i = 0; i < 100; i++) { + ShreddedObject obj = Variants.object(meta); + obj.put("common", Variants.of(i)); + if (i < rareCount) { + obj.put("rare", Variants.of("text")); + } + rows.add(wrap.apply(obj)); + } + return rows; + } + /** Count typed_value group nesting depth along field "a". */ private static int countObjectDepth(Type type) { int depth = 0; From 6608235a6f9636da747e21cf7e4652faccaa519b Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Tue, 7 Apr 2026 18:06:16 -0700 Subject: [PATCH 14/17] Test fix and pr comment --- .../org/apache/iceberg/parquet/VariantShreddingAnalyzer.java | 3 ++- .../org/apache/iceberg/spark/variant/TestVariantShredding.java | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java index 5442b78449a2..117dcfb7851f 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java @@ -27,6 +27,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Types.NestedField; import org.apache.iceberg.variants.PhysicalType; import org.apache.iceberg.variants.VariantArray; import org.apache.iceberg.variants.VariantObject; @@ -122,7 +123,7 @@ protected abstract List extractVariantValues( public Map analyzeVariantColumns( List bufferedRows, Schema icebergSchema, S engineSchema) { Map shreddedTypes = Maps.newHashMap(); - for (org.apache.iceberg.types.Types.NestedField col : icebergSchema.columns()) { + for (NestedField col : icebergSchema.columns()) { if (col.type().isVariantType()) { int rowIndex = resolveColumnIndex(engineSchema, col.name()); if (rowIndex >= 0) { diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index b645a5ff5e34..5b2b6103c683 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -97,6 +97,7 @@ public static void startMetastoreAndSpark() { .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .config(DISABLE_UI) .enableHiveSupport() .getOrCreate(); From a5121ff47546456042a1fafb1578309ce7c956d2 Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Sat, 18 Apr 2026 16:23:15 -0700 Subject: [PATCH 15/17] Fixing PR comments --- .../iceberg/io/TestBufferedFileAppender.java | 3 ++- .../iceberg/parquet/ParquetFormatModel.java | 16 ++++++++++++---- .../parquet/VariantShreddingAnalyzer.java | 1 + 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java index 8157800d07f8..74b4056dba35 100644 --- a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java +++ b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java @@ -31,6 +31,7 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.data.avro.DataWriter; import org.apache.iceberg.data.avro.PlannedDataReader; +import org.apache.iceberg.exceptions.RuntimeIOException; import org.apache.iceberg.inmemory.InMemoryOutputFile; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; @@ -63,7 +64,7 @@ private Function, FileAppender> avroFactory(OutputFile out) .overwrite() .build(); } catch (IOException e) { - throw new org.apache.iceberg.exceptions.RuntimeIOException(e); + throw new RuntimeIOException(e); } }; } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java index f6531b5cfe60..0845c0727897 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFormatModel.java @@ -172,6 +172,12 @@ public ModelWriteBuilder set(String property, String value) { return this; } + @Override + public ModelWriteBuilder setAll(Map properties) { + properties.forEach(this::set); + return this; + } + @Override public ModelWriteBuilder meta(String property, String value) { internal.meta(property, value); @@ -216,16 +222,14 @@ public ModelWriteBuilder withAADPrefix(ByteBuffer aadPrefix) { @Override public FileAppender build() throws IOException { - Preconditions.checkState(content != null, "File content type must be set before building"); + boolean shredVariants = false; switch (content) { case DATA: internal.createContextFunc(Parquet.WriteBuilder.Context::dataContext); internal.createWriterFunc( (icebergSchema, messageType) -> writerFunction.write(icebergSchema, messageType, engineSchema)); - if (shreddingEnabled && variantAnalyzer != null && hasVariantColumns(schema)) { - return buildShreddedAppender(); - } + shredVariants = shreddingEnabled && variantAnalyzer != null && hasVariantColumns(schema); break; case EQUALITY_DELETES: internal.createContextFunc(Parquet.WriteBuilder.Context::deleteContext); @@ -256,6 +260,10 @@ public FileAppender build() throws IOException { throw new IllegalArgumentException("Unknown file content: " + content); } + if (shredVariants) { + return buildShreddedAppender(); + } + return internal.build(); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java index 117dcfb7851f..024635939c5d 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantShreddingAnalyzer.java @@ -134,6 +134,7 @@ public Map analyzeVariantColumns( } } } + return shreddedTypes; } From 4f104b042f216217a348149b02dd8abffd41e698 Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Sat, 18 Apr 2026 16:37:23 -0700 Subject: [PATCH 16/17] Update doc for spark config --- docs/docs/spark-configuration.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/docs/spark-configuration.md b/docs/docs/spark-configuration.md index 524890dd6528..613fe5b66554 100644 --- a/docs/docs/spark-configuration.md +++ b/docs/docs/spark-configuration.md @@ -264,6 +264,7 @@ df.writeTo("catalog.db.table") | compression-strategy | Table write.orc.compression-strategy | Overrides this table's compression strategy for ORC tables for this write | | distribution-mode | See [Spark Writes](spark-writes.md#writing-distribution-modes) for defaults | Override this table's distribution mode for this write | | delete-granularity | file | Override this table's delete granularity for this write | +| shred-variants | false | Overrides this table's write.parquet.variant.shred for this write | CommitMetadata provides an interface to add custom metadata to a snapshot summary during a SQL execution, which can be beneficial for purposes such as auditing or change tracking. If properties start with `snapshot-property.`, then that prefix will be removed from each property. Here is an example: From c63155d659682b471d87b8627742554b416d588c Mon Sep 17 00:00:00 2001 From: Neelesh Salian Date: Tue, 21 Apr 2026 13:55:10 -0700 Subject: [PATCH 17/17] Core: Move DataTestHelpers to core and use in TestBufferedFileAppender Co-authored-by: Neelesh Salian Co-authored-by: Aihua Xu --- .../apache/iceberg/data/DataTestHelpers.java | 0 .../iceberg/io/TestBufferedFileAppender.java | 36 +++++++++---------- 2 files changed, 17 insertions(+), 19 deletions(-) rename {data => core}/src/test/java/org/apache/iceberg/data/DataTestHelpers.java (100%) diff --git a/data/src/test/java/org/apache/iceberg/data/DataTestHelpers.java b/core/src/test/java/org/apache/iceberg/data/DataTestHelpers.java similarity index 100% rename from data/src/test/java/org/apache/iceberg/data/DataTestHelpers.java rename to core/src/test/java/org/apache/iceberg/data/DataTestHelpers.java diff --git a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java index 74b4056dba35..9bbc0f9f8c71 100644 --- a/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java +++ b/core/src/test/java/org/apache/iceberg/io/TestBufferedFileAppender.java @@ -27,6 +27,7 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.avro.Avro; import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.data.DataTestHelpers; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.data.avro.DataWriter; @@ -106,10 +107,14 @@ public void testBufferFlushesOnThreshold() throws IOException { appender.add(createRecord(5L, "e")); appender.close(); - List actual = readBack(); - assertThat(actual).hasSize(5); - assertThat(actual.get(0).getField("id")).isEqualTo(1L); - assertThat(actual.get(4).getField("id")).isEqualTo(5L); + List expected = + Lists.newArrayList( + createRecord(1L, "a"), + createRecord(2L, "b"), + createRecord(3L, "c"), + createRecord(4L, "d"), + createRecord(5L, "e")); + DataTestHelpers.assertEquals(SCHEMA.asStruct(), expected, readBack()); } @Test @@ -126,10 +131,9 @@ public void testCloseWithPartialBuffer() throws IOException { // close flushes partial buffer through factory appender.close(); - List actual = readBack(); - assertThat(actual).hasSize(3); - assertThat(actual.get(0).getField("data")).isEqualTo("a"); - assertThat(actual.get(2).getField("data")).isEqualTo("c"); + List expected = + Lists.newArrayList(createRecord(1L, "a"), createRecord(2L, "b"), createRecord(3L, "c")); + DataTestHelpers.assertEquals(SCHEMA.asStruct(), expected, readBack()); } @Test @@ -151,13 +155,10 @@ public void testCopyFuncIsApplied() throws IOException { appender.close(); - List actual = readBack(); - assertThat(actual).hasSize(3); - // without copyFunc, all 3 rows would have the last values (3, "third") - assertThat(actual.get(0).getField("id")).isEqualTo(1L); - assertThat(actual.get(0).getField("data")).isEqualTo("first"); - assertThat(actual.get(1).getField("id")).isEqualTo(2L); - assertThat(actual.get(1).getField("data")).isEqualTo("second"); + List expected = + Lists.newArrayList( + createRecord(1L, "first"), createRecord(2L, "second"), createRecord(3L, "third")); + DataTestHelpers.assertEquals(SCHEMA.asStruct(), expected, readBack()); } @Test @@ -209,10 +210,7 @@ public void testAddAllSpanningBuffer() throws IOException { appender.addAll(records); appender.close(); - List actual = readBack(); - assertThat(actual).hasSize(4); - assertThat(actual.get(0).getField("id")).isEqualTo(1L); - assertThat(actual.get(3).getField("id")).isEqualTo(4L); + DataTestHelpers.assertEquals(SCHEMA.asStruct(), records, readBack()); } @Test