diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1c93a65867615..348b45472c57d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -301,32 +301,21 @@ abstract class BinaryArithmetic extends BinaryOperator with SupportQueryContext val mathUtils = IntervalMathUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)") // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType if failOnError => + val methodName = symbol match { + case "+" => "plus" + case "-" => "minus" + case "*" => "times" + case _ => + throw SparkException.internalError( + s"Unexpected symbol '$symbol' for Byte/Short BinaryArithmetic") + } + val numericObj = (if (dataType == ByteType) ByteExactNumeric else ShortExactNumeric) + .getClass.getCanonicalName.stripSuffix("$") + defineCodeGen(ctx, ev, (eval1, eval2) => s"$numericObj.$methodName($eval1, $eval2)") case ByteType | ShortType => - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val tmpResult = ctx.freshName("tmpResult") - val try_suggestion = symbol match { - case "+" => "try_add" - case "-" => "try_subtract" - case "*" => "try_multiply" - case _ => "unknown_function" - } - val overflowCheck = if (failOnError) { - val javaType = CodeGenerator.boxedType(dataType) - s""" - |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) { - | throw QueryExecutionErrors.binaryArithmeticCauseOverflowError( - | $eval1, "$symbol", $eval2, "$try_suggestion"); - |} - """.stripMargin - } else { - "" - } - s""" - |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2; - |$overflowCheck - |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult); - """.stripMargin - }) + defineCodeGen(ctx, ev, (eval1, eval2) => + s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case IntegerType | LongType if failOnError && exactMathMethod.isDefined => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val errorContext = getContextOrNullCode(ctx)