diff --git a/docs/src/config.md b/docs/src/config.md index 302cf761e..77d2f8e64 100644 --- a/docs/src/config.md +++ b/docs/src/config.md @@ -499,4 +499,63 @@ Lance Spark maintains index and metadata caches to minimize redundant I/O. Cache | `LANCE_INDEX_CACHE_SIZE` | 6GB | Index cache size in bytes. | | `LANCE_METADATA_CACHE_SIZE`| 1GB | Metadata cache size in bytes. | -For details on how caching works and tuning recommendations, see [Performance Tuning - Caching](performance.md#caching). \ No newline at end of file +For details on how caching works and tuning recommendations, see [Performance Tuning - Caching](performance.md#caching). + +## Blob v2 Reads + +Lance datasets that contain a blob v2 column expose that column to Spark as the native 5-field descriptor struct: `struct`. Querying the descriptor never fetches the blob bytes, so `SELECT payload.size` and `SELECT payload.blob_uri` are cheap. + +```sql +-- Query metadata only (no byte fetch): +SELECT id, payload.size, payload.kind FROM lance.ns.tbl; +``` + +A column is treated as blob v2 when the Arrow field carries `ARROW:extension:name = lance.blob.v2`, matching lance-core's blob v2 extension type. + +Filter pushdown for SQL `WHERE` is disabled on blob v2 tables; Spark evaluates predicates after the scan. Zonemap-based fragment pruning still runs. + +The connector does not materialize blob bytes on read; queries against descriptor fields fetch metadata only. See [Blob v2 Writes](#blob-v2-writes) below for the write path. + +## Blob v2 Writes + +To write blob v2 columns, set `file_format_version` to `2.2` or higher and set +`.lance.encoding = blob` in `TBLPROPERTIES`. + +Spark still sees the column as `BINARY` when writing. The connector converts that binary +value into the Arrow blob write struct during encoding. + +On reads, blob v2 columns are exposed as descriptor structs. See +[Blob v2 Reads](#blob-v2-reads). For writes, `INSERT` and DataFrame append still take +`BINARY`. + +```sql +CREATE TABLE lance.mydb.users ( + id INT NOT NULL, + content BINARY +) USING lance +TBLPROPERTIES ( + 'content.lance.encoding' = 'blob', + 'file_format_version' = '2.2' +); +``` + +With `file_format_version = '2.2'` or higher, blob columns are written using blob v2 +encoding and `ARROW:extension:name = lance.blob.v2 metadata`. + +With an older version, or when `file_format_version` is not set, blob columns use the +legacy v1 encoding with `lance-encoding:blob = true` metadata. + +Blob encoding requires a numeric `file_format_version`, such as `2.2`. + +Blob v2 writes must go through the catalog path. Use SQL DDL with `TBLPROPERTIES`, as +shown above, or use the `DataFrameWriterV2` API: + +```python +df.writeTo("lance.ns.users") \ + .tableProperty("content.lance.encoding", "blob") \ + .tableProperty("file_format_version", "2.2") \ + .create() +``` + +Setting only `file_format_version` does not enable blob encoding. Without +`.lance.encoding = blob`, the column is written as plain `BINARY`. diff --git a/docs/src/operations/ddl/create-table.md b/docs/src/operations/ddl/create-table.md index 5468fabc5..21cf9b6ca 100644 --- a/docs/src/operations/ddl/create-table.md +++ b/docs/src/operations/ddl/create-table.md @@ -338,6 +338,33 @@ To create a table with blob columns, use the table property pattern ` str: + return "X'" + data.hex().upper() + "'" + + def _table_location(spark, table_name): rows = ( spark.sql(f"DESCRIBE EXTENDED {table_name}") @@ -705,6 +709,72 @@ def test_compression_metadata_reaches_lance_file(self, spark): assert b"lance-encoding:compression" not in (id_meta or {}) +class TestDDLBlobV2: + def test_blob_v2_table_reads_content_as_descriptor(self, spark): + spark.sql(""" + CREATE TABLE default.test_blob_v2 ( + id INT NOT NULL, + content BINARY + ) USING lance + TBLPROPERTIES ( + 'content.lance.encoding' = 'blob', + 'file_format_version' = '2.2' + ) + """) + + first_content = b"SQL insert content 1" + second_content = b"SQL insert content 2" + + spark.sql( + f"INSERT INTO default.test_blob_v2 VALUES (1, {_sql_binary_literal(first_content)})" + ) + spark.sql( + f"INSERT INTO default.test_blob_v2 VALUES (2, {_sql_binary_literal(second_content)})" + ) + + describe_rows = spark.sql("DESCRIBE default.test_blob_v2").collect() + content_field = next(row for row in describe_rows if row.col_name == "content") + content_type = content_field.data_type.lower() + + assert "struct" in content_type + assert "kind" in content_type + assert "blob_uri" in content_type + + rows = spark.sql(""" + SELECT id, content.size, content.kind, content.blob_id, content.blob_uri + FROM default.test_blob_v2 + ORDER BY id + """).collect() + + assert len(rows) == 2 + + assert rows[0].id == 1 + assert rows[0].size == len(first_content) + assert rows[0].kind == 0 + + assert rows[1].id == 2 + assert rows[1].size == len(second_content) + assert rows[1].kind == 0 + + def test_blob_v2_insert_rejects_non_binary_content(self, spark): + spark.sql(""" + CREATE TABLE default.test_blob_v2_bad_insert ( + id INT NOT NULL, + content BINARY + ) USING lance + TBLPROPERTIES ( + 'content.lance.encoding' = 'blob', + 'file_format_version' = '2.2' + ) + """) + + with pytest.raises(Exception, match="got string"): + spark.sql(""" + INSERT INTO default.test_blob_v2_bad_insert + VALUES (1, 'not-binary') + """) + + class TestDDLIndex: """Test DDL index operations: CREATE INDEX (BTree, FTS).""" diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java index 73bddc9ac..760d4908e 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/BaseLanceNamespaceSparkCatalog.java @@ -599,7 +599,9 @@ public Table createTable( // Build the table ID for credential vending List tableIdList = buildTableId(actualIdent); - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); // Create dataset using namespace - WriteDatasetBuilder handles declareTable internally // and properly leverages namespace client for credential vending @@ -613,7 +615,6 @@ public Table createTable( .mode(WriteParams.WriteMode.CREATE) .enableStableRowIds(catalogConfig.isEnableStableRowIds(properties)) .storageOptions(catalogConfig.getStorageOptions()); - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); if (fileFormatVersion != null) { writeBuilder.dataStorageVersion(fileFormatVersion); } @@ -672,12 +673,13 @@ private Table createTableAtPath( throws TableAlreadyExistsException { String datasetUri = getDatasetUri(ident); - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); LanceSparkReadOptions readOptions = createReadOptions( datasetUri, catalogConfig, Optional.empty(), Optional.empty(), Optional.empty(), name); - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); Map tableProperties = copyUserTableProperties(properties); try { WriteDatasetBuilder writeBuilder = @@ -895,7 +897,9 @@ public StagedTable stageCreate( Identifier actualIdent = transformIdentifierForApi(ident); List tableIdList = buildTableId(actualIdent); - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); DeclareTableRequest declareRequest = new DeclareTableRequest(); tableIdList.forEach(declareRequest::addIdItem); @@ -925,7 +929,6 @@ public StagedTable stageCreate( managedVersioning); StagedCommit stagedCommit = StagedCommit.forNewTable(arrowSchema, location, commitOptions); stagedCommit.setShardingSpec(shardingSpec); - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); return createStagedDataset( readOptions, processedSchema, @@ -946,7 +949,9 @@ private StagedTable stageCreateAtPath( Map properties, ShardingSpec shardingSpec) { String datasetUri = getDatasetUri(ident); - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); LanceSparkReadOptions readOptions = createReadOptions( @@ -958,7 +963,6 @@ private StagedTable stageCreateAtPath( catalogConfig.getStorageOptions(), catalogConfig.isEnableStableRowIds(properties)); StagedCommit stagedCommit = StagedCommit.forNewTable(arrowSchema, datasetUri, commitOptions); stagedCommit.setShardingSpec(shardingSpec); - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); return createStagedDataset( readOptions, processedSchema, @@ -986,7 +990,9 @@ public StagedTable stageReplace( ResolvedTable resolved = resolveIdentifier(ident); DescribeTableResponse describeResponse = resolved.describeResponse; - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); Map initialStorageOptions = describeResponse.getStorageOptions(); boolean managedVersioning = Boolean.TRUE.equals(describeResponse.getManagedVersioning()); @@ -1004,7 +1010,6 @@ public StagedTable stageReplace( StagedCommit stagedCommit = StagedCommit.forExistingTable(ds, arrowSchema, commitOptions); stagedCommit.setShardingSpec(shardingSpec); // Use specified file format version, or fall back to existing table's version - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); if (fileFormatVersion == null) { fileFormatVersion = ds.getLanceFileFormatVersion(); } @@ -1029,7 +1034,9 @@ private StagedTable stageReplaceAtPath( ShardingSpec shardingSpec) throws NoSuchTableException { String datasetUri = getDatasetUri(ident); - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); LanceSparkReadOptions readOptions = createReadOptions( @@ -1049,7 +1056,6 @@ private StagedTable stageReplaceAtPath( StagedCommit stagedCommit = StagedCommit.forExistingTable(ds, arrowSchema, commitOptions); stagedCommit.setShardingSpec(shardingSpec); // Use specified file format version, or fall back to existing table's version - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); if (fileFormatVersion == null) { fileFormatVersion = ds.getLanceFileFormatVersion(); } @@ -1086,7 +1092,9 @@ public StagedTable stageCreateOrReplace( Identifier actualIdent = transformIdentifierForApi(ident); List tableIdList = buildTableId(actualIdent); - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); boolean exists = tableExists(ident); String location; @@ -1120,7 +1128,6 @@ public StagedTable stageCreateOrReplace( Schema arrowSchema = LanceArrowUtils.toArrowSchema(processedSchema, "UTC", true); // Use specified file format version, or fall back to existing table's version - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); Map merged = LanceRuntime.mergeStorageOptions(catalogConfig.getStorageOptions(), initialStorageOptions); final StagedCommitOptions commitOptions = @@ -1161,7 +1168,9 @@ private StagedTable stageCreateOrReplaceAtPath( Map properties, ShardingSpec shardingSpec) { String datasetUri = getDatasetUri(ident); - StructType processedSchema = SchemaConverter.processSchemaWithProperties(schema, properties); + String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); + StructType processedSchema = + SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); LanceSparkReadOptions readOptions = createReadOptions( @@ -1174,7 +1183,6 @@ private StagedTable stageCreateOrReplaceAtPath( catalogConfig.getStorageOptions(), catalogConfig.isEnableStableRowIds(properties)); StagedCommit stagedCommit; // Use specified file format version, or fall back to existing table's version - String fileFormatVersion = catalogConfig.getFileFormatVersion(properties); if (exists) { Dataset ds = Utils.openDatasetBuilder(readOptions).build(); stagedCommit = StagedCommit.forExistingTable(ds, arrowSchema, commitOptions); diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java index c3e817dc9..ddea43d5e 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java @@ -18,6 +18,7 @@ import org.lance.spark.utils.BlobSourceContext; import org.lance.spark.utils.BlobUtils; import org.lance.spark.write.AddColumnsBackfillWrite; +import org.lance.spark.write.LanceWriteSchemaValidator; import org.lance.spark.write.SparkWrite; import org.lance.spark.write.StagedCommit; import org.lance.spark.write.UpdateColumnsBackfillWrite; @@ -60,6 +61,14 @@ public class LanceDataset ImmutableSet.of( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE); + // Blob v2 is read as descriptor structs, but written as BINARY from sparkSchema. + private static final Set CAPABILITIES_WITH_BLOB_V2 = + ImmutableSet.of( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.TRUNCATE, + TableCapability.ACCEPT_ANY_SCHEMA); + public static final MetadataColumn FRAGMENT_ID_COLUMN = new MetadataColumn() { @Override @@ -303,7 +312,7 @@ public String name() { @Override public StructType schema() { - return sparkSchema; + return BlobUtils.applyBlobV2DescriptorSchema(sparkSchema); } @Override @@ -318,11 +327,14 @@ public Map properties() { @Override public Set capabilities() { - return CAPABILITIES; + return BlobUtils.hasBlobV2Fields(sparkSchema) ? CAPABILITIES_WITH_BLOB_V2 : CAPABILITIES; } @Override public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { + if (capabilities().contains(TableCapability.ACCEPT_ANY_SCHEMA)) { + LanceWriteSchemaValidator.validate(sparkSchema, logicalWriteInfo.schema()); + } // Merge write-time options with the base options from read options CaseInsensitiveStringMap sparkWriteOptions = logicalWriteInfo.options(); Map blobSourceContexts = decodeBlobSourceContexts(sparkWriteOptions); diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java index 4df430989..834e9dd0b 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java @@ -26,6 +26,7 @@ import org.lance.schema.LanceSchema; import org.lance.spark.LanceSparkReadOptions; import org.lance.spark.sharding.SparkLanceShardingUtils; +import org.lance.spark.utils.BlobUtils; import org.lance.spark.utils.Optional; import org.lance.spark.utils.Utils; @@ -117,8 +118,8 @@ public LanceScanBuilder( String namespaceImpl, java.util.Map namespaceProperties, ShardingSpec shardingSpec) { - this.fullSchema = schema; - this.schema = schema; + this.fullSchema = BlobUtils.applyBlobV2DescriptorSchema(schema); + this.schema = this.fullSchema; this.readOptions = readOptions; this.initialStorageOptions = initialStorageOptions; this.namespaceImpl = namespaceImpl; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobUtils.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobUtils.java index 2a94557a3..af387dd5b 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobUtils.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobUtils.java @@ -13,13 +13,36 @@ */ package org.lance.spark.utils; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.util.List; +import java.util.Map; public class BlobUtils { public static final String LANCE_ENCODING_BLOB_KEY = "lance-encoding:blob"; public static final String LANCE_ENCODING_BLOB_VALUE = "true"; + public static final String ARROW_EXTENSION_NAME_KEY = "ARROW:extension:name"; + public static final String ARROW_EXTENSION_BLOB_V2 = "lance.blob.v2"; + + /** + * Spark struct type for a Lance blob v2 descriptor: {@code kind, position, size, blob_id, + * blob_uri}. + */ + public static final StructType BLOB_DESCRIPTOR_STRUCT = + new StructType() + .add("kind", DataTypes.ShortType) + .add("position", DataTypes.LongType) + .add("size", DataTypes.LongType) + .add("blob_id", DataTypes.LongType) + .add("blob_uri", DataTypes.StringType); + /** * Check if a Spark field is a blob field based on its metadata. * @@ -66,4 +89,83 @@ public static boolean isBlobArrowField(org.apache.arrow.vector.types.pojo.Field String value = metadata.get(LANCE_ENCODING_BLOB_KEY); return LANCE_ENCODING_BLOB_VALUE.equalsIgnoreCase(value); } + + /** Returns true when a Spark field carries the lance-core blob v2 Arrow extension. */ + public static boolean isBlobV2SparkField(StructField field) { + return field != null && isBlobV2SparkMetadata(field.metadata()); + } + + public static boolean isBlobV2SparkMetadata(Metadata metadata) { + if (metadata == null) { + return false; + } + + return metadata.contains(ARROW_EXTENSION_NAME_KEY) + && ARROW_EXTENSION_BLOB_V2.equals(metadata.getString(ARROW_EXTENSION_NAME_KEY)); + } + + /** + * Arrow-side counterpart of {@link #isBlobV2SparkField} used inside the columnar batch scanner. + */ + public static boolean isBlobV2ArrowField(Field field) { + if (field == null) { + return false; + } + + Map metadata = field.getMetadata(); + if (metadata != null + && ARROW_EXTENSION_BLOB_V2.equals(metadata.get(ARROW_EXTENSION_NAME_KEY))) { + return true; + } + + // lance-core scan batches expose the unloaded descriptor struct (no extension metadata). + return isBlobV2DescriptorArrowField(field); + } + + private static boolean isBlobV2DescriptorArrowField(Field field) { + if (!(field.getType() instanceof ArrowType.Struct)) { + return false; + } + List children = field.getChildren(); + if (children == null || children.size() != BLOB_DESCRIPTOR_STRUCT.fields().length) { + return false; + } + StructField[] expected = BLOB_DESCRIPTOR_STRUCT.fields(); + for (int i = 0; i < expected.length; i++) { + if (!expected[i].name().equals(children.get(i).getName())) { + return false; + } + } + return true; + } + + /** Returns true if any field in {@code schema} is a blob v2 column. */ + public static boolean hasBlobV2Fields(StructType schema) { + for (StructField field : schema.fields()) { + if (isBlobV2SparkField(field)) { + return true; + } + } + + return false; + } + + /** Rewrites blob v2 columns to the descriptor struct returned by Lance. */ + public static StructType applyBlobV2DescriptorSchema(StructType schema) { + StructField[] fields = new StructField[schema.fields().length]; + boolean changed = false; + for (int i = 0; i < schema.fields().length; i++) { + StructField field = schema.fields()[i]; + if (!isBlobV2SparkField(field)) { + fields[i] = field; + continue; + } + + fields[i] = + new StructField(field.name(), BLOB_DESCRIPTOR_STRUCT, field.nullable(), field.metadata()); + changed = true; + } + + return changed ? new StructType(fields) : schema; + } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SchemaConverter.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SchemaConverter.java index 4a6b046d4..918f88586 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SchemaConverter.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/SchemaConverter.java @@ -26,6 +26,8 @@ import java.util.Map; +import static org.lance.spark.utils.BlobUtils.ARROW_EXTENSION_BLOB_V2; +import static org.lance.spark.utils.BlobUtils.ARROW_EXTENSION_NAME_KEY; import static org.lance.spark.utils.BlobUtils.LANCE_ENCODING_BLOB_KEY; import static org.lance.spark.utils.BlobUtils.LANCE_ENCODING_BLOB_VALUE; import static org.lance.spark.utils.Float16Utils.ARROW_FLOAT16_KEY; @@ -41,8 +43,11 @@ */ public class SchemaConverter { - private SchemaConverter() { - // Utility class + private SchemaConverter() {} + + public static StructType processSchemaWithProperties( + StructType sparkSchema, Map properties) { + return processSchemaWithProperties(sparkSchema, properties, null); } /** @@ -51,13 +56,16 @@ private SchemaConverter() { * * @param sparkSchema the original Spark StructType * @param properties table properties that may contain column metadata + * @param fileFormatVersion the file format version (e.g. "2.1", "2.2") used to choose blob v1 vs + * v2 encoding metadata on {@code BinaryType} columns. Pass {@code null} to default blob + * columns to v1. * @return StructType with metadata added for matching columns */ public static StructType processSchemaWithProperties( - StructType sparkSchema, Map properties) { + StructType sparkSchema, Map properties, String fileFormatVersion) { StructType schemaWithVectors = addVectorMetadata(sparkSchema, properties); StructType schemaWithFloat16 = addFloat16Metadata(schemaWithVectors, properties); - StructType schemaWithBlobs = addBlobMetadata(schemaWithFloat16, properties); + StructType schemaWithBlobs = addBlobMetadata(schemaWithFloat16, properties, fileFormatVersion); StructType schemaWithLargeVarChar = addLargeVarCharMetadata(schemaWithBlobs, properties); return addCompressionMetadata(schemaWithLargeVarChar, properties); } @@ -192,10 +200,11 @@ private static StructType addFloat16Metadata( * * @param sparkSchema the original Spark StructType * @param properties table properties that may contain blob column metadata + * @param fileFormatVersion the file format version used to choose blob v1 or v2 * @return StructType with metadata added for blob columns */ private static StructType addBlobMetadata( - StructType sparkSchema, Map properties) { + StructType sparkSchema, Map properties, String fileFormatVersion) { if (properties == null || properties.isEmpty()) { return sparkSchema; } @@ -211,10 +220,13 @@ private static StructType addBlobMetadata( if ("blob".equalsIgnoreCase(encodingValue)) { if (field.dataType() instanceof BinaryType) { // Add metadata for blob encoding + boolean useV2 = enablesBlobV2(fileFormatVersion); + String metaKey = useV2 ? ARROW_EXTENSION_NAME_KEY : LANCE_ENCODING_BLOB_KEY; + String metaVal = useV2 ? ARROW_EXTENSION_BLOB_V2 : LANCE_ENCODING_BLOB_VALUE; Metadata newMetadata = new MetadataBuilder() .withMetadata(field.metadata()) - .putString(LANCE_ENCODING_BLOB_KEY, LANCE_ENCODING_BLOB_VALUE) + .putString(metaKey, metaVal) .build(); newFields[i] = new StructField(field.name(), field.dataType(), field.nullable(), newMetadata); @@ -238,6 +250,26 @@ private static StructType addBlobMetadata( return new StructType(newFields); } + private static boolean enablesBlobV2(String fileFormatVersion) { + if (fileFormatVersion == null) { + return false; + } + + String[] parts = fileFormatVersion.split("\\."); + try { + int major = Integer.parseInt(parts[0]); + int minor = parts.length > 1 ? Integer.parseInt(parts[1]) : 0; + + return major > 2 || (major == 2 && minor >= 2); + } catch (NumberFormatException | ArrayIndexOutOfBoundsException e) { + throw new IllegalArgumentException( + String.format( + "Blob columns require a numeric file_format_version like '2.2'. Got: '%s'.", + fileFormatVersion), + e); + } + } + /** * Adds metadata to StringType fields based on table properties for large varchar columns. * Properties with pattern ".arrow.large_var_char" = "true" are applied to matching diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java index 8d30f9783..2414666cb 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java @@ -38,6 +38,7 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.util.LanceArrowUtils; import org.apache.spark.sql.vectorized.ArrowColumnVector; @@ -71,7 +72,7 @@ public LanceArrowColumnVector(ValueVector vector) { } public LanceArrowColumnVector(ValueVector vector, boolean closeVectorOnClose) { - super(LanceArrowUtils.fromArrowField(vector.getField())); + super(computeDataType(vector)); this.closeVectorOnClose = closeVectorOnClose; if (vector instanceof UInt1Vector) { @@ -86,6 +87,8 @@ public LanceArrowColumnVector(ValueVector vector, boolean closeVectorOnClose) { fixedSizeBinaryAccessor = new FixedSizeBinaryAccessor((FixedSizeBinaryVector) vector); } else if (vector instanceof FixedSizeListVector) { fixedSizeListAccessor = new FixedSizeListAccessor((FixedSizeListVector) vector); + } else if (vector instanceof StructVector && BlobUtils.isBlobV2ArrowField(vector.getField())) { + structAccessor = new LanceStructAccessor((StructVector) vector); } else if (vector instanceof StructVector && BlobUtils.isBlobArrowField(vector.getField())) { blobStructAccessor = new BlobStructAccessor((StructVector) vector); } else if (vector instanceof StructVector) { @@ -522,4 +525,11 @@ public ColumnVector getChild(int ordinal) { public BlobStructAccessor getBlobStructAccessor() { return blobStructAccessor; } + + private static DataType computeDataType(ValueVector vector) { + if (vector instanceof StructVector && BlobUtils.isBlobV2ArrowField(vector.getField())) { + return BlobUtils.BLOB_DESCRIPTOR_STRUCT; + } + return LanceArrowUtils.fromArrowField(vector.getField()); + } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceWriteSchemaValidator.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceWriteSchemaValidator.java new file mode 100644 index 000000000..dd71e2f03 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceWriteSchemaValidator.java @@ -0,0 +1,185 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.write; + +import org.lance.spark.utils.BlobUtils; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public final class LanceWriteSchemaValidator { + + private LanceWriteSchemaValidator() {} + + public static void validate(StructType tableSchema, StructType inputSchema) { + StructField[] tableFields = tableSchema.fields(); + StructField[] inputFields = inputSchema.fields(); + + if (tableFields.length != inputFields.length) { + throw new IllegalArgumentException( + String.format( + "Cannot write to Lance table because the column count is different. " + + "Table has %d columns, input has %d.", + tableFields.length, inputFields.length)); + } + + StructField[] resolved = resolveFields(tableFields, inputFields); + + List errors = new ArrayList<>(); + for (int i = 0; i < inputFields.length; i++) { + checkField(resolved[i].name(), resolved[i], inputFields[i], errors); + } + + if (errors.isEmpty()) { + return; + } + + throw new IllegalArgumentException( + String.format( + "Cannot write to Lance table because the schemas do not match: %s", + String.join("; ", errors))); + } + + private static void checkField( + String path, StructField tableField, StructField inputField, List errors) { + if (!tableField.nullable() && inputField.nullable()) { + errors.add( + String.format( + "column '%s': the table column does not allow nulls, but the input does", path)); + } + + DataType tableType = tableField.dataType(); + DataType inputType = inputField.dataType(); + + if (BlobUtils.isBlobV2SparkField(tableField)) { + if (!DataTypes.BinaryType.equals(inputType)) { + errors.add( + String.format( + "column '%s': blob v2 columns accept binary data; got %s", + path, inputType.simpleString())); + } + return; + } + + if (tableType instanceof StructType && inputType instanceof StructType) { + checkStruct(path, (StructType) tableType, (StructType) inputType, errors); + return; + } + + if (!tableType.equals(inputType)) { + errors.add( + String.format( + "column '%s': expected %s, got %s", + path, tableType.simpleString(), inputType.simpleString())); + } + } + + private static void checkStruct( + String path, StructType tableStruct, StructType inputStruct, List errors) { + StructField[] tableFields = tableStruct.fields(); + StructField[] inputFields = inputStruct.fields(); + + if (tableFields.length != inputFields.length) { + errors.add( + String.format( + "column '%s': the struct has %d fields in the table but %d in the input", + path, tableFields.length, inputFields.length)); + return; + } + + for (int i = 0; i < tableFields.length; i++) { + StructField tableField = tableFields[i]; + StructField inputField = inputFields[i]; + if (!tableField.name().equals(inputField.name())) { + errors.add( + String.format( + "column '%s': struct fields are in the wrong order; expected '%s' at position %d," + + " got '%s'", + path, tableField.name(), i, inputField.name())); + return; + } + checkField(path + "." + tableField.name(), tableField, inputField, errors); + } + } + + // LanceArrowWriter maps input column i to table column i. Match by name in table order, + // or accept Spark's SQL VALUES column names {col1, col2, ...}. + private static StructField[] resolveFields(StructField[] tableFields, StructField[] inputFields) { + Map tableByName = new HashMap<>(tableFields.length); + for (StructField field : tableFields) { + tableByName.put(field.name(), field); + } + + StructField[] resolved = new StructField[inputFields.length]; + int matchedByName = 0; + for (int i = 0; i < inputFields.length; i++) { + resolved[i] = tableByName.get(inputFields[i].name()); + if (resolved[i] != null) { + matchedByName++; + } + } + + if (matchedByName == inputFields.length) { + for (int i = 0; i < inputFields.length; i++) { + if (!tableFields[i].name().equals(inputFields[i].name())) { + throw new IllegalArgumentException( + String.format( + "Cannot write to Lance table because column names match, but the " + + "order is different. Expected column order: %s.", + quotedNames(tableFields))); + } + } + return tableFields; + } + if (matchedByName == 0 && isSparkValuesColumns(inputFields)) { + return tableFields; + } + + List unknownFields = new ArrayList<>(); + for (int i = 0; i < inputFields.length; i++) { + if (resolved[i] == null) { + unknownFields.add(String.format("'%s'", inputFields[i].name())); + } + } + throw new IllegalArgumentException( + String.format( + "Cannot write to Lance table because input columns %s are not in the table. Table" + + " columns: %s", + String.join(", ", unknownFields), quotedNames(tableFields))); + } + + private static boolean isSparkValuesColumns(StructField[] inputFields) { + for (int i = 0; i < inputFields.length; i++) { + if (!String.format("col%d", i + 1).equals(inputFields[i].name())) { + return false; + } + } + return true; + } + + private static String quotedNames(StructField[] fields) { + List names = new ArrayList<>(fields.length); + for (StructField field : fields) { + names.add(String.format("\"%s\"", field.name())); + } + return String.join(", ", names); + } +} diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala index 3a07a3106..17c6abf3e 100644 --- a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala @@ -44,6 +44,8 @@ object LanceArrowUtils { val ARROW_FIXED_SIZE_LIST_SIZE_KEY = VectorUtils.ARROW_FIXED_SIZE_LIST_SIZE_KEY val ARROW_FLOAT16_KEY = Float16Utils.ARROW_FLOAT16_KEY val ENCODING_BLOB = BlobUtils.LANCE_ENCODING_BLOB_KEY + val ARROW_EXT_NAME_KEY = BlobUtils.ARROW_EXTENSION_NAME_KEY + val BLOB_V2_EXT_NAME = BlobUtils.ARROW_EXTENSION_BLOB_V2 val ARROW_LARGE_VAR_CHAR_KEY = LargeVarCharUtils.ARROW_LARGE_VAR_CHAR_KEY val ARROW_DATE_MILLISECOND_KEY = DateMilliUtils.ARROW_DATE_MILLISECOND_KEY @@ -82,6 +84,8 @@ object LanceArrowUtils { val elementType = fromArrowField(elementField) val containsNull = elementField.isNullable ArrayType(elementType, containsNull) + case _: ArrowType.Struct if isBlobField(field) => + BinaryType case _: ArrowType.Struct => // Always recurse through LanceArrowUtils for struct children so special cases // like Date(MILLISECOND), FixedSizeBinary, etc. are applied in nested schemas too. @@ -406,6 +410,19 @@ object LanceArrowUtils { null, meta.asJava) new Field(name, fieldType, Seq.empty[Field].asJava) + case _: BinaryType + if BLOB_V2_EXT_NAME.equals(meta.getOrElse(ARROW_EXT_NAME_KEY, "")) => + // Blob v2 writes the struct lance-core expects: data, uri, position, size. + val structFieldType = + new FieldType(nullable, ArrowType.Struct.INSTANCE, null, meta.asJava) + new Field( + name, + structFieldType, + Seq( + toArrowField("data", BinaryType, nullable = true, timeZoneId, largeVarTypes = true), + toArrowField("uri", StringType, nullable = true, timeZoneId), + arrowUInt64Field("position"), + arrowUInt64Field("size")).asJava) case dataType => val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, large, name), null, meta.asJava) @@ -529,9 +546,17 @@ object LanceArrowUtils { private def isBlobField(field: Field): Boolean = { val metadata = field.getMetadata - metadata != null && metadata.containsKey(ENCODING_BLOB) && - "true".equalsIgnoreCase(metadata.get(ENCODING_BLOB)) + if (metadata == null) return false + (metadata.containsKey(ENCODING_BLOB) && + "true".equalsIgnoreCase(metadata.get(ENCODING_BLOB))) || + BLOB_V2_EXT_NAME.equals(metadata.get(ARROW_EXT_NAME_KEY)) } + + private def arrowUInt64Field(name: String): Field = + new Field( + name, + new FieldType(true, new ArrowType.Int(64, false), null, Map.empty[String, String].asJava), + Seq.empty[Field].asJava) } /** diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala index 637261bba..74b3e663b 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.types._ import org.apache.spark.sql.util.LanceArrowUtils -import org.lance.spark.utils.{BlobReferenceResolver, Float16Utils} +import org.lance.spark.utils.{BlobReferenceResolver, BlobUtils, Float16Utils} import scala.collection.JavaConverters._ @@ -126,6 +126,8 @@ object LanceArrowWriter { null, resolver) new MapWriter(vector, structVector, keyWriter, valueWriter) + case (BinaryType, vector: StructVector) if BlobUtils.isBlobV2SparkMetadata(metadata) => + new BlobV2StructWriter(vector, createFieldWriter(vector.getChild("data"), BinaryType)) case (StructType(fields), vector: StructVector) => val children = fields.zipWithIndex.map { case (field, ordinal) => createFieldWriter( @@ -460,6 +462,55 @@ private[arrow] class MapWriter( } } +private[arrow] class BlobV2StructWriter( + val valueVector: StructVector, + val dataWriter: LanceArrowFieldWriter) extends LanceArrowFieldWriter { + + override def write(input: SpecializedGetters, ordinal: Int): Unit = { + if (input.isNullAt(ordinal)) { + dataWriter.write(input, ordinal) + valueVector.setNull(count) + } else { + valueVector.setIndexDefined(count) + dataWriter.write(input, ordinal) + } + count += 1 + } + + // Rows are written through write(). setNull and setValue are no-ops so child counts stay + // owned by dataWriter.write(). + override def setNull(): Unit = () + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = () + + override def finish(): Unit = { + super.finish() + dataWriter.finish() + // The sibling descriptor fields (uri, position, size) are never populated on write — + // Lance synthesises them on read. Skip per-row null setting: setValueCount allocates the + // validity buffer and a fresh buffer's bits are all zero, which is "null". Saves + // O(rows * 3) Arrow validity writes per batch. + var i = 0 + while (i < nullChildren.length) { + nullChildren(i).setValueCount(count) + i += 1 + } + } + + override def reset(): Unit = { + super.reset() + dataWriter.reset() + var i = 0 + while (i < nullChildren.length) { + nullChildren(i).reset() + i += 1 + } + } + + private val nullChildren: Array[FieldVector] = + Array("uri", "position", "size").map(valueVector.getChild).toArray +} + private[arrow] class StructWriter( val valueVector: StructVector, val children: Array[LanceArrowFieldWriter]) extends LanceArrowFieldWriter { diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java index c13317822..da8f91047 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java @@ -71,7 +71,7 @@ void tearDown() { // These tests verify that when a table is created with blob encoding property, // subsequent DataFrame writes don't need to set the blob encoding metadata. - /** Helper method to verify a field has blob encoding metadata set. */ + /** Helper method to verify a field has blob v1 encoding metadata set. */ private void assertBlobMetadata(StructType schema, String fieldName) { StructField field = schema.apply(fieldName); assertNotNull(field, fieldName + " field should exist in schema"); @@ -82,6 +82,9 @@ private void assertBlobMetadata(StructType schema, String fieldName) { BlobUtils.LANCE_ENCODING_BLOB_VALUE, field.metadata().getString(BlobUtils.LANCE_ENCODING_BLOB_KEY), BlobUtils.LANCE_ENCODING_BLOB_KEY + " metadata should be 'true'"); + assertFalse( + BlobUtils.isBlobV2SparkField(field), + fieldName + " should not be tagged as blob v2 without file_format_version >= 2.2"); } @Test @@ -489,12 +492,198 @@ public void testBlobVirtualColumns() { spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); } - private String bytesToHex(byte[] bytes) { - StringBuilder hexString = new StringBuilder(); - for (byte b : bytes) { - hexString.append(String.format("%02X", b)); + @Test + public void testBlobV2SupportsSqlInsertAndDataFrameAppend() { + String tableName = "blob_v2_sql_" + System.currentTimeMillis(); + + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + catalogName + + ".default." + + tableName + + " (" + + "id INT NOT NULL, " + + "data BINARY" + + ") USING lance " + + "TBLPROPERTIES (" + + "'data.lance.encoding' = 'blob', " + + "'file_format_version' = '2.2'" + + ")"); + + String testData1 = "SQL insert content 1"; + String testData2 = "SQL insert content 2"; + spark.sql( + "INSERT INTO " + + catalogName + + ".default." + + tableName + + " VALUES " + + "(1, X'" + + bytesToHex(testData1.getBytes(StandardCharsets.UTF_8)) + + "'), " + + "(2, X'" + + bytesToHex(testData2.getBytes(StandardCharsets.UTF_8)) + + "')"); + + List rows = new ArrayList<>(); + Random random = new Random(42); + for (int i = 10; i < 13; i++) { + byte[] largeData = new byte[100000]; // 100KB + random.nextBytes(largeData); + rows.add(RowFactory.create(i, largeData)); } - return hexString.toString(); + StructType plainSchema = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("data", DataTypes.BinaryType, true) + }); + Dataset df = spark.createDataFrame(rows, plainSchema); + assertDoesNotThrow(() -> df.writeTo(catalogName + ".default." + tableName).append()); + + Dataset count = spark.sql("SELECT COUNT(*) FROM " + catalogName + ".default." + tableName); + assertEquals(5L, count.collectAsList().get(0).getLong(0)); + + List descriptors = + spark + .sql( + "SELECT id, data.size AS sz FROM " + + catalogName + + ".default." + + tableName + + " ORDER BY id") + .collectAsList(); + + assertEquals(5, descriptors.size()); + assertEquals(testData1.getBytes(StandardCharsets.UTF_8).length, descriptors.get(0).getLong(1)); + assertEquals(testData2.getBytes(StandardCharsets.UTF_8).length, descriptors.get(1).getLong(1)); + assertEquals(100000L, descriptors.get(2).getLong(1)); + assertBlobV2Metadata(spark.table(catalogName + ".default." + tableName).schema(), "data"); + spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); + } + + @Test + public void testBlobV2SupportsSqlValuesInsert() { + String tableName = "blob_v2_empty_table_" + System.currentTimeMillis(); + + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + catalogName + + ".default." + + tableName + + " (" + + "id INT NOT NULL, " + + "text STRING, " + + "blob_data BINARY" + + ") USING lance " + + "TBLPROPERTIES (" + + "'blob_data.lance.encoding' = 'blob', " + + "'file_format_version' = '2.2'" + + ")"); + + assertBlobV2Metadata(spark.table(catalogName + ".default." + tableName).schema(), "blob_data"); + + String testData1 = "This is test blob data 1"; + String testData2 = "This is test blob data 2"; + spark.sql( + "INSERT INTO " + + catalogName + + ".default." + + tableName + + " VALUES " + + "(1, 'first text', X'" + + bytesToHex(testData1.getBytes(StandardCharsets.UTF_8)) + + "'), " + + "(2, 'second text', X'" + + bytesToHex(testData2.getBytes(StandardCharsets.UTF_8)) + + "')"); + + Dataset count = spark.sql("SELECT COUNT(*) FROM " + catalogName + ".default." + tableName); + + assertEquals(2L, count.collectAsList().get(0).getLong(0)); + + List projection = + spark + .sql("SELECT id, text FROM " + catalogName + ".default." + tableName + " ORDER BY id") + .collectAsList(); + + assertEquals(2, projection.size()); + assertEquals(1, projection.get(0).getInt(0)); + assertEquals("first text", projection.get(0).getString(1)); + assertEquals(2, projection.get(1).getInt(0)); + assertEquals("second text", projection.get(1).getString(1)); + + List descriptors = + spark + .sql( + "SELECT id, blob_data.kind AS kind, blob_data.size AS sz FROM " + + catalogName + + ".default." + + tableName + + " ORDER BY id") + .collectAsList(); + + assertEquals(2, descriptors.size()); + assertEquals(0, descriptors.get(0).getShort(1)); + assertEquals(testData1.getBytes(StandardCharsets.UTF_8).length, descriptors.get(0).getLong(2)); + assertEquals(testData2.getBytes(StandardCharsets.UTF_8).length, descriptors.get(1).getLong(2)); + + spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); + } + + @Test + public void testBlobColumnsStayBlobV1WithFileFormatVersion21() { + String tableName = "blob_v1_ffv_21_" + System.currentTimeMillis(); + + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + catalogName + + ".default." + + tableName + + " (" + + "id INT NOT NULL, " + + "data BINARY" + + ") USING lance " + + "TBLPROPERTIES (" + + "'data.lance.encoding' = 'blob', " + + "'file_format_version' = '2.1'" + + ")"); + + StructType schema = spark.table(catalogName + ".default." + tableName).schema(); + + assertBlobMetadata(schema, "data"); + assertEquals(DataTypes.BinaryType, schema.apply("data").dataType()); + + spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); + } + + @Test + public void testCreateTableSupportsMultipleBlobV2Columns() { + String tableName = "blob_v2_table_multi_" + System.currentTimeMillis(); + + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + catalogName + + ".default." + + tableName + + " (" + + "id INT NOT NULL, " + + "blob1 BINARY, " + + "regular_binary BINARY, " + + "blob2 BINARY" + + ") USING lance " + + "TBLPROPERTIES (" + + "'blob1.lance.encoding' = 'blob', " + + "'blob2.lance.encoding' = 'blob', " + + "'file_format_version' = '2.2'" + + ")"); + + StructType schema = spark.table(catalogName + ".default." + tableName).schema(); + assertBlobV2Metadata(schema, "blob1"); + assertBlobV2Metadata(schema, "blob2"); + assertFalse(BlobUtils.isBlobV2SparkField(schema.apply("regular_binary"))); + + spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); } // ==================== Large VarChar Tests ==================== @@ -708,4 +897,19 @@ public void testLargeVarCharWithTablePropertyAPI() { // Clean up spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); } + + private static String bytesToHex(byte[] bytes) { + StringBuilder hexString = new StringBuilder(); + for (byte b : bytes) { + hexString.append(String.format("%02X", b)); + } + return hexString.toString(); + } + + private void assertBlobV2Metadata(StructType schema, String fieldName) { + StructField field = schema.apply(fieldName); + assertNotNull(field); + assertEquals(BlobUtils.BLOB_DESCRIPTOR_STRUCT, field.dataType()); + assertTrue(BlobUtils.isBlobV2SparkField(field)); + } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/arrow/BlobV2StructWriterTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/arrow/BlobV2StructWriterTest.java new file mode 100644 index 000000000..031f704e6 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/arrow/BlobV2StructWriterTest.java @@ -0,0 +1,144 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.arrow; + +import org.lance.spark.utils.SchemaConverter; + +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.LanceArrowUtils; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class BlobV2StructWriterTest { + + @Test + public void testWritesBinaryToDataChild() { + StructType sparkSchema = blobV2Schema(); + Schema arrowSchema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", true); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) { + LanceArrowWriter writer = LanceArrowWriter.create(root, sparkSchema); + + byte[][] payloads = { + "row0".getBytes(StandardCharsets.UTF_8), + "row1".getBytes(StandardCharsets.UTF_8), + "row2".getBytes(StandardCharsets.UTF_8), + }; + + for (int i = 0; i < payloads.length; i++) { + writer.write(new GenericInternalRow(new Object[] {i, payloads[i]})); + } + + writer.finish(); + StructVector content = (StructVector) root.getVector("content"); + assertEquals(payloads.length, content.getValueCount()); + LargeVarBinaryVector data = (LargeVarBinaryVector) content.getChild("data"); + assertEquals(payloads.length, data.getValueCount()); + for (int i = 0; i < payloads.length; i++) { + assertFalse(content.isNull(i)); + assertArrayEquals(payloads[i], data.getObject(i)); + } + + for (String sibling : new String[] {"uri", "position", "size"}) { + FieldVector child = content.getChild(sibling); + for (int i = 0; i < payloads.length; i++) { + assertTrue(child.isNull(i)); + } + } + } + } + + @Test + public void testNullRowNullsStruct() { + StructType sparkSchema = blobV2Schema(); + Schema arrowSchema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", true); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) { + LanceArrowWriter writer = LanceArrowWriter.create(root, sparkSchema); + + writer.write( + new GenericInternalRow(new Object[] {0, "first".getBytes(StandardCharsets.UTF_8)})); + writer.write(new GenericInternalRow(new Object[] {1, null})); + writer.write( + new GenericInternalRow(new Object[] {2, "third".getBytes(StandardCharsets.UTF_8)})); + writer.finish(); + + StructVector content = (StructVector) root.getVector("content"); + assertEquals(3, content.getValueCount()); + assertFalse(content.isNull(0)); + assertTrue(content.isNull(1)); + assertFalse(content.isNull(2)); + + LargeVarBinaryVector data = (LargeVarBinaryVector) content.getChild("data"); + assertEquals(3, data.getValueCount()); + assertArrayEquals("first".getBytes(StandardCharsets.UTF_8), data.getObject(0)); + assertTrue(data.isNull(1)); + assertArrayEquals("third".getBytes(StandardCharsets.UTF_8), data.getObject(2)); + } + } + + @Test + public void testResetClearsBatch() { + StructType sparkSchema = blobV2Schema(); + Schema arrowSchema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", true); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) { + LanceArrowWriter writer = LanceArrowWriter.create(root, sparkSchema); + writer.write( + new GenericInternalRow(new Object[] {0, "discard".getBytes(StandardCharsets.UTF_8)})); + writer.write( + new GenericInternalRow( + new Object[] {1, "also-discard".getBytes(StandardCharsets.UTF_8)})); + writer.reset(); + + writer.write( + new GenericInternalRow(new Object[] {7, "keep".getBytes(StandardCharsets.UTF_8)})); + writer.finish(); + + StructVector content = (StructVector) root.getVector("content"); + assertEquals(1, content.getValueCount()); + LargeVarBinaryVector data = (LargeVarBinaryVector) content.getChild("data"); + assertEquals(1, data.getValueCount()); + assertArrayEquals("keep".getBytes(StandardCharsets.UTF_8), data.getObject(0)); + } + } + + private static StructType blobV2Schema() { + StructType raw = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("content", DataTypes.BinaryType, true), + }); + Map properties = ImmutableMap.of("content.lance.encoding", "blob"); + return SchemaConverter.processSchemaWithProperties(raw, properties, "2.2"); + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobUtilsTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobUtilsTest.java new file mode 100644 index 000000000..24edf3735 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobUtilsTest.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.utils; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class BlobUtilsTest { + + @Test + public void testBlobV2FieldWithArrowExtensionName() { + assertTrue(BlobUtils.isBlobV2SparkField(blobV2Field())); + } + + @Test + public void testBlobV2FieldNullSafety() { + assertFalse(BlobUtils.isBlobV2SparkField(null)); + } + + @Test + public void testV1Field() { + assertTrue(BlobUtils.isBlobSparkField(blobV1Field())); + } + + @Test + public void testBlobV2ArrowFieldRejectsUnrelated() { + Field f = + new Field( + "payload", + new FieldType(true, ArrowType.Binary.INSTANCE, null, Collections.emptyMap()), + null); + assertFalse(BlobUtils.isBlobV2ArrowField(f)); + assertFalse(BlobUtils.isBlobV2ArrowField(null)); + } + + @Test + public void testHasBlobV2FieldsInSchema() { + StructType schema = + new StructType( + new StructField[] { + field("id", DataTypes.IntegerType), blobV2Field(), + }); + assertTrue(BlobUtils.hasBlobV2Fields(schema)); + } + + @Test + public void testDescriptorStructShape() { + StructType s = BlobUtils.BLOB_DESCRIPTOR_STRUCT; + assertEquals(5, s.fields().length); + assertEquals(DataTypes.ShortType, s.apply("kind").dataType()); + assertEquals(DataTypes.LongType, s.apply("position").dataType()); + assertEquals(DataTypes.LongType, s.apply("size").dataType()); + assertEquals(DataTypes.LongType, s.apply("blob_id").dataType()); + assertEquals(DataTypes.StringType, s.apply("blob_uri").dataType()); + } + + @Test + public void testBlobV2DescriptorSchemaRewrite() { + StructType schema = + new StructType( + new StructField[] { + field("id", DataTypes.IntegerType), blobV2Field(), + }); + StructType rewritten = BlobUtils.applyBlobV2DescriptorSchema(schema); + assertEquals(DataTypes.IntegerType, rewritten.apply("id").dataType()); + assertEquals(BlobUtils.BLOB_DESCRIPTOR_STRUCT, rewritten.apply("payload").dataType()); + } + + @Test + public void testV1FieldsPreservedInRewrite() { + StructType schema = + new StructType( + new StructField[] { + field("id", DataTypes.IntegerType), blobV1Field(), + }); + StructType rewritten = BlobUtils.applyBlobV2DescriptorSchema(schema); + assertEquals(DataTypes.BinaryType, rewritten.apply("payload").dataType()); + } + + @Test + public void testUnloadedDescriptorStructRecognizedAsBlobV2() { + Field f = + new Field( + "payload", + new FieldType(true, ArrowType.Struct.INSTANCE, null, Collections.emptyMap()), + Arrays.asList( + intChild("kind"), + intChild("position"), + intChild("size"), + intChild("blob_id"), + utf8Child("blob_uri"))); + assertTrue(BlobUtils.isBlobV2ArrowField(f)); + assertFalse(BlobUtils.isBlobArrowField(f)); + } + + private static Field intChild(String name) { + return new Field( + name, + new FieldType(true, new ArrowType.Int(64, false), null, Collections.emptyMap()), + null); + } + + private static Field utf8Child(String name) { + return new Field( + name, new FieldType(true, ArrowType.Utf8.INSTANCE, null, Collections.emptyMap()), null); + } + + private static StructField field(String name, org.apache.spark.sql.types.DataType dt) { + return new StructField(name, dt, true, Metadata.empty()); + } + + private static StructField blobV2Field() { + Metadata md = + new MetadataBuilder() + .putString(BlobUtils.ARROW_EXTENSION_NAME_KEY, BlobUtils.ARROW_EXTENSION_BLOB_V2) + .build(); + return new StructField("payload", DataTypes.BinaryType, true, md); + } + + private static StructField blobV1Field() { + Metadata md = + new MetadataBuilder() + .putString(BlobUtils.LANCE_ENCODING_BLOB_KEY, BlobUtils.LANCE_ENCODING_BLOB_VALUE) + .build(); + return new StructField("payload", DataTypes.BinaryType, true, md); + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/SchemaConverterTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/SchemaConverterTest.java index c6b58a08f..441ea6bed 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/SchemaConverterTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/SchemaConverterTest.java @@ -13,9 +13,14 @@ */ package org.lance.spark.utils; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.LanceArrowUtils; import org.junit.jupiter.api.Test; import java.util.Collections; @@ -138,8 +143,8 @@ public void testBlobMetadataRejectsNonBinaryType() { }); Map properties = new HashMap<>(); properties.put("text.lance.encoding", "blob"); - assertThrows( - IllegalArgumentException.class, + assertValidationFailure( + "must have BINARY type", () -> SchemaConverter.processSchemaWithProperties(schema, properties)); } @@ -501,4 +506,110 @@ public void testNegativeCompressionLevelThrows() { IllegalArgumentException.class, () -> SchemaConverter.processSchemaWithProperties(schema, properties)); } + + @Test + public void testBlobV2ArrowSchemaUsesWriteStruct() { + StructType sparkSchema = blobSchemaWithVersion("2.2"); + + assertBlobV2Field(sparkSchema.apply("data")); + + Schema arrowSchema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", true); + Field contentField = arrowSchema.findField("data"); + + assertNotNull(contentField); + assertInstanceOf(ArrowType.Struct.class, contentField.getType()); + assertEquals( + BlobUtils.ARROW_EXTENSION_BLOB_V2, + contentField.getMetadata().get(BlobUtils.ARROW_EXTENSION_NAME_KEY)); + } + + @Test + public void testBlobV1WithUnsupportedVersion() { + // v2 only supported on 2.2 or higher + assertBlobV1Field(blobSchemaWithVersion("2.1").apply("data")); + } + + @Test + public void testBlobV1WhenVersionNull() { + assertBlobV1Field(blobSchemaWithVersion(null).apply("data")); + } + + @Test + public void testBlobEncodingRequiresColumnProperty() { + StructType schema = + new StructType( + new StructField[] { + DataTypes.createStructField("data", DataTypes.BinaryType, true), + }); + Map properties = ImmutableMap.of("file_format_version", "2.2"); + StructField field = + SchemaConverter.processSchemaWithProperties(schema, properties, "2.2").apply("data"); + + assertFalse(field.metadata().contains(BlobUtils.LANCE_ENCODING_BLOB_KEY)); + assertFalse(field.metadata().contains(BlobUtils.ARROW_EXTENSION_NAME_KEY)); + } + + @Test + public void testBlobEncodingRejectsUnknownFileFormatVersion() { + StructType schema = + new StructType( + new StructField[] { + DataTypes.createStructField("data", DataTypes.BinaryType, true), + }); + Map properties = ImmutableMap.of("data.lance.encoding", "blob"); + + assertValidationFailure( + "Blob columns require a numeric file_format_version like '2.2'. Got: 'stable'.", + () -> SchemaConverter.processSchemaWithProperties(schema, properties, "stable")); + } + + @Test + public void testBlobEncodingRejectsMalformedFileFormatVersion() { + StructType schema = + new StructType( + new StructField[] { + DataTypes.createStructField("data", DataTypes.BinaryType, true), + }); + Map properties = ImmutableMap.of("data.lance.encoding", "blob"); + + assertValidationFailure( + "numeric file_format_version", + () -> SchemaConverter.processSchemaWithProperties(schema, properties, "2.x")); + } + + private static StructType blobSchemaWithVersion(String fileFormatVersion) { + StructType schema = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("data", DataTypes.BinaryType, true), + }); + Map properties = ImmutableMap.of("data.lance.encoding", "blob"); + + return SchemaConverter.processSchemaWithProperties(schema, properties, fileFormatVersion); + } + + private static void assertBlobV1Field(StructField field) { + assertEquals(DataTypes.BinaryType, field.dataType()); + assertTrue(field.metadata().contains(BlobUtils.LANCE_ENCODING_BLOB_KEY)); + assertEquals( + BlobUtils.LANCE_ENCODING_BLOB_VALUE, + field.metadata().getString(BlobUtils.LANCE_ENCODING_BLOB_KEY)); + assertFalse(field.metadata().contains(BlobUtils.ARROW_EXTENSION_NAME_KEY)); + } + + private static void assertBlobV2Field(StructField field) { + assertEquals(DataTypes.BinaryType, field.dataType()); + assertEquals( + BlobUtils.ARROW_EXTENSION_BLOB_V2, + field.metadata().getString(BlobUtils.ARROW_EXTENSION_NAME_KEY)); + assertFalse(field.metadata().contains(BlobUtils.LANCE_ENCODING_BLOB_KEY)); + } + + private static void assertValidationFailure(String expectedFragment, Runnable action) { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, action::run); + assertTrue( + e.getMessage().contains(expectedFragment), + () -> "expected message to contain '" + expectedFragment + "': " + e.getMessage()); + } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/vectorized/BlobV2DescriptorColumnVectorTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/vectorized/BlobV2DescriptorColumnVectorTest.java new file mode 100644 index 000000000..17451ec66 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/vectorized/BlobV2DescriptorColumnVectorTest.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.vectorized; + +import org.lance.spark.utils.BlobUtils; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class BlobV2DescriptorColumnVectorTest { + + @Test + public void testDescriptorChildrenRead() { + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + StructVector struct = buildDescriptor(allocator, (short) 3, 12345L, 678L, 99L, "s3://b/o"); + LanceArrowColumnVector wrapper = new LanceArrowColumnVector(struct)) { + + assertEquals(1, struct.getValueCount()); + assertNotNull(wrapper.getChild(0)); + assertEquals((short) 3, wrapper.getChild(0).getShort(0)); + assertEquals(12345L, wrapper.getChild(1).getLong(0)); + assertEquals(678L, wrapper.getChild(2).getLong(0)); + assertEquals(99L, wrapper.getChild(3).getLong(0)); + assertEquals("s3://b/o", wrapper.getChild(4).getUTF8String(0).toString()); + } + } + + @Test + public void testDescriptorMaxUnsignedValues() { + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + StructVector struct = + buildDescriptor(allocator, (short) 255, 1L << 40, 1L << 30, 4294967295L, "x"); + LanceArrowColumnVector wrapper = new LanceArrowColumnVector(struct)) { + + assertEquals((short) 255, wrapper.getChild(0).getShort(0)); + assertEquals(1L << 40, wrapper.getChild(1).getLong(0)); + assertEquals(1L << 30, wrapper.getChild(2).getLong(0)); + assertEquals(4294967295L, wrapper.getChild(3).getLong(0)); + } + } + + private static StructVector buildDescriptor( + BufferAllocator allocator, + short kind, + long position, + long size, + long blobId, + String blobUri) { + Map structMd = new HashMap<>(); + structMd.put(BlobUtils.ARROW_EXTENSION_NAME_KEY, BlobUtils.ARROW_EXTENSION_BLOB_V2); + Field structField = + new Field( + "payload", + new FieldType(true, ArrowType.Struct.INSTANCE, null, structMd), + Arrays.asList( + intChild("kind", 8), + intChild("position", 64), + intChild("size", 64), + intChild("blob_id", 32), + utf8Child("blob_uri"))); + StructVector struct = (StructVector) structField.createVector(allocator); + struct.allocateNew(); + + UInt1Vector kindV = (UInt1Vector) struct.getChild("kind"); + UInt8Vector posV = (UInt8Vector) struct.getChild("position"); + UInt8Vector sizeV = (UInt8Vector) struct.getChild("size"); + UInt4Vector idV = (UInt4Vector) struct.getChild("blob_id"); + VarCharVector uriV = (VarCharVector) struct.getChild("blob_uri"); + + struct.setIndexDefined(0); + kindV.setSafe(0, kind); + posV.setSafe(0, position); + sizeV.setSafe(0, size); + idV.setSafe(0, (int) blobId); + uriV.setSafe(0, blobUri.getBytes(StandardCharsets.UTF_8)); + + struct.setValueCount(1); + return struct; + } + + private static Field intChild(String name, int bitWidth) { + return new Field( + name, + new FieldType(true, new ArrowType.Int(bitWidth, false), null, Collections.emptyMap()), + null); + } + + private static Field utf8Child(String name) { + return new Field( + name, new FieldType(true, ArrowType.Utf8.INSTANCE, null, Collections.emptyMap()), null); + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/LanceWriteSchemaValidatorTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/LanceWriteSchemaValidatorTest.java new file mode 100644 index 000000000..e96c7cf68 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/LanceWriteSchemaValidatorTest.java @@ -0,0 +1,275 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.write; + +import org.lance.spark.utils.BlobUtils; + +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LanceWriteSchemaValidatorTest { + + @Test + public void testValidateAcceptsBinaryInputForBlobV2Column() { + assertDoesNotThrow( + () -> + LanceWriteSchemaValidator.validate( + writeSchemaWithIdAndBlob(), inputSchemaWithIdAndBinaryContent())); + } + + @Test + public void testValidateRejectsColumnCountMismatch() { + StructType input = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + }); + + assertValidationFailure( + "column count is different", + () -> LanceWriteSchemaValidator.validate(writeSchemaWithIdAndBlob(), input)); + } + + @Test + public void testValidateRejectsNonBinaryInputForBlobV2Column() { + StructType writeSchema = new StructType(new StructField[] {blobV2Field("content")}); + StructType input = + new StructType( + new StructField[] { + DataTypes.createStructField("content", DataTypes.StringType, true), + }); + + assertValidationFailure( + "accept binary", () -> LanceWriteSchemaValidator.validate(writeSchema, input)); + } + + @Test + public void testValidateRejectsOutOfOrderColumns() { + StructType reordered = + new StructType( + new StructField[] { + DataTypes.createStructField("content", DataTypes.BinaryType, true), + DataTypes.createStructField("id", DataTypes.IntegerType, false), + }); + + assertValidationFailure( + "order is different", + () -> LanceWriteSchemaValidator.validate(writeSchemaWithIdAndBlob(), reordered)); + } + + @Test + public void testValidateRejectsUnknownColumnNames() { + StructType foreignNames = + new StructType( + new StructField[] { + DataTypes.createStructField("identifier", DataTypes.IntegerType, false), + DataTypes.createStructField("payload", DataTypes.BinaryType, true), + }); + + assertValidationFailure( + "'identifier'", + () -> LanceWriteSchemaValidator.validate(writeSchemaWithIdAndBlob(), foreignNames)); + } + + @Test + public void testValidateAcceptsSqlValuesColumnNames() { + StructType sqlValuesNames = + new StructType( + new StructField[] { + DataTypes.createStructField("col1", DataTypes.IntegerType, false), + DataTypes.createStructField("col2", DataTypes.BinaryType, true), + }); + + assertDoesNotThrow( + () -> LanceWriteSchemaValidator.validate(writeSchemaWithIdAndBlob(), sqlValuesNames)); + } + + @Test + public void testValidateRejectsPartiallyMatchingColumnNames() { + StructType partial = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("payload", DataTypes.BinaryType, true), + }); + + assertValidationFailure( + "'payload'", () -> LanceWriteSchemaValidator.validate(writeSchemaWithIdAndBlob(), partial)); + } + + @Test + public void testValidateRejectsNullableInputForNonNullableColumn() { + StructType writeSchema = + new StructType( + new StructField[] {DataTypes.createStructField("id", DataTypes.IntegerType, false)}); + StructType input = + new StructType( + new StructField[] {DataTypes.createStructField("id", DataTypes.IntegerType, true)}); + + assertValidationFailure( + "does not allow nulls", () -> LanceWriteSchemaValidator.validate(writeSchema, input)); + } + + @Test + public void testValidateAcceptsNonNullableInputForNullableColumn() { + StructType writeSchema = + new StructType( + new StructField[] {DataTypes.createStructField("id", DataTypes.IntegerType, true)}); + StructType input = + new StructType( + new StructField[] {DataTypes.createStructField("id", DataTypes.IntegerType, false)}); + + assertDoesNotThrow(() -> LanceWriteSchemaValidator.validate(writeSchema, input)); + } + + @Test + public void testValidateAcceptsMatchingNestedStruct() { + StructType nested = + new StructType( + new StructField[] { + DataTypes.createStructField("city", DataTypes.StringType, true), + DataTypes.createStructField("zip", DataTypes.IntegerType, false), + }); + StructType writeSchema = + new StructType(new StructField[] {DataTypes.createStructField("address", nested, false)}); + StructType input = + new StructType( + new StructField[] { + DataTypes.createStructField( + "address", + new StructType( + new StructField[] { + DataTypes.createStructField("city", DataTypes.StringType, true), + DataTypes.createStructField("zip", DataTypes.IntegerType, false), + }), + false) + }); + + assertDoesNotThrow(() -> LanceWriteSchemaValidator.validate(writeSchema, input)); + } + + @Test + public void testValidateRejectsNestedStructFieldTypeMismatch() { + StructType nested = + new StructType( + new StructField[] { + DataTypes.createStructField("city", DataTypes.StringType, true), + DataTypes.createStructField("zip", DataTypes.IntegerType, false), + }); + StructType writeSchema = + new StructType(new StructField[] {DataTypes.createStructField("address", nested, true)}); + StructType input = + new StructType( + new StructField[] { + DataTypes.createStructField( + "address", + new StructType( + new StructField[] { + DataTypes.createStructField("city", DataTypes.StringType, true), + DataTypes.createStructField("zip", DataTypes.StringType, true), + }), + true) + }); + + assertValidationFailure( + "address.zip", () -> LanceWriteSchemaValidator.validate(writeSchema, input)); + } + + @Test + public void testValidateRejectsNestedStructNullabilityMismatch() { + StructType nested = + new StructType( + new StructField[] {DataTypes.createStructField("zip", DataTypes.IntegerType, false)}); + StructType writeSchema = + new StructType(new StructField[] {DataTypes.createStructField("address", nested, true)}); + StructType input = + new StructType( + new StructField[] { + DataTypes.createStructField( + "address", + new StructType( + new StructField[] { + DataTypes.createStructField("zip", DataTypes.IntegerType, true) + }), + true) + }); + + assertValidationFailure( + "address.zip", () -> LanceWriteSchemaValidator.validate(writeSchema, input)); + } + + @Test + public void testValidateRejectsNestedStructFieldCountMismatch() { + StructType nested = + new StructType( + new StructField[] { + DataTypes.createStructField("city", DataTypes.StringType, true), + DataTypes.createStructField("zip", DataTypes.IntegerType, false), + }); + StructType writeSchema = + new StructType(new StructField[] {DataTypes.createStructField("address", nested, true)}); + StructType input = + new StructType( + new StructField[] { + DataTypes.createStructField( + "address", + new StructType( + new StructField[] { + DataTypes.createStructField("city", DataTypes.StringType, true), + }), + true) + }); + + assertValidationFailure( + "fields in the table", () -> LanceWriteSchemaValidator.validate(writeSchema, input)); + } + + private static StructField blobV2Field(String name) { + Metadata metadata = + new MetadataBuilder() + .putString(BlobUtils.ARROW_EXTENSION_NAME_KEY, BlobUtils.ARROW_EXTENSION_BLOB_V2) + .build(); + return DataTypes.createStructField(name, DataTypes.BinaryType, true, metadata); + } + + private static StructType writeSchemaWithIdAndBlob() { + return new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), blobV2Field("content"), + }); + } + + private static StructType inputSchemaWithIdAndBinaryContent() { + return new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("content", DataTypes.BinaryType, true), + }); + } + + private static void assertValidationFailure(String expectedFragment, Runnable action) { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, action::run); + assertTrue( + e.getMessage().contains(expectedFragment), + () -> "expected message to contain '" + expectedFragment + "': " + e.getMessage()); + } +}