diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java index 5411aa684ea5f..3d7c5dccc7f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java @@ -19,10 +19,53 @@ import java.util.Arrays; import java.util.Comparator; +import org.apache.spark.QueryContext; import org.apache.spark.sql.catalyst.util.SQLOrderingUtil; +import org.apache.spark.sql.errors.QueryExecutionErrors; public class ArrayExpressionUtils { + // ANSI index helpers used by ArrayType expression codegen and eval paths. + + /** + * Resolves the user-supplied 1-based {@code element_at} index to a + * 0-based array position. Throws when the absolute index exceeds the + * array length (ANSI out-of-bounds) or when {@code index} is zero + * (always invalid). + * + * @param length the array length + * @param index the 1-based index supplied by the user (positive or negative) + * @param context the query context attached to the error + * @return the resolved 0-based position + */ + public static int resolveArrayIndex(int length, int index, QueryContext context) { + if (length < Math.abs(index)) { + throw QueryExecutionErrors.invalidElementAtIndexError(index, length, context); + } + if (index == 0) { + throw QueryExecutionErrors.invalidIndexOfZeroError(context); + } + return index > 0 ? index - 1 : length + index; + } + + /** + * Validates a 0-based {@code arr[idx]} index against the array length + * under ANSI mode. Throws when {@code index} is negative or + * {@code >= length}; otherwise returns {@code index} unchanged so the + * caller can chain into {@code arr.get(idx, dataType)}. + * + * @param length the array length + * @param index the 0-based index supplied by the user + * @param context the query context attached to the error + * @return the validated 0-based position (== {@code index}) + */ + public static int checkArrayIndex(int length, int index, QueryContext context) { + if (index < 0 || index >= length) { + throw QueryExecutionErrors.invalidArrayIndexError(index, length, context); + } + return index; + } + // comparator // Boolean ascending nullable comparator private static final Comparator booleanComp = (o1, o2) -> { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java deleted file mode 100644 index 1aece7a91b26e..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.QueryContext; -import org.apache.spark.sql.errors.QueryExecutionErrors; - -/** - * Static helpers used by {@link ElementAt} on {@code ArrayType} - * (codegen and eval) under ANSI mode. - */ -public final class ElementAtUtils { - - private ElementAtUtils() {} - - /** - * Resolves the user-supplied 1-based {@code element_at} index to a - * 0-based array position. Throws when the absolute index exceeds the - * array length (ANSI out-of-bounds) or when {@code index} is zero - * (always invalid). - * - * @param length the array length - * @param index the 1-based index supplied by the user (positive or negative) - * @param context the query context attached to the error - * @return the resolved 0-based position - */ - public static int resolveArrayIndex(int length, int index, QueryContext context) { - if (length < Math.abs(index)) { - throw QueryExecutionErrors.invalidElementAtIndexError(index, length, context); - } - if (index == 0) { - throw QueryExecutionErrors.invalidIndexOfZeroError(context); - } - return index > 0 ? index - 1 : length + index; - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4f699de137c93..b0396188bcdd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2743,7 +2743,7 @@ case class ElementAt( case _: ArrayType if failOnError => (value, ordinal) => { val array = value.asInstanceOf[ArrayData] - val idx = ElementAtUtils.resolveArrayIndex( + val idx = ArrayExpressionUtils.resolveArrayIndex( array.numElements(), ordinal.asInstanceOf[Int], getContextOrNull()) if (arrayElementNullable && array.isNullAt(idx)) null else array.get(idx, dataType) } @@ -2783,7 +2783,7 @@ case class ElementAt( nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") val errorContext = getContextOrNullCode(ctx) - val utils = classOf[ElementAtUtils].getName + val utils = classOf[ArrayExpressionUtils].getName val assignment = s"${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};" val body = if (arrayElementNullable) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 8cc71381ddab7..022e130ec3a5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -356,13 +356,11 @@ case class GetArrayItem( protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0) { - if (failOnError) { - throw QueryExecutionErrors.invalidArrayIndexError( - index, baseValue.numElements(), getContextOrNull()) - } else { - null - } + if (failOnError) { + ArrayExpressionUtils.checkArrayIndex(baseValue.numElements(), index, getContextOrNull()) + if (baseValue.isNullAt(index)) null else baseValue.get(index, dataType) + } else if (index >= baseValue.numElements() || index < 0) { + null } else if (baseValue.isNullAt(index)) { null } else { @@ -371,36 +369,52 @@ case class GetArrayItem( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val index = ctx.freshName("index") - val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull - val nullCheck = if (childArrayElementNullable) { - s"""else if ($eval1.isNullAt($index)) { - ${ev.isNull} = true; - } - """ - } else { - "" - } - - val indexOutOfBoundBranch = if (failOnError) { + // ANSI (failOnError) and non-ANSI paths generate different codegen. + if (failOnError) { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") val errorContext = getContextOrNullCode(ctx) - // scalastyle:off line.size.limit - s"throw QueryExecutionErrors.invalidArrayIndexError($index, $eval1.numElements(), $errorContext);" - // scalastyle:on line.size.limit - } else { - s"${ev.isNull} = true;" - } - - s""" - final int $index = (int) $eval2; - if ($index >= $eval1.numElements() || $index < 0) { - $indexOutOfBoundBranch - } $nullCheck else { - ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + val utils = classOf[ArrayExpressionUtils].getName + val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull + val assignment = s"${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};" + val body = if (childArrayElementNullable) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else { + | $assignment + |} + """.stripMargin + } else { + assignment } - """ - }) + s""" + |int $index = $utils.checkArrayIndex($eval1.numElements(), (int) $eval2, $errorContext); + |$body + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") + val childArrayElementNullable = child.dataType.asInstanceOf[ArrayType].containsNull + val nullCheck = if (childArrayElementNullable) { + s"""else if ($eval1.isNullAt($index)) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + s""" + final int $index = (int) $eval2; + if ($index >= $eval1.numElements() || $index < 0) { + ${ev.isNull} = true; + } $nullCheck else { + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + } + """ + }) + } } override protected def withNewChildrenInternal(