Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions dev/diffs/3.4.3.diff
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ index cf40e944c09..bdd5be4f462 100644

test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 1cc09c3d7fc..f031fa45c33 100644
index 1cc09c3d7fc..2a6b073ca7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.SparkException
Expand All @@ -272,7 +272,21 @@ index 1cc09c3d7fc..f031fa45c33 100644
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -755,7 +755,7 @@ class DataFrameAggregateSuite extends QueryTest
@@ -733,7 +733,12 @@ class DataFrameAggregateSuite extends QueryTest
}

test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ withSQLConf(
+ SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
+ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the
+ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet
+ // to preserve the original plan-shape assertions.
+ "spark.comet.enabled" -> "false") {
val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
.repartition(col("a"))

@@ -755,7 +760,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)

val exchangePlans = collect(aggPlan) {
Expand Down Expand Up @@ -2977,7 +2991,7 @@ index dd55fcfe42c..d9a3f2df535 100644

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index ed2e309fa07..0658bfe9e12 100644
index ed2e309fa07..863868646a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -74,6 +74,20 @@ trait SharedSparkSessionBase
Expand Down
18 changes: 16 additions & 2 deletions dev/diffs/3.5.8.diff
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ index e5494726695..00937f025c2 100644

test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 6f3090d8908..c08a60fb0c2 100644
index 6f3090d8908..4f2e8970be8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Expand
Expand All @@ -253,7 +253,21 @@ index 6f3090d8908..c08a60fb0c2 100644
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -793,7 +793,7 @@ class DataFrameAggregateSuite extends QueryTest
@@ -771,7 +771,12 @@ class DataFrameAggregateSuite extends QueryTest
}

test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ withSQLConf(
+ SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
+ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the
+ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet
+ // to preserve the original plan-shape assertions.
+ "spark.comet.enabled" -> "false") {
val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
.repartition(col("a"))

@@ -793,7 +798,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)

val exchangePlans = collect(aggPlan) {
Expand Down
18 changes: 16 additions & 2 deletions dev/diffs/4.0.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ index 0f42502f1d9..e9ff802141f 100644

withTempView("t0", "t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 9db406ff12f..245e4caa319 100644
index 9db406ff12f..066b03283e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
Expand All @@ -390,7 +390,21 @@ index 9db406ff12f..245e4caa319 100644
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -855,7 +855,7 @@ class DataFrameAggregateSuite extends QueryTest
@@ -833,7 +833,12 @@ class DataFrameAggregateSuite extends QueryTest
}

test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ withSQLConf(
+ SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
+ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the
+ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet
+ // to preserve the original plan-shape assertions.
+ "spark.comet.enabled" -> "false") {
val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
.repartition(col("a"))

@@ -855,7 +860,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)

val exchangePlans = collect(aggPlan) {
Expand Down
18 changes: 16 additions & 2 deletions dev/diffs/4.1.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ index 0d807aeae4d..6d7744e771b 100644

withTempView("t0", "t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index bfe15b33768..55c23a38ccc 100644
index bfe15b33768..c391a9e1790 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
Expand All @@ -404,7 +404,21 @@ index bfe15b33768..55c23a38ccc 100644
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -856,7 +856,7 @@ class DataFrameAggregateSuite extends QueryTest
@@ -834,7 +834,12 @@ class DataFrameAggregateSuite extends QueryTest
}

test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
- withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ withSQLConf(
+ SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
+ // Comet runs collect_list/collect_set natively (CometHashAggregateExec), so the
+ // ObjectHashAggregateExec this test asserts on is no longer present. Disable Comet
+ // to preserve the original plan-shape assertions.
+ "spark.comet.enabled" -> "false") {
val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
.repartition(col("a"))

@@ -856,7 +861,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)

val exchangePlans = collect(aggPlan) {
Expand Down
8 changes: 8 additions & 0 deletions docs/source/contributor-guide/expression-audits/agg_funcs.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,12 @@
- Spark 3.5.8 (2026-05-26)
- Spark 4.0.1 (2026-05-26)

## collect_list

- Spark 3.4.3 (audited 2026-06-24): `CollectList` extends `Collect[ArrayBuffer[Any]]`, returns `ArrayType(child.dataType, containsNull = false)`, ignores NULL inputs in `update()` (Hive-compatible semantics), and yields an empty array as `defaultResult`. `nullable = false`. No `checkInputDataTypes` override, so any input type is accepted (including STRUCT, ARRAY, MAP). Registered as both `collect_list` and `array_agg` aliases in `FunctionRegistry`.
- Spark 3.5.8 (audited 2026-06-24): identical to 3.4.3.
- Spark 4.0.1 (audited 2026-06-24): only structural change is adding `with UnaryLike[Expression]` to the case class (no behavior change).
- Spark 4.1.1 (audited 2026-06-24): identical to 4.0.1.
- Comet implementation: native side delegates to `datafusion_spark::function::aggregate::collect::SparkCollectList`, which wraps `ArrayAggAccumulator` with `ignore_nulls = true` and converts a final NULL accumulator state to an empty array (matching Spark's `defaultResult`). The native return type is `List(Field, containsNull = true)`, while Spark uses `containsNull = false`. Because nulls are filtered before insertion, no nulls actually appear in the array, so this is a schema-shape difference only and tests using `checkSparkAnswerAndOperator` accept it (same pattern already in use for `collect_set`).

[Spark Expression Support]: ../../user-guide/latest/expressions.md
4 changes: 2 additions & 2 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ The tables below list every Spark built-in expression with its current status.
| `any` | ✅ | |
| `any_value` | ✅ | |
| `approx_count_distinct` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) |
| `array_agg` | 🔜 | Array aggregate (related to `collect_list`, [#2524](https://github.com/apache/datafusion-comet/issues/2524)) |
| `array_agg` | | Alias for `collect_list` |
| `avg` | ✅ | Interval types fall back |
| `bit_and` | ✅ | |
| `bit_or` | ✅ | |
| `bit_xor` | ✅ | |
| `bool_and` | ✅ | |
| `bool_or` | ✅ | |
| `collect_list` | 🔜 | [#2524](https://github.com/apache/datafusion-comet/issues/2524) |
| `collect_list` | | |
| `collect_set` | ✅ | |
| `corr` | ✅ | |
| `count` | ✅ | |
Expand Down
7 changes: 6 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ use datafusion_comet_spark_expr::{
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc,
SparkBloomFilterVersion, SumInteger, ToCsv,
};
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
use datafusion_spark::function::aggregate::collect::{SparkCollectList, SparkCollectSet};
use iceberg::expr::Bind;

use crate::execution::operators::ExecutionError::GeneralError;
Expand Down Expand Up @@ -2558,6 +2558,11 @@ impl PhysicalPlanner {
let func = AggregateUDF::new_from_impl(SparkCollectSet::new());
Self::create_aggr_func_expr("collect_set", schema, vec![child], func)
}
AggExprStruct::CollectList(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let func = AggregateUDF::new_from_impl(SparkCollectList::new());
Self::create_aggr_func_expr("collect_list", schema, vec![child], func)
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ message AggExpr {
Correlation correlation = 15;
BloomFilterAgg bloomFilterAgg = 16;
CollectSet collectSet = 17;
CollectList collectList = 18;
}

// Optional filter expression for SQL FILTER (WHERE ...) clause.
Expand Down Expand Up @@ -267,6 +268,11 @@ message CollectSet {
DataType datatype = 2;
}

message CollectList {
Expr child = 1;
DataType datatype = 2;
}

enum EvalMode {
LEGACY = 0;
TRY = 1;
Expand Down
20 changes: 19 additions & 1 deletion spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
Expand Down Expand Up @@ -836,6 +836,24 @@ case class CometExecRule(session: SparkSession)
}
}
}

// CollectList/CollectSet round-trip an ArrayType buffer that Spark declares as BinaryType.
// In a multi-stage aggregate with a PartialMerge stage (e.g. Spark's distinct-aggregate
// rewrite), Comet cannot represent that buffer consistently across the intermediate stages
// (issue #4724), so a fully-native pipeline crashes. Force the whole chain to fall back to
// Spark by tagging the feeding pure-Partial; the PartialMerge/Final stages then fall back
// via the buffer-source check in doConvert.
if (agg.aggregateExpressions.exists(_.mode == PartialMerge) &&
QueryPlanSerde.hasIncompatibleBufferAgg(agg.aggregateExpressions)) {
findPartialAggInPlan(agg.child).foreach { partial =>
if (canAggregateBeConverted(partial, Partial)) {
partial.setTagValue(
CometExecRule.COMET_UNSAFE_PARTIAL,
"Partial aggregate disabled: part of a multi-stage CollectList/CollectSet " +
"aggregate whose intermediate buffer cannot round-trip in Comet (issue #4724)")
}
}
}
case _ =>
}
}
Expand Down
17 changes: 17 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
classOf[BitOrAgg] -> CometBitOrAgg,
classOf[BitXorAgg] -> CometBitXOrAgg,
classOf[BloomFilterAggregate] -> CometBloomFilterAggregate,
classOf[CollectList] -> CometCollectList,
classOf[CollectSet] -> CometCollectSet,
classOf[Corr] -> CometCorr,
classOf[Count] -> CometCount,
Expand Down Expand Up @@ -413,6 +414,22 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
}
}

/**
* Returns true if any aggregate function produces a native intermediate buffer whose Arrow type
* (e.g. ArrayType for CollectList/CollectSet) differs from the BinaryType that Spark declares
* for its serialized TypedImperativeAggregate buffer. Comet cannot interpret Spark's Binary
* buffer for these functions, and cannot yet represent the buffer consistently across the
* intermediate PartialMerge stages of a multi-stage aggregate (issue #4724). Such aggregates
* are therefore only safe to run natively when every stage runs in Comet and there are at most
* two stages (Partial + Final).
*/
def hasIncompatibleBufferAgg(aggExprs: Seq[AggregateExpression]): Boolean = {
aggExprs.exists(_.aggregateFunction match {
case _: CollectList | _: CollectSet => true
case _ => false
})
}

// A unique id for each expression. ~used to look up QueryContext during error creation.
private val exprIdCounter = new AtomicLong(0)

Expand Down
34 changes: 33 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package org.apache.comet.serde
import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectList, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, ShortType, StringType}

Expand Down Expand Up @@ -740,6 +740,38 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] {
}
}

object CometCollectList extends CometAggregateExpressionSerde[CollectList] {

override def convert(
aggExpr: AggregateExpression,
expr: CollectList,
inputs: Seq[Attribute],
binding: Boolean,
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
val child = expr.children.head
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(expr.dataType)

if (childExpr.isDefined && dataType.isDefined) {
val builder = ExprOuterClass.CollectList.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)

Some(
ExprOuterClass.AggExpr
.newBuilder()
.setCollectList(builder)
.build())
} else if (dataType.isEmpty) {
withFallbackReason(aggExpr, s"datatype ${expr.dataType} is not supported", child)
None
} else {
withFallbackReason(aggExpr, child)
None
}
}
}

object AggSerde {
import org.apache.spark.sql.types._

Expand Down
Loading
Loading