Skip to content
75 changes: 52 additions & 23 deletions datafusion/physical-expr/src/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,8 @@ impl CastExpr {
&self.cast_options
}

fn is_default_target_field(&self) -> bool {
self.target_field.name().is_empty()
&& self.target_field.is_nullable()
&& self.target_field.metadata().is_empty()
}

fn resolved_target_field(&self, input_schema: &Schema) -> Result<FieldRef> {
if self.is_default_target_field() {
if is_default_target_field(&self.target_field) {
self.expr.return_field(input_schema).map(|field| {
Arc::new(
field
Expand Down Expand Up @@ -201,6 +195,12 @@ impl CastExpr {
}
}

fn is_default_target_field(target_field: &FieldRef) -> bool {
target_field.name().is_empty()
&& target_field.is_nullable()
&& target_field.metadata().is_empty()
}

pub(crate) fn is_order_preserving_cast_family(
source_type: &DataType,
target_type: &DataType,
Expand Down Expand Up @@ -315,26 +315,55 @@ pub fn cast_with_options(
input_schema: &Schema,
cast_type: DataType,
cast_options: Option<CastOptions<'static>>,
) -> Result<Arc<dyn PhysicalExpr>> {
cast_with_target_field(
expr,
input_schema,
cast_type.into_nullable_field_ref(),
cast_options,
)
}

/// Return a PhysicalExpression representing `expr` casted to `target_field`,
/// preserving any explicit field semantics such as name, nullability, and
/// metadata.
///
/// If the input expression already has the same data type, this helper still
/// preserves an explicit `target_field` by constructing a field-aware
/// [`CastExpr`]. Only the default synthesized field created by the legacy
/// type-only API is elided back to the original child expression.
pub fn cast_with_target_field(
expr: Arc<dyn PhysicalExpr>,
input_schema: &Schema,
target_field: FieldRef,
cast_options: Option<CastOptions<'static>>,
) -> Result<Arc<dyn PhysicalExpr>> {
let expr_type = expr.data_type(input_schema)?;
if expr_type == cast_type {
Ok(Arc::clone(&expr))
} else if requires_nested_struct_cast(&expr_type, &cast_type) {
if can_cast_named_struct_types(&expr_type, &cast_type) {
// Allow casts involving structs (including nested inside Lists, Dictionaries,
// etc.) that pass name-based compatibility validation. This validation is
// applied at planning time (now) to fail fast, rather than deferring errors
// to execution time. The name-based casting logic will be executed at runtime
// via ColumnarValue::cast_to.
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
} else {
not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}")
}
} else if can_cast_types(&expr_type, &cast_type) {
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
let cast_type = target_field.data_type();
if expr_type == *cast_type && is_default_target_field(&target_field) {
return Ok(Arc::clone(&expr));
}

let can_build_cast = if requires_nested_struct_cast(&expr_type, cast_type) {
// Allow casts involving structs (including nested inside Lists, Dictionaries,
// etc.) that pass name-based compatibility validation. This validation is
// applied at planning time (now) to fail fast, rather than deferring errors
// to execution time. The name-based casting logic will be executed at runtime
// via ColumnarValue::cast_to.
can_cast_named_struct_types(&expr_type, cast_type)
} else {
not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}")
can_cast_types(&expr_type, cast_type)
};

if !can_build_cast {
return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}");
}

Ok(Arc::new(CastExpr::new_with_target_field(
expr,
target_field,
cast_options,
)))
}

/// Return a PhysicalExpression representing `expr` casted to
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,5 @@ pub use no_op::NoOp;
pub use not::{NotExpr, not};
pub use try_cast::{TryCastExpr, try_cast};
pub use unknown_column::UnKnownColumn;

pub(crate) use cast::cast_with_target_field;
121 changes: 75 additions & 46 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,25 +288,12 @@ pub fn create_physical_expr(
};
Ok(expressions::case(expr, when_then_expr, else_expr)?)
}
Expr::Cast(Cast { expr, field }) => {
if !field.metadata().is_empty() {
let (_, src_field) = expr.to_field(input_dfschema)?;
return plan_err!(
"Cast from {} to {} is not supported",
format_type_and_metadata(
src_field.data_type(),
Some(src_field.metadata()),
),
format_type_and_metadata(field.data_type(), Some(field.metadata()))
);
}

expressions::cast(
create_physical_expr(expr, input_dfschema, execution_props)?,
input_schema,
field.data_type().clone(),
)
}
Expr::Cast(Cast { expr, field }) => expressions::cast_with_target_field(
create_physical_expr(expr, input_dfschema, execution_props)?,
input_schema,
Arc::clone(field),
None,
),
Expr::TryCast(TryCast { expr, field }) => {
if !field.metadata().is_empty() {
let (_, src_field) = expr.to_field(input_dfschema)?;
Expand Down Expand Up @@ -445,11 +432,26 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc<dyn PhysicalExpr> {
mod tests {
use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field};
use datafusion_common::datatype::DataTypeExt;
use datafusion_expr::col;

use super::*;

fn test_cast_schema() -> Schema {
Schema::new(vec![Field::new("a", DataType::Int32, false)])
}

fn lower_cast_expr(expr: &Expr, schema: &Schema) -> Result<Arc<dyn PhysicalExpr>> {
let df_schema = DFSchema::try_from(schema.clone())?;
create_physical_expr(expr, &df_schema, &ExecutionProps::new())
}

fn as_planner_cast(physical: &Arc<dyn PhysicalExpr>) -> &expressions::CastExpr {
physical
.as_any()
.downcast_ref::<expressions::CastExpr>()
.expect("planner should lower logical CAST to CastExpr")
}

#[test]
fn test_create_physical_expr_scalar_input_output() -> Result<()> {
let expr = col("letter").eq(lit("A"));
Expand All @@ -476,36 +478,63 @@ mod tests {
}

#[test]
fn test_cast_to_extension_type() -> Result<()> {
let extension_field_type = Arc::new(
DataType::FixedSizeBinary(16)
.into_nullable_field()
.with_metadata(
[("ARROW:extension:name".to_string(), "arrow.uuid".to_string())]
.into(),
),
fn test_cast_lowering_preserves_target_field_metadata() -> Result<()> {
let schema = test_cast_schema();
let target_field = Arc::new(
Field::new("cast_target", DataType::Int64, true)
.with_metadata([("target_meta".to_string(), "1".to_string())].into()),
);
let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58");
let cast_expr = Expr::Cast(Cast::new_from_field(
Box::new(expr.clone()),
Arc::clone(&extension_field_type),
Box::new(col("a")),
Arc::clone(&target_field),
));
let err =
create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new())
.unwrap_err();
assert!(err.message().contains("arrow.uuid"));

let try_cast_expr = Expr::TryCast(TryCast::new_from_field(
Box::new(expr.clone()),
Arc::clone(&extension_field_type),

let physical = lower_cast_expr(&cast_expr, &schema)?;
let cast = as_planner_cast(&physical);

assert_eq!(cast.target_field(), &target_field);
assert_eq!(physical.return_field(&schema)?, target_field);
assert!(physical.nullable(&schema)?);

Ok(())
}

#[test]
fn test_cast_lowering_preserves_standard_cast_semantics() -> Result<()> {
let schema = test_cast_schema();
let cast_expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int64));

let physical = lower_cast_expr(&cast_expr, &schema)?;
let cast = as_planner_cast(&physical);
let returned_field = physical.return_field(&schema)?;

assert_eq!(cast.cast_type(), &DataType::Int64);
assert_eq!(returned_field.name(), "a");
assert_eq!(returned_field.data_type(), &DataType::Int64);
assert!(!physical.nullable(&schema)?);

Ok(())
}

#[test]
fn test_cast_lowering_preserves_same_type_field_semantics() -> Result<()> {
let schema = test_cast_schema();
let target_field = Arc::new(
Field::new("same_type_cast", DataType::Int32, true).with_metadata(
[("target_meta".to_string(), "same-type".to_string())].into(),
),
);
let cast_expr = Expr::Cast(Cast::new_from_field(
Box::new(col("a")),
Arc::clone(&target_field),
));
let err = create_physical_expr(
&try_cast_expr,
&DFSchema::empty(),
&ExecutionProps::new(),
)
.unwrap_err();
assert!(err.message().contains("arrow.uuid"));

let physical = lower_cast_expr(&cast_expr, &schema)?;
let cast = as_planner_cast(&physical);

assert_eq!(cast.target_field(), &target_field);
assert_eq!(physical.return_field(&schema)?, target_field);
assert!(physical.nullable(&schema)?);

Ok(())
}
Expand Down
31 changes: 30 additions & 1 deletion datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ use arrow::array::{
LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray,
};
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields};
use arrow::datatypes::{
DataType, Field, FieldRef, Schema, SchemaRef, TimeUnit, UnionFields,
};
use arrow::record_batch::RecordBatch;
use datafusion::catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, Session,
};
use datafusion::common::{DataFusionError, Result, not_impl_err};
use datafusion::functions::math::abs;
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::planner::TypePlanner;
use datafusion::logical_expr::{
ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
Volatility, create_udf,
Expand All @@ -54,6 +57,7 @@ use datafusion::common::cast::as_float64_array;
use datafusion::execution::SessionStateBuilder;
use datafusion::execution::runtime_env::RuntimeEnv;
use log::info;
use sqlparser::ast;
use tempfile::TempDir;

/// Context for running tests
Expand All @@ -64,6 +68,23 @@ pub struct TestContext {
test_dir: Option<TempDir>,
}

#[derive(Debug)]
struct SqlLogicTestTypePlanner;

impl TypePlanner for SqlLogicTestTypePlanner {
fn plan_type_field(&self, sql_type: &ast::DataType) -> Result<Option<FieldRef>> {
match sql_type {
ast::DataType::Uuid => Ok(Some(Arc::new(
Field::new("", DataType::FixedSizeBinary(16), true).with_metadata(
[("ARROW:extension:name".to_string(), "arrow.uuid".to_string())]
.into(),
),
))),
_ => Ok(None),
}
}
}

impl TestContext {
pub fn new(ctx: SessionContext) -> Self {
Self {
Expand Down Expand Up @@ -92,6 +113,14 @@ impl TestContext {
state_builder = state_builder.with_spark_features();
}

if matches!(
relative_path.file_name().and_then(|name| name.to_str()),
Some("cast_extension_type_metadata.slt")
) {
state_builder =
state_builder.with_type_planner(Arc::new(SqlLogicTestTypePlanner));
}

let state = state_builder.build();

let mut test_ctx = TestContext::new(SessionContext::new_with_state(state));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Regression tests for logical CAST targets that carry explicit field metadata.

query ?T
SELECT
CAST(
arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)')
AS UUID
),
arrow_metadata(
CAST(
arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)')
AS UUID
),
'ARROW:extension:name'
);
----
00010203040506070809000102030506 arrow.uuid

query ?T
SELECT
CAST(raw AS UUID),
arrow_metadata(CAST(raw AS UUID), 'ARROW:extension:name')
FROM (
VALUES (
arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)')
)
) AS uuids(raw);
----
00010203040506070809000102030506 arrow.uuid

statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*TryCast from FixedSizeBinary\(16\) to FixedSizeBinary\(16\)<\{"ARROW:extension:name": "arrow\.uuid"\}> is not supported
SELECT TRY_CAST(arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') AS UUID);
Loading