Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@

package org.opensearch.sql.api.spec.search;

import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CAST;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR;
import static org.apache.calcite.sql.type.SqlTypeName.VARCHAR;

import java.util.ArrayList;
import java.util.List;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlShuttle;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.opensearch.sql.api.spec.UnifiedFunctionSpec;
Expand Down Expand Up @@ -42,34 +50,69 @@ public final class NamedArgRewriter extends SqlShuttle {

/**
* Rewrites each argument into a MAP entry. For match(name, 'John', operator='AND'):
* <li>Positional arg: name → MAP('field', name)
* <li>Named arg: operator='AND' → MAP('operator', 'AND')
* <li>Positional arg: name → MAP('field', name)
* <li>ARRAY arg: ARRAY['f1','f2'] → MAP('fields', MAP(CAST('f1' AS VARCHAR), 1, ...))
*/
private static SqlCall rewriteToMaps(SqlCall call, List<String> paramNames) {
List<SqlNode> operands = call.getOperandList();
SqlNode[] maps = new SqlNode[operands.size()];
for (int i = 0; i < operands.size(); i++) {
SqlNode op = operands.get(i);
if (op instanceof SqlCall eq && op.getKind() == SqlKind.EQUALS) {
SqlNode key = eq.operand(0);
String name =
key instanceof SqlIdentifier ident
? ident.getSimple()
: key.toString(); // avoid backtick-decorated keys for reserved words
maps[i] = toMap(name, eq.operand(1));
} else {
if (isNamedArg(op)) {
maps[i] = namedArgToMap((SqlCall) op);
} else { // Positional arg
if (i >= paramNames.size()) {
throw new IllegalArgumentException(
String.format("Invalid arguments for function '%s'", call.getOperator().getName()));
} else if (isArrayArg(op)) {
maps[i] = map(paramNames.get(i), arrayArgToMap((SqlCall) op));
} else {
maps[i] = map(paramNames.get(i), op);
}
maps[i] = toMap(paramNames.get(i), op);
}
}
return call.getOperator().createCall(call.getParserPosition(), maps);
}

private static SqlNode toMap(String key, SqlNode value) {
return SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall(
private static boolean isNamedArg(SqlNode node) {
return node instanceof SqlCall && node.getKind() == SqlKind.EQUALS;
}

private static boolean isArrayArg(SqlNode node) {
return node instanceof SqlCall call && call.getOperator() == ARRAY_VALUE_CONSTRUCTOR;
}

private static SqlNode namedArgToMap(SqlCall eq) {
SqlNode key = eq.operand(0);
String name =
key instanceof SqlIdentifier ident
? ident.getSimple()
: key.toString(); // avoid backtick-decorated keys for reserved words
return map(name, eq.operand(1));
}

private static SqlNode arrayArgToMap(SqlCall arrayCall) {
List<SqlNode> mapArgs = new ArrayList<>();
for (SqlNode element : arrayCall.getOperandList()) {
mapArgs.add(cast(element, VARCHAR));
mapArgs.add(SqlLiteral.createApproxNumeric("1.0", SqlParserPos.ZERO));
}
return map(mapArgs);
}

private static SqlNode cast(SqlNode node, SqlTypeName type) {
SqlDataTypeSpec typeSpec =
new SqlDataTypeSpec(new SqlBasicTypeNameSpec(type, SqlParserPos.ZERO), SqlParserPos.ZERO);
return CAST.createCall(SqlParserPos.ZERO, node, typeSpec);
}

private static SqlNode map(String key, SqlNode value) {
return MAP_VALUE_CONSTRUCTOR.createCall(
SqlParserPos.ZERO, SqlLiteral.createCharString(key, SqlParserPos.ZERO), value);
}

private static SqlNode map(List<SqlNode> kvPairs) {
return MAP_VALUE_CONSTRUCTOR.createCall(SqlParserPos.ZERO, kvPairs.toArray(SqlNode[]::new));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,42 @@ SELECT upper(name) FROM catalog.employees\
// FIXME: Calcite's SQL parser does not support V2 bracket field list syntax ['field1', 'field2'].
// Multi-field relevance functions only accept a single column reference in the Calcite SQL path.

@Test
public void testMultiMatchArraySyntax() {
givenQuery(
"""
SELECT * FROM catalog.employees
WHERE multi_match(ARRAY['name', 'department'], 'John')\
""")
.assertPlanContains(
"multi_match(MAP('fields', MAP('name':VARCHAR, 1.0E0:DOUBLE,"
+ " 'department':VARCHAR, 1.0E0:DOUBLE)), MAP('query', 'John'))");
}

@Test
public void testSimpleQueryStringArraySyntax() {
givenQuery(
"""
SELECT * FROM catalog.employees
WHERE simple_query_string(ARRAY['name', 'department'], 'John')\
""")
.assertPlanContains(
"simple_query_string(MAP('fields', MAP('name':VARCHAR, 1.0E0:DOUBLE,"
+ " 'department':VARCHAR, 1.0E0:DOUBLE)), MAP('query', 'John'))");
}

@Test
public void testQueryStringArraySyntax() {
givenQuery(
"""
SELECT * FROM catalog.employees
WHERE query_string(ARRAY['name', 'department'], 'John')\
""")
.assertPlanContains(
"query_string(MAP('fields', MAP('name':VARCHAR, 1.0E0:DOUBLE,"
+ " 'department':VARCHAR, 1.0E0:DOUBLE)), MAP('query', 'John'))");
}

@Test
public void testMultiMatchBracketSyntaxNotSupported() {
givenInvalidQuery(
Expand Down
Loading