diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 66501ebe7d5c8..f190f8ca5055e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -695,6 +695,8 @@ case class Cast( // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { + case _: StringType if ansiEnabled => + buildCast[UTF8String](_, s => UTF8StringUtils.toBooleanExact(s, getContextOrNull())) case _: StringType => buildCast[UTF8String](_, s => { if (StringUtils.isTrueString(s)) { @@ -702,11 +704,7 @@ case class Cast( } else if (StringUtils.isFalseString(s)) { false } else { - if (ansiEnabled) { - throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull()) - } else { - null - } + null } }) case TimestampType => @@ -1881,22 +1879,20 @@ case class Cast( private[this] def castToBooleanCode( from: DataType, ctx: CodegenContext): CastFunction = from match { + case _: StringType if ansiEnabled => + val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$") + val errorContext = getContextOrNullCode(ctx) + (c, evPrim, _) => code"$evPrim = $stringUtils.toBooleanExact($c, $errorContext);" case _: StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - val castFailureCode = if (ansiEnabled) { - val errorContext = getContextOrNullCode(ctx) - s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);" - } else { - s"$evNull = true;" - } code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { $evPrim = false; } else { - $castFailureCode + $evNull = true; } """ case TimestampType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index 1c3a5075dab2c..5f9aa4695f50d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, Sh import org.apache.spark.unsafe.types.UTF8String /** - * Helper functions for casting string to numeric values. + * Helper functions for casting string to primitive values under ANSI mode. */ object UTF8StringUtils { @@ -39,6 +39,12 @@ object UTF8StringUtils { def toByteExact(s: UTF8String, context: QueryContext): Byte = withException(s.toByteExact, context, ByteType, s) + def toBooleanExact(s: UTF8String, context: QueryContext): Boolean = { + if (StringUtils.isTrueString(s)) true + else if (StringUtils.isFalseString(s)) false + else throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, context) + } + private def withException[A]( f: => A, context: QueryContext,