Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,30 @@ public static Function create(Method method) {
}

/*
* Finds a method in a given class by name.
* Finds a method in a given class by name. In case of overloaded methods with the same name,
* this prioritizes the overload with the maximum number of parameters. This ensures Calcite
* can resolve optional/default trailing parameters correctly when binding UDF overloads.
*
* @param clazz class to search method in
* @param name name of the method to find
* @return the first method with matching name or null when no method found
* @return the matching method with the highest parameter count or null when no method found
*/
static @Nullable Method findMethod(Class<?> clazz, String name) {
Method bestMethod = null;
for (Method method : clazz.getMethods()) {
if (method.getName().equals(name) && !method.isBridge()) {
return method;
if (bestMethod == null) {
bestMethod = method;
} else {
int cmp =
Integer.compare(
method.getParameterCount(), bestMethod.getParameterCount());
if (cmp > 0 || (cmp == 0 && method.toString().compareTo(bestMethod.toString()) < 0)) {
bestMethod = method;
}
}
}
}
return null;
return bestMethod;
}
Comment thread
damccorm marked this conversation as resolved.
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
Comment thread
damccorm marked this conversation as resolved.
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Top;
Expand Down Expand Up @@ -74,15 +75,12 @@
* <pre>{@code
* SELECT * FROM t ORDER BY id DESC LIMIT 10;
* SELECT * FROM t ORDER BY id DESC LIMIT 10 OFFSET 5;
* }</pre>
*
* <p>but an ORDER BY without a LIMIT is NOT supported. For example, the following will throw an
* exception:
*
* <pre>{@code
* SELECT * FROM t ORDER BY id DESC;
* }</pre>
*
* <p>Note: ORDER BY without a LIMIT is supported by keying all rows to a single key and sorting
* them in memory. This can be memory-intensive and may fail for large datasets.
*
* <h3>Constraints</h3>
*
* <ul>
Expand Down Expand Up @@ -134,12 +132,12 @@ public BeamSortRel(
}

if (fetch == null) {
throw new UnsupportedOperationException("ORDER BY without a LIMIT is not supported!");
count = -1;
} else {
RexLiteral fetchLiteral = (RexLiteral) fetch;
count = ((BigDecimal) fetchLiteral.getValue()).intValue();
}

RexLiteral fetchLiteral = (RexLiteral) fetch;
count = ((BigDecimal) fetchLiteral.getValue()).intValue();

if (offset != null) {
RexLiteral offsetLiteral = (RexLiteral) offset;
startIndex = ((BigDecimal) offsetLiteral.getValue()).intValue();
Expand Down Expand Up @@ -209,6 +207,21 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
GlobalWindows.class.getSimpleName(), windowingStrategy));
}

// When no limit is specified (count == -1), we must sort the entire dataset.
// To achieve this globally, we key all rows by a single dummy key, group them together
// using GroupByKey to ensure they are processed together, and then sort them in-memory
// via SortInMemoryFn. Note: This can be memory-intensive for large datasets. It should
// only be done as a final step when the remaining data is small
if (count == -1) {
BeamSqlRowComparator comparator =
new BeamSqlRowComparator(fieldIndices, orientation, nullsFirst);
return upstream
.apply("WithDummyKey", WithKeys.of("DummyKey"))
.apply("GroupByKey", GroupByKey.create())
.apply("SortInMemory", ParDo.of(new SortInMemoryFn(comparator)))
.setRowSchema(CalciteUtils.toSchema(getRowType()));
}
Comment thread
damccorm marked this conversation as resolved.

ReversedBeamSqlRowComparator comparator =
new ReversedBeamSqlRowComparator(fieldIndices, orientation, nullsFirst);

Expand Down Expand Up @@ -303,6 +316,31 @@ public void processElement(ProcessContext ctx) {
}
}

/**
* A {@link DoFn} that sorts all elements in-memory. Expects input grouped by a dummy key, sorts
* the iterable values, and outputs them.
*/
private static class SortInMemoryFn extends DoFn<KV<String, Iterable<Row>>, Row> {
private final BeamSqlRowComparator comparator;

public SortInMemoryFn(BeamSqlRowComparator comparator) {
this.comparator = comparator;
}

@ProcessElement
public void processElement(ProcessContext ctx) {
Iterable<Row> input = ctx.element().getValue();
List<Row> list = new ArrayList<>();
for (Row r : input) {
list.add(r);
}
list.sort(comparator);
for (Row r : list) {
ctx.output(r);
}
}
}

@Override
public Sort copy(
RelTraitSet traitSet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ public static boolean isStringType(FieldType fieldType) {
FieldType.DATETIME, SqlTypeName.TIMESTAMP,
FieldType.STRING, SqlTypeName.VARCHAR);

private static final Map<Class<?>, SqlTypeName> JAVA_TO_SQL_TYPE_MAPPING =
ImmutableMap.<Class<?>, SqlTypeName>builder()
.put(String.class, SqlTypeName.VARCHAR)
.put(Integer.class, SqlTypeName.INTEGER)
.put(int.class, SqlTypeName.INTEGER)
.put(Long.class, SqlTypeName.BIGINT)
.put(long.class, SqlTypeName.BIGINT)
.put(Double.class, SqlTypeName.DOUBLE)
.put(double.class, SqlTypeName.DOUBLE)
.put(Float.class, SqlTypeName.FLOAT)
.put(float.class, SqlTypeName.FLOAT)
.put(Short.class, SqlTypeName.SMALLINT)
.put(short.class, SqlTypeName.SMALLINT)
.put(Byte.class, SqlTypeName.TINYINT)
.put(byte.class, SqlTypeName.TINYINT)
.put(Boolean.class, SqlTypeName.BOOLEAN)
.put(boolean.class, SqlTypeName.BOOLEAN)
.build();

// Associating FieldType to generated RelDataType objects for Beam logical types. Used for
// recovering the original type in output schema after full Beam FieldType->Calcite Type->Beam
// FieldType trip
Expand Down Expand Up @@ -365,7 +384,9 @@ private static RelDataType toRelDataType(
* SQL-Java type mapping, with specified Beam rules: <br>
* 1. redirect {@link AbstractInstant} to {@link Date} so Calcite can recognize it. <br>
* 2. For a list, the component type is needed to create a Sql array type. <br>
* 3. For a Map, the component type is needed to create a Sql map type.
* 3. For a Map, the component type is needed to create a Sql map type. <br>
* 4. For standard Java classes (String, Integer, etc.), map them to corresponding Calcite SQL
* type with appropriate nullability.
*
* @param type
* @return Calcite RelDataType
Expand Down Expand Up @@ -396,6 +417,14 @@ public static RelDataType sqlTypeWithAutoCast(RelDataTypeFactory typeFactory, Ty
+ ". This is currently unsupported, use List instead "
+ "of Array.");
}
if (type instanceof Class) {
Class<?> clazz = (Class<?>) type;
SqlTypeName sqlTypeName = JAVA_TO_SQL_TYPE_MAPPING.get(clazz);
if (sqlTypeName != null) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(sqlTypeName), !clazz.isPrimitive());
}
}
Comment thread
damccorm marked this conversation as resolved.
Comment thread
damccorm marked this conversation as resolved.
return typeFactory.createJavaType((Class) type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.schema.AggregateFunction;
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.schema.FunctionParameter;
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.Rule;
Expand Down Expand Up @@ -77,7 +78,9 @@ public void subclassGetUdafImpl() {
LazyAggregateCombineFn<?, ?, ?> combiner = new LazyAggregateCombineFn<>(aggregateFn);
AggregateFunction aggregateFunction = combiner.getUdafImpl();
RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
RelDataType expectedType = typeFactory.createJavaType(Long.class);
RelDataType expectedType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), true);

List<FunctionParameter> params = aggregateFunction.getParameters();
assertThat(params, hasSize(1));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

import java.lang.reflect.Method;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link UdfImpl}. */
@RunWith(JUnit4.class)
public class UdfImplTest {

@Test
public void testFindMethod_overloaded_prioritizesMaxParams() {
Method method = UdfImpl.findMethod(OverloadedFn.class, "eval");
assertNotNull(method);
assertEquals(3, method.getParameterTypes().length);
}

@Test
public void testFindMethod_singleMethod() {
Method method = UdfImpl.findMethod(SingleFn.class, "eval");
assertNotNull(method);
assertEquals(1, method.getParameterTypes().length);
}

public static class OverloadedFn {
public String eval(String a) {
return a;
}

public String eval(String a, String b) {
return a + b;
}

public String eval(String a, String b, String c) {
return a + b + c;
}
}

public static class SingleFn {
public String eval(String a) {
return a;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;

import java.util.ArrayList;
import java.util.List;
import org.apache.beam.sdk.extensions.sql.TestUtils;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.RelNode;
Expand Down Expand Up @@ -316,4 +319,44 @@ public void testNodeStatsEstimation() {
Assert.assertEquals(10., estimate.getRowCount(), 0.01);
Assert.assertEquals(10., estimate.getWindow(), 0.01);
}

@Test
public void testOrderBy_noLimit() {
String sql =
"SELECT order_id, site_id, price "
+ "FROM ORDER_DETAILS "
+ "ORDER BY order_id asc, site_id desc";

PCollection<Row> rows = compilePipeline(sql, pipeline);
PAssert.that(rows).satisfies(new AssertSorted());
pipeline.run().waitUntilFinish();
}

private static class AssertSorted implements SerializableFunction<Iterable<Row>, Void> {
@Override
public Void apply(Iterable<Row> input) {
List<Row> list = new ArrayList<>();
for (Row r : input) {
list.add(r);
}
Assert.assertEquals(10, list.size());
for (int i = 0; i < list.size() - 1; i++) {
Row r1 = list.get(i);
Row r2 = list.get(i + 1);
Long id1 = r1.getInt64("order_id");
Long id2 = r2.getInt64("order_id");
int comp = id1.compareTo(id2);
if (comp > 0) {
Assert.fail("Rows not sorted by order_id asc: " + list);
} else if (comp == 0) {
Integer site1 = r1.getInt32("site_id");
Integer site2 = r2.getInt32("site_id");
if (site1 < site2) {
Assert.fail("Rows not sorted by site_id desc when order_id is equal: " + list);
}
}
}
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,67 @@ public void testToRelDataTypeWithRowBackedLogicalType() {
assertEquals(1, relDataType.getFieldCount());
assertEquals("nested_f1", relDataType.getFieldList().get(0).getName());
}

@Test
public void testSqlTypeWithAutoCast() {
RelDataType type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, String.class);
assertEquals(SqlTypeName.VARCHAR, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, Integer.class);
assertEquals(SqlTypeName.INTEGER, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, int.class);
assertEquals(SqlTypeName.INTEGER, type.getSqlTypeName());
assertFalse(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, Long.class);
assertEquals(SqlTypeName.BIGINT, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, long.class);
assertEquals(SqlTypeName.BIGINT, type.getSqlTypeName());
assertFalse(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, Double.class);
assertEquals(SqlTypeName.DOUBLE, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, double.class);
assertEquals(SqlTypeName.DOUBLE, type.getSqlTypeName());
assertFalse(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, Float.class);
assertEquals(SqlTypeName.FLOAT, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, float.class);
assertEquals(SqlTypeName.FLOAT, type.getSqlTypeName());
assertFalse(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, Short.class);
assertEquals(SqlTypeName.SMALLINT, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, short.class);
assertEquals(SqlTypeName.SMALLINT, type.getSqlTypeName());
assertFalse(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, Byte.class);
assertEquals(SqlTypeName.TINYINT, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, byte.class);
assertEquals(SqlTypeName.TINYINT, type.getSqlTypeName());
assertFalse(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, Boolean.class);
assertEquals(SqlTypeName.BOOLEAN, type.getSqlTypeName());
assertTrue(type.isNullable());

type = CalciteUtils.sqlTypeWithAutoCast(dataTypeFactory, boolean.class);
assertEquals(SqlTypeName.BOOLEAN, type.getSqlTypeName());
assertFalse(type.isNullable());
}
}
Loading