From 5ece29bda151765e2dd6b4d85794332a71537124 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 19 May 2026 05:08:49 +0000 Subject: [PATCH 1/2] [SPARK-56935][SQL] Simplify GetArrayItem codegen and consolidate ElementAtUtils into ArrayExpressionUtils ### What changes were proposed in this pull request? Two related changes: 1. Fold `ElementAtUtils.resolveArrayIndex` into the existing `ArrayExpressionUtils.java`, and remove `ElementAtUtils.java`. The per-expression naming chosen in SPARK-56916 didn't match the codebase's category-scoped utility-class convention (`ArrayExpressionUtils`, `BitmapExpressionUtils`, `ExpressionImplUtils`, ...) and there's now a natural home for any future array-expression ANSI helper. 2. Refactor `GetArrayItem`'s ANSI codegen + eval paths to use a new `ArrayExpressionUtils.checkArrayIndex(int length, int index, QueryContext context)` helper, mirroring how `ElementAt` uses `resolveArrayIndex`. The helper throws `invalidArrayIndexError` for negative / out-of-bound ANSI indices and returns the validated 0-based position so the caller chains into `arr.get(idx, dataType)`. The non-ANSI branch keeps its inline form because it must return `null` (not throw) on out-of-bound. Net effect: the existing per-expression `ElementAtUtils.java` is removed; the existing `ArrayExpressionUtils.java` grows two `*ArrayIndex` helpers used by `ElementAt` and `GetArrayItem` codegen + eval. ### Why are the changes needed? Part of SPARK-56908 (umbrella). `arr[idx]` and `element_at(arr, idx)` share the same ANSI out-of-bound error shape; collapsing both into one-line helper calls keeps the codegen size small and avoids maintaining two parallel inline forms. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *ComplexTypeSuite *CollectionExpressionsSuite" build/sbt "sql/testOnly *QueryExecutionAnsiErrorsSuite" ``` All pass (83/83 catalyst, 21/21 sql). ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../expressions/ArrayExpressionUtils.java | 43 +++++++++ .../catalyst/expressions/ElementAtUtils.java | 51 ----------- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeExtractors.scala | 87 +++++++++++-------- 4 files changed, 96 insertions(+), 89 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java index 5411aa684ea5f..3d7c5dccc7f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java @@ -19,10 +19,53 @@ import java.util.Arrays; import java.util.Comparator; +import org.apache.spark.QueryContext; import org.apache.spark.sql.catalyst.util.SQLOrderingUtil; +import org.apache.spark.sql.errors.QueryExecutionErrors; public class ArrayExpressionUtils { + // ANSI index helpers used by ArrayType expression codegen and eval paths. + + /** + * Resolves the user-supplied 1-based {@code element_at} index to a + * 0-based array position. Throws when the absolute index exceeds the + * array length (ANSI out-of-bounds) or when {@code index} is zero + * (always invalid). + * + * @param length the array length + * @param index the 1-based index supplied by the user (positive or negative) + * @param context the query context attached to the error + * @return the resolved 0-based position + */ + public static int resolveArrayIndex(int length, int index, QueryContext context) { + if (length < Math.abs(index)) { + throw QueryExecutionErrors.invalidElementAtIndexError(index, length, context); + } + if (index == 0) { + throw QueryExecutionErrors.invalidIndexOfZeroError(context); + } + return index > 0 ? index - 1 : length + index; + } + + /** + * Validates a 0-based {@code arr[idx]} index against the array length + * under ANSI mode. Throws when {@code index} is negative or + * {@code >= length}; otherwise returns {@code index} unchanged so the + * caller can chain into {@code arr.get(idx, dataType)}. + * + * @param length the array length + * @param index the 0-based index supplied by the user + * @param context the query context attached to the error + * @return the validated 0-based position (== {@code index}) + */ + public static int checkArrayIndex(int length, int index, QueryContext context) { + if (index < 0 || index >= length) { + throw QueryExecutionErrors.invalidArrayIndexError(index, length, context); + } + return index; + } + // comparator // Boolean ascending nullable comparator private static final Comparator booleanComp = (o1, o2) -> { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java deleted file mode 100644 index 1aece7a91b26e..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.QueryContext; -import org.apache.spark.sql.errors.QueryExecutionErrors; - -/** - * Static helpers used by {@link ElementAt} on {@code ArrayType} - * (codegen and eval) under ANSI mode. - */ -public final class ElementAtUtils { - - private ElementAtUtils() {} - - /** - * Resolves the user-supplied 1-based {@code element_at} index to a - * 0-based array position. Throws when the absolute index exceeds the - * array length (ANSI out-of-bounds) or when {@code index} is zero - * (always invalid). - * - * @param length the array length - * @param index the 1-based index supplied by the user (positive or negative) - * @param context the query context attached to the error - * @return the resolved 0-based position - */ - public static int resolveArrayIndex(int length, int index, QueryContext context) { - if (length < Math.abs(index)) { - throw QueryExecutionErrors.invalidElementAtIndexError(index, length, context); - } - if (index == 0) { - throw QueryExecutionErrors.invalidIndexOfZeroError(context); - } - return index > 0 ? index - 1 : length + index; - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4f699de137c93..b0396188bcdd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2743,7 +2743,7 @@ case class ElementAt( case _: ArrayType if failOnError => (value, ordinal) => { val array = value.asInstanceOf[ArrayData] - val idx = ElementAtUtils.resolveArrayIndex( + val idx = ArrayExpressionUtils.resolveArrayIndex( array.numElements(), ordinal.asInstanceOf[Int], getContextOrNull()) if (arrayElementNullable && array.isNullAt(idx)) null else array.get(idx, dataType) } @@ -2783,7 +2783,7 @@ case class ElementAt( nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") val errorContext = getContextOrNullCode(ctx) - val utils = classOf[ElementAtUtils].getName + val utils = classOf[ArrayExpressionUtils].getName val assignment = s"${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};" val body = if (arrayElementNullable) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 8cc71381ddab7..f835ebccbf052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -356,13 +356,11 @@ case class GetArrayItem( protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0) { - if (failOnError) { - throw QueryExecutionErrors.invalidArrayIndexError( - index, baseValue.numElements(), getContextOrNull()) - } else { - null - } + if (failOnError) { + ArrayExpressionUtils.checkArrayIndex(baseValue.numElements(), index, getContextOrNull()) + if (baseValue.isNullAt(index)) null else baseValue.get(index, dataType) + } else if (index >= baseValue.numElements() || index < 0) { + null } else if (baseValue.isNullAt(index)) { null } else { @@ -371,36 +369,53 @@ case class GetArrayItem( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val index = ctx.freshName("index") - val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull - val nullCheck = if (childArrayElementNullable) { - s"""else if ($eval1.isNullAt($index)) { - ${ev.isNull} = true; - } - """ - } else { - "" - } - - val indexOutOfBoundBranch = if (failOnError) { + // ArrayType is split into ANSI (failOnError) and non-ANSI branches. + // Order matters: the guarded case must come first. + if (failOnError) { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") val errorContext = getContextOrNullCode(ctx) - // scalastyle:off line.size.limit - s"throw QueryExecutionErrors.invalidArrayIndexError($index, $eval1.numElements(), $errorContext);" - // scalastyle:on line.size.limit - } else { - s"${ev.isNull} = true;" - } - - s""" - final int $index = (int) $eval2; - if ($index >= $eval1.numElements() || $index < 0) { - $indexOutOfBoundBranch - } $nullCheck else { - ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + val utils = classOf[ArrayExpressionUtils].getName + val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull + val assignment = s"${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};" + val body = if (childArrayElementNullable) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else { + | $assignment + |} + """.stripMargin + } else { + assignment } - """ - }) + s""" + |int $index = $utils.checkArrayIndex($eval1.numElements(), (int) $eval2, $errorContext); + |$body + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") + val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull + val nullCheck = if (childArrayElementNullable) { + s"""else if ($eval1.isNullAt($index)) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + s""" + final int $index = (int) $eval2; + if ($index >= $eval1.numElements() || $index < 0) { + ${ev.isNull} = true; + } $nullCheck else { + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + } + """ + }) + } } override protected def withNewChildrenInternal( From 6fd8aad545b8b19d9ea52dbbf4ab8be1ba367e32 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 22 May 2026 05:58:50 -0700 Subject: [PATCH 2/2] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala Co-authored-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/complexTypeExtractors.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f835ebccbf052..022e130ec3a5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -369,8 +369,7 @@ case class GetArrayItem( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // ArrayType is split into ANSI (failOnError) and non-ANSI branches. - // Order matters: the guarded case must come first. + // ANSI (failOnError) and non-ANSI paths generate different codegen. if (failOnError) { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index")