diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java index 944ac4a4bf..f317e9960c 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java @@ -25,6 +25,11 @@ public List postAnalysisRules() { return List.of(DatetimeUdtNormalizeRule.INSTANCE, DatetimeOutputCastRule.INSTANCE); } + @Override + public List preCompilationRules() { + return List.of(DatetimeUdfCompilationAdapterRule.INSTANCE); + } + /** Maps datetime UDT types to their standard Calcite equivalents. */ @Getter @RequiredArgsConstructor diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdfCompilationAdapterRule.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdfCompilationAdapterRule.java new file mode 100644 index 0000000000..daf9ef4fdf --- /dev/null +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdfCompilationAdapterRule.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import static org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping.isDatetimeType; + +import java.util.ArrayList; +import java.util.List; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.calcite.rel.RelHomogeneousShuttle; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; + +/** + * Adapts datetime UDF calls for Enumerable compilation. PPL UDF implementors expect String + * input/output, but after normalization the plan uses standard DATE/TIME/TIMESTAMP types + * (int/long). This rule inserts CASTs to bridge the mismatch: + * + *
+ *   Before: LAST_DAY($2:DATE) : DATE
+ *   After:  CAST(LAST_DAY(CAST($2 AS VARCHAR)):VARCHAR AS DATE)
+ * 
+ */ +@NoArgsConstructor(access = AccessLevel.PACKAGE) +class DatetimeUdfCompilationAdapterRule extends RelHomogeneousShuttle { + + static final DatetimeUdfCompilationAdapterRule INSTANCE = new DatetimeUdfCompilationAdapterRule(); + + @Override + public RelNode visit(RelNode other) { + RelNode visited = super.visit(other); + RexBuilder rexBuilder = visited.getCluster().getRexBuilder(); + RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory(); + return visited.accept( + new RexShuttle() { + @Override + public RexNode visitCall(RexCall call) { + call = (RexCall) super.visitCall(call); + if (!(call.getOperator() instanceof SqlUserDefinedFunction)) { + return call; + } + + // Adapt operands: CAST(datetime_operand AS VARCHAR) for UDF implementor + List adapted = new ArrayList<>(call.getOperands().size()); + boolean operandsChanged = false; + for (RexNode operand : call.getOperands()) { + if (isDatetimeType(operand.getType().getSqlTypeName())) { + RelDataType varcharType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.VARCHAR), + operand.getType().isNullable()); + adapted.add(rexBuilder.makeCast(varcharType, operand)); + operandsChanged = true; + } else { + adapted.add(operand); + } + } + + // Adapt result: if return type is datetime, wrap call with VARCHAR return + CAST back + if (isDatetimeType(call.getType().getSqlTypeName())) { + RelDataType declaredType = call.getType(); + RelDataType varcharType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.VARCHAR), declaredType.isNullable()); + RexCall varcharCall = + call.clone(varcharType, operandsChanged ? adapted : call.getOperands()); + return rexBuilder.makeCast(declaredType, varcharCall); + } + + return operandsChanged ? call.clone(call.getType(), adapted) : call; + } + }); + } +} diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java index b15d830d41..8a7760e6d7 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java @@ -5,15 +5,21 @@ package org.opensearch.sql.api.spec.datetime; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; import lombok.AccessLevel; import lombok.NoArgsConstructor; import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexShuttle; import org.apache.calcite.sql.type.SqlTypeName; @@ -30,10 +36,82 @@ class DatetimeUdtNormalizeRule extends RelHomogeneousShuttle { @Override public RelNode visit(RelNode other) { - RelNode visited = super.visit(other); - RexBuilder rexBuilder = visited.getCluster().getRexBuilder(); + // Visit children first + List newInputs = new ArrayList<>(); + boolean childChanged = false; + for (RelNode input : other.getInputs()) { + RelNode newInput = input.accept(this); + newInputs.add(newInput); + if (newInput != input) { + childChanged = true; + } + } + + // Rebuild current node if children changed + RelNode current = other; + if (childChanged) { + if (current instanceof LogicalAggregate agg) { + // Aggregate needs AggregateCall types rebuilt + RelNode newInput = newInputs.get(0); + List newAggCalls = + agg.getAggCallList().stream() + .map( + call -> + AggregateCall.create( + call.getAggregation(), + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + call.rexList, + call.getArgList(), + call.filterArg, + call.distinctKeys, + call.collation, + agg.getGroupCount(), + newInput, + null, + call.getName())) + .toList(); + current = + agg.copy( + agg.getTraitSet(), newInput, agg.getGroupSet(), agg.getGroupSets(), newAggCalls); + } else if (current instanceof LogicalProject proj) { + // Project needs RexInputRef types refreshed from new child + RelNode newInput = newInputs.get(0); + RexBuilder rexBuilder = proj.getCluster().getRexBuilder(); + List newProjects = + proj.getProjects().stream() + .map( + expr -> + expr.accept( + new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef ref) { + RelDataType newType = + newInput + .getRowType() + .getFieldList() + .get(ref.getIndex()) + .getType(); + if (!newType.equals(ref.getType())) { + return rexBuilder.makeInputRef(newType, ref.getIndex()); + } + return ref; + } + })) + .toList(); + current = + LogicalProject.create( + newInput, proj.getHints(), newProjects, proj.getRowType().getFieldNames()); + } else { + current = current.copy(current.getTraitSet(), newInputs); + } + } + + // Apply RexShuttle to normalize UDT types in this node's expressions + RexBuilder rexBuilder = current.getCluster().getRexBuilder(); RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory(); - return visited.accept( + return current.accept( new RexShuttle() { @Override public RexNode visitCall(RexCall call) { @@ -43,7 +121,6 @@ public RexNode visitCall(RexCall call) { return call; } - // Normalize UDT return type to standard Calcite DATE/TIME/TIMESTAMP UdtMapping m = mapping.get(); SqlTypeName stdTypeName = m.getStdType(); RelDataType baseType = diff --git a/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java index fc08915010..dea6b32b09 100644 --- a/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java +++ b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java @@ -71,114 +71,97 @@ private Table createEventsTable() { } @Test - public void testUdfResultNormalizedAndCastToVarchar() { + public void testUdfOnLiteralsNormalizedAndExecutable() throws Exception { var plan = givenQuery( """ source = catalog.events \ - | eval d = DATE(name), t = TIME(name), ts = TIMESTAMP(name) \ + | eval d = DATE('2024-01-01'), t = TIME('12:30:00'), ts = TIMESTAMP('2024-01-01 12:30:00') \ | fields d, t, ts\ """) .assertPlan( """ LogicalProject(d=[CAST($0):VARCHAR], t=[CAST($1):VARCHAR], ts=[CAST($2):VARCHAR]) - LogicalProject(d=[DATE($1)], t=[TIME($1)], ts=[TIMESTAMP($1)]) + LogicalProject(d=[DATE('2024-01-01':VARCHAR)], t=[TIME('12:30:00':VARCHAR)], ts=[TIMESTAMP('2024-01-01 12:30:00':VARCHAR)]) LogicalTableScan(table=[[catalog, events]]) """) .plan(); assertCallType(plan, "DATE", DATE); assertCallType(plan, "TIME", TIME, 9); assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema( + col("d", java.sql.Types.VARCHAR), + col("t", java.sql.Types.VARCHAR), + col("ts", java.sql.Types.VARCHAR)) + .expectData( + row("2024-01-01", "12:30:00", "2024-01-01 12:30:00"), + row("2024-01-01", "12:30:00", "2024-01-01 12:30:00")); + } } @Test - public void testNestedUdfCallsNormalized() { + public void testNestedUdfCallsExecutable() throws Exception { var plan = - givenQuery("source = catalog.events | eval d = DATEDIFF(DATE(name), DATE(name)) | fields d") - .assertPlan( + givenQuery( """ - LogicalProject(d=[DATEDIFF(DATE($1), DATE($1))]) - LogicalTableScan(table=[[catalog, events]]) + source = catalog.events \ + | eval d = DATEDIFF(DATE('2025-01-01'), DATE('2024-01-01')) \ + | fields d\ """) - .plan(); - assertCallType(plan, "DATE", DATE); - assertCallType(plan, "DATEDIFF", BIGINT); - } - - @Test - public void testDateLiteralCastToVarchar() { - var plan = - givenQuery("source = catalog.events | eval d = DATE('2024-01-01') | fields d") .assertPlan( """ - LogicalProject(d=[CAST($0):VARCHAR]) - LogicalProject(d=[DATE('2024-01-01':VARCHAR)]) - LogicalTableScan(table=[[catalog, events]]) + LogicalProject(d=[DATEDIFF(DATE('2025-01-01':VARCHAR), DATE('2024-01-01':VARCHAR))]) + LogicalTableScan(table=[[catalog, events]]) """) .plan(); assertCallType(plan, "DATE", DATE); + assertCallType(plan, "DATEDIFF", BIGINT); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("d", java.sql.Types.BIGINT)).expectData(row(366L), row(366L)); + } } @Test - public void testFilterWithTimestampLiteral() { + public void testFilterWithTimestampUdf() throws Exception { var plan = givenQuery( """ - source = catalog.events | where created_at > "2024-01-01T00:00:00Z" | fields id\ + source = catalog.events \ + | where created_at < TIMESTAMP('2024-06-01 00:00:00') \ + | fields id\ """) .assertPlan( """ LogicalProject(id=[$0]) - LogicalFilter(condition=[>($4, TIMESTAMP('2024-01-01T00:00:00Z':VARCHAR))]) + LogicalFilter(condition=[<($4, TIMESTAMP('2024-06-01 00:00:00':VARCHAR))]) LogicalTableScan(table=[[catalog, events]]) """) .plan(); assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("id", java.sql.Types.INTEGER)).expectData(row(1)); + } } @Test - public void testComparisonWithDatetimeUdf() { + public void testStandardDatetimeFieldsCastToVarchar() throws Exception { var plan = - givenQuery("source = catalog.events | where created_at < DATE(name) | fields id") + givenQuery("source = catalog.events | fields hire_date, start_time, created_at") .assertPlan( """ - LogicalProject(id=[$0]) - LogicalFilter(condition=[<($4, TIMESTAMP(DATE($1)))]) + LogicalProject(hire_date=[CAST($0):VARCHAR NOT NULL], start_time=[CAST($1):VARCHAR NOT NULL], created_at=[CAST($2):VARCHAR NOT NULL]) + LogicalProject(hire_date=[$2], start_time=[$3], created_at=[$4]) LogicalTableScan(table=[[catalog, events]]) """) .plan(); - assertCallType(plan, "DATE", DATE); - assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); - } - - @Test - public void testAllStandardDatetimeTypesCastToVarchar() { - givenQuery("source = catalog.events | fields hire_date, start_time, created_at") - .assertPlan( - """ - LogicalProject(hire_date=[CAST($0):VARCHAR NOT NULL], start_time=[CAST($1):VARCHAR NOT NULL], created_at=[CAST($2):VARCHAR NOT NULL]) - LogicalProject(hire_date=[$2], start_time=[$3], created_at=[$4]) - LogicalTableScan(table=[[catalog, events]]) - """); - } - - @Test - public void testNonDatetimeFieldsNotWrapped() { - givenQuery("source = catalog.events | fields id, name") - .assertPlan( - """ - LogicalProject(id=[$0], name=[$1]) - LogicalTableScan(table=[[catalog, events]]) - """); - } - - @Test - public void testOutputCastCanCompileAndExecute() throws Exception { - RelNode plan = - planner.plan("source = catalog.events | fields hire_date, start_time, created_at"); - try (PreparedStatement statement = compiler.compile(plan)) { - ResultSet resultSet = statement.executeQuery(); - verify(resultSet) + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) .expectSchema( col("hire_date", java.sql.Types.VARCHAR), col("start_time", java.sql.Types.VARCHAR), @@ -189,6 +172,42 @@ public void testOutputCastCanCompileAndExecute() throws Exception { } } + @Test + public void testNonDatetimeFieldsNotWrapped() throws Exception { + var plan = + givenQuery("source = catalog.events | fields id, name") + .assertPlan( + """ + LogicalProject(id=[$0], name=[$1]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema(col("id", java.sql.Types.INTEGER), col("name", java.sql.Types.VARCHAR)) + .expectData(row(1, "Alice"), row(2, "Bob")); + } + } + + @Test + public void testNonDatetimeUdfUnaffected() throws Exception { + var plan = + givenQuery("source = catalog.events | eval s = CONCAT(name, ' test') | fields s") + .assertPlan( + """ + LogicalProject(s=[CONCAT($1, ' test':VARCHAR)]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema(col("s", java.sql.Types.VARCHAR)) + .expectData(row("Alice test"), row("Bob test")); + } + } + private static void assertCallType(RelNode plan, String operatorName, SqlTypeName expectedType) { assertCallType(plan, operatorName, expectedType, -1); } @@ -222,4 +241,35 @@ public RexNode visitCall(RexCall call) { operatorName + " precision", expectedPrecision, ref.get().getType().getPrecision()); } } + + @Test + public void testAggMaxOnDatetimeUdf() throws Exception { + var plan = planner.plan("source = catalog.events | stats max(DATE('2024-01-01')) as m"); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("m", java.sql.Types.VARCHAR)).expectData(row("2024-01-01")); + } + } + + @Test + public void testAggGroupByDatetimeUdf() throws Exception { + var plan = + planner.plan( + "source = catalog.events | eval d = DATE('2024-01-01') | stats count() as c by d"); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs) + .expectSchema(col("c", java.sql.Types.BIGINT), col("d", java.sql.Types.VARCHAR)) + .expectData(row(2L, "2024-01-01")); + } + } + + @Test + public void testAggOnDatetimeColumnWorks() throws Exception { + var plan = planner.plan("source = catalog.events | stats max(hire_date) as m"); + try (PreparedStatement stmt = compiler.compile(plan)) { + ResultSet rs = stmt.executeQuery(); + verify(rs).expectSchema(col("m", java.sql.Types.VARCHAR)).expectData(row("2024-06-20")); + } + } }