diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DateTimeExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DateTimeExpressionUtils.java new file mode 100644 index 0000000000000..0413278d0cb86 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DateTimeExpressionUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.time.DateTimeException; +import java.time.LocalDate; + +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.IntervalUtils; +import org.apache.spark.sql.errors.QueryExecutionErrors; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; + +/** + * Static helpers shared by date/time/interval expression {@code doGenCode} + * paths (and their corresponding eval paths). The methods here wrap the + * {@code DateTimeUtils} / {@code IntervalUtils} routines whose checked + * exceptions need to be translated into the user-facing ANSI errors. + */ +public final class DateTimeExpressionUtils { + + private DateTimeExpressionUtils() {} + + /** + * Builds a day count for {@code MakeDate(year, month, day)} in ANSI mode. + * Only the {@link DateTimeException} thrown by + * {@link LocalDate#of(int, int, int)} for invalid year/month/day is caught + * and converted to {@code ansiDateTimeArgumentOutOfRange}. Any other + * exception thrown by {@code DateTimeUtils.localDateToDays} (e.g. a + * day-count overflow from its internal {@code toIntExact}) propagates to + * the caller unchanged. + */ + public static int makeDateExact(int year, int month, int day) { + try { + return DateTimeUtils.localDateToDays(LocalDate.of(year, month, day)); + } catch (DateTimeException e) { + throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e); + } + } + + /** + * Builds a {@link CalendarInterval} for + * {@code MakeInterval(years, months, weeks, days, hours, mins, secs)} in + * ANSI mode. Throws a {@code SparkArithmeticException} if any of the + * intermediate {@code Math.*Exact} calls overflows. + */ + public static CalendarInterval makeIntervalExact( + int years, int months, int weeks, int days, + int hours, int mins, Decimal secs) { + try { + return IntervalUtils.makeInterval(years, months, weeks, days, hours, mins, secs); + } catch (ArithmeticException e) { + throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", null); + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index aa4ed692d5745..a724f02cd107e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -2572,30 +2572,36 @@ case class MakeDate( override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def nullSafeEval(year: Any, month: Any, day: Any): Any = { - try { - val ld = LocalDate.of(year.asInstanceOf[Int], month.asInstanceOf[Int], day.asInstanceOf[Int]) - localDateToDays(ld) - } catch { - case e: java.time.DateTimeException => - if (failOnError) throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e) else null + if (failOnError) { + DateTimeExpressionUtils.makeDateExact( + year.asInstanceOf[Int], month.asInstanceOf[Int], day.asInstanceOf[Int]) + } else { + try { + val ld = LocalDate.of( + year.asInstanceOf[Int], month.asInstanceOf[Int], day.asInstanceOf[Int]) + localDateToDays(ld) + } catch { + case _: java.time.DateTimeException => null + } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val failOnErrorBranch = if (failOnError) { - "throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e);" + if (failOnError) { + val utils = classOf[DateTimeExpressionUtils].getName + defineCodeGen(ctx, ev, (year, month, day) => + s"$utils.makeDateExact($year, $month, $day)") } else { - s"${ev.isNull} = true;" + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + nullSafeCodeGen(ctx, ev, (year, month, day) => { + s""" + try { + ${ev.value} = $dtu.localDateToDays(java.time.LocalDate.of($year, $month, $day)); + } catch (java.time.DateTimeException e) { + ${ev.isNull} = true; + }""" + }) } - nullSafeCodeGen(ctx, ev, (year, month, day) => { - s""" - try { - ${ev.value} = $dtu.localDateToDays(java.time.LocalDate.of($year, $month, $day)); - } catch (java.time.DateTimeException e) { - $failOnErrorBranch - }""" - }) } override def prettyName: String = "make_date" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 653ee9f836edd..3e4d6772c4fc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -399,42 +399,57 @@ case class MakeInterval( hour: Any, min: Any, sec: Option[Any]): Any = { - try { - IntervalUtils.makeInterval( + val secs = sec.map(_.asInstanceOf[Decimal]).getOrElse(Decimal(0, Decimal.MAX_LONG_DIGITS, 6)) + if (failOnError) { + DateTimeExpressionUtils.makeIntervalExact( year.asInstanceOf[Int], month.asInstanceOf[Int], week.asInstanceOf[Int], day.asInstanceOf[Int], hour.asInstanceOf[Int], min.asInstanceOf[Int], - sec.map(_.asInstanceOf[Decimal]).getOrElse(Decimal(0, Decimal.MAX_LONG_DIGITS, 6))) - } catch { - case e: ArithmeticException => - if (failOnError) { - throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage) - } else { - null - } + secs) + } else { + try { + IntervalUtils.makeInterval( + year.asInstanceOf[Int], + month.asInstanceOf[Int], + week.asInstanceOf[Int], + day.asInstanceOf[Int], + hour.asInstanceOf[Int], + min.asInstanceOf[Int], + secs) + } catch { + case _: ArithmeticException => null + } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (year, month, week, day, hour, min, sec) => { - val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val secFrac = sec.getOrElse("0") - val failOnErrorBranch = if (failOnError) { - """throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", null);""" - } else { - s"${ev.isNull} = true;" - } - s""" - try { - ${ev.value} = $iu.makeInterval($year, $month, $week, $day, $hour, $min, $secFrac); - } catch (java.lang.ArithmeticException e) { - $failOnErrorBranch - } - """ - }) + if (failOnError) { + val utils = classOf[DateTimeExpressionUtils].getName + nullSafeCodeGen(ctx, ev, (year, month, week, day, hour, min, sec) => { + // `MakeInterval` always passes 7 children (auxiliary constructors fill + // missing slots with `Literal(0)` / `Literal(Decimal(0, ...))`), so the + // 7th codegen value is always present. + val secFrac = sec.get + s"${ev.value} = $utils.makeIntervalExact(" + + s"$year, $month, $week, $day, $hour, $min, $secFrac);" + }) + } else { + nullSafeCodeGen(ctx, ev, (year, month, week, day, hour, min, sec) => { + val iu = IntervalUtils.getClass.getName.stripSuffix("$") + // `MakeInterval` always passes 7 children (see note above). + val secFrac = sec.get + s""" + try { + ${ev.value} = $iu.makeInterval($year, $month, $week, $day, $hour, $min, $secFrac); + } catch (java.lang.ArithmeticException e) { + ${ev.isNull} = true; + } + """ + }) + } } override def prettyName: String = "make_interval"