Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,20 @@ impl ScalarUDFImpl for ArrowCastFunc {

fn simplify(
&self,
mut args: Vec<Expr>,
args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
// convert this into a real cast
let target_type = data_type_from_args(self.name(), &args)?;
// remove second (type) argument
args.pop().unwrap();
let arg = args.pop().unwrap();

let source_type = info.get_data_type(&arg)?;
let [source_arg, type_arg] = take_function_args(self.name(), args)?;
let target_type = data_type_from_type_arg(self.name(), &type_arg)?;
let source_type = info.get_data_type(&source_arg)?;
let new_expr = if source_type == target_type {
// the argument's data type is already the correct type
arg
source_arg
} else {
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
expr: Box::new(source_arg),
field: target_type.into_nullable_field_ref(),
})
};
Expand All @@ -183,10 +180,8 @@ impl ScalarUDFImpl for ArrowCastFunc {
}
}

/// Returns the requested type from the arguments
pub(crate) fn data_type_from_args(name: &str, args: &[Expr]) -> Result<DataType> {
let [_, type_arg] = take_function_args(name, args)?;

/// Returns the requested type from the type argument
pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result<DataType> {
let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else {
return exec_err!(
"{name} requires its second argument to be a constant string, got {:?}",
Expand Down
16 changes: 7 additions & 9 deletions datafusion/functions/src/core/arrow_try_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_expr::{
};
use datafusion_macros::user_doc;

use super::arrow_cast::data_type_from_args;
use super::arrow_cast::data_type_from_type_arg;

/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring.
///
Expand Down Expand Up @@ -127,20 +127,18 @@ impl ScalarUDFImpl for ArrowTryCastFunc {

fn simplify(
&self,
mut args: Vec<Expr>,
args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
let target_type = data_type_from_args(self.name(), &args)?;
// remove second (type) argument
args.pop().unwrap();
let arg = args.pop().unwrap();
let [source_arg, type_arg] = take_function_args(self.name(), args)?;
let target_type = data_type_from_type_arg(self.name(), &type_arg)?;

let source_type = info.get_data_type(&arg)?;
let source_type = info.get_data_type(&source_arg)?;
let new_expr = if source_type == target_type {
arg
source_arg
} else {
Expr::TryCast(datafusion_expr::TryCast {
expr: Box::new(arg),
expr: Box::new(source_arg),
field: target_type.into_nullable_field_ref(),
})
};
Expand Down
146 changes: 146 additions & 0 deletions datafusion/functions/src/core/cast_to_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`CastToTypeFunc`]: Implementation of the `cast_to_type` function

use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{Result, internal_err, utils::take_function_args};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

/// Casts the first argument to the data type of the second argument.
///
/// Only the type of the second argument is used; its value is ignored.
/// This is useful in macros or generic SQL where you need to preserve
/// or match types dynamically.
///
/// For example:
/// ```sql
/// select cast_to_type('42', NULL::INTEGER);
/// ```
#[user_doc(
doc_section(label = "Other Functions"),
description = "Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored.",
syntax_example = "cast_to_type(expression, reference)",
sql_example = r#"```sql
> select cast_to_type('42', NULL::INTEGER) as a;
+----+
| a |
+----+
| 42 |
+----+

> select cast_to_type(1 + 2, NULL::DOUBLE) as b;
+-----+
| b |
+-----+
| 3.0 |
+-----+
```"#,
argument(
name = "expression",
description = "The expression to cast. It can be a constant, column, or function, and any combination of operators."
),
argument(
name = "reference",
description = "Reference expression whose data type determines the target cast type. The value is ignored."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct CastToTypeFunc {
signature: Signature,
}

impl Default for CastToTypeFunc {
fn default() -> Self {
Self::new()
}
}

impl CastToTypeFunc {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Any),
Coercion::new_exact(TypeSignatureClass::Any),
],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for CastToTypeFunc {
fn name(&self) -> &str {
"cast_to_type"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be called instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let [source_field, reference_field] =
take_function_args(self.name(), args.arg_fields)?;
let target_type = reference_field.data_type().clone();
// Nullability is inherited only from the first argument (the value
// being cast). The second argument is used solely for its type, so
// its own nullability is irrelevant. The one exception is when the
// target type is Null – that type is inherently nullable.
let nullable = source_field.is_nullable() || target_type == DataType::Null;
Ok(Field::new(self.name(), target_type, nullable).into())
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
internal_err!("cast_to_type should have been simplified to cast")
}

fn simplify(
&self,
args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
let [source_arg, type_arg] = take_function_args(self.name(), args)?;
let target_type = info.get_data_type(&type_arg)?;
let source_type = info.get_data_type(&source_arg)?;
let new_expr = if source_type == target_type {
// the argument's data type is already the correct type
source_arg
} else {
let nullable = info.nullable(&source_arg)? || target_type == DataType::Null;
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(source_arg),
field: Field::new("", target_type, nullable).into(),
})
};
Ok(ExprSimplifyResult::Simplified(new_expr))
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
14 changes: 14 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod arrow_cast;
pub mod arrow_metadata;
pub mod arrow_try_cast;
pub mod arrowtypeof;
pub mod cast_to_type;
pub mod coalesce;
pub mod expr_ext;
pub mod getfield;
Expand All @@ -37,13 +38,16 @@ pub mod nvl2;
pub mod overlay;
pub mod planner;
pub mod r#struct;
pub mod try_cast_to_type;
pub mod union_extract;
pub mod union_tag;
pub mod version;

// create UDFs
make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast);
make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast);
make_udf_function!(cast_to_type::CastToTypeFunc, cast_to_type);
make_udf_function!(try_cast_to_type::TryCastToTypeFunc, try_cast_to_type);
make_udf_function!(nullif::NullIfFunc, nullif);
make_udf_function!(nvl::NVLFunc, nvl);
make_udf_function!(nvl2::NVL2Func, nvl2);
Expand Down Expand Up @@ -75,6 +79,14 @@ pub mod expr_fn {
arrow_try_cast,
"Casts a value to a specific Arrow data type, returning NULL if the cast fails",
arg1 arg2
),(
cast_to_type,
"Casts the first argument to the data type of the second argument",
arg1 arg2
),(
try_cast_to_type,
"Casts the first argument to the data type of the second argument, returning NULL on failure",
arg1 arg2
),(
nvl,
"Returns value2 if value1 is NULL; otherwise it returns value1",
Expand Down Expand Up @@ -147,6 +159,8 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
nullif(),
arrow_cast(),
arrow_try_cast(),
cast_to_type(),
try_cast_to_type(),
arrow_metadata(),
nvl(),
nvl2(),
Expand Down
Loading
Loading