diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 24e486f8050fe..7c4ac60419ec5 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -151,14 +151,8 @@ impl CastExpr { &self.cast_options } - fn is_default_target_field(&self) -> bool { - self.target_field.name().is_empty() - && self.target_field.is_nullable() - && self.target_field.metadata().is_empty() - } - fn resolved_target_field(&self, input_schema: &Schema) -> Result { - if self.is_default_target_field() { + if is_default_target_field(&self.target_field) { self.expr.return_field(input_schema).map(|field| { Arc::new( field @@ -201,6 +195,12 @@ impl CastExpr { } } +fn is_default_target_field(target_field: &FieldRef) -> bool { + target_field.name().is_empty() + && target_field.is_nullable() + && target_field.metadata().is_empty() +} + pub(crate) fn is_order_preserving_cast_family( source_type: &DataType, target_type: &DataType, @@ -315,26 +315,55 @@ pub fn cast_with_options( input_schema: &Schema, cast_type: DataType, cast_options: Option>, +) -> Result> { + cast_with_target_field( + expr, + input_schema, + cast_type.into_nullable_field_ref(), + cast_options, + ) +} + +/// Return a PhysicalExpression representing `expr` casted to `target_field`, +/// preserving any explicit field semantics such as name, nullability, and +/// metadata. +/// +/// If the input expression already has the same data type, this helper still +/// preserves an explicit `target_field` by constructing a field-aware +/// [`CastExpr`]. Only the default synthesized field created by the legacy +/// type-only API is elided back to the original child expression. +pub fn cast_with_target_field( + expr: Arc, + input_schema: &Schema, + target_field: FieldRef, + cast_options: Option>, ) -> Result> { let expr_type = expr.data_type(input_schema)?; - if expr_type == cast_type { - Ok(Arc::clone(&expr)) - } else if requires_nested_struct_cast(&expr_type, &cast_type) { - if can_cast_named_struct_types(&expr_type, &cast_type) { - // Allow casts involving structs (including nested inside Lists, Dictionaries, - // etc.) that pass name-based compatibility validation. This validation is - // applied at planning time (now) to fail fast, rather than deferring errors - // to execution time. The name-based casting logic will be executed at runtime - // via ColumnarValue::cast_to. - Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) - } else { - not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") - } - } else if can_cast_types(&expr_type, &cast_type) { - Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) + let cast_type = target_field.data_type(); + if expr_type == *cast_type && is_default_target_field(&target_field) { + return Ok(Arc::clone(&expr)); + } + + let can_build_cast = if requires_nested_struct_cast(&expr_type, cast_type) { + // Allow casts involving structs (including nested inside Lists, Dictionaries, + // etc.) that pass name-based compatibility validation. This validation is + // applied at planning time (now) to fail fast, rather than deferring errors + // to execution time. The name-based casting logic will be executed at runtime + // via ColumnarValue::cast_to. + can_cast_named_struct_types(&expr_type, cast_type) } else { - not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") + can_cast_types(&expr_type, cast_type) + }; + + if !can_build_cast { + return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}"); } + + Ok(Arc::new(CastExpr::new_with_target_field( + expr, + target_field, + cast_options, + ))) } /// Return a PhysicalExpression representing `expr` casted to diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c9e02708d6c28..8a27009280b0e 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -56,3 +56,5 @@ pub use no_op::NoOp; pub use not::{NotExpr, not}; pub use try_cast::{TryCastExpr, try_cast}; pub use unknown_column::UnKnownColumn; + +pub(crate) use cast::cast_with_target_field; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index fd2de812e4664..3227e20f68def 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -288,25 +288,12 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, field }) => { - if !field.metadata().is_empty() { - let (_, src_field) = expr.to_field(input_dfschema)?; - return plan_err!( - "Cast from {} to {} is not supported", - format_type_and_metadata( - src_field.data_type(), - Some(src_field.metadata()), - ), - format_type_and_metadata(field.data_type(), Some(field.metadata())) - ); - } - - expressions::cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - field.data_type().clone(), - ) - } + Expr::Cast(Cast { expr, field }) => expressions::cast_with_target_field( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + Arc::clone(field), + None, + ), Expr::TryCast(TryCast { expr, field }) => { if !field.metadata().is_empty() { let (_, src_field) = expr.to_field(input_dfschema)?; @@ -445,11 +432,26 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field}; - use datafusion_common::datatype::DataTypeExt; use datafusion_expr::col; use super::*; + fn test_cast_schema() -> Schema { + Schema::new(vec![Field::new("a", DataType::Int32, false)]) + } + + fn lower_cast_expr(expr: &Expr, schema: &Schema) -> Result> { + let df_schema = DFSchema::try_from(schema.clone())?; + create_physical_expr(expr, &df_schema, &ExecutionProps::new()) + } + + fn as_planner_cast(physical: &Arc) -> &expressions::CastExpr { + physical + .as_any() + .downcast_ref::() + .expect("planner should lower logical CAST to CastExpr") + } + #[test] fn test_create_physical_expr_scalar_input_output() -> Result<()> { let expr = col("letter").eq(lit("A")); @@ -476,36 +478,63 @@ mod tests { } #[test] - fn test_cast_to_extension_type() -> Result<()> { - let extension_field_type = Arc::new( - DataType::FixedSizeBinary(16) - .into_nullable_field() - .with_metadata( - [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())] - .into(), - ), + fn test_cast_lowering_preserves_target_field_metadata() -> Result<()> { + let schema = test_cast_schema(); + let target_field = Arc::new( + Field::new("cast_target", DataType::Int64, true) + .with_metadata([("target_meta".to_string(), "1".to_string())].into()), ); - let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58"); let cast_expr = Expr::Cast(Cast::new_from_field( - Box::new(expr.clone()), - Arc::clone(&extension_field_type), + Box::new(col("a")), + Arc::clone(&target_field), )); - let err = - create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new()) - .unwrap_err(); - assert!(err.message().contains("arrow.uuid")); - - let try_cast_expr = Expr::TryCast(TryCast::new_from_field( - Box::new(expr.clone()), - Arc::clone(&extension_field_type), + + let physical = lower_cast_expr(&cast_expr, &schema)?; + let cast = as_planner_cast(&physical); + + assert_eq!(cast.target_field(), &target_field); + assert_eq!(physical.return_field(&schema)?, target_field); + assert!(physical.nullable(&schema)?); + + Ok(()) + } + + #[test] + fn test_cast_lowering_preserves_standard_cast_semantics() -> Result<()> { + let schema = test_cast_schema(); + let cast_expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int64)); + + let physical = lower_cast_expr(&cast_expr, &schema)?; + let cast = as_planner_cast(&physical); + let returned_field = physical.return_field(&schema)?; + + assert_eq!(cast.cast_type(), &DataType::Int64); + assert_eq!(returned_field.name(), "a"); + assert_eq!(returned_field.data_type(), &DataType::Int64); + assert!(!physical.nullable(&schema)?); + + Ok(()) + } + + #[test] + fn test_cast_lowering_preserves_same_type_field_semantics() -> Result<()> { + let schema = test_cast_schema(); + let target_field = Arc::new( + Field::new("same_type_cast", DataType::Int32, true).with_metadata( + [("target_meta".to_string(), "same-type".to_string())].into(), + ), + ); + let cast_expr = Expr::Cast(Cast::new_from_field( + Box::new(col("a")), + Arc::clone(&target_field), )); - let err = create_physical_expr( - &try_cast_expr, - &DFSchema::empty(), - &ExecutionProps::new(), - ) - .unwrap_err(); - assert!(err.message().contains("arrow.uuid")); + + let physical = lower_cast_expr(&cast_expr, &schema)?; + let cast = as_planner_cast(&physical); + + assert_eq!(cast.target_field(), &target_field); + assert_eq!(physical.return_field(&schema)?, target_field); + assert!(physical.nullable(&schema)?); Ok(()) } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 773f61655e41f..23bbf1c951446 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -28,7 +28,9 @@ use arrow::array::{ LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray, }; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; +use arrow::datatypes::{ + DataType, Field, FieldRef, Schema, SchemaRef, TimeUnit, UnionFields, +}; use arrow::record_batch::RecordBatch; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, Session, @@ -36,6 +38,7 @@ use datafusion::catalog::{ use datafusion::common::{DataFusionError, Result, not_impl_err}; use datafusion::functions::math::abs; use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion::logical_expr::planner::TypePlanner; use datafusion::logical_expr::{ ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, create_udf, @@ -54,6 +57,7 @@ use datafusion::common::cast::as_float64_array; use datafusion::execution::SessionStateBuilder; use datafusion::execution::runtime_env::RuntimeEnv; use log::info; +use sqlparser::ast; use tempfile::TempDir; /// Context for running tests @@ -64,6 +68,23 @@ pub struct TestContext { test_dir: Option, } +#[derive(Debug)] +struct SqlLogicTestTypePlanner; + +impl TypePlanner for SqlLogicTestTypePlanner { + fn plan_type_field(&self, sql_type: &ast::DataType) -> Result> { + match sql_type { + ast::DataType::Uuid => Ok(Some(Arc::new( + Field::new("", DataType::FixedSizeBinary(16), true).with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())] + .into(), + ), + ))), + _ => Ok(None), + } + } +} + impl TestContext { pub fn new(ctx: SessionContext) -> Self { Self { @@ -92,6 +113,14 @@ impl TestContext { state_builder = state_builder.with_spark_features(); } + if matches!( + relative_path.file_name().and_then(|name| name.to_str()), + Some("cast_extension_type_metadata.slt") + ) { + state_builder = + state_builder.with_type_planner(Arc::new(SqlLogicTestTypePlanner)); + } + let state = state_builder.build(); let mut test_ctx = TestContext::new(SessionContext::new_with_state(state)); diff --git a/datafusion/sqllogictest/test_files/cast_extension_type_metadata.slt b/datafusion/sqllogictest/test_files/cast_extension_type_metadata.slt new file mode 100644 index 0000000000000..425d8ac16eaee --- /dev/null +++ b/datafusion/sqllogictest/test_files/cast_extension_type_metadata.slt @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Regression tests for logical CAST targets that carry explicit field metadata. + +query ?T +SELECT + CAST( + arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') + AS UUID + ), + arrow_metadata( + CAST( + arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') + AS UUID + ), + 'ARROW:extension:name' + ); +---- +00010203040506070809000102030506 arrow.uuid + +query ?T +SELECT + CAST(raw AS UUID), + arrow_metadata(CAST(raw AS UUID), 'ARROW:extension:name') +FROM ( + VALUES ( + arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') + ) +) AS uuids(raw); +---- +00010203040506070809000102030506 arrow.uuid + +statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*TryCast from FixedSizeBinary\(16\) to FixedSizeBinary\(16\)<\{"ARROW:extension:name": "arrow\.uuid"\}> is not supported +SELECT TRY_CAST(arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') AS UUID);