diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java new file mode 100644 index 000000000000..1aece7a91b26 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.QueryContext; +import org.apache.spark.sql.errors.QueryExecutionErrors; + +/** + * Static helpers used by {@link ElementAt} on {@code ArrayType} + * (codegen and eval) under ANSI mode. + */ +public final class ElementAtUtils { + + private ElementAtUtils() {} + + /** + * Resolves the user-supplied 1-based {@code element_at} index to a + * 0-based array position. Throws when the absolute index exceeds the + * array length (ANSI out-of-bounds) or when {@code index} is zero + * (always invalid). + * + * @param length the array length + * @param index the 1-based index supplied by the user (positive or negative) + * @param context the query context attached to the error + * @return the resolved 0-based position + */ + public static int resolveArrayIndex(int length, int index, QueryContext context) { + if (length < Math.abs(index)) { + throw QueryExecutionErrors.invalidElementAtIndexError(index, length, context); + } + if (index == 0) { + throw QueryExecutionErrors.invalidIndexOfZeroError(context); + } + return index > 0 ? index - 1 : length + index; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 60966f3098ca..4f699de137c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2738,19 +2738,23 @@ case class ElementAt( override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) @transient private lazy val doElementAt: (Any, Any) => Any = left.dataType match { + // ArrayType is split into ANSI (failOnError) and non-ANSI branches. + // Order matters: the guarded case must come first. + case _: ArrayType if failOnError => + (value, ordinal) => { + val array = value.asInstanceOf[ArrayData] + val idx = ElementAtUtils.resolveArrayIndex( + array.numElements(), ordinal.asInstanceOf[Int], getContextOrNull()) + if (arrayElementNullable && array.isNullAt(idx)) null else array.get(idx, dataType) + } case _: ArrayType => (value, ordinal) => { val array = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { - if (failOnError) { - throw QueryExecutionErrors.invalidElementAtIndexError( - index, array.numElements(), getContextOrNull()) - } else { - defaultValueOutOfBound match { - case Some(value) => value.eval() - case None => null - } + defaultValueOutOfBound match { + case Some(value) => value.eval() + case None => null } } else { val idx = if (index == 0) { @@ -2773,6 +2777,31 @@ case class ElementAt( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { left.dataType match { + // ArrayType is split into ANSI (failOnError) and non-ANSI branches. + // Order matters: the guarded case must come first. + case _: ArrayType if failOnError => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val errorContext = getContextOrNullCode(ctx) + val utils = classOf[ElementAtUtils].getName + val assignment = s"${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};" + val body = if (arrayElementNullable) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else { + | $assignment + |} + """.stripMargin + } else { + assignment + } + s""" + |int $index = $utils.resolveArrayIndex( + | $eval1.numElements(), (int) $eval2, $errorContext); + |$body + """.stripMargin + }) case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") @@ -2786,21 +2815,15 @@ case class ElementAt( "" } val errorContext = getContextOrNullCode(ctx) - val indexOutOfBoundBranch = if (failOnError) { - // scalastyle:off line.size.limit - s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements(), $errorContext);" - // scalastyle:on line.size.limit - } else { - defaultValueOutOfBound match { - case Some(value) => - val defaultValueEval = value.genCode(ctx) - s""" - ${defaultValueEval.code} - ${ev.isNull} = ${defaultValueEval.isNull}; - ${ev.value} = ${defaultValueEval.value}; - """.stripMargin - case None => s"${ev.isNull} = true;" - } + val indexOutOfBoundBranch = defaultValueOutOfBound match { + case Some(value) => + val defaultValueEval = value.genCode(ctx) + s""" + ${defaultValueEval.code} + ${ev.isNull} = ${defaultValueEval.isNull}; + ${ev.value} = ${defaultValueEval.value}; + """.stripMargin + case None => s"${ev.isNull} = true;" } s"""