From 7be95751ef657ae93fbf7b0a7ff7ca5e28da2866 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 7 May 2026 11:14:50 -0700 Subject: [PATCH] feat(api): Add pre-compilation adapter and aggregate fix for datetime UDF Add preCompilationRules to bridge the type mismatch between normalized standard types (DATE/TIME/TIMESTAMP as int/long) and PPL UDF implementors (which expect and produce String values): 1. DatetimeUdfCompilationAdapterRule inserts CAST nodes around datetime UDFs so implementors receive String input and produce String output, with CASTs bridging int/long <-> String conversion. 2. DatetimeUdtNormalizeRule enhanced to handle LogicalAggregate (rebuild AggregateCall with re-inferred types) and LogicalProject (refresh RexInputRef types from new child row type) to prevent type mismatch assertions when datetime UDF results feed into aggregates. Both fixes are only needed for the UnifiedQueryCompiler (Enumerable) path. The Analytics Engine (Substrait/DataFusion) path is unaffected. Signed-off-by: Chen Dai --- .../api/spec/datetime/DatetimeExtension.java | 5 + .../DatetimeUdfCompilationAdapterRule.java | 85 +++++++++ .../datetime/DatetimeUdtNormalizeRule.java | 85 ++++++++- .../spec/datetime/DatetimeExtensionTest.java | 168 ++++++++++++------ 4 files changed, 280 insertions(+), 63 deletions(-) create mode 100644 api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdfCompilationAdapterRule.java diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java index 944ac4a4bf..f317e9960c 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java @@ -25,6 +25,11 @@ public List postAnalysisRules() { return List.of(DatetimeUdtNormalizeRule.INSTANCE, DatetimeOutputCastRule.INSTANCE); } + @Override + public List preCompilationRules() { + return List.of(DatetimeUdfCompilationAdapterRule.INSTANCE); + } + /** Maps datetime UDT types to their standard Calcite equivalents. */ @Getter @RequiredArgsConstructor diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdfCompilationAdapterRule.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdfCompilationAdapterRule.java new file mode 100644 index 0000000000..daf9ef4fdf --- /dev/null +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdfCompilationAdapterRule.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import static org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping.isDatetimeType; + +import java.util.ArrayList; +import java.util.List; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.calcite.rel.RelHomogeneousShuttle; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; + +/** + * Adapts datetime UDF calls for Enumerable compilation. PPL UDF implementors expect String + * input/output, but after normalization the plan uses standard DATE/TIME/TIMESTAMP types + * (int/long). This rule inserts CASTs to bridge the mismatch: + * + *
+ *   Before: LAST_DAY($2:DATE) : DATE
+ *   After:  CAST(LAST_DAY(CAST($2 AS VARCHAR)):VARCHAR AS DATE)
+ * 
+ */ +@NoArgsConstructor(access = AccessLevel.PACKAGE) +class DatetimeUdfCompilationAdapterRule extends RelHomogeneousShuttle { + + static final DatetimeUdfCompilationAdapterRule INSTANCE = new DatetimeUdfCompilationAdapterRule(); + + @Override + public RelNode visit(RelNode other) { + RelNode visited = super.visit(other); + RexBuilder rexBuilder = visited.getCluster().getRexBuilder(); + RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory(); + return visited.accept( + new RexShuttle() { + @Override + public RexNode visitCall(RexCall call) { + call = (RexCall) super.visitCall(call); + if (!(call.getOperator() instanceof SqlUserDefinedFunction)) { + return call; + } + + // Adapt operands: CAST(datetime_operand AS VARCHAR) for UDF implementor + List adapted = new ArrayList<>(call.getOperands().size()); + boolean operandsChanged = false; + for (RexNode operand : call.getOperands()) { + if (isDatetimeType(operand.getType().getSqlTypeName())) { + RelDataType varcharType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.VARCHAR), + operand.getType().isNullable()); + adapted.add(rexBuilder.makeCast(varcharType, operand)); + operandsChanged = true; + } else { + adapted.add(operand); + } + } + + // Adapt result: if return type is datetime, wrap call with VARCHAR return + CAST back + if (isDatetimeType(call.getType().getSqlTypeName())) { + RelDataType declaredType = call.getType(); + RelDataType varcharType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.VARCHAR), declaredType.isNullable()); + RexCall varcharCall = + call.clone(varcharType, operandsChanged ? adapted : call.getOperands()); + return rexBuilder.makeCast(declaredType, varcharCall); + } + + return operandsChanged ? call.clone(call.getType(), adapted) : call; + } + }); + } +} diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java index b15d830d41..8a7760e6d7 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java @@ -5,15 +5,21 @@ package org.opensearch.sql.api.spec.datetime; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; import lombok.AccessLevel; import lombok.NoArgsConstructor; import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexShuttle; import org.apache.calcite.sql.type.SqlTypeName; @@ -30,10 +36,82 @@ class DatetimeUdtNormalizeRule extends RelHomogeneousShuttle { @Override public RelNode visit(RelNode other) { - RelNode visited = super.visit(other); - RexBuilder rexBuilder = visited.getCluster().getRexBuilder(); + // Visit children first + List newInputs = new ArrayList<>(); + boolean childChanged = false; + for (RelNode input : other.getInputs()) { + RelNode newInput = input.accept(this); + newInputs.add(newInput); + if (newInput != input) { + childChanged = true; + } + } + + // Rebuild current node if children changed + RelNode current = other; + if (childChanged) { + if (current instanceof LogicalAggregate agg) { + // Aggregate needs AggregateCall types rebuilt + RelNode newInput = newInputs.get(0); + List newAggCalls = + agg.getAggCallList().stream() + .map( + call -> + AggregateCall.create( + call.getAggregation(), + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + call.rexList, + call.getArgList(), + call.filterArg, + call.distinctKeys, + call.collation, + agg.getGroupCount(), + newInput, + null, + call.getName())) + .toList(); + current = + agg.copy( + agg.getTraitSet(), newInput, agg.getGroupSet(), agg.getGroupSets(), newAggCalls); + } else if (current instanceof LogicalProject proj) { + // Project needs RexInputRef types refreshed from new child + RelNode newInput = newInputs.get(0); + RexBuilder rexBuilder = proj.getCluster().getRexBuilder(); + List newProjects = + proj.getProjects().stream() + .map( + expr -> + expr.accept( + new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef ref) { + RelDataType newType = + newInput + .getRowType() + .getFieldList() + .get(ref.getIndex()) + .getType(); + if (!newType.equals(ref.getType())) { + return rexBuilder.makeInputRef(newType, ref.getIndex()); + } + return ref; + } + })) + .toList(); + current = + LogicalProject.create( + newInput, proj.getHints(), newProjects, proj.getRowType().getFieldNames()); + } else { + current = current.copy(current.getTraitSet(), newInputs); + } + } + + // Apply RexShuttle to normalize UDT types in this node's expressions + RexBuilder rexBuilder = current.getCluster().getRexBuilder(); RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory(); - return visited.accept( + return current.accept( new RexShuttle() { @Override public RexNode visitCall(RexCall call) { @@ -43,7 +121,6 @@ public RexNode visitCall(RexCall call) { return call; } - // Normalize UDT return type to standard Calcite DATE/TIME/TIMESTAMP UdtMapping m = mapping.get(); SqlTypeName stdTypeName = m.getStdType(); RelDataType baseType = diff --git a/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java index fc08915010..dea6b32b09 100644 --- a/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java +++ b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java @@ -71,114 +71,97 @@ private Table createEventsTable() { } @Test - public void testUdfResultNormalizedAndCastToVarchar() { + public void testUdfOnLiteralsNormalizedAndExecutable() throws Exception { var plan = givenQuery( """ source = catalog.events \ - | eval d = DATE(name), t = TIME(name), ts = TIMESTAMP(name) \ + | eval d = DATE('2024-01-01'), t = TIME('12:30:00'), ts = TIMESTAMP('2024-01-01 12:30:00') \ | fields d, t, ts\ """) .assertPlan( """ LogicalProject(d=[CAST($0):VARCHAR], t=[CAST($1):VARCHAR], ts=[CAST($2):VARCHAR]) - LogicalProject(d=[DATE($1)], t=[TIME($1)], ts=[TIMESTAMP($1)]) + LogicalProject(d=[DATE('2024-01-01':VARCHAR)], t=[TIME('12:30:00':VARCHAR)], ts=[TIMESTAMP('2024-01-01 12:30:00':VARCHAR)]) LogicalTableScan(table=[[catalog, events]]) """) .plan(); assertCallType(plan, "DATE", DATE); assertCallType(plan, "TIME", TIME, 9); assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema( + col("d", java.sql.Types.VARCHAR), + col("t", java.sql.Types.VARCHAR), + col("ts", java.sql.Types.VARCHAR)) + .expectData( + row("2024-01-01", "12:30:00", "2024-01-01 12:30:00"), + row("2024-01-01", "12:30:00", "2024-01-01 12:30:00")); + } } @Test - public void testNestedUdfCallsNormalized() { + public void testNestedUdfCallsExecutable() throws Exception { var plan = - givenQuery("source = catalog.events | eval d = DATEDIFF(DATE(name), DATE(name)) | fields d") - .assertPlan( + givenQuery( """ - LogicalProject(d=[DATEDIFF(DATE($1), DATE($1))]) - LogicalTableScan(table=[[catalog, events]]) + source = catalog.events \ + | eval d = DATEDIFF(DATE('2025-01-01'), DATE('2024-01-01')) \ + | fields d\ """) - .plan(); - assertCallType(plan, "DATE", DATE); - assertCallType(plan, "DATEDIFF", BIGINT); - } - - @Test - public void testDateLiteralCastToVarchar() { - var plan = - givenQuery("source = catalog.events | eval d = DATE('2024-01-01') | fields d") .assertPlan( """ - LogicalProject(d=[CAST($0):VARCHAR]) - LogicalProject(d=[DATE('2024-01-01':VARCHAR)]) - LogicalTableScan(table=[[catalog, events]]) + LogicalProject(d=[DATEDIFF(DATE('2025-01-01':VARCHAR), DATE('2024-01-01':VARCHAR))]) + LogicalTableScan(table=[[catalog, events]]) """) .plan(); assertCallType(plan, "DATE", DATE); + assertCallType(plan, "DATEDIFF", BIGINT); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("d", java.sql.Types.BIGINT)).expectData(row(366L), row(366L)); + } } @Test - public void testFilterWithTimestampLiteral() { + public void testFilterWithTimestampUdf() throws Exception { var plan = givenQuery( """ - source = catalog.events | where created_at > "2024-01-01T00:00:00Z" | fields id\ + source = catalog.events \ + | where created_at < TIMESTAMP('2024-06-01 00:00:00') \ + | fields id\ """) .assertPlan( """ LogicalProject(id=[$0]) - LogicalFilter(condition=[>($4, TIMESTAMP('2024-01-01T00:00:00Z':VARCHAR))]) + LogicalFilter(condition=[<($4, TIMESTAMP('2024-06-01 00:00:00':VARCHAR))]) LogicalTableScan(table=[[catalog, events]]) """) .plan(); assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("id", java.sql.Types.INTEGER)).expectData(row(1)); + } } @Test - public void testComparisonWithDatetimeUdf() { + public void testStandardDatetimeFieldsCastToVarchar() throws Exception { var plan = - givenQuery("source = catalog.events | where created_at < DATE(name) | fields id") + givenQuery("source = catalog.events | fields hire_date, start_time, created_at") .assertPlan( """ - LogicalProject(id=[$0]) - LogicalFilter(condition=[<($4, TIMESTAMP(DATE($1)))]) + LogicalProject(hire_date=[CAST($0):VARCHAR NOT NULL], start_time=[CAST($1):VARCHAR NOT NULL], created_at=[CAST($2):VARCHAR NOT NULL]) + LogicalProject(hire_date=[$2], start_time=[$3], created_at=[$4]) LogicalTableScan(table=[[catalog, events]]) """) .plan(); - assertCallType(plan, "DATE", DATE); - assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); - } - - @Test - public void testAllStandardDatetimeTypesCastToVarchar() { - givenQuery("source = catalog.events | fields hire_date, start_time, created_at") - .assertPlan( - """ - LogicalProject(hire_date=[CAST($0):VARCHAR NOT NULL], start_time=[CAST($1):VARCHAR NOT NULL], created_at=[CAST($2):VARCHAR NOT NULL]) - LogicalProject(hire_date=[$2], start_time=[$3], created_at=[$4]) - LogicalTableScan(table=[[catalog, events]]) - """); - } - - @Test - public void testNonDatetimeFieldsNotWrapped() { - givenQuery("source = catalog.events | fields id, name") - .assertPlan( - """ - LogicalProject(id=[$0], name=[$1]) - LogicalTableScan(table=[[catalog, events]]) - """); - } - - @Test - public void testOutputCastCanCompileAndExecute() throws Exception { - RelNode plan = - planner.plan("source = catalog.events | fields hire_date, start_time, created_at"); - try (PreparedStatement statement = compiler.compile(plan)) { - ResultSet resultSet = statement.executeQuery(); - verify(resultSet) + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) .expectSchema( col("hire_date", java.sql.Types.VARCHAR), col("start_time", java.sql.Types.VARCHAR), @@ -189,6 +172,42 @@ public void testOutputCastCanCompileAndExecute() throws Exception { } } + @Test + public void testNonDatetimeFieldsNotWrapped() throws Exception { + var plan = + givenQuery("source = catalog.events | fields id, name") + .assertPlan( + """ + LogicalProject(id=[$0], name=[$1]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema(col("id", java.sql.Types.INTEGER), col("name", java.sql.Types.VARCHAR)) + .expectData(row(1, "Alice"), row(2, "Bob")); + } + } + + @Test + public void testNonDatetimeUdfUnaffected() throws Exception { + var plan = + givenQuery("source = catalog.events | eval s = CONCAT(name, ' test') | fields s") + .assertPlan( + """ + LogicalProject(s=[CONCAT($1, ' test':VARCHAR)]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema(col("s", java.sql.Types.VARCHAR)) + .expectData(row("Alice test"), row("Bob test")); + } + } + private static void assertCallType(RelNode plan, String operatorName, SqlTypeName expectedType) { assertCallType(plan, operatorName, expectedType, -1); } @@ -222,4 +241,35 @@ public RexNode visitCall(RexCall call) { operatorName + " precision", expectedPrecision, ref.get().getType().getPrecision()); } } + + @Test + public void testAggMaxOnDatetimeUdf() throws Exception { + var plan = planner.plan("source = catalog.events | stats max(DATE('2024-01-01')) as m"); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("m", java.sql.Types.VARCHAR)).expectData(row("2024-01-01")); + } + } + + @Test + public void testAggGroupByDatetimeUdf() throws Exception { + var plan = + planner.plan( + "source = catalog.events | eval d = DATE('2024-01-01') | stats count() as c by d"); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema(col("c", java.sql.Types.BIGINT), col("d", java.sql.Types.VARCHAR)) + .expectData(row(2L, "2024-01-01")); + } + } + + @Test + public void testAggOnDatetimeColumnWorks() throws Exception { + var plan = planner.plan("source = catalog.events | stats max(hire_date) as m"); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("m", java.sql.Types.VARCHAR)).expectData(row("2024-06-20")); + } + } }