From b186a2aeaf2e5100551b1af7d759d84d0d9df2fd Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 17 May 2026 23:09:16 +0000 Subject: [PATCH 1/2] [SPARK-56912][SQL] Refactor Cast to boolean codegen under ANSI mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Extend `CastUtils.java` with `stringToBooleanExact(UTF8String, QueryContext)` and use it from `Cast.scala` for the ANSI `String -> Boolean` cast path (both eval and codegen). The non-ANSI path keeps the inline `if/else if/else evNull = true` form because it has no error to throw. ### Why are the changes needed? Part of SPARK-56908 (umbrella). The ANSI String->Boolean cast emits an 8-line `if (isTrueString) … else if (isFalseString) … else throw` block in codegen. This PR collapses it to a one-line `CastUtils .stringToBooleanExact(...)` call. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite \ *AnsiCastSuite *TryCastSuite" ``` 204/204 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../sql/catalyst/expressions/CastUtils.java | 10 ++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 20 ++++++++----------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java index a2e427b4a4ce2..01877de2deecd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -18,10 +18,12 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.QueryContext; +import org.apache.spark.sql.catalyst.util.StringUtils; import org.apache.spark.sql.errors.QueryExecutionErrors; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; /** * Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths) @@ -112,4 +114,12 @@ public static Decimal changePrecisionExact( public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) { return d.changePrecision(precision, scale) ? d : null; } + + // ----- string -> boolean (ANSI: throw on invalid syntax) ----- + + public static boolean stringToBooleanExact(UTF8String s, QueryContext context) { + if (StringUtils.isTrueString(s)) return true; + if (StringUtils.isFalseString(s)) return false; + throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, context); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 66501ebe7d5c8..b595078421d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -695,6 +695,8 @@ case class Cast( // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { + case _: StringType if ansiEnabled => + buildCast[UTF8String](_, s => CastUtils.stringToBooleanExact(s, getContextOrNull())) case _: StringType => buildCast[UTF8String](_, s => { if (StringUtils.isTrueString(s)) { @@ -702,11 +704,7 @@ case class Cast( } else if (StringUtils.isFalseString(s)) { false } else { - if (ansiEnabled) { - throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull()) - } else { - null - } + null } }) case TimestampType => @@ -1881,22 +1879,20 @@ case class Cast( private[this] def castToBooleanCode( from: DataType, ctx: CodegenContext): CastFunction = from match { + case _: StringType if ansiEnabled => + val castUtils = classOf[CastUtils].getName + val errorContext = getContextOrNullCode(ctx) + (c, evPrim, _) => code"$evPrim = $castUtils.stringToBooleanExact($c, $errorContext);" case _: StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - val castFailureCode = if (ansiEnabled) { - val errorContext = getContextOrNullCode(ctx) - s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);" - } else { - s"$evNull = true;" - } code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { $evPrim = false; } else { - $castFailureCode + $evNull = true; } """ case TimestampType => From 83a3cbb112aa17de2cf47456b6c497c72f50b2fb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 28 May 2026 12:44:33 +0000 Subject: [PATCH 2/2] address cloud-fan review: move stringToBooleanExact to UTF8StringUtils.toBooleanExact, peer the other string ANSI helpers --- .../spark/sql/catalyst/expressions/CastUtils.java | 10 ---------- .../apache/spark/sql/catalyst/expressions/Cast.scala | 6 +++--- .../spark/sql/catalyst/util/UTF8StringUtils.scala | 8 +++++++- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java index 01877de2deecd..a2e427b4a4ce2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -18,12 +18,10 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.QueryContext; -import org.apache.spark.sql.catalyst.util.StringUtils; import org.apache.spark.sql.errors.QueryExecutionErrors; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.types.UTF8String; /** * Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths) @@ -114,12 +112,4 @@ public static Decimal changePrecisionExact( public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) { return d.changePrecision(precision, scale) ? d : null; } - - // ----- string -> boolean (ANSI: throw on invalid syntax) ----- - - public static boolean stringToBooleanExact(UTF8String s, QueryContext context) { - if (StringUtils.isTrueString(s)) return true; - if (StringUtils.isFalseString(s)) return false; - throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, context); - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b595078421d74..f190f8ca5055e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -696,7 +696,7 @@ case class Cast( // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case _: StringType if ansiEnabled => - buildCast[UTF8String](_, s => CastUtils.stringToBooleanExact(s, getContextOrNull())) + buildCast[UTF8String](_, s => UTF8StringUtils.toBooleanExact(s, getContextOrNull())) case _: StringType => buildCast[UTF8String](_, s => { if (StringUtils.isTrueString(s)) { @@ -1880,9 +1880,9 @@ case class Cast( from: DataType, ctx: CodegenContext): CastFunction = from match { case _: StringType if ansiEnabled => - val castUtils = classOf[CastUtils].getName + val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") val errorContext = getContextOrNullCode(ctx) - (c, evPrim, _) => code"$evPrim = $castUtils.stringToBooleanExact($c, $errorContext);" + (c, evPrim, _) => code"$evPrim = $stringUtils.toBooleanExact($c, $errorContext);" case _: StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index 1c3a5075dab2c..5f9aa4695f50d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, Sh import org.apache.spark.unsafe.types.UTF8String /** - * Helper functions for casting string to numeric values. + * Helper functions for casting string to primitive values under ANSI mode. */ object UTF8StringUtils { @@ -39,6 +39,12 @@ object UTF8StringUtils { def toByteExact(s: UTF8String, context: QueryContext): Byte = withException(s.toByteExact, context, ByteType, s) + def toBooleanExact(s: UTF8String, context: QueryContext): Boolean = { + if (StringUtils.isTrueString(s)) true + else if (StringUtils.isFalseString(s)) false + else throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, context) + } + private def withException[A]( f: => A, context: QueryContext,