Skip to content

Commit 878d7a1

Browse files
l46kokcopybara-github
authored andcommitted
Create CelBlock abstraction to centralize cel.@block logic
PiperOrigin-RevId: 940123116
1 parent 3e7dea1 commit 878d7a1

8 files changed

Lines changed: 196 additions & 79 deletions

File tree

common/ast/BUILD.bazel

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ java_library(
1111
exports = ["//common/src/main/java/dev/cel/common/ast"],
1212
)
1313

14+
java_library(
15+
name = "cel_block",
16+
exports = ["//common/src/main/java/dev/cel/common/ast:cel_block"],
17+
)
18+
1419
cel_android_library(
1520
name = "ast_android",
1621
exports = ["//common/src/main/java/dev/cel/common/ast:ast_android"],

common/src/main/java/dev/cel/common/ast/BUILD.bazel

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ java_library(
5757
],
5858
)
5959

60+
java_library(
61+
name = "cel_block",
62+
srcs = ["CelBlock.java"],
63+
tags = [
64+
],
65+
deps = [
66+
":ast",
67+
"//common:cel_ast",
68+
"//common/annotations",
69+
"//common/navigation",
70+
"@maven//:com_google_guava_guava",
71+
],
72+
)
73+
6074
java_library(
6175
name = "expr_converter",
6276
srcs = EXPR_CONVERTER_SOURCES,
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dev.cel.common.ast;
16+
17+
import com.google.common.base.Preconditions;
18+
import com.google.common.collect.ImmutableList;
19+
import dev.cel.common.CelAbstractSyntaxTree;
20+
import dev.cel.common.annotations.Internal;
21+
import dev.cel.common.navigation.CelNavigableExpr;
22+
import java.util.Optional;
23+
24+
/**
25+
* Represents a {@code cel.@block} expression.
26+
*
27+
* <p>CEL Block is used by the CSE (Common Subexpression Elimination) optimizer to hoist common
28+
* subexpressions into an evaluated block.
29+
*/
30+
@Internal
31+
public final class CelBlock {
32+
public static final String FUNCTION_NAME = "cel.@block";
33+
public static final String INDEX_PREFIX = "@index";
34+
35+
private final CelExpr blockExpr;
36+
37+
private CelBlock(CelExpr blockExpr) {
38+
this.blockExpr = blockExpr;
39+
}
40+
41+
public ImmutableList<CelExpr> indices() {
42+
return blockExpr.call().args().get(0).list().elements();
43+
}
44+
45+
public CelExpr result() {
46+
return blockExpr.call().args().get(1);
47+
}
48+
49+
public CelExpr expr() {
50+
return blockExpr;
51+
}
52+
53+
/**
54+
* Extracts a {@link CelBlock} from the given AST.
55+
*
56+
* <p>Enforces the contract that {@code cel.@block} must only appear exactly once and at the root
57+
* of the AST.
58+
*
59+
* @throws IllegalArgumentException if the block is malformed or its indices are invalid.
60+
*/
61+
public static Optional<CelBlock> extract(CelAbstractSyntaxTree ast) {
62+
CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());
63+
64+
ImmutableList<CelExpr> allCelBlocks =
65+
celNavigableExpr
66+
.allNodes()
67+
.map(CelNavigableExpr::expr)
68+
.filter(expr -> expr.callOrDefault().function().equals(FUNCTION_NAME))
69+
.collect(ImmutableList.toImmutableList());
70+
if (allCelBlocks.isEmpty()) {
71+
return Optional.empty();
72+
}
73+
74+
Preconditions.checkArgument(
75+
allCelBlocks.size() == 1,
76+
"Expected 1 cel.block function to be present but found %s",
77+
allCelBlocks.size());
78+
Preconditions.checkArgument(
79+
celNavigableExpr.expr().equals(allCelBlocks.get(0)),
80+
"Expected cel.block to be present at root");
81+
82+
return Optional.of(fromExpr(allCelBlocks.get(0)));
83+
}
84+
85+
/**
86+
* Constructs a {@link CelBlock} from a {@link CelExpr}.
87+
*
88+
* @throws IllegalArgumentException if the expression is not a valid block.
89+
*/
90+
public static CelBlock fromExpr(CelExpr expr) {
91+
Preconditions.checkArgument(
92+
expr.exprKind().getKind() == CelExpr.ExprKind.Kind.CALL,
93+
"Expected cel.@block to be a call expression");
94+
Preconditions.checkArgument(
95+
expr.call().function().equals(FUNCTION_NAME), "Expected function to be cel.@block");
96+
Preconditions.checkArgument(
97+
expr.call().args().size() == 2, "Expected exactly 2 arguments for cel.@block");
98+
Preconditions.checkArgument(
99+
expr.call().args().get(0).exprKind().getKind() == CelExpr.ExprKind.Kind.LIST,
100+
"Expected first argument of cel.@block to be a list");
101+
102+
CelBlock block = new CelBlock(expr);
103+
104+
// Assert correctness on block indices used in subexpressions
105+
ImmutableList<CelExpr> subexprs = block.indices();
106+
for (int i = 0; i < subexprs.size(); i++) {
107+
verifyBlockIndex(subexprs.get(i), i, expr);
108+
}
109+
110+
// Assert correctness on block indices used in block result
111+
CelExpr blockResult = block.result();
112+
verifyBlockIndex(blockResult, subexprs.size(), expr);
113+
boolean resultHasAtLeastOneBlockIndex =
114+
CelNavigableExpr.fromExpr(blockResult)
115+
.allNodes()
116+
.map(CelNavigableExpr::expr)
117+
.anyMatch(e -> e.identOrDefault().name().startsWith(INDEX_PREFIX));
118+
Preconditions.checkArgument(
119+
resultHasAtLeastOneBlockIndex,
120+
"Expected at least one reference of index in cel.block result");
121+
122+
return block;
123+
}
124+
125+
private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue, CelExpr rootBlock) {
126+
boolean areAllIndicesValid =
127+
CelNavigableExpr.fromExpr(celExpr)
128+
.allNodes()
129+
.map(CelNavigableExpr::expr)
130+
.filter(expr -> expr.identOrDefault().name().startsWith(INDEX_PREFIX))
131+
.map(CelExpr::ident)
132+
.allMatch(
133+
blockIdent ->
134+
Integer.parseInt(blockIdent.name().substring(INDEX_PREFIX.length()))
135+
< maxIndexValue);
136+
Preconditions.checkArgument(
137+
areAllIndicesValid,
138+
"Illegal block index found. The index value must be less than %s. Expr: %s",
139+
maxIndexValue,
140+
rootBlock);
141+
}
142+
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ java_library(
6060
"//common:mutable_ast",
6161
"//common:mutable_source",
6262
"//common/ast",
63+
"//common/ast:cel_block",
6364
"//common/ast:mutable_expr",
6465
"//common/navigation",
6566
"//common/navigation:common",

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import dev.cel.common.CelSource.Extension.Version;
4242
import dev.cel.common.CelValidationException;
4343
import dev.cel.common.CelVarDecl;
44+
import dev.cel.common.ast.CelBlock;
4445
import dev.cel.common.ast.CelExpr;
4546
import dev.cel.common.ast.CelExpr.CelCall;
4647
import dev.cel.common.ast.CelExpr.CelComprehension;
@@ -64,6 +65,7 @@
6465
import java.util.Comparator;
6566
import java.util.HashSet;
6667
import java.util.List;
68+
import java.util.Optional;
6769
import java.util.Set;
6870
import java.util.stream.Stream;
6971

@@ -238,65 +240,15 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
238240
*/
239241
@VisibleForTesting
240242
static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
241-
CelNavigableExpr celNavigableExpr = CelNavigableExpr.fromExpr(ast.getExpr());
242-
243-
ImmutableList<CelExpr> allCelBlocks =
244-
celNavigableExpr
245-
.allNodes()
246-
.map(CelNavigableExpr::expr)
247-
.filter(expr -> expr.callOrDefault().function().equals(CEL_BLOCK_FUNCTION))
248-
.collect(toImmutableList());
249-
if (allCelBlocks.isEmpty()) {
243+
Optional<CelBlock> celBlockOpt = CelBlock.extract(ast);
244+
if (!celBlockOpt.isPresent()) {
250245
return;
251246
}
252247

253-
CelExpr celBlockExpr = allCelBlocks.get(0);
254-
Verify.verify(
255-
allCelBlocks.size() == 1,
256-
"Expected 1 cel.block function to be present but found %s",
257-
allCelBlocks.size());
258-
Verify.verify(
259-
celNavigableExpr.expr().equals(celBlockExpr), "Expected cel.block to be present at root");
260-
261-
// Assert correctness on block indices used in subexpressions
262-
CelCall celBlockCall = celBlockExpr.call();
263-
ImmutableList<CelExpr> subexprs = celBlockCall.args().get(0).list().elements();
264-
for (int i = 0; i < subexprs.size(); i++) {
265-
verifyBlockIndex(subexprs.get(i), i);
266-
}
267-
268-
// Assert correctness on block indices used in block result
269-
CelExpr blockResult = celBlockCall.args().get(1);
270-
verifyBlockIndex(blockResult, subexprs.size());
271-
boolean resultHasAtLeastOneBlockIndex =
272-
CelNavigableExpr.fromExpr(blockResult)
273-
.allNodes()
274-
.map(CelNavigableExpr::expr)
275-
.anyMatch(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX));
276-
Verify.verify(
277-
resultHasAtLeastOneBlockIndex,
278-
"Expected at least one reference of index in cel.block result");
279-
280-
verifyNoInvalidScopedMangledVariables(celBlockExpr);
248+
verifyNoInvalidScopedMangledVariables(celBlockOpt.get().expr());
281249
}
282250

283-
private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
284-
boolean areAllIndicesValid =
285-
CelNavigableExpr.fromExpr(celExpr)
286-
.allNodes()
287-
.map(CelNavigableExpr::expr)
288-
.filter(expr -> expr.identOrDefault().name().startsWith(BLOCK_INDEX_PREFIX))
289-
.map(CelExpr::ident)
290-
.allMatch(
291-
blockIdent ->
292-
Integer.parseInt(blockIdent.name().substring(BLOCK_INDEX_PREFIX.length()))
293-
< maxIndexValue);
294-
Verify.verify(
295-
areAllIndicesValid,
296-
"Illegal block index found. The index value must be less than %s. Expr: %s",
297-
maxIndexValue,
298-
celExpr);
299-
}
251+
300252

301253
private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {
302254
CelCall celBlockCall = celExpr.call();

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import static dev.cel.common.CelOverloadDecl.newGlobalOverload;
1919
import static org.junit.Assert.assertThrows;
2020

21-
import com.google.common.base.VerifyException;
2221
import com.google.common.collect.ImmutableList;
2322
import com.google.common.collect.ImmutableMap;
2423
import com.google.testing.junit.testparameterinjector.TestParameter;
@@ -604,9 +603,10 @@ public void verifyOptimizedAstCorrectness_twoCelBlocks_throws() throws Exception
604603
CelAbstractSyntaxTree ast =
605604
compileUsingInternalFunctions("cel.block([1, 2], cel.block([2], 3))");
606605

607-
VerifyException e =
606+
IllegalArgumentException e =
608607
assertThrows(
609-
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
608+
IllegalArgumentException.class,
609+
() -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
610610
assertThat(e)
611611
.hasMessageThat()
612612
.isEqualTo("Expected 1 cel.block function to be present but found 2");
@@ -616,19 +616,21 @@ public void verifyOptimizedAstCorrectness_twoCelBlocks_throws() throws Exception
616616
public void verifyOptimizedAstCorrectness_celBlockNotAtRoot_throws() throws Exception {
617617
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("1 + cel.block([1, 2], index0)");
618618

619-
VerifyException e =
619+
IllegalArgumentException e =
620620
assertThrows(
621-
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
621+
IllegalArgumentException.class,
622+
() -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
622623
assertThat(e).hasMessageThat().isEqualTo("Expected cel.block to be present at root");
623624
}
624625

625626
@Test
626627
public void verifyOptimizedAstCorrectness_blockContainsNoIndexResult_throws() throws Exception {
627628
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([1, index0], 2)");
628629

629-
VerifyException e =
630+
IllegalArgumentException e =
630631
assertThrows(
631-
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
632+
IllegalArgumentException.class,
633+
() -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
632634
assertThat(e)
633635
.hasMessageThat()
634636
.isEqualTo("Expected at least one reference of index in cel.block result");
@@ -641,9 +643,10 @@ public void verifyOptimizedAstCorrectness_indexOutOfBounds_throws(String source)
641643
throws Exception {
642644
CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source);
643645

644-
VerifyException e =
646+
IllegalArgumentException e =
645647
assertThrows(
646-
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
648+
IllegalArgumentException.class,
649+
() -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
647650
assertThat(e)
648651
.hasMessageThat()
649652
.contains("Illegal block index found. The index value must be less than");
@@ -658,9 +661,10 @@ public void verifyOptimizedAstCorrectness_indexIsNotForwardReferencing_throws(St
658661
throws Exception {
659662
CelAbstractSyntaxTree ast = compileUsingInternalFunctions(source);
660663

661-
VerifyException e =
664+
IllegalArgumentException e =
662665
assertThrows(
663-
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
666+
IllegalArgumentException.class,
667+
() -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
664668
assertThat(e)
665669
.hasMessageThat()
666670
.contains("Illegal block index found. The index value must be less than");
@@ -670,9 +674,11 @@ public void verifyOptimizedAstCorrectness_indexIsNotForwardReferencing_throws(St
670674
public void block_containsCycle_throws() throws Exception {
671675
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([index1,index0],index0)");
672676

673-
CelEvaluationException e =
674-
assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval());
675-
assertThat(e).hasMessageThat().contains("Cycle detected: @index0");
677+
IllegalArgumentException e =
678+
assertThrows(IllegalArgumentException.class, () -> cel.createProgram(ast));
679+
assertThat(e)
680+
.hasMessageThat()
681+
.contains("Illegal block index found. The index value must be less than 0. Got: 1");
676682
}
677683

678684
@Test

runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ java_library(
4949
"//common:options",
5050
"//common/annotations",
5151
"//common/ast",
52+
"//common/ast:cel_block",
5253
"//common/exceptions:overload_not_found",
5354
"//common/types",
5455
"//common/types:type_providers",

0 commit comments

Comments
 (0)