diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java index 700f7e41d2336..a2e427b4a4ce2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -17,19 +17,15 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.QueryContext; import org.apache.spark.sql.errors.QueryExecutionErrors; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; /** - * Static helpers used by {@code Cast.doGenCode} (and corresponding eval - * paths) for ANSI overflow-checked narrowing to {@code byte} / {@code short}. - * - *
Narrowing to {@code int} / {@code long} is handled by calling the existing - * {@code LongExactNumeric} / {@code FloatExactNumeric} / {@code DoubleExactNumeric} - * Scala objects directly from codegen (see SPARK-56909). The helpers below - * cover {@code byte} / {@code short} only, since {@code ByteExactNumeric} / - * {@code ShortExactNumeric} don't expose a cross-type narrowing API. + * Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths) + * for ANSI overflow-checked casts. * *
The source and target {@link DataType} objects referenced by the overflow * error message are held in {@code private static final} fields so the happy @@ -47,6 +43,9 @@ private CastUtils() {} private static final DataType DOUBLE = DataTypes.DoubleType; // ----- integral narrowing (ANSI: throw on overflow) ----- + // byte / short narrowing only; int / long narrowing is handled by calling the existing + // LongExactNumeric Scala object directly from codegen (see SPARK-56909). ByteExactNumeric / + // ShortExactNumeric don't expose a cross-type narrowing API, so a Java helper is the fit here. public static byte shortToByteExact(short v) { if (v == (byte) v) return (byte) v; @@ -95,4 +94,22 @@ public static short doubleToShortExact(double v) { if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) return (short) v; throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, SHORT); } + + // ----- decimal precision adjustment ----- + // Mutates the input Decimal in place to avoid the per-row clone() done by + // Decimal.toPrecision, since these helpers are called on the per-row hot path. + // On overflow, Decimal.changePrecision returns false before writing back any of + // decimalVal / longVal / _precision / _scale, so `d` is still in its original + // externally-visible state when changePrecisionExact throws -- the error message + // therefore cites the original (pre-cast) value. + + public static Decimal changePrecisionExact( + Decimal d, int precision, int scale, QueryContext context) { + if (d.changePrecision(precision, scale)) return d; + throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(d, precision, scale, context); + } + + public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) { + return d.changePrecision(precision, scale) ? d : null; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0611c3e9bfb3b..66501ebe7d5c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1097,15 +1097,11 @@ case class Cast( value: Decimal, decimalType: DecimalType, nullOnOverflow: Boolean): Decimal = { - if (value.changePrecision(decimalType.precision, decimalType.scale)) { - value + if (nullOnOverflow) { + CastUtils.changePrecisionOrNull(value, decimalType.precision, decimalType.scale) } else { - if (nullOnOverflow) { - null - } else { - throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - value, decimalType.precision, decimalType.scale, getContextOrNull()) - } + CastUtils.changePrecisionExact( + value, decimalType.precision, decimalType.scale, getContextOrNull()) } } @@ -1558,23 +1554,21 @@ case class Cast( |$d.changePrecision(${decimalType.precision}, ${decimalType.scale}); |$evPrim = $d; """.stripMargin - } else { - val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) - val overflowCode = if (nullOnOverflow) { - s"$evNull = true;" - } else { - s""" - |throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode); - """.stripMargin - } + } else if (nullOnOverflow) { code""" |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { | $evPrim = $d; |} else { - | $overflowCode + | $evNull = true; |} """.stripMargin + } else { + val errorContextCode = getContextOrNullCode(ctx) + val castUtils = classOf[CastUtils].getName + code""" + |$evPrim = $castUtils.changePrecisionExact( + | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode); + """.stripMargin } }