diff --git a/common/ast/BUILD.bazel b/common/ast/BUILD.bazel
index 276db0322..302abfc79 100644
--- a/common/ast/BUILD.bazel
+++ b/common/ast/BUILD.bazel
@@ -11,6 +11,12 @@ java_library(
exports = ["//common/src/main/java/dev/cel/common/ast"],
)
+java_library(
+ name = "cel_block",
+ visibility = ["//:internal"],
+ exports = ["//common/src/main/java/dev/cel/common/ast:cel_block"],
+)
+
cel_android_library(
name = "ast_android",
exports = ["//common/src/main/java/dev/cel/common/ast:ast_android"],
diff --git a/common/src/main/java/dev/cel/common/ast/BUILD.bazel b/common/src/main/java/dev/cel/common/ast/BUILD.bazel
index 3fc709a07..46c235d1f 100644
--- a/common/src/main/java/dev/cel/common/ast/BUILD.bazel
+++ b/common/src/main/java/dev/cel/common/ast/BUILD.bazel
@@ -57,6 +57,20 @@ java_library(
],
)
+java_library(
+ name = "cel_block",
+ srcs = ["CelBlock.java"],
+ tags = [
+ ],
+ deps = [
+ ":ast",
+ "//common:cel_ast",
+ "//common/annotations",
+ "//common/navigation",
+ "@maven//:com_google_guava_guava",
+ ],
+)
+
java_library(
name = "expr_converter",
srcs = EXPR_CONVERTER_SOURCES,
diff --git a/common/src/main/java/dev/cel/common/ast/CelBlock.java b/common/src/main/java/dev/cel/common/ast/CelBlock.java
new file mode 100644
index 000000000..12de6d4dd
--- /dev/null
+++ b/common/src/main/java/dev/cel/common/ast/CelBlock.java
@@ -0,0 +1,144 @@
+// Copyright 2026 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev.cel.common.ast;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import dev.cel.common.CelAbstractSyntaxTree;
+import dev.cel.common.annotations.Internal;
+import dev.cel.common.navigation.CelNavigableExpr;
+import java.util.Optional;
+
+/**
+ * Represents a {@code cel.@block} expression.
+ *
+ *
CEL Block is used by the CSE (Common Subexpression Elimination) optimizer to hoist common
+ * subexpressions into an evaluated block.
+ */
+@Internal
+public final class CelBlock {
+ public static final String FUNCTION_NAME = "cel.@block";
+ public static final String INDEX_PREFIX = "@index";
+
+ private final CelExpr blockExpr;
+
+ private CelBlock(CelExpr blockExpr) {
+ this.blockExpr = blockExpr;
+ }
+
+ public ImmutableList indices() {
+ return blockExpr.call().args().get(0).list().elements();
+ }
+
+ public CelExpr result() {
+ return blockExpr.call().args().get(1);
+ }
+
+ public CelExpr expr() {
+ return blockExpr;
+ }
+
+ /**
+ * Extracts a {@link CelBlock} from the given AST.
+ *
+ * Enforces the contract that {@code cel.@block} must only appear exactly once and at the root
+ * of the AST.
+ *
+ * @throws IllegalArgumentException if the block is malformed or its indices are invalid.
+ */
+ public static Optional extract(CelAbstractSyntaxTree ast) {
+ CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());
+
+ ImmutableList allCelBlocks =
+ celNavigableExpr
+ .allNodes()
+ .map(CelNavigableExpr::expr)
+ .filter(expr -> expr.callOrDefault().function().equals(FUNCTION_NAME))
+ .collect(toImmutableList());
+ if (allCelBlocks.isEmpty()) {
+ return Optional.empty();
+ }
+
+ Preconditions.checkArgument(
+ allCelBlocks.size() == 1,
+ "Expected 1 cel.block function to be present but found %s",
+ allCelBlocks.size());
+ Preconditions.checkArgument(
+ celNavigableExpr.expr().equals(allCelBlocks.get(0)),
+ "Expected cel.block to be present at root");
+
+ return Optional.of(fromExpr(allCelBlocks.get(0)));
+ }
+
+ /**
+ * Constructs a {@link CelBlock} from a {@link CelExpr}.
+ *
+ * @throws IllegalArgumentException if the expression is not a valid block.
+ */
+ private static CelBlock fromExpr(CelExpr expr) {
+ Preconditions.checkArgument(
+ expr.exprKind().getKind() == CelExpr.ExprKind.Kind.CALL,
+ "Expected cel.@block to be a call expression");
+ Preconditions.checkArgument(
+ expr.call().function().equals(FUNCTION_NAME), "Expected function to be cel.@block");
+ Preconditions.checkArgument(
+ expr.call().args().size() == 2, "Expected exactly 2 arguments for cel.@block");
+ Preconditions.checkArgument(
+ expr.call().args().get(0).exprKind().getKind() == CelExpr.ExprKind.Kind.LIST,
+ "Expected first argument of cel.@block to be a list");
+
+ CelBlock block = new CelBlock(expr);
+
+ // Assert correctness on block indices used in subexpressions
+ ImmutableList subexprs = block.indices();
+ for (int i = 0; i < subexprs.size(); i++) {
+ verifyBlockIndex(subexprs.get(i), i, expr);
+ }
+
+ // Assert correctness on block indices used in block result
+ CelExpr blockResult = block.result();
+ verifyBlockIndex(blockResult, subexprs.size(), expr);
+ boolean resultHasAtLeastOneBlockIndex =
+ CelNavigableExpr.fromExpr(blockResult)
+ .allNodes()
+ .map(CelNavigableExpr::expr)
+ .anyMatch(e -> e.identOrDefault().name().startsWith(INDEX_PREFIX));
+ Preconditions.checkArgument(
+ resultHasAtLeastOneBlockIndex,
+ "Expected at least one reference of index in cel.block result");
+
+ return block;
+ }
+
+ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue, CelExpr rootBlock) {
+ boolean areAllIndicesValid =
+ CelNavigableExpr.fromExpr(celExpr)
+ .allNodes()
+ .map(CelNavigableExpr::expr)
+ .filter(expr -> expr.identOrDefault().name().startsWith(INDEX_PREFIX))
+ .map(CelExpr::ident)
+ .allMatch(
+ blockIdent ->
+ Integer.parseInt(blockIdent.name().substring(INDEX_PREFIX.length()))
+ < maxIndexValue);
+ Preconditions.checkArgument(
+ areAllIndicesValid,
+ "Illegal block index found. The index value must be less than %s. Expr: %s",
+ maxIndexValue,
+ rootBlock);
+ }
+}
diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel
index 155ae262d..35476a792 100644
--- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel
+++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel
@@ -60,6 +60,7 @@ java_library(
"//common:mutable_ast",
"//common:mutable_source",
"//common/ast",
+ "//common/ast:cel_block",
"//common/ast:mutable_expr",
"//common/navigation",
"//common/navigation:common",
diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java
index ce9a5dc77..5eebb1c54 100644
--- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java
+++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java
@@ -41,6 +41,7 @@
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelVarDecl;
+import dev.cel.common.ast.CelBlock;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelComprehension;
@@ -238,64 +239,12 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
*/
@VisibleForTesting
static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
- CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());
-
- ImmutableList allCelBlocks =
- celNavigableExpr
- .allNodes()
- .map(CelNavigableExpr::expr)
- .filter(expr -> expr.callOrDefault().function().equals(CEL_BLOCK_FUNCTION))
- .collect(toImmutableList());
- if (allCelBlocks.isEmpty()) {
+ CelBlock celBlock = CelBlock.extract(ast).orElse(null);
+ if (celBlock == null) {
return;
}
- CelExpr celBlockExpr = allCelBlocks.get(0);
- Verify.verify(
- allCelBlocks.size() == 1,
- "Expected 1 cel.block function to be present but found %s",
- allCelBlocks.size());
- Verify.verify(
- celNavigableExpr.expr().equals(celBlockExpr), "Expected cel.block to be present at root");
-
- // Assert correctness on block indices used in subexpressions
- CelCall celBlockCall = celBlockExpr.call();
- ImmutableList subexprs = celBlockCall.args().get(0).list().elements();
- for (int i = 0; i < subexprs.size(); i++) {
- verifyBlockIndex(subexprs.get(i), i);
- }
-
- // Assert correctness on block indices used in block result
- CelExpr blockResult = celBlockCall.args().get(1);
- verifyBlockIndex(blockResult, subexprs.size());
- boolean resultHasAtLeastOneBlockIndex =
- CelNavigableExpr.fromExpr(blockResult)
- .allNodes()
- .map(CelNavigableExpr::expr)
- .anyMatch(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX));
- Verify.verify(
- resultHasAtLeastOneBlockIndex,
- "Expected at least one reference of index in cel.block result");
-
- verifyNoInvalidScopedMangledVariables(celBlockExpr);
- }
-
- private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
- boolean areAllIndicesValid =
- CelNavigableExpr.fromExpr(celExpr)
- .allNodes()
- .map(CelNavigableExpr::expr)
- .filter(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX))
- .map(CelExpr::ident)
- .allMatch(
- blockIdent ->
- Integer.parseInt(blockIdent.name().substring(BLOCK_INDEX_PREFIX.length()))
- < maxIndexValue);
- Verify.verify(
- areAllIndicesValid,
- "Illegal block index found. The index value must be less than %s. Expr: %s",
- maxIndexValue,
- celExpr);
+ verifyNoInvalidScopedMangledVariables(celBlock.expr());
}
private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {
diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java
index e7387d7d8..1a36bd16b 100644
--- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java
+++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java
@@ -18,7 +18,6 @@
import static dev.cel.common.CelOverloadDecl.newGlobalOverload;
import static org.junit.Assert.assertThrows;
-import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.testing.junit.testparameterinjector.TestParameter;
@@ -65,6 +64,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Before;
import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
import org.junit.runner.RunWith;
@RunWith(TestParameterInjector.class)
@@ -377,18 +377,22 @@ public void celBlock_astExtensionTagged() throws Exception {
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME));
}
+ @SuppressWarnings("Immutable") // Test only
private enum BlockTestCase {
- BOOL_LITERAL("cel.block([true, false], index0 || index1)"),
- STRING_CONCAT("cel.block(['a' + 'b', index0 + 'c'], index1 + 'd') == 'abcd'"),
+ BOOL_LITERAL("cel.block([true, false], index0 || index1)", true),
+ STRING_CONCAT("cel.block(['a' + 'b', index0 + 'c'], index1 + 'd')", "abcd"),
- BLOCK_WITH_EXISTS_TRUE("cel.block([[1, 2, 3], [3, 4, 5].exists(e, e in index0)], index1)"),
- BLOCK_WITH_EXISTS_FALSE("cel.block([[1, 2, 3], ![4, 5].exists(e, e in index0)], index1)"),
+ BLOCK_WITH_EXISTS_TRUE(
+ "cel.block([[1, 2, 3], [3, 4, 5].exists(e, e in index0)], index1)", true),
+ BLOCK_WITH_EXISTS_FALSE("cel.block([[1, 2, 3], ![4, 5].exists(e, e in index0)], index1)", true),
;
private final String source;
+ private final Object expectedResult;
- BlockTestCase(String source) {
+ BlockTestCase(String source, Object expectedResult) {
this.source = source;
+ this.expectedResult = expectedResult;
}
}
@@ -398,7 +402,7 @@ public void block_success(@TestParameter BlockTestCase testCase) throws Exceptio
Object evaluatedResult = celForEvaluatingBlock.createProgram(ast).eval();
- assertThat(evaluatedResult).isNotNull();
+ assertThat(evaluatedResult).isEqualTo(testCase.expectedResult);
}
@Test
@@ -411,7 +415,7 @@ public void block_success_parsedOnly(@TestParameter BlockTestCase testCase) thro
Object evaluatedResult = celForEvaluatingBlock.createProgram(ast).eval();
- assertThat(evaluatedResult).isNotNull();
+ assertThat(evaluatedResult).isEqualTo(testCase.expectedResult);
}
@Test
@@ -604,9 +608,10 @@ public void verifyOptimizedAstCorrectness_twoCelBlocks_throws() throws Exception
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions("cel.block([1, 2], cel.block([2], 3))");
- VerifyException e =
+ IllegalArgumentException e =
assertThrows(
- VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
+ IllegalArgumentException.class,
+ () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.isEqualTo("Expected 1 cel.block function to be present but found 2");
@@ -616,9 +621,10 @@ public void verifyOptimizedAstCorrectness_twoCelBlocks_throws() throws Exception
public void verifyOptimizedAstCorrectness_celBlockNotAtRoot_throws() throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("1 + cel.block([1, 2], index0)");
- VerifyException e =
+ IllegalArgumentException e =
assertThrows(
- VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
+ IllegalArgumentException.class,
+ () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e).hasMessageThat().isEqualTo("Expected cel.block to be present at root");
}
@@ -626,9 +632,10 @@ public void verifyOptimizedAstCorrectness_celBlockNotAtRoot_throws() throws Exce
public void verifyOptimizedAstCorrectness_blockContainsNoIndexResult_throws() throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([1, index0], 2)");
- VerifyException e =
+ IllegalArgumentException e =
assertThrows(
- VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
+ IllegalArgumentException.class,
+ () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.isEqualTo("Expected at least one reference of index in cel.block result");
@@ -641,9 +648,10 @@ public void verifyOptimizedAstCorrectness_indexOutOfBounds_throws(String source)
throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source);
- VerifyException e =
+ IllegalArgumentException e =
assertThrows(
- VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
+ IllegalArgumentException.class,
+ () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.contains("Illegal block index found. The index value must be less than");
@@ -658,9 +666,10 @@ public void verifyOptimizedAstCorrectness_indexIsNotForwardReferencing_throws(St
throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source);
- VerifyException e =
+ IllegalArgumentException e =
assertThrows(
- VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
+ IllegalArgumentException.class,
+ () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
assertThat(e)
.hasMessageThat()
.contains("Illegal block index found. The index value must be less than");
@@ -670,9 +679,14 @@ public void verifyOptimizedAstCorrectness_indexIsNotForwardReferencing_throws(St
public void block_containsCycle_throws() throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([index1,index0],index0)");
- CelEvaluationException e =
- assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval());
- assertThat(e).hasMessageThat().contains("Cycle detected: @index0");
+ ThrowingRunnable evaluateProgram = () -> cel.createProgram(ast).eval();
+
+ CelEvaluationException e = assertThrows(CelEvaluationException.class, evaluateProgram);
+ assertThat(e)
+ .hasMessageThat()
+ .containsMatch(
+ "Cycle detected: @index0|Illegal block index found. The index value must be less than"
+ + " 0.");
}
@Test
diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel
index 801e56d73..e82f77c67 100644
--- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel
+++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel
@@ -49,6 +49,7 @@ java_library(
"//common:options",
"//common/annotations",
"//common/ast",
+ "//common/ast:cel_block",
"//common/exceptions:overload_not_found",
"//common/types",
"//common/types:type_providers",
diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java
index e38d08f8f..6bb3d1e22 100644
--- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java
+++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java
@@ -26,6 +26,7 @@
import dev.cel.common.CelOptions;
import dev.cel.common.Operator;
import dev.cel.common.annotations.Internal;
+import dev.cel.common.ast.CelBlock;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
@@ -79,7 +80,11 @@ public Program plan(CelAbstractSyntaxTree ast) throws CelEvaluationException {
ErrorMetadata errorMetadata =
ErrorMetadata.create(ast.getSource().getPositionsMap(), ast.getSource().getDescription());
try {
- plannedInterpretable = plan(ast.getExpr(), PlannerContext.create(ast));
+ PlannerContext ctx = PlannerContext.create(ast);
+ plannedInterpretable =
+ CelBlock.extract(ast)
+ .map(celBlock -> planBlock(celBlock, ctx))
+ .orElseGet(() -> plan(ast.getExpr(), ctx));
} catch (RuntimeException e) {
throw CelEvaluationExceptionBuilder.newBuilder(e.getMessage())
.setMetadata(errorMetadata, ast.getExpr().id())
@@ -231,11 +236,6 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) {
ResolvedFunction resolvedFunction = resolveFunction(expr, ctx.referenceMap());
String functionName = resolvedFunction.functionName();
- PlannedInterpretable blockCall = maybeInterceptBlockCall(functionName, expr, ctx).orElse(null);
- if (blockCall != null) {
- return blockCall;
- }
-
CelExpr target = resolvedFunction.target().orElse(null);
int argCount = expr.call().args().size();
if (target != null) {
@@ -331,26 +331,15 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) {
}
}
- private Optional maybeInterceptBlockCall(
- String functionName, CelExpr expr, PlannerContext ctx) {
- if (!functionName.equals("cel.@block")) {
- return Optional.empty();
- }
-
- CelCall blockCall = expr.call();
-
- if (blockCall.args().size() != 2) {
- throw new IllegalArgumentException(
- "Expected 2 arguments for cel.@block call. Got: " + blockCall.args().size());
- }
+ private PlannedInterpretable planBlock(CelBlock celBlock, PlannerContext ctx) {
+ ImmutableList indices = celBlock.indices();
- CelList exprList = blockCall.args().get(0).list();
- PlannedInterpretable[] slotExprs = new PlannedInterpretable[exprList.elements().size()];
+ PlannedInterpretable[] slotExprs = new PlannedInterpretable[indices.size()];
for (int i = 0; i < slotExprs.length; i++) {
- slotExprs[i] = plan(exprList.elements().get(i), ctx);
+ slotExprs[i] = plan(indices.get(i), ctx);
}
- PlannedInterpretable resultExpr = plan(blockCall.args().get(1), ctx);
- return Optional.of(EvalBlock.create(expr, slotExprs, resultExpr));
+ PlannedInterpretable resultExpr = plan(celBlock.result(), ctx);
+ return EvalBlock.create(celBlock.expr(), slotExprs, resultExpr);
}
/**