1616
1717import static com .google .common .base .Preconditions .checkNotNull ;
1818import static com .google .common .collect .ImmutableList .toImmutableList ;
19+ import static com .google .common .collect .ImmutableSet .toImmutableSet ;
1920import static java .util .stream .Collectors .toCollection ;
2021
2122import com .google .auto .value .AutoValue ;
2223import com .google .common .annotations .VisibleForTesting ;
2324import com .google .common .base .Preconditions ;
25+ import com .google .common .base .Strings ;
2426import com .google .common .base .Verify ;
2527import com .google .common .collect .ImmutableList ;
2628import com .google .common .collect .ImmutableSet ;
4143import dev .cel .common .CelVarDecl ;
4244import dev .cel .common .ast .CelExpr ;
4345import dev .cel .common .ast .CelExpr .CelCall ;
46+ import dev .cel .common .ast .CelExpr .CelComprehension ;
47+ import dev .cel .common .ast .CelExpr .CelList ;
4448import dev .cel .common .ast .CelExpr .ExprKind .Kind ;
4549import dev .cel .common .ast .CelMutableExpr ;
50+ import dev .cel .common .ast .CelMutableExpr .CelMutableComprehension ;
4651import dev .cel .common .ast .CelMutableExprConverter ;
4752import dev .cel .common .navigation .CelNavigableExpr ;
4853import dev .cel .common .navigation .CelNavigableMutableAst ;
5964import java .util .Comparator ;
6065import java .util .HashSet ;
6166import java .util .List ;
67+ import java .util .Objects ;
6268import java .util .Set ;
69+ import java .util .stream .Stream ;
6370
6471/**
6572 * Performs Common Subexpression Elimination.
@@ -90,14 +97,18 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9097 private static final SubexpressionOptimizer INSTANCE =
9198 new SubexpressionOptimizer (SubexpressionOptimizerOptions .newBuilder ().build ());
9299 private static final String BIND_IDENTIFIER_PREFIX = "@r" ;
93- private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it" ;
94- private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2" ;
95- private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac" ;
96100 private static final String CEL_BLOCK_FUNCTION = "cel.@block" ;
97101 private static final String BLOCK_INDEX_PREFIX = "@index" ;
98102 private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
99103 Extension .create ("cel_block" , Version .of (1L , 1L ), Component .COMPONENT_RUNTIME );
100104
105+ @ VisibleForTesting
106+ static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it" ;
107+ @ VisibleForTesting
108+ static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2" ;
109+ @ VisibleForTesting
110+ static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac" ;
111+
101112 private final SubexpressionOptimizerOptions cseOptions ;
102113 private final AstMutator astMutator ;
103114 private final ImmutableSet <String > cseEliminableFunctions ;
@@ -269,6 +280,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
269280 Verify .verify (
270281 resultHasAtLeastOneBlockIndex ,
271282 "Expected at least one reference of index in cel.block result" );
283+
284+ verifyNoInvalidScopedMangledVariables (celBlockExpr );
272285 }
273286
274287 private static void verifyBlockIndex (CelExpr celExpr , int maxIndexValue ) {
@@ -289,6 +302,69 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
289302 celExpr );
290303 }
291304
305+ private static void verifyNoInvalidScopedMangledVariables (CelExpr celExpr ) {
306+ CelCall celBlockCall = celExpr .call ();
307+ CelExpr blockBody = celBlockCall .args ().get (1 );
308+
309+ ImmutableSet <String > allMangledVariablesInBlockBody =
310+ CelNavigableExpr .fromExpr (blockBody )
311+ .allNodes ()
312+ .map (CelNavigableExpr ::expr )
313+ .flatMap (SubexpressionOptimizer ::extractMangledNames )
314+ .collect (toImmutableSet ());
315+
316+ CelList blockIndices = celBlockCall .args ().get (0 ).list ();
317+ for (CelExpr blockIndex : blockIndices .elements ()) {
318+ ImmutableSet <String > indexDeclaredCompVariables =
319+ CelNavigableExpr .fromExpr (blockIndex )
320+ .allNodes ()
321+ .map (CelNavigableExpr ::expr )
322+ .filter (expr -> expr .getKind () == Kind .COMPREHENSION )
323+ .map (CelExpr ::comprehension )
324+ .flatMap (comp -> Stream .of (
325+ comp .iterVar (),
326+ comp .iterVar2 ()
327+ ))
328+ .filter (iter -> !Strings .isNullOrEmpty (iter ))
329+ .collect (toImmutableSet ());
330+
331+ boolean containsIllegalDeclaration =
332+ CelNavigableExpr .fromExpr (blockIndex )
333+ .allNodes ()
334+ .map (CelNavigableExpr ::expr )
335+ .filter (expr -> expr .getKind () == Kind .IDENT )
336+ .map (expr -> expr .ident ().name ())
337+ .filter (SubexpressionOptimizer ::isMangled )
338+ .anyMatch (ident ->
339+ !indexDeclaredCompVariables .contains (ident ) &&
340+ allMangledVariablesInBlockBody .contains (ident ));
341+
342+ Verify .verify (
343+ !containsIllegalDeclaration ,
344+ "Illegal declared reference to a comprehension variable found in block indices. Expr: %s" ,
345+ celExpr );
346+ }
347+ }
348+
349+ private static Stream <String > extractMangledNames (CelExpr expr ) {
350+ if (expr .getKind () == Kind .IDENT ) {
351+ String name = expr .ident ().name ();
352+ return isMangled (name ) ? Stream .of (name ) : Stream .empty ();
353+ }
354+ if (expr .getKind () == Kind .COMPREHENSION ) {
355+ CelComprehension comp = expr .comprehension ();
356+ return Stream .of (comp .iterVar (), comp .iterVar2 (), comp .accuVar ())
357+ .filter (Objects ::nonNull ) // Handle potential null/empty iterVar2
358+ .filter (SubexpressionOptimizer ::isMangled );
359+ }
360+ return Stream .empty ();
361+ }
362+
363+ private static boolean isMangled (String name ) {
364+ return name .startsWith (MANGLED_COMPREHENSION_ITER_VAR_PREFIX )
365+ || name .startsWith (MANGLED_COMPREHENSION_ITER_VAR2_PREFIX );
366+ }
367+
292368 private static CelAbstractSyntaxTree tagAstExtension (CelAbstractSyntaxTree ast ) {
293369 // Tag the extension
294370 CelSource .Builder celSourceBuilder =
@@ -355,8 +431,8 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
355431 navAst
356432 .getRoot ()
357433 .descendants (TraversalOrder .PRE_ORDER )
358- .filter (node -> canEliminate (node , ineligibleExprs ))
359434 .filter (node -> node .height () <= recursionLimit )
435+ .filter (node -> canEliminate (node , ineligibleExprs ))
360436 .sorted (Comparator .comparingInt (CelNavigableMutableExpr ::height ).reversed ())
361437 .collect (toImmutableList ());
362438 if (descendants .isEmpty ()) {
@@ -441,9 +517,44 @@ private boolean canEliminate(
441517 && navigableExpr .expr ().list ().elements ().isEmpty ())
442518 && containsEliminableFunctionOnly (navigableExpr )
443519 && !ineligibleExprs .contains (navigableExpr .expr ())
444- && containsComprehensionIdentInSubexpr (navigableExpr );
520+ && containsComprehensionIdentInSubexpr (navigableExpr )
521+ && containsProperScopedComprehensionIdents (navigableExpr );
522+ }
523+
524+ private boolean containsProperScopedComprehensionIdents (CelNavigableMutableExpr navExpr ) {
525+ if (!navExpr .getKind ().equals (Kind .COMPREHENSION )) {
526+ return true ;
527+ }
528+
529+ // For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner comprehension [2].exists(y, x == y)
530+ // should not be extracted out into a block index, as it causes issues with scoping.
531+ ImmutableSet <String > mangledIterVars = navExpr .descendants ()
532+ .filter (x -> x .getKind ().equals (Kind .IDENT ))
533+ .map (x -> x .expr ().ident ().name ())
534+ .filter (name ->
535+ name .startsWith (MANGLED_COMPREHENSION_ITER_VAR_PREFIX ) ||
536+ name .startsWith (MANGLED_COMPREHENSION_ITER_VAR2_PREFIX )
537+ ).collect (toImmutableSet ());
538+
539+ CelNavigableMutableExpr parent = navExpr .parent ().orElse (null );
540+ while (parent != null ) {
541+ if (parent .getKind ().equals (Kind .COMPREHENSION )) {
542+ CelMutableComprehension comp = parent .expr ().comprehension ();
543+ boolean containsParentIterReferences =
544+ mangledIterVars .contains (comp .iterVar ()) || mangledIterVars .contains (comp .iterVar2 ());
545+
546+ if (containsParentIterReferences ) {
547+ return false ;
548+ }
549+ }
550+
551+ parent = parent .parent ().orElse (null );
552+ }
553+
554+ return true ;
445555 }
446556
557+
447558 private boolean containsComprehensionIdentInSubexpr (CelNavigableMutableExpr navExpr ) {
448559 if (navExpr .getKind ().equals (Kind .COMPREHENSION )) {
449560 return true ;
0 commit comments