Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private Count(boolean isStar, AggregateFunctionParams functionParams) {
public boolean isCountStar() {
return isStar
|| children.isEmpty()
|| (children.size() == 1 && child(0) instanceof Literal);
|| (children.size() == 1 && child(0) instanceof Literal && !child(0).isNullLiteral());
Comment thread
morrySnow marked this conversation as resolved.
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/**
Expand Down Expand Up @@ -71,4 +75,32 @@ void testCountNull() {
.matches(logicalAggregate().when(agg -> agg.getExpressions().stream().noneMatch(Count.class::isInstance)))
.printlnTree();
}

@Test
void testCountNullIsNotCountStar() {
// count(null) should NOT be treated as count(*)
Count countNull = new Count(NullLiteral.INSTANCE);
Assertions.assertFalse(countNull.isCountStar(),
"count(null) should not be count star, because count(null) is always 0");

// typed null literal: count(CAST(null AS BIGINT)) should NOT be count(*)
Count countTypedNull = new Count(new NullLiteral(BigIntType.INSTANCE));
Assertions.assertFalse(countTypedNull.isCountStar(),
"count(typed null) should not be count star");

// count(distinct null) should NOT be treated as count(*)
Count countDistinctNull = new Count(true, NullLiteral.INSTANCE);
Assertions.assertFalse(countDistinctNull.isCountStar(),
"count(distinct null) should not be count star");

// count(1) should be treated as count(*)
Count countOne = new Count(new BigIntLiteral(1));
Assertions.assertTrue(countOne.isCountStar(),
"count(1) should be count star");

// count(*) should be treated as count(*)
Count countStar = new Count();
Assertions.assertTrue(countStar.isCountStar(),
"count(*) should be count star");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,36 @@ void testNotPushCountBecauseOtherAggFunc() {
);
}

@Test
void testCountNullPushedAsCountNullNotCountStar() {
// count(null) in a union-all query should be pushed down as count(null), NOT as count(*).
// CountLiteralRewrite skips count(null) when it is the sole aggregate with no GROUP BY
// (it would produce an empty aggregate, which is invalid), so PushCountIntoUnionAll
// does see count(null). With the fix, isCountStar() returns false for count(null),
// so it is pushed as count(null) into each branch — not misclassified as count(*).
String sql = "select count(null) from (select id,a from t1 union all select id,a from t1 where id>10) t;";
PlanChecker checker = PlanChecker.from(connectContext).analyze(sql).rewrite();

// PushCountIntoUnionAll fires: upper agg uses Sum0 over pushed-down counts
checker.matches(
logicalAggregate(
logicalUnion(logicalAggregate(), logicalAggregate())
).when(agg -> ExpressionUtils.containsTypes(agg.getOutputExpressions(), Sum0.class))
);

// count(null) was pushed as count(null), not promoted to count(*)
checker.nonMatch(
logicalAggregate(
logicalUnion(
logicalAggregate().when(agg -> agg.getOutputExpressions().stream()
.anyMatch(e -> e.anyMatch(
expr -> expr instanceof Count && ((Count) expr).isCountStar()))),
logicalAggregate()
)
)
);
}

@Test
void testNotPushCountBecauseUnion() {
String sql = "select count(1), sum(id) from (select id,a from t1 union select id,a from t1 where id>10) t;";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
Expand Down Expand Up @@ -165,6 +166,31 @@ void testSingleCountStarEmptyGroupBy() {
);
}

@Test
void testCountNullNotPushedDown() {
// count(null) should NOT be treated as count(*) and should NOT be pushed down.
// NullLiteral is neither isCountStar() nor instanceof Slot,
// so the rule's predicate rejects the aggregate.
Alias countNull = new Count(NullLiteral.INSTANCE).alias("countNull");
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.aggGroupUsingIndex(ImmutableList.of(0),
ImmutableList.of(scan1.getOutput().get(0), countNull))
.build();

// Should NOT rewrite — aggregate stays above the original join.
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
)
);
}

@Test
void testBothSideCountAndCountStar() {
Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !count_null --
0

-- !count_star --
4

-- !count_literal --
4

-- !count_null_group_by --
1 0
2 0
3 0
4 0

-- !count_mixed --
0 4 2 2

-- !count_null_union --
0

-- !count_null_window --
1 0
2 0
3 0
4 0

-- !count_null_join_grouped --
1 0 1
2 0 1

-- !count_star_join --
2 0

-- !count_null_index --
0 2

Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("count_null_not_count_star") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"

sql "DROP TABLE IF EXISTS count_null_test"

sql """
create table count_null_test(pk int, a int, b int) distributed by hash(pk) buckets 10
properties('replication_num' = '1');
"""

sql """
insert into count_null_test values(1, 1, 1), (2, null, 2), (3, 3, null), (4, null, null);
"""
sql "sync"

// count(null) should always return 0 regardless of row count
order_qt_count_null """select count(null) from count_null_test"""

// count(*) should return total number of rows
order_qt_count_star """select count(*) from count_null_test"""

// count(1) should be equivalent to count(*)
order_qt_count_literal """select count(1) from count_null_test"""

// count(null) with group by should return 0 for each group
order_qt_count_null_group_by """select pk, count(null) from count_null_test group by pk order by pk"""

// mixed: count(null) vs count(*) vs count(column)
order_qt_count_mixed """select count(null), count(*), count(a), count(b) from count_null_test"""

// count(null) in subquery with union all
order_qt_count_null_union """
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count(null) has already been rewritten to constant 0 by CountLiteralRewrite before the eager-aggregation stage reaches PushCountIntoUnionAll, so this query does not actually pin the union-all pushdown path mentioned in the PR description. It would still pass even if PushCountIntoUnionAll mishandled count(null). Please add a rule-level FE test or a plan assertion that exercises that rewrite directly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added FE unit test testCountNullPushedAsCountNullNotCountStar in PushCountIntoUnionAllTest.java.

Note: CountLiteralRewrite does not rewrite count(null) to 0 in the no-GROUP-BY case. At line 66 of CountLiteralRewrite.java, it guards with aggFuncs.isEmpty() — when count(null) is the only aggregate and there are no GROUP BY keys, the resulting aggregate would have zero output expressions, so the rule returns null (no-op). PushCountIntoUnionAll therefore does see count(null) in this case.

The new FE unit test directly verifies: (1) PushCountIntoUnionAll fires (Sum0 appears in upper agg), and (2) child agg expressions do NOT contain a Count with isCountStar=true.

select count(null) from (
select a from count_null_test
union all
select a from count_null_test
) t
"""

// count(null) as window function should still return 0
order_qt_count_null_window """
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This window case does not reach SimplifyWindowExpression. That rule only rewrites when the PARTITION BY keys are unique, and an empty partition-key set is not treated as unique. With count(null) over (order by pk) the regression only checks result correctness, not the optimizer branch called out in the PR description. Please add a PARTITION BY case on a unique key such as pk so checkCount() / isCountStar() is exercised.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added a UNIQUE KEY table count_null_uniq_test with PARTITION BY pk so isUnique({pk}) returns true and SimplifyWindowExpression does fire. The test asserts count(null) over (partition by pk) = 0 for each row with explicit assertEquals(0L, row[1]) assertions. With the fix, isCountStar() returns false for count(null)checkCount() returns false → window is NOT simplified → correctly returns 0.

select pk, count(null) over (order by pk) from count_null_test order by pk
"""

// count(null) over (partition by unique pk) exercises SimplifyWindowExpression.
// That rule only fires when partition keys are proven unique; with unique pk, it calls
// checkCount() → isCountStar(). Bug: count(null) was treated as count(*) → simplified to 1.
// Fix: isCountStar() returns false → count(null) is NOT simplified → correctly returns 0.
sql "DROP TABLE IF EXISTS count_null_uniq_test"
sql """
create table count_null_uniq_test(pk int, a int)
unique key(pk)
distributed by hash(pk) buckets 10
properties('replication_num' = '1');
"""
sql """insert into count_null_uniq_test values(1, 1), (2, 2), (3, 3);"""
sql "sync"

def windowUniqResult = sql """
select pk, count(null) over (partition by pk order by pk)
from count_null_uniq_test order by pk
"""
assertEquals(3, windowUniqResult.size())
windowUniqResult.each { row -> assertEquals(0L, row[1]) }

// count(null) through join with GROUP BY: exercises PushDownAggThroughJoin path.
// With the bug, count(null) would be treated as count(*) and pushed down through join,
// producing wrong non-zero results. With the fix, it correctly returns 0.
sql "DROP TABLE IF EXISTS count_null_test2"
sql """
create table count_null_test2(pk int, c int) distributed by hash(pk) buckets 10
properties('replication_num' = '1');
"""
sql """insert into count_null_test2 values(1, 10), (2, 20);"""
sql "sync"

order_qt_count_null_join_grouped """
select t1.pk, count(null), count(*)
from count_null_test t1 inner join count_null_test2 t2 on t1.pk = t2.pk
group by t1.pk order by t1.pk
"""

// count(null) vs count(*) through join without GROUP BY for comparison
order_qt_count_star_join """
select count(*), count(null) from count_null_test t1 inner join count_null_test2 t2 on t1.pk = t2.pk
"""

// COUNT_ON_INDEX path: count(null) on a table with inverted index should NOT use COUNT_ON_MATCH
sql "DROP TABLE IF EXISTS count_null_idx_test"
sql """
create table count_null_idx_test(
pk int,
content varchar(200),
INDEX idx_content (content) USING INVERTED PROPERTIES("parser" = "english")
) duplicate key(pk)
distributed by hash(pk) buckets 1
properties('replication_num' = '1');
"""
sql """insert into count_null_idx_test values(1, 'hello world'), (2, 'doris test'), (3, 'hello doris');"""
sql "sync"

// count(*) with MATCH predicate should use COUNT_ON_INDEX optimization (via COUNT_ON_MATCH path).
// count(null) with same predicate must NOT — the fix ensures isCountStar() is false for count(null).
// Note: pushAggOp=COUNT_ON_INDEX appears in the explain output even for the COUNT_ON_MATCH code path
// because COUNT_ON_MATCH maps to TPushAggOp.COUNT_ON_INDEX in physical plan translation.
explain {
sql("select count(*) from count_null_idx_test where content match 'hello'")
contains "pushAggOp=COUNT_ON_INDEX"
}
explain {
sql("select count(null) from count_null_idx_test where content match 'hello'")
notContains "pushAggOp=COUNT_ON_INDEX"
}

// result correctness: count(null) should be 0 even with MATCH filter
order_qt_count_null_index """
select count(null), count(*) from count_null_idx_test where content match 'hello'
"""
}
Loading