Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* eager aggregation
Expand Down Expand Up @@ -505,6 +506,28 @@ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, PushDown

@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, PushDownAggContext context) {
if (filter.child() instanceof LogicalRelation) {
return genAggregate(filter, context);
}
if (filter.getConjuncts().stream().anyMatch(Expression::containsUniqueFunction)) {
return genAggregate(filter, context);
}
List<SlotReference> filterInputSlots = filter.getInputSlots().stream()
.map(slot -> (SlotReference) slot)
.collect(Collectors.toList());
List<SlotReference> childGroupKeys = Stream.concat(
context.getGroupKeys().stream(),
filterInputSlots.stream())
.distinct()
.collect(Collectors.toList());
PushDownAggContext childContext = context.withGroupKeys(childGroupKeys);
if (!childContext.isValid()) {
return genAggregate(filter, context);
}
Plan newChild = filter.child().accept(this, childContext);
if (newChild != filter.child()) {
return filter.withChildren(newChild);
}
return genAggregate(filter, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.doris.nereids.rules.rewrite.eageraggregation;

import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

class EagerAggRewriterTest extends TestWithFeService implements MemoPatternMatchSupported {
Expand Down Expand Up @@ -311,4 +315,74 @@ void testAsofJoinNotPushAgg() {
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}


@Test
void testUniqueFunctionFilterBlocksPushDownThroughFilter() {
connectContext.getSessionVariable().setEagerAggregationMode(1);
connectContext.getSessionVariable().setDisableJoinReorder(true);
try {
String sql = "select count(s.name1), t2.id2"
+ " from (select * from (select id1, name as name1 from t1) s1 where random() < 0.5) s"
+ " join t2 on s.id1 = t2.id2 group by t2.id2";
Plan plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.getPlan();
Assertions.assertEquals(2, countPlans(plan, LogicalAggregate.class), plan.treeString());
LogicalFilter<?> filter = findFirstPlan(plan, LogicalFilter.class);
Assertions.assertNotNull(filter, plan.treeString());
Assertions.assertFalse(containsPlan(filter.child(), LogicalAggregate.class), plan.treeString());
} finally {
connectContext.getSessionVariable().setEagerAggregationMode(0);
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}

@Test
void testInvalidFilterContextFallsBackToCurrentFilter() {
connectContext.getSessionVariable().setEagerAggregationMode(1);
connectContext.getSessionVariable().setDisableJoinReorder(true);
try {
String sql = "select count(s.name1), t2.id2"
+ " from (select * from (select id1, name as name1 from t1) s1 where s1.name1 is not null) s"
+ " join t2 on s.id1 = t2.id2 group by t2.id2";
Plan plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.getPlan();
Assertions.assertEquals(2, countPlans(plan, LogicalAggregate.class), plan.treeString());
LogicalFilter<?> filter = findFirstPlan(plan, LogicalFilter.class);
Assertions.assertNotNull(filter, plan.treeString());
Assertions.assertFalse(containsPlan(filter.child(), LogicalAggregate.class), plan.treeString());
} finally {
connectContext.getSessionVariable().setEagerAggregationMode(0);
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}

private int countPlans(Plan plan, Class<? extends Plan> clazz) {
int count = clazz.isInstance(plan) ? 1 : 0;
for (Plan child : plan.children()) {
count += countPlans(child, clazz);
}
return count;
}

private boolean containsPlan(Plan plan, Class<? extends Plan> clazz) {
return countPlans(plan, clazz) > 0;
}

private <T extends Plan> T findFirstPlan(Plan plan, Class<T> clazz) {
if (clazz.isInstance(plan)) {
return clazz.cast(plan);
}
for (Plan child : plan.children()) {
T matched = findFirstPlan(child, clazz);
if (matched != null) {
return matched;
}
}
return null;
}
}
149 changes: 111 additions & 38 deletions regression-test/data/nereids_p0/eager_agg/eager_agg.out
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -28,9 +28,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -49,9 +49,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -68,9 +68,9 @@ PhysicalResultSink
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
----------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -87,11 +87,11 @@ PhysicalResultSink
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[store_sales(ss)]
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -110,9 +110,9 @@ PhysicalResultSink
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -131,9 +131,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -152,9 +152,9 @@ PhysicalResultSink
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -173,9 +173,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[date_dim]
--------PhysicalOlapScan[web_sales]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand All @@ -197,9 +197,9 @@ PhysicalResultSink
--------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[date_dim]
--------PhysicalOlapScan[web_sales]
--------------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand Down Expand Up @@ -266,11 +266,11 @@ PhysicalResultSink
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
------------PhysicalOlapScan[store_sales]
------------PhysicalOlapScan[store_sales(ss)]
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------PhysicalOlapScan[date_dim]
----------PhysicalOlapScan[web_sales]
----------------PhysicalOlapScan[date_dim(dt)]
----------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand All @@ -287,9 +287,9 @@ PhysicalResultSink
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[date_dim]
--------PhysicalOlapScan[web_sales]
----------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand All @@ -302,12 +302,85 @@ PhysicalResultSink
----hashAgg[LOCAL]
------PhysicalUnion
--------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[date_dim]
----------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[date_dim]

Hint log:
Used:
UnUsed:
SyntaxError: leading({ ss broadcast dt } broadcast ws) Msg:can not find table: ws

-- !check_sum_literal_right_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
----------PhysicalOlapScan[eager_agg_t1(a)]
----------PhysicalOlapScan[eager_agg_t2(b)]
--------PhysicalOlapScan[eager_agg_t3(c)]

-- !check_sum_literal_left_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalOlapScan[store_sales]
--------PhysicalOlapScan[date_dim]

-- !check_min_literal_right_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
----------PhysicalOlapScan[eager_agg_t1(a)]
----------PhysicalOlapScan[eager_agg_t2(b)]
--------PhysicalOlapScan[eager_agg_t3(c)]

-- !check_max_literal_left_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalOlapScan[store_sales]
--------PhysicalOlapScan[date_dim]

-- !sum_literal_right_join_eager_off --
\N 4
10 2

-- !sum_literal_right_join_eager_on --
\N 4
10 2

-- !min_literal_right_join_eager_on --
\N 1
10 1

-- !max_literal_right_join_eager_on --
\N 3
10 3

-- !check_filter_slots_preserved_pushdown --
PhysicalResultSink
--hashAgg[GLOBAL]
----filter(OR[( not (id = 1)),id IS NULL])
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
--------hashAgg[GLOBAL]
----------PhysicalOlapScan[eager_agg_filter_t1(a)]
--------PhysicalOlapScan[eager_agg_filter_t2(b)]

Hint log:
Used: [broadcast]_1
UnUsed:
SyntaxError:

-- !filter_slots_preserved_eager_on --
2 20

Loading
Loading