Skip to content

Commit 4089a60

Browse files
committed
foo
1 parent a6bb966 commit 4089a60

5 files changed

Lines changed: 221 additions & 12 deletions

File tree

extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,39 @@ public void lazyBinding_boundAttributeInComprehension() throws Exception {
280280
assertThat(result).containsExactly(true, true, true);
281281
assertThat(invocation.get()).isEqualTo(1);
282282
}
283+
284+
@Test
285+
@SuppressWarnings("Immutable") // Test only
286+
public void foo() throws Exception {
287+
CelCompiler celCompiler =
288+
CelCompilerFactory.standardCelCompilerBuilder()
289+
.setStandardMacros(CelStandardMacro.MAP)
290+
.addLibraries(CelExtensions.bindings())
291+
.addFunctionDeclarations(
292+
CelFunctionDecl.newFunctionDeclaration(
293+
"get_true",
294+
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
295+
.build();
296+
AtomicInteger invocation = new AtomicInteger();
297+
CelRuntime celRuntime =
298+
CelRuntimeFactory.standardCelRuntimeBuilder()
299+
.addFunctionBindings(
300+
CelFunctionBinding.from(
301+
"get_true_overload",
302+
ImmutableList.of(),
303+
arg -> {
304+
invocation.getAndIncrement();
305+
return true;
306+
}))
307+
.build();
308+
309+
CelAbstractSyntaxTree ast = celCompiler.compile("cel.bind(x, get_true(), [x, false].map(c0, [c0].map(c1, [c0, x])))").getAst();
310+
311+
Object foo = celRuntime.createProgram(ast).eval();
312+
System.out.println(foo);
313+
List<Boolean> result = (List<Boolean>) celRuntime.createProgram(ast).eval();
314+
315+
assertThat(result).containsExactly(true, true, true);
316+
assertThat(invocation.get()).isEqualTo(1);
317+
}
283318
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818
import static com.google.common.collect.ImmutableList.toImmutableList;
19+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
1920
import static java.util.stream.Collectors.toCollection;
2021

2122
import com.google.auto.value.AutoValue;
2223
import com.google.common.annotations.VisibleForTesting;
2324
import com.google.common.base.Preconditions;
25+
import com.google.common.base.Strings;
2426
import com.google.common.base.Verify;
2527
import com.google.common.collect.ImmutableList;
2628
import com.google.common.collect.ImmutableSet;
@@ -41,8 +43,11 @@
4143
import dev.cel.common.CelVarDecl;
4244
import dev.cel.common.ast.CelExpr;
4345
import dev.cel.common.ast.CelExpr.CelCall;
46+
import dev.cel.common.ast.CelExpr.CelComprehension;
47+
import dev.cel.common.ast.CelExpr.CelList;
4448
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
4549
import dev.cel.common.ast.CelMutableExpr;
50+
import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension;
4651
import dev.cel.common.ast.CelMutableExprConverter;
4752
import dev.cel.common.navigation.CelNavigableExpr;
4853
import dev.cel.common.navigation.CelNavigableMutableAst;
@@ -59,7 +64,9 @@
5964
import java.util.Comparator;
6065
import java.util.HashSet;
6166
import java.util.List;
67+
import java.util.Objects;
6268
import java.util.Set;
69+
import java.util.stream.Stream;
6370

6471
/**
6572
* Performs Common Subexpression Elimination.
@@ -90,14 +97,18 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9097
private static final SubexpressionOptimizer INSTANCE =
9198
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
9299
private static final String BIND_IDENTIFIER_PREFIX = "@r";
93-
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
94-
private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
95-
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
96100
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
97101
private static final String BLOCK_INDEX_PREFIX = "@index";
98102
private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
99103
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);
100104

105+
@VisibleForTesting
106+
static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
107+
@VisibleForTesting
108+
static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
109+
@VisibleForTesting
110+
static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
111+
101112
private final SubexpressionOptimizerOptions cseOptions;
102113
private final AstMutator astMutator;
103114
private final ImmutableSet<String> cseEliminableFunctions;
@@ -269,6 +280,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
269280
Verify.verify(
270281
resultHasAtLeastOneBlockIndex,
271282
"Expected at least one reference of index in cel.block result");
283+
284+
verifyNoInvalidScopedMangledVariables(celBlockExpr);
272285
}
273286

274287
private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
@@ -289,6 +302,69 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
289302
celExpr);
290303
}
291304

305+
private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {
306+
CelCall celBlockCall = celExpr.call();
307+
CelExpr blockBody = celBlockCall.args().get(1);
308+
309+
ImmutableSet<String> allMangledVariablesInBlockBody =
310+
CelNavigableExpr.fromExpr(blockBody)
311+
.allNodes()
312+
.map(CelNavigableExpr::expr)
313+
.flatMap(SubexpressionOptimizer::extractMangledNames)
314+
.collect(toImmutableSet());
315+
316+
CelList blockIndices = celBlockCall.args().get(0).list();
317+
for (CelExpr blockIndex : blockIndices.elements()) {
318+
ImmutableSet<String> indexDeclaredCompVariables =
319+
CelNavigableExpr.fromExpr(blockIndex)
320+
.allNodes()
321+
.map(CelNavigableExpr::expr)
322+
.filter(expr -> expr.getKind() == Kind.COMPREHENSION)
323+
.map(CelExpr::comprehension)
324+
.flatMap(comp -> Stream.of(
325+
comp.iterVar(),
326+
comp.iterVar2()
327+
))
328+
.filter(iter -> !Strings.isNullOrEmpty(iter))
329+
.collect(toImmutableSet());
330+
331+
boolean containsIllegalDeclaration =
332+
CelNavigableExpr.fromExpr(blockIndex)
333+
.allNodes()
334+
.map(CelNavigableExpr::expr)
335+
.filter(expr -> expr.getKind() == Kind.IDENT)
336+
.map(expr -> expr.ident().name())
337+
.filter(SubexpressionOptimizer::isMangled)
338+
.anyMatch(ident ->
339+
!indexDeclaredCompVariables.contains(ident) &&
340+
allMangledVariablesInBlockBody.contains(ident));
341+
342+
Verify.verify(
343+
!containsIllegalDeclaration,
344+
"Illegal declared reference to a comprehension variable found in block indices. Expr: %s",
345+
celExpr);
346+
}
347+
}
348+
349+
private static Stream<String> extractMangledNames(CelExpr expr) {
350+
if (expr.getKind() == Kind.IDENT) {
351+
String name = expr.ident().name();
352+
return isMangled(name) ? Stream.of(name) : Stream.empty();
353+
}
354+
if (expr.getKind() == Kind.COMPREHENSION) {
355+
CelComprehension comp = expr.comprehension();
356+
return Stream.of(comp.iterVar(), comp.iterVar2(), comp.accuVar())
357+
.filter(Objects::nonNull) // Handle potential null/empty iterVar2
358+
.filter(SubexpressionOptimizer::isMangled);
359+
}
360+
return Stream.empty();
361+
}
362+
363+
private static boolean isMangled(String name) {
364+
return name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
365+
|| name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX);
366+
}
367+
292368
private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) {
293369
// Tag the extension
294370
CelSource.Builder celSourceBuilder =
@@ -355,8 +431,8 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
355431
navAst
356432
.getRoot()
357433
.descendants(TraversalOrder.PRE_ORDER)
358-
.filter(node -> canEliminate(node, ineligibleExprs))
359434
.filter(node -> node.height() <= recursionLimit)
435+
.filter(node -> canEliminate(node, ineligibleExprs))
360436
.sorted(Comparator.comparingInt(CelNavigableMutableExpr::height).reversed())
361437
.collect(toImmutableList());
362438
if (descendants.isEmpty()) {
@@ -441,9 +517,44 @@ private boolean canEliminate(
441517
&& navigableExpr.expr().list().elements().isEmpty())
442518
&& containsEliminableFunctionOnly(navigableExpr)
443519
&& !ineligibleExprs.contains(navigableExpr.expr())
444-
&& containsComprehensionIdentInSubexpr(navigableExpr);
520+
&& containsComprehensionIdentInSubexpr(navigableExpr)
521+
&& containsProperScopedComprehensionIdents(navigableExpr);
522+
}
523+
524+
private boolean containsProperScopedComprehensionIdents(CelNavigableMutableExpr navExpr) {
525+
if (!navExpr.getKind().equals(Kind.COMPREHENSION)) {
526+
return true;
527+
}
528+
529+
// For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner comprehension [2].exists(y, x == y)
530+
// should not be extracted out into a block index, as it causes issues with scoping.
531+
ImmutableSet<String> mangledIterVars = navExpr.descendants()
532+
.filter(x -> x.getKind().equals(Kind.IDENT))
533+
.map(x -> x.expr().ident().name())
534+
.filter(name ->
535+
name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX) ||
536+
name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX)
537+
).collect(toImmutableSet());
538+
539+
CelNavigableMutableExpr parent = navExpr.parent().orElse(null);
540+
while (parent != null) {
541+
if (parent.getKind().equals(Kind.COMPREHENSION)) {
542+
CelMutableComprehension comp = parent.expr().comprehension();
543+
boolean containsParentIterReferences =
544+
mangledIterVars.contains(comp.iterVar()) || mangledIterVars.contains(comp.iterVar2());
545+
546+
if (containsParentIterReferences) {
547+
return false;
548+
}
549+
}
550+
551+
parent = parent.parent().orElse(null);
552+
}
553+
554+
return true;
445555
}
446556

557+
447558
private boolean containsComprehensionIdentInSubexpr(CelNavigableMutableExpr navExpr) {
448559
if (navExpr.getKind().equals(Kind.COMPREHENSION)) {
449560
return true;

optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ java_library(
1818
"//common:mutable_ast",
1919
"//common:options",
2020
"//common/ast",
21+
"//common/ast:mutable_expr",
2122
"//common/navigation:mutable_navigation",
2223
"//common/types",
2324
"//extensions",
2425
"//extensions:optional_library",
2526
"//optimizer",
27+
"//optimizer:mutable_ast",
2628
"//optimizer:optimization_exception",
2729
"//optimizer:optimizer_builder",
2830
"//optimizer/optimizers:common_subexpression_elimination",

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@
3838
import dev.cel.common.CelValidationException;
3939
import dev.cel.common.CelVarDecl;
4040
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
41+
import dev.cel.common.ast.CelMutableExpr;
4142
import dev.cel.common.navigation.CelNavigableMutableAst;
4243
import dev.cel.common.navigation.CelNavigableMutableExpr;
4344
import dev.cel.common.types.ListType;
4445
import dev.cel.common.types.SimpleType;
4546
import dev.cel.common.types.StructTypeReference;
4647
import dev.cel.expr.conformance.proto3.TestAllTypes;
4748
import dev.cel.extensions.CelExtensions;
49+
import dev.cel.optimizer.AstMutator;
50+
import dev.cel.optimizer.AstMutator.MangledComprehensionAst;
51+
import dev.cel.optimizer.AstMutator.MangledComprehensionName;
4852
import dev.cel.optimizer.CelOptimizationException;
4953
import dev.cel.optimizer.CelOptimizer;
5054
import dev.cel.optimizer.CelOptimizerFactory;
@@ -91,9 +95,11 @@ public class SubexpressionOptimizerTest {
9195
CelVarDecl.newVarDeclaration("index0", SimpleType.DYN),
9296
CelVarDecl.newVarDeclaration("index1", SimpleType.DYN),
9397
CelVarDecl.newVarDeclaration("index2", SimpleType.DYN),
98+
CelVarDecl.newVarDeclaration("it", SimpleType.DYN),
9499
CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN),
95100
CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN),
96-
CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN))
101+
CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN),
102+
CelVarDecl.newVarDeclaration("@it:0:0", SimpleType.DYN))
97103
.addMessageTypes(TestAllTypes.getDescriptor())
98104
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
99105
.build();
@@ -285,6 +291,27 @@ public void iterationLimitReached_throws() throws Exception {
285291
assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached.");
286292
}
287293

294+
@Test
295+
public void foo() throws Exception {
296+
CelAbstractSyntaxTree ast = CEL.compile("[1].map(y, [1, 2].filter(x, x == y))").getAst();
297+
CelOptimizer optimizer =
298+
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
299+
.addAstOptimizers(
300+
SubexpressionOptimizer.newInstance(
301+
SubexpressionOptimizerOptions.newBuilder()
302+
.subexpressionMaxRecursionDepth(4)
303+
.populateMacroCalls(true).build()))
304+
.build();
305+
306+
CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast);
307+
308+
Object result = CEL.createProgram(ast).eval();
309+
System.out.println(result);
310+
311+
assertThat(CEL_UNPARSER.unparse(optimizedAst))
312+
.isEqualTo("foo");
313+
}
314+
288315
@Test
289316
public void celBlock_astExtensionTagged() throws Exception {
290317
CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst();
@@ -478,9 +505,7 @@ public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws
478505
// Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true]))
479506
CelAbstractSyntaxTree ast =
480507
compileUsingInternalFunctions(
481-
"cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0,"
482-
+ " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true,"
483-
+ " true, true]]]");
508+
"cel.block([true, false, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [c0, c1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true, true, true]]]");
484509

485510
boolean result = (boolean) celRuntime.createProgram(ast).eval();
486511

@@ -547,6 +572,18 @@ public void verifyOptimizedAstCorrectness_blockContainsNoIndexResult_throws() th
547572
.isEqualTo("Expected at least one reference of index in cel.block result");
548573
}
549574

575+
@Test
576+
public void verifyOptimizedAstCorrectness_containsForwardReferenceFromComprehensionVar_throws() throws Exception {
577+
CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([it], [1].exists(it, it > 0 && index0 > 0))");
578+
579+
VerifyException e =
580+
assertThrows(
581+
VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast));
582+
assertThat(e)
583+
.hasMessageThat()
584+
.startsWith("Illegal declared reference to a comprehension variable found in block indices.");
585+
}
586+
550587
@Test
551588
@TestParameters("{source: 'cel.block([], index0)'}")
552589
@TestParameters("{source: 'cel.block([1, 2], index2)'}")
@@ -600,13 +637,37 @@ private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expres
600637
.allNodes()
601638
.filter(node -> node.getKind().equals(Kind.IDENT))
602639
.map(CelNavigableMutableExpr::expr)
603-
.filter(expr -> expr.ident().name().startsWith("index"))
640+
.filter(expr ->
641+
expr.ident().name().startsWith("index")
642+
)
604643
.forEach(
605644
indexExpr -> {
606645
String internalIdentName = "@" + indexExpr.ident().name();
607646
indexExpr.ident().setName(internalIdentName);
608647
});
609648

649+
MangledComprehensionAst mangledComprehensionAst = AstMutator.newInstance(10000).mangleComprehensionIdentifierNames(
650+
mutableAst,
651+
SubexpressionOptimizer.MANGLED_COMPREHENSION_ITER_VAR_PREFIX,
652+
SubexpressionOptimizer.MANGLED_COMPREHENSION_ITER_VAR2_PREFIX,
653+
SubexpressionOptimizer.MANGLED_COMPREHENSION_ACCU_VAR_PREFIX
654+
);
655+
mutableAst = mangledComprehensionAst.mutableAst();
656+
657+
CelNavigableMutableAst.fromAst(mutableAst)
658+
.getRoot()
659+
.allNodes()
660+
.filter(node -> node.getKind().equals(Kind.IDENT))
661+
.map(CelNavigableMutableExpr::expr)
662+
.filter(expr ->
663+
expr.ident().name().equals("it")
664+
)
665+
.forEach(
666+
indexExpr -> {
667+
indexExpr.ident().setName("@it:0:0");
668+
});
669+
670+
610671
return CEL_FOR_EVALUATING_BLOCK.check(mutableAst.toParsedAst()).getAst();
611672
}
612673
}

runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ DefaultInterpreter.IntermediateResult resolveSimpleName(String name, Long exprId
116116
}
117117

118118
void cacheLazilyEvaluatedResult(String name, DefaultInterpreter.IntermediateResult result) {
119-
throw new IllegalStateException("Internal error: Lazy attributes can only be cached in ScopedResolver.");
119+
// throw new IllegalStateException("Internal error: Lazy attributes can only be cached in ScopedResolver.");
120120
}
121121

122122
void declareLazyAttribute(String attrName) {
123-
throw new IllegalStateException("Internal error: Lazy attributes can only be declared in ScopedResolver.");
123+
// throw new IllegalStateException("Internal error: Lazy attributes can only be declared in ScopedResolver.");
124124
}
125125

126126
/**

0 commit comments

Comments
 (0)