diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java index 175c25c7ce4394..8f486bfc2ef2e4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java @@ -77,7 +77,7 @@ private Count(boolean isStar, AggregateFunctionParams functionParams) { public boolean isCountStar() { return isStar || children.isEmpty() - || (children.size() == 1 && child(0) instanceof Literal); + || (children.size() == 1 && child(0) instanceof Literal && !child(0).isNullLiteral()); } @Override diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewriteTest.java index 89da282b0c97a1..840b7fc0fea708 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewriteTest.java @@ -18,10 +18,14 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; /** @@ -71,4 +75,32 @@ void testCountNull() { .matches(logicalAggregate().when(agg -> agg.getExpressions().stream().noneMatch(Count.class::isInstance))) .printlnTree(); } + + @Test + void testCountNullIsNotCountStar() { + // count(null) should NOT be treated as count(*) + Count countNull = new Count(NullLiteral.INSTANCE); + Assertions.assertFalse(countNull.isCountStar(), + "count(null) should not be count star, because count(null) is always 0"); + + // typed null literal: count(CAST(null AS BIGINT)) should NOT be count(*) + Count countTypedNull = new Count(new NullLiteral(BigIntType.INSTANCE)); + Assertions.assertFalse(countTypedNull.isCountStar(), + "count(typed null) should not be count star"); + + // count(distinct null) should NOT be treated as count(*) + Count countDistinctNull = new Count(true, NullLiteral.INSTANCE); + Assertions.assertFalse(countDistinctNull.isCountStar(), + "count(distinct null) should not be count star"); + + // count(1) should be treated as count(*) + Count countOne = new Count(new BigIntLiteral(1)); + Assertions.assertTrue(countOne.isCountStar(), + "count(1) should be count star"); + + // count(*) should be treated as count(*) + Count countStar = new Count(); + Assertions.assertTrue(countStar.isCountStar(), + "count(*) should be count star"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java index 396cc155aceafa..631ab98209febc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java @@ -145,6 +145,36 @@ void testNotPushCountBecauseOtherAggFunc() { ); } + @Test + void testCountNullPushedAsCountNullNotCountStar() { + // count(null) in a union-all query should be pushed down as count(null), NOT as count(*). + // CountLiteralRewrite skips count(null) when it is the sole aggregate with no GROUP BY + // (it would produce an empty aggregate, which is invalid), so PushCountIntoUnionAll + // does see count(null). With the fix, isCountStar() returns false for count(null), + // so it is pushed as count(null) into each branch — not misclassified as count(*). + String sql = "select count(null) from (select id,a from t1 union all select id,a from t1 where id>10) t;"; + PlanChecker checker = PlanChecker.from(connectContext).analyze(sql).rewrite(); + + // PushCountIntoUnionAll fires: upper agg uses Sum0 over pushed-down counts + checker.matches( + logicalAggregate( + logicalUnion(logicalAggregate(), logicalAggregate()) + ).when(agg -> ExpressionUtils.containsTypes(agg.getOutputExpressions(), Sum0.class)) + ); + + // count(null) was pushed as count(null), not promoted to count(*) + checker.nonMatch( + logicalAggregate( + logicalUnion( + logicalAggregate().when(agg -> agg.getOutputExpressions().stream() + .anyMatch(e -> e.anyMatch( + expr -> expr instanceof Count && ((Count) expr).isCountStar()))), + logicalAggregate() + ) + ) + ); + } + @Test void testNotPushCountBecauseUnion() { String sql = "select count(1), sum(id) from (select id,a from t1 union select id,a from t1 where id>10) t;"; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java index 05ea0c713e27ac..6f86c63aeffaba 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -165,6 +166,31 @@ void testSingleCountStarEmptyGroupBy() { ); } + @Test + void testCountNullNotPushedDown() { + // count(null) should NOT be treated as count(*) and should NOT be pushed down. + // NullLiteral is neither isCountStar() nor instanceof Slot, + // so the rule's predicate rejects the aggregate. + Alias countNull = new Count(NullLiteral.INSTANCE).alias("countNull"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of(scan1.getOutput().get(0), countNull)) + .build(); + + // Should NOT rewrite — aggregate stays above the original join. + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoin()) + .matches( + logicalAggregate( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ) + ); + } + @Test void testBothSideCountAndCountStar() { Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt"); diff --git a/regression-test/data/nereids_rules_p0/count_null_not_count_star/count_null_not_count_star.out b/regression-test/data/nereids_rules_p0/count_null_not_count_star/count_null_not_count_star.out new file mode 100644 index 00000000000000..d029412d681395 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/count_null_not_count_star/count_null_not_count_star.out @@ -0,0 +1,38 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !count_null -- +0 + +-- !count_star -- +4 + +-- !count_literal -- +4 + +-- !count_null_group_by -- +1 0 +2 0 +3 0 +4 0 + +-- !count_mixed -- +0 4 2 2 + +-- !count_null_union -- +0 + +-- !count_null_window -- +1 0 +2 0 +3 0 +4 0 + +-- !count_null_join_grouped -- +1 0 1 +2 0 1 + +-- !count_star_join -- +2 0 + +-- !count_null_index -- +0 2 + diff --git a/regression-test/suites/nereids_rules_p0/count_null_not_count_star/count_null_not_count_star.groovy b/regression-test/suites/nereids_rules_p0/count_null_not_count_star/count_null_not_count_star.groovy new file mode 100644 index 00000000000000..b25481e0e1350d --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/count_null_not_count_star/count_null_not_count_star.groovy @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("count_null_not_count_star") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + sql "DROP TABLE IF EXISTS count_null_test" + + sql """ + create table count_null_test(pk int, a int, b int) distributed by hash(pk) buckets 10 + properties('replication_num' = '1'); + """ + + sql """ + insert into count_null_test values(1, 1, 1), (2, null, 2), (3, 3, null), (4, null, null); + """ + sql "sync" + + // count(null) should always return 0 regardless of row count + order_qt_count_null """select count(null) from count_null_test""" + + // count(*) should return total number of rows + order_qt_count_star """select count(*) from count_null_test""" + + // count(1) should be equivalent to count(*) + order_qt_count_literal """select count(1) from count_null_test""" + + // count(null) with group by should return 0 for each group + order_qt_count_null_group_by """select pk, count(null) from count_null_test group by pk order by pk""" + + // mixed: count(null) vs count(*) vs count(column) + order_qt_count_mixed """select count(null), count(*), count(a), count(b) from count_null_test""" + + // count(null) in subquery with union all + order_qt_count_null_union """ + select count(null) from ( + select a from count_null_test + union all + select a from count_null_test + ) t + """ + + // count(null) as window function should still return 0 + order_qt_count_null_window """ + select pk, count(null) over (order by pk) from count_null_test order by pk + """ + + // count(null) over (partition by unique pk) exercises SimplifyWindowExpression. + // That rule only fires when partition keys are proven unique; with unique pk, it calls + // checkCount() → isCountStar(). Bug: count(null) was treated as count(*) → simplified to 1. + // Fix: isCountStar() returns false → count(null) is NOT simplified → correctly returns 0. + sql "DROP TABLE IF EXISTS count_null_uniq_test" + sql """ + create table count_null_uniq_test(pk int, a int) + unique key(pk) + distributed by hash(pk) buckets 10 + properties('replication_num' = '1'); + """ + sql """insert into count_null_uniq_test values(1, 1), (2, 2), (3, 3);""" + sql "sync" + + def windowUniqResult = sql """ + select pk, count(null) over (partition by pk order by pk) + from count_null_uniq_test order by pk + """ + assertEquals(3, windowUniqResult.size()) + windowUniqResult.each { row -> assertEquals(0L, row[1]) } + + // count(null) through join with GROUP BY: exercises PushDownAggThroughJoin path. + // With the bug, count(null) would be treated as count(*) and pushed down through join, + // producing wrong non-zero results. With the fix, it correctly returns 0. + sql "DROP TABLE IF EXISTS count_null_test2" + sql """ + create table count_null_test2(pk int, c int) distributed by hash(pk) buckets 10 + properties('replication_num' = '1'); + """ + sql """insert into count_null_test2 values(1, 10), (2, 20);""" + sql "sync" + + order_qt_count_null_join_grouped """ + select t1.pk, count(null), count(*) + from count_null_test t1 inner join count_null_test2 t2 on t1.pk = t2.pk + group by t1.pk order by t1.pk + """ + + // count(null) vs count(*) through join without GROUP BY for comparison + order_qt_count_star_join """ + select count(*), count(null) from count_null_test t1 inner join count_null_test2 t2 on t1.pk = t2.pk + """ + + // COUNT_ON_INDEX path: count(null) on a table with inverted index should NOT use COUNT_ON_MATCH + sql "DROP TABLE IF EXISTS count_null_idx_test" + sql """ + create table count_null_idx_test( + pk int, + content varchar(200), + INDEX idx_content (content) USING INVERTED PROPERTIES("parser" = "english") + ) duplicate key(pk) + distributed by hash(pk) buckets 1 + properties('replication_num' = '1'); + """ + sql """insert into count_null_idx_test values(1, 'hello world'), (2, 'doris test'), (3, 'hello doris');""" + sql "sync" + + // count(*) with MATCH predicate should use COUNT_ON_INDEX optimization (via COUNT_ON_MATCH path). + // count(null) with same predicate must NOT — the fix ensures isCountStar() is false for count(null). + // Note: pushAggOp=COUNT_ON_INDEX appears in the explain output even for the COUNT_ON_MATCH code path + // because COUNT_ON_MATCH maps to TPushAggOp.COUNT_ON_INDEX in physical plan translation. + explain { + sql("select count(*) from count_null_idx_test where content match 'hello'") + contains "pushAggOp=COUNT_ON_INDEX" + } + explain { + sql("select count(null) from count_null_idx_test where content match 'hello'") + notContains "pushAggOp=COUNT_ON_INDEX" + } + + // result correctness: count(null) should be 0 even with MATCH filter + order_qt_count_null_index """ + select count(null), count(*) from count_null_idx_test where content match 'hello' + """ +}