diff --git a/dev/diffs/3.4.3.diff b/dev/diffs/3.4.3.diff index 9577aa1131..e7ee62e547 100644 --- a/dev/diffs/3.4.3.diff +++ b/dev/diffs/3.4.3.diff @@ -260,7 +260,7 @@ index cf40e944c09..bdd5be4f462 100644 test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala -index 1cc09c3d7fc..f031fa45c33 100644 +index 1cc09c3d7fc..2a6b073ca7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException @@ -272,7 +272,21 @@ index 1cc09c3d7fc..f031fa45c33 100644 import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -@@ -755,7 +755,7 @@ class DataFrameAggregateSuite extends QueryTest +@@ -733,7 +733,12 @@ class DataFrameAggregateSuite extends QueryTest + } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { +- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { ++ withSQLConf( ++ SQLConf.USE_OBJECT_HASH_AGG.key -> "true", ++ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the ++ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet ++ // to preserve the original plan-shape assertions. ++ "spark.comet.enabled" -> "false") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + +@@ -755,7 +760,7 @@ class DataFrameAggregateSuite extends QueryTest assert(objHashAggPlans.nonEmpty) val exchangePlans = collect(aggPlan) { diff --git a/dev/diffs/3.5.8.diff b/dev/diffs/3.5.8.diff index f9152a1159..5eba683376 100644 --- a/dev/diffs/3.5.8.diff +++ b/dev/diffs/3.5.8.diff @@ -241,7 +241,7 @@ index e5494726695..00937f025c2 100644 test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala -index 6f3090d8908..c08a60fb0c2 100644 +index 6f3090d8908..4f2e8970be8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Expand @@ -253,7 +253,21 @@ index 6f3090d8908..c08a60fb0c2 100644 import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -@@ -793,7 +793,7 @@ class DataFrameAggregateSuite extends QueryTest +@@ -771,7 +771,12 @@ class DataFrameAggregateSuite extends QueryTest + } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { +- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { ++ withSQLConf( ++ SQLConf.USE_OBJECT_HASH_AGG.key -> "true", ++ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the ++ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet ++ // to preserve the original plan-shape assertions. ++ "spark.comet.enabled" -> "false") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + +@@ -793,7 +798,7 @@ class DataFrameAggregateSuite extends QueryTest assert(objHashAggPlans.nonEmpty) val exchangePlans = collect(aggPlan) { diff --git a/dev/diffs/4.0.2.diff b/dev/diffs/4.0.2.diff index 1c6684f1f8..1ad6b5f533 100644 --- a/dev/diffs/4.0.2.diff +++ b/dev/diffs/4.0.2.diff @@ -378,7 +378,7 @@ index 0f42502f1d9..e9ff802141f 100644 withTempView("t0", "t1", "t2") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala -index 9db406ff12f..245e4caa319 100644 +index 9db406ff12f..066b03283e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -390,7 +390,21 @@ index 9db406ff12f..245e4caa319 100644 import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -@@ -855,7 +855,7 @@ class DataFrameAggregateSuite extends QueryTest +@@ -833,7 +833,12 @@ class DataFrameAggregateSuite extends QueryTest + } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { +- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { ++ withSQLConf( ++ SQLConf.USE_OBJECT_HASH_AGG.key -> "true", ++ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the ++ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet ++ // to preserve the original plan-shape assertions. ++ "spark.comet.enabled" -> "false") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + +@@ -855,7 +860,7 @@ class DataFrameAggregateSuite extends QueryTest assert(objHashAggPlans.nonEmpty) val exchangePlans = collect(aggPlan) { diff --git a/dev/diffs/4.1.2.diff b/dev/diffs/4.1.2.diff index 8b58070f21..a050335887 100644 --- a/dev/diffs/4.1.2.diff +++ b/dev/diffs/4.1.2.diff @@ -392,7 +392,7 @@ index 0d807aeae4d..6d7744e771b 100644 withTempView("t0", "t1", "t2") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala -index bfe15b33768..55c23a38ccc 100644 +index bfe15b33768..c391a9e1790 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -404,7 +404,21 @@ index bfe15b33768..55c23a38ccc 100644 import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -@@ -856,7 +856,7 @@ class DataFrameAggregateSuite extends QueryTest +@@ -834,7 +834,12 @@ class DataFrameAggregateSuite extends QueryTest + } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { +- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { ++ withSQLConf( ++ SQLConf.USE_OBJECT_HASH_AGG.key -> "true", ++ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the ++ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet ++ // to preserve the original plan-shape assertions. ++ "spark.comet.enabled" -> "false") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + +@@ -856,7 +861,7 @@ class DataFrameAggregateSuite extends QueryTest assert(objHashAggPlans.nonEmpty) val exchangePlans = collect(aggPlan) { diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..deaf1a99dc 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -39,6 +39,15 @@ - Spark 3.5.8 (2026-05-26) - Spark 4.0.1 (2026-05-26) +## collect_list + +- Spark 3.4.3 (audited 2026-06-24): `CollectList` extends `Collect[ArrayBuffer[Any]]`, returns `ArrayType(child.dataType, containsNull = false)`, ignores NULL inputs in `update()` (Hive-compatible semantics), and yields an empty array as `defaultResult`. `nullable = false`. No `checkInputDataTypes` override, so any input type is accepted (including STRUCT, ARRAY, MAP). Registered as both `collect_list` and `array_agg` aliases in `FunctionRegistry`. +- Spark 3.5.8 (audited 2026-06-24): identical to 3.4.3. +- Spark 4.0.1 (audited 2026-06-24): only structural change is adding `with UnaryLike[Expression]` to the case class (no behavior change). +- Spark 4.1.1 (audited 2026-06-24): identical to 4.0.1. +- Comet implementation: native side delegates to `datafusion_spark::function::aggregate::collect::SparkCollectList`, which wraps `ArrayAggAccumulator` with `ignore_nulls = true` and converts a final NULL accumulator state to an empty array (matching Spark's `defaultResult`). The native return type is `List(Field, containsNull = true)`, while Spark uses `containsNull = false`. Because nulls are filtered before insertion, no nulls actually appear in the array, so this is a schema-shape difference only and tests using `checkSparkAnswerAndOperator` accept it (same pattern already in use for `collect_set`). +- Spark 4.2 (preview): `CollectList` and `CollectSet` gain an `ignoreNulls` field (default `true`); `RESPECT NULLS` sets it to `false` and keeps null elements. The native path always drops nulls, so `CometCollectShim` reads the field per Spark version (always `true` on 3.4-4.1) and `CometCollectList` / `CometCollectSet` report `Unsupported` when it is `false`, falling back to Spark. + ## median - Spark 3.4.3 (audited 2026-06-24): `Median(child)` is a `RuntimeReplaceableAggregate` with `replacement = Percentile(child, Literal(0.5))`. Catalyst rewrites `median(x)` to `percentile(x, 0.5)` before Comet sees the plan, so it is served by `CometPercentile`. diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 680422992c..3ea4da66ab 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -71,14 +71,14 @@ The tables below list every Spark built-in expression with its current status. | `any` | ✅ | | | `any_value` | ✅ | | | `approx_count_distinct` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) | -| `array_agg` | 🔜 | Array aggregate (related to `collect_list`, [#2524](https://github.com/apache/datafusion-comet/issues/2524)) | +| `array_agg` | ✅ | Alias for `collect_list` | | `avg` | ✅ | Interval types fall back | | `bit_and` | ✅ | | | `bit_or` | ✅ | | | `bit_xor` | ✅ | | | `bool_and` | ✅ | | | `bool_or` | ✅ | | -| `collect_list` | 🔜 | [#2524](https://github.com/apache/datafusion-comet/issues/2524) | +| `collect_list` | ✅ | | | `collect_set` | ✅ | | | `corr` | ✅ | | | `count` | ✅ | | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..feb1730382 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -76,7 +76,7 @@ use datafusion_comet_spark_expr::{ BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, SparkBloomFilterVersion, SumInteger, ToCsv, }; -use datafusion_spark::function::aggregate::collect::SparkCollectSet; +use datafusion_spark::function::aggregate::collect::{SparkCollectList, SparkCollectSet}; use iceberg::expr::Bind; use crate::execution::operators::ExecutionError::GeneralError; @@ -2653,6 +2653,11 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::CollectList(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(SparkCollectList::new()); + Self::create_aggr_func_expr("collect_list", schema, vec![child], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5b2a6ce9ee..f811617e87 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -145,6 +145,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + CollectList collectList = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -276,6 +277,11 @@ message CollectSet { DataType datatype = 2; } +message CollectList { + Expr child = 1; + DataType datatype = 2; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 52f39c59ad..137ba4adfe 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -858,6 +858,24 @@ case class CometExecRule(session: SparkSession) } } } + + // CollectList/CollectSet round-trip an ArrayType buffer that Spark declares as BinaryType. + // In a multi-stage aggregate with a PartialMerge stage (e.g. Spark's distinct-aggregate + // rewrite), Comet cannot represent that buffer consistently across the intermediate stages + // (issue #4724), so a fully-native pipeline crashes. Force the whole chain to fall back to + // Spark by tagging the feeding pure-Partial; the PartialMerge/Final stages then fall back + // via the buffer-source check in doConvert. + if (agg.aggregateExpressions.exists(_.mode == PartialMerge) && + QueryPlanSerde.hasNativeArrayBufferAgg(agg.aggregateExpressions)) { + findPartialAggInPlan(agg.child).foreach { partial => + if (canAggregateBeConverted(partial, Partial)) { + partial.setTagValue( + CometExecRule.COMET_UNSAFE_PARTIAL, + "Partial aggregate disabled: part of a multi-stage CollectList/CollectSet " + + "aggregate whose intermediate buffer cannot round-trip in Comet (issue #4724)") + } + } + } case _ => } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b752f41d74..6043b93155 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -390,6 +390,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[BitOrAgg] -> CometBitOrAgg, classOf[BitXorAgg] -> CometBitXOrAgg, classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, + classOf[CollectList] -> CometCollectList, classOf[CollectSet] -> CometCollectSet, classOf[Corr] -> CometCorr, classOf[Count] -> CometCount, @@ -424,6 +425,22 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { } } + /** + * Returns true if any aggregate function produces a native intermediate buffer whose Arrow type + * (e.g. ArrayType for CollectList/CollectSet) differs from the BinaryType that Spark declares + * for its serialized TypedImperativeAggregate buffer. Comet cannot interpret Spark's Binary + * buffer for these functions, and cannot yet represent the buffer consistently across the + * intermediate PartialMerge stages of a multi-stage aggregate (issue #4724). Such aggregates + * are therefore only safe to run natively when every stage runs in Comet and there are at most + * two stages (Partial + Final). + */ + def hasNativeArrayBufferAgg(aggExprs: Seq[AggregateExpression]): Boolean = { + aggExprs.exists(_.aggregateFunction match { + case _: CollectList | _: CollectSet => true + case _ => false + }) + } + // A unique id for each expression. ~used to look up QueryContext during error creation. private val exprIdCounter = new AtomicLong(0) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..57f98bf02c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,14 +22,14 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectList, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} -import org.apache.comet.shims.CometEvalModeUtil +import org.apache.comet.shims.{CometCollectShim, CometEvalModeUtil} object CometMin extends CometAggregateExpressionSerde[Min] { @@ -784,13 +784,19 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] { " `spark.comet.expression.CollectSet.allowIncompatible=true` is set.") override def getSupportLevel(expr: CollectSet): SupportLevel = { - SupportLevel - .strictFloatingPointReason( - expr.children.head.dataType, - "collect_set on floating-point types " + - "(Comet deduplicates NaN values while Spark treats each NaN as distinct)") - .map(reason => Incompatible(Some(reason))) - .getOrElse(Compatible()) + // The native path always drops null inputs. Spark 4.2 adds `RESPECT NULLS` + // (`ignoreNulls = false`), which keeps nulls, so fall back there. + if (!CometCollectShim.ignoreNulls(expr)) { + Unsupported(Some("collect_set with RESPECT NULLS (ignoreNulls = false) is not supported")) + } else { + SupportLevel + .strictFloatingPointReason( + expr.children.head.dataType, + "collect_set on floating-point types " + + "(Comet deduplicates NaN values while Spark treats each NaN as distinct)") + .map(reason => Incompatible(Some(reason))) + .getOrElse(Compatible()) + } } override def convert( @@ -823,6 +829,48 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] { } } +object CometCollectList extends CometAggregateExpressionSerde[CollectList] { + + override def getSupportLevel(expr: CollectList): SupportLevel = { + // The native path delegates to SparkCollectList, which always drops null inputs. Spark 4.2 + // adds `RESPECT NULLS` (`ignoreNulls = false`), which keeps nulls, so fall back there. + if (!CometCollectShim.ignoreNulls(expr)) { + Unsupported(Some("collect_list with RESPECT NULLS (ignoreNulls = false) is not supported")) + } else { + Compatible() + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: CollectList, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val child = expr.children.head + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(expr.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val builder = ExprOuterClass.CollectList.newBuilder() + builder.setChild(childExpr.get) + builder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setCollectList(builder) + .build()) + } else if (dataType.isEmpty) { + withFallbackReason(aggExpr, s"datatype ${expr.dataType} is not supported", child) + None + } else { + withFallbackReason(aggExpr, child) + None + } + } +} + object AggSerde { import org.apache.spark.sql.types._ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e4d6b53770..4115f2c67e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, First, Last, Partial, PartialMerge, Percentile} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectList, CollectSet, Final, First, Last, Partial, PartialMerge, Percentile} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -1547,6 +1547,21 @@ trait CometBaseAggregate { return None } + // CollectList/CollectSet declare their buffer as BinaryType in Spark but produce a native + // ArrayType state. A PartialMerge (or multi-mode) stage consuming that buffer can only read it + // when an upstream Comet partial produced it; if the buffer came from a Spark partial (no Comet + // partial in the child subtree), Comet cannot interpret Spark's serialized Binary buffer and + // must fall back. The matching Partial is forced to fall back too (see CometExecRule), so the + // whole multi-stage chain runs consistently in one engine. See issue #4724. + if (modes.contains(PartialMerge) && + QueryPlanSerde.hasNativeArrayBufferAgg(aggregate.aggregateExpressions) && + findCometPartialAgg(aggregate.child).isEmpty) { + withFallbackReason( + aggregate, + "CollectList/CollectSet PartialMerge cannot read a Spark-produced intermediate buffer") + return None + } + // Check if this aggregate has been tagged as unsafe for mixed execution // (Comet partial + Spark final with incompatible intermediate buffers) val unsafeReason = aggregate.getTagValue(CometExecRule.COMET_UNSAFE_PARTIAL) @@ -1892,14 +1907,15 @@ object CometObjectHashAggregateExec } /** - * For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet), - * the Spark-side output declares buffer columns as BinaryType (since Spark serializes state to - * binary). However, the native Comet aggregate produces the actual state type (e.g., - * ArrayType(elementType) for CollectSet). This method corrects the output schema to match the - * native state types so the shuffle exchange schema is consistent with the actual data. + * For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet or + * CollectList), the Spark-side output declares buffer columns as BinaryType (since Spark + * serializes state to binary). However, the native Comet aggregate produces the actual state + * type (e.g., ArrayType(elementType) for CollectSet/CollectList). This method corrects the + * output schema to match the native state types so the shuffle exchange schema is consistent + * with the actual data. * - * NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a - * case branch here mapping it to the native state type. + * NOTE: If a new TypedImperativeAggregate function is added natively, add a case branch here + * mapping it to the native state type. */ private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = { // This adjustment only applies to pure-Partial aggregates (checked below). @@ -1916,8 +1932,8 @@ object CometObjectHashAggregateExec val aggFunc = aggExpr.aggregateFunction val bufferAttrs = aggFunc.aggBufferAttributes aggFunc match { - case cs: CollectSet => - val elementType = cs.children.head.dataType + case _: CollectSet | _: CollectList => + val elementType = aggFunc.children.head.dataType val nativeStateType = ArrayType(elementType, containsNull = true) output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) case _: Percentile => diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometCollectShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometCollectShim.scala new file mode 100644 index 0000000000..69f33a1b29 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometCollectShim.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} + +/** + * Shim for the `ignoreNulls` flag on `CollectList` / `CollectSet`. Spark 3.4 through 4.1 have no + * such field: these aggregates always drop null inputs, so this shim reports `true`. Spark 4.2 + * added `ignoreNulls` (settable to `false` via `RESPECT NULLS`), handled by the spark-4.2 shim. + */ +object CometCollectShim { + def ignoreNulls(agg: CollectList): Boolean = true + def ignoreNulls(agg: CollectSet): Boolean = true +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometCollectShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometCollectShim.scala new file mode 100644 index 0000000000..69f33a1b29 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometCollectShim.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} + +/** + * Shim for the `ignoreNulls` flag on `CollectList` / `CollectSet`. Spark 3.4 through 4.1 have no + * such field: these aggregates always drop null inputs, so this shim reports `true`. Spark 4.2 + * added `ignoreNulls` (settable to `false` via `RESPECT NULLS`), handled by the spark-4.2 shim. + */ +object CometCollectShim { + def ignoreNulls(agg: CollectList): Boolean = true + def ignoreNulls(agg: CollectSet): Boolean = true +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometCollectShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometCollectShim.scala new file mode 100644 index 0000000000..69f33a1b29 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometCollectShim.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} + +/** + * Shim for the `ignoreNulls` flag on `CollectList` / `CollectSet`. Spark 3.4 through 4.1 have no + * such field: these aggregates always drop null inputs, so this shim reports `true`. Spark 4.2 + * added `ignoreNulls` (settable to `false` via `RESPECT NULLS`), handled by the spark-4.2 shim. + */ +object CometCollectShim { + def ignoreNulls(agg: CollectList): Boolean = true + def ignoreNulls(agg: CollectSet): Boolean = true +} diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometCollectShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometCollectShim.scala new file mode 100644 index 0000000000..69f33a1b29 --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometCollectShim.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} + +/** + * Shim for the `ignoreNulls` flag on `CollectList` / `CollectSet`. Spark 3.4 through 4.1 have no + * such field: these aggregates always drop null inputs, so this shim reports `true`. Spark 4.2 + * added `ignoreNulls` (settable to `false` via `RESPECT NULLS`), handled by the spark-4.2 shim. + */ +object CometCollectShim { + def ignoreNulls(agg: CollectList): Boolean = true + def ignoreNulls(agg: CollectSet): Boolean = true +} diff --git a/spark/src/main/spark-4.2/org/apache/comet/shims/CometCollectShim.scala b/spark/src/main/spark-4.2/org/apache/comet/shims/CometCollectShim.scala new file mode 100644 index 0000000000..defae750ae --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/comet/shims/CometCollectShim.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} + +/** + * Shim for the `ignoreNulls` flag on `CollectList` / `CollectSet`. Spark 4.2 added the field + * (`ignoreNulls = false` via `RESPECT NULLS`), so this shim reports the actual value. The serde + * falls back to Spark when it is `false`, since the native path always drops nulls. + */ +object CometCollectShim { + def ignoreNulls(agg: CollectList): Boolean = agg.ignoreNulls + def ignoreNulls(agg: CollectSet): Boolean = agg.ignoreNulls +} diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/collect_list.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/collect_list.sql new file mode 100644 index 0000000000..507944419a --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/collect_list.sql @@ -0,0 +1,405 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- collect_list result order is non-deterministic across partitions, so +-- every query wraps the result in sort_array to make comparisons stable. + +-- ============================================================ +-- Setup: tables +-- ============================================================ + +statement +CREATE TABLE cl_src_int(i int, grp string) USING parquet + +statement +INSERT INTO cl_src_int VALUES + (1, 'a'), (2, 'a'), (1, 'a'), (3, 'a'), + (4, 'b'), (4, 'b'), (NULL, 'b'), (5, 'b'), + (NULL, 'c'), (NULL, 'c') + +statement +CREATE TABLE cl_src_nulls(val int, grp string) USING parquet + +statement +INSERT INTO cl_src_nulls VALUES + (NULL, 'a'), (NULL, 'a'), (NULL, 'b'), (1, 'b') + +statement +CREATE TABLE cl_src_empty(val int) USING parquet + +statement +CREATE TABLE cl_src_single(val int) USING parquet + +statement +INSERT INTO cl_src_single VALUES (42) + +statement +CREATE TABLE cl_src_dupes(val int, grp string) USING parquet + +statement +INSERT INTO cl_src_dupes VALUES (7, 'a'), (7, 'a'), (7, 'a'), (8, 'b'), (9, 'b') + +-- ============================================================ +-- Basic: integer (global aggregate, no GROUP BY) — duplicates kept +-- ============================================================ + +query +SELECT sort_array(collect_list(i)) FROM cl_src_int + +-- ============================================================ +-- GROUP BY: integer per group +-- ============================================================ + +query +SELECT grp, sort_array(collect_list(i)) FROM cl_src_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- NULLs: nulls are dropped; all-NULL group returns empty array +-- ============================================================ + +query +SELECT grp, sort_array(collect_list(val)) FROM cl_src_nulls GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Empty table: returns empty array +-- ============================================================ + +query +SELECT sort_array(collect_list(val)) FROM cl_src_empty + +-- ============================================================ +-- Single value +-- ============================================================ + +query +SELECT sort_array(collect_list(val)) FROM cl_src_single + +-- ============================================================ +-- All duplicates in a group — collect_list keeps repeats +-- ============================================================ + +query +SELECT grp, sort_array(collect_list(val)) FROM cl_src_dupes GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Boolean (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cl_src_bool(v boolean, grp string) USING parquet + +statement +INSERT INTO cl_src_bool VALUES + (true, 'a'), (false, 'a'), (true, 'a'), (NULL, 'a'), + (NULL, 'b'), (true, 'b') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_bool GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Byte / Short (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cl_src_small(b tinyint, s smallint, grp string) USING parquet + +statement +INSERT INTO cl_src_small VALUES + (1, 100, 'a'), (2, 200, 'a'), (1, 100, 'a'), (NULL, NULL, 'a'), + (3, 300, 'b'), (NULL, 300, 'b') + +query +SELECT grp, sort_array(collect_list(b)) FROM cl_src_small GROUP BY grp ORDER BY grp + +query +SELECT grp, sort_array(collect_list(s)) FROM cl_src_small GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Int / BigInt (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cl_src_intbig(i int, bi bigint, grp string) USING parquet + +statement +INSERT INTO cl_src_intbig VALUES + (10, 1000000000000, 'a'), (20, 2000000000000, 'a'), + (10, 1000000000000, 'a'), (NULL, NULL, 'a'), + (30, 3000000000000, 'b'), (30, NULL, 'b') + +query +SELECT grp, sort_array(collect_list(i)) FROM cl_src_intbig GROUP BY grp ORDER BY grp + +query +SELECT grp, sort_array(collect_list(bi)) FROM cl_src_intbig GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Float (with NULLs, NaN, Inf, -Inf, +0, -0) +-- collect_list keeps duplicates verbatim, so floating-point is fine. +-- ============================================================ + +statement +CREATE TABLE cl_src_float(v float, grp string) USING parquet + +statement +INSERT INTO cl_src_float VALUES + (1.5, 'a'), (2.5, 'a'), (1.5, 'a'), (NULL, 'a'), + (CAST('NaN' AS FLOAT), 'b'), (CAST('NaN' AS FLOAT), 'b'), (1.0, 'b'), + (CAST('Infinity' AS FLOAT), 'c'), (CAST('-Infinity' AS FLOAT), 'c'), + (CAST('Infinity' AS FLOAT), 'c'), + (CAST(0.0 AS FLOAT), 'd'), (CAST(-0.0 AS FLOAT), 'd'), (1.0, 'd'), (NULL, 'd') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_float GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double (with NULLs, NaN, Inf, -Inf, +0, -0) +-- ============================================================ + +statement +CREATE TABLE cl_src_double(v double, grp string) USING parquet + +statement +INSERT INTO cl_src_double VALUES + (1.1, 'a'), (2.2, 'a'), (1.1, 'a'), (NULL, 'a'), + (CAST('NaN' AS DOUBLE), 'b'), (CAST('NaN' AS DOUBLE), 'b'), (1.0, 'b'), + (CAST('Infinity' AS DOUBLE), 'c'), (CAST('-Infinity' AS DOUBLE), 'c'), + (CAST('Infinity' AS DOUBLE), 'c'), + (0.0, 'd'), (-0.0, 'd'), (1.0, 'd'), (NULL, 'd') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_double GROUP BY grp ORDER BY grp + +-- ============================================================ +-- String (with NULLs and empty string) +-- ============================================================ + +statement +CREATE TABLE cl_src_string(v string, grp string) USING parquet + +statement +INSERT INTO cl_src_string VALUES + ('hello', 'a'), ('world', 'a'), ('hello', 'a'), (NULL, 'a'), + ('', 'b'), ('x', 'b'), ('', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_string GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Binary (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cl_src_binary(v binary, grp string) USING parquet + +statement +INSERT INTO cl_src_binary VALUES + (X'CAFE', 'a'), (X'BABE', 'a'), (X'CAFE', 'a'), (NULL, 'a'), + (X'', 'b'), (X'FF', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_binary GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Decimal (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cl_src_decimal(v decimal(10,2), grp string) USING parquet + +statement +INSERT INTO cl_src_decimal VALUES + (1.50, 'a'), (2.50, 'a'), (1.50, 'a'), (NULL, 'a'), + (0.00, 'b'), (99999999.99, 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_decimal GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cl_src_date(v date, grp string) USING parquet + +statement +INSERT INTO cl_src_date VALUES + (DATE '2024-01-01', 'a'), (DATE '2024-06-15', 'a'), (DATE '2024-01-01', 'a'), (NULL, 'a'), + (DATE '1970-01-01', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_date GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Timestamp (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cl_src_ts(v timestamp, grp string) USING parquet + +statement +INSERT INTO cl_src_ts VALUES + (TIMESTAMP '2024-01-01 00:00:00', 'a'), (TIMESTAMP '2024-06-15 12:30:00', 'a'), + (TIMESTAMP '2024-01-01 00:00:00', 'a'), (NULL, 'a'), + (TIMESTAMP '1970-01-01 00:00:00', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_ts GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Mixed with other aggregates +-- ============================================================ + +query +SELECT grp, sort_array(collect_list(i)), count(*), sum(i) +FROM cl_src_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Multiple collect_list in the same query +-- ============================================================ + +statement +CREATE TABLE cl_src_multi(a int, b string, grp string) USING parquet + +statement +INSERT INTO cl_src_multi VALUES + (1, 'x', 'g1'), (2, 'y', 'g1'), (1, 'x', 'g1'), + (3, 'z', 'g2'), (NULL, NULL, 'g2') + +query +SELECT grp, sort_array(collect_list(a)), sort_array(collect_list(b)) +FROM cl_src_multi GROUP BY grp ORDER BY grp + +-- ============================================================ +-- DISTINCT: deduplicates before collecting (different planner path) +-- ============================================================ + +query +SELECT grp, sort_array(collect_list(DISTINCT i)) FROM cl_src_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- HAVING clause with collect_list +-- ============================================================ + +query +SELECT grp, sort_array(collect_list(i)) +FROM cl_src_int GROUP BY grp HAVING size(collect_list(i)) > 1 ORDER BY grp + +-- ============================================================ +-- Result size matches count of non-null values per group +-- (collect_list ignores NULL inputs, like Hive) +-- ============================================================ + +query +SELECT grp, size(collect_list(val)) FROM cl_src_nulls GROUP BY grp ORDER BY grp + +-- ============================================================ +-- array_agg alias (registered as alias of CollectList in +-- FunctionRegistry: `expression[CollectList]("array_agg")`) +-- ============================================================ + +query +SELECT grp, sort_array(array_agg(i)) FROM cl_src_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Struct input (Spark DataFrameAggregateSuite "collect functions structs") +-- ============================================================ + +statement +CREATE TABLE cl_src_struct(s struct, grp string) USING parquet + +statement +INSERT INTO cl_src_struct VALUES + (named_struct('x', 1, 'y', 'a'), 'g1'), + (named_struct('x', 2, 'y', 'b'), 'g1'), + (named_struct('x', 1, 'y', 'a'), 'g1'), + (NULL, 'g1'), + (named_struct('x', 3, 'y', 'c'), 'g2'), + (NULL, 'g2') + +query +SELECT grp, sort_array(collect_list(s)) FROM cl_src_struct GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Nested array input +-- ============================================================ + +statement +CREATE TABLE cl_src_array(a array, grp string) USING parquet + +statement +INSERT INTO cl_src_array VALUES + (array(1, 2), 'g1'), + (array(3, 4, 5), 'g1'), + (NULL, 'g1'), + (array(), 'g2'), + (array(NULL), 'g2'), + (NULL, 'g2') + +query +SELECT grp, sort_array(collect_list(a)) FROM cl_src_array GROUP BY grp ORDER BY grp + +-- ============================================================ +-- DECIMAL boundary precisions +-- ============================================================ + +statement +CREATE TABLE cl_src_decimal38(v decimal(38,0), grp string) USING parquet + +statement +INSERT INTO cl_src_decimal38 VALUES + (CAST('99999999999999999999999999999999999999' AS DECIMAL(38,0)), 'a'), + (CAST('-99999999999999999999999999999999999999' AS DECIMAL(38,0)), 'a'), + (CAST(0 AS DECIMAL(38,0)), 'a'), + (NULL, 'a') + +query +SELECT grp, sort_array(collect_list(v)) FROM cl_src_decimal38 GROUP BY grp ORDER BY grp + +-- ============================================================ +-- INT / BIGINT boundary values +-- ============================================================ + +statement +CREATE TABLE cl_src_bounds(i int, bi bigint) USING parquet + +statement +INSERT INTO cl_src_bounds VALUES + (-2147483648, -9223372036854775808), + (2147483647, 9223372036854775807), + (0, 0), + (NULL, NULL) + +query +SELECT sort_array(collect_list(i)), sort_array(collect_list(bi)) FROM cl_src_bounds + +-- ============================================================ +-- Spark SPARK-17641 regression: collect functions should not +-- collect null values. Verifies the absolute size matches the +-- number of non-null inputs across mixed types. +-- ============================================================ + +statement +CREATE TABLE cl_src_17641(a string, b int) USING parquet + +statement +INSERT INTO cl_src_17641 VALUES ('1', 2), (NULL, 2), ('1', 4) + +query +SELECT sort_array(collect_list(a)), sort_array(collect_list(b)) FROM cl_src_17641 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae14c68207..b689a6e2eb 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -48,6 +48,33 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "false") + test("collect_list/collect_set combined with distinct aggregate falls back safely") { + // SPARK-17616: a distinct aggregate combined with collect_list/collect_set produces a + // multi-stage plan where the buffer-producing Partial may run in Spark (e.g. over a + // non-native LocalTableScan). Comet cannot read Spark's serialized Binary buffer, so the + // dependent PartialMerge/Final stages must also fall back rather than crash. See issue #4724 + // for enabling the fully-native distinct path. + import org.apache.spark.sql.functions.{collect_list, collect_set, sort_array} + // Non-native source (LocalTableScan): the buffer-producing Partial runs in Spark. + val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d")) + .toDF("x", "y", "z") + checkSparkAnswer( + df.groupBy(col("x")).agg(count_distinct(col("y")), sort_array(collect_list(col("z"))))) + checkSparkAnswer( + df.groupBy(col("x")).agg(count_distinct(col("y")), sort_array(collect_set(col("z"))))) + + // Native source (Parquet): the whole multi-stage distinct chain must still fall back to + // Spark consistently (issue #4724), rather than running a fully-native pipeline that crashes. + withParquetTable( + Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d")), + "t17616") { + for (fn <- Seq("collect_list", "collect_set")) { + checkSparkAnswer( + sql(s"SELECT _1, count(distinct _2), sort_array($fn(_3)) FROM t17616 GROUP BY _1")) + } + } + } + test("min/max floating point with negative zero") { val r = new Random(42) val schema = StructType(