diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index dac208be534c..4e62c94cd325 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -17,10 +17,11 @@ mod literal_lookup_table; -use super::{Column, Literal}; +use super::{CastExpr, Column, Literal}; use crate::PhysicalExpr; -use crate::expressions::{lit, try_cast}; +use crate::expressions::{BinaryExpr, lit, try_cast}; use arrow::array::*; +use arrow::compute::kernels::numeric::div; use arrow::compute::kernels::zip::zip; use arrow::compute::{ FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter, @@ -33,6 +34,7 @@ use datafusion_common::{ internal_datafusion_err, internal_err, }; use datafusion_expr::ColumnarValue; +use datafusion_expr_common::operator::Operator; use indexmap::{IndexMap, IndexSet}; use std::borrow::Cow; use std::hash::Hash; @@ -86,6 +88,14 @@ enum EvalMethod { /// /// See [`LiteralLookupTable`] for more details WithExprScalarLookupTable(LiteralLookupTable), + + /// This is a specialization for divide-by-zero protection pattern: + /// CASE WHEN y > 0 THEN x / y ELSE NULL END + /// CASE WHEN y != 0 THEN x / y ELSE NULL END + /// + /// Instead of evaluating the full CASE expression, it is preferred to directly perform division + /// that return NULL when the divisor is zero. + DivideByZeroProtection, } /// Implementing hash so we can use `derive` on [`EvalMethod`]. @@ -652,6 +662,20 @@ impl CaseExpr { return Ok(EvalMethod::WithExpression(body.project()?)); } + // Check for divide-by-zero protection pattern: + // CASE WHEN y > 0 THEN x / y ELSE NULL END + if body.when_then_expr.len() == 1 && body.else_expr.is_none() { + let (when_expr, then_expr) = &body.when_then_expr[0]; + + if let Some(checked_operand) = Self::extract_non_zero_operand(when_expr) + && let Some((_numerator, divisor)) = + Self::extract_division_operands(then_expr) + && divisor.eq(&checked_operand) + { + return Ok(EvalMethod::DivideByZeroProtection); + } + } + Ok( if body.when_then_expr.len() == 1 && is_cheap_and_infallible(&(body.when_then_expr[0].1)) @@ -686,6 +710,67 @@ impl CaseExpr { pub fn else_expr(&self) -> Option<&Arc> { self.body.else_expr.as_ref() } + + /// Extract the operand being checked for non-zero from a comparison expression. + /// Return Some(operand) for patterns like `y > 0`, `y != 0`, `0 < y`, `0 != y`. + fn extract_non_zero_operand( + expr: &Arc, + ) -> Option> { + let binary = expr.as_any().downcast_ref::()?; + + match binary.op() { + // y > 0 or y != 0 + Operator::Gt | Operator::NotEq if Self::is_literal_zero(binary.right()) => { + Some(Arc::clone(binary.left())) + } + // 0 < y or 0 != y + Operator::Lt | Operator::NotEq if Self::is_literal_zero(binary.left()) => { + Some(Arc::clone(binary.right())) + } + _ => None, + } + } + + /// Extract (numerator, divisor) from a division expression. + fn extract_division_operands( + expr: &Arc, + ) -> Option<(Arc, Arc)> { + let binary = expr.as_any().downcast_ref::()?; + + if binary.op() == &Operator::Divide { + let divisor = + if let Some(cast) = binary.right().as_any().downcast_ref::() { + Arc::clone(cast.expr()) + } else { + Arc::clone(binary.right()) + }; + Some((Arc::clone(binary.left()), divisor)) + } else { + None + } + } + + /// Check if an expression is a literal zero value + fn is_literal_zero(expr: &Arc) -> bool { + if let Some(lit) = expr.as_any().downcast_ref::() { + match lit.value() { + ScalarValue::Int8(Some(0)) + | ScalarValue::Int16(Some(0)) + | ScalarValue::Int32(Some(0)) + | ScalarValue::Int64(Some(0)) + | ScalarValue::UInt8(Some(0)) + | ScalarValue::UInt16(Some(0)) + | ScalarValue::UInt32(Some(0)) + | ScalarValue::UInt64(Some(0)) => true, + ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true, + ScalarValue::Float32(Some(v)) if *v == 0.0 => true, + ScalarValue::Float64(Some(v)) if *v == 0.0 => true, + _ => false, + } + } else { + false + } + } } impl CaseBody { @@ -1191,6 +1276,30 @@ impl CaseExpr { Ok(result) } + + fn divide_by_zero_protection(&self, batch: &RecordBatch) -> Result { + let (when_expr, then_expr) = &self.body.when_then_expr[0]; + + let when_value = when_expr.evaluate(batch)?; + + let binary = then_expr + .as_any() + .downcast_ref::() + .expect("then expression should be a binary expression"); + + let numerator = binary.left().evaluate(batch)?; + let divisor = binary.right().evaluate(batch)?; + + let num_rows = batch.num_rows(); + let num_array = numerator.into_array(num_rows)?; + let div_array = divisor.into_array(num_rows)?; + let condition_array = when_value.into_array(num_rows)?; + let condition = as_boolean_array(&condition_array)?; + + let result = safe_divide_with_mask(&num_array, &div_array, condition)?; + + Ok(ColumnarValue::Array(result)) + } } impl PhysicalExpr for CaseExpr { @@ -1289,6 +1398,7 @@ impl PhysicalExpr for CaseExpr { EvalMethod::WithExprScalarLookupTable(lookup_table) => { self.with_lookup_table(batch, lookup_table) } + EvalMethod::DivideByZeroProtection => self.divide_by_zero_protection(batch), } } @@ -1410,6 +1520,18 @@ fn replace_with_null( Ok(with_null) } +fn safe_divide_with_mask( + numerator: &ArrayRef, + divisor: &ArrayRef, + condition: &BooleanArray, +) -> Result { + let not_condition = not(condition)?; + let ones = ScalarValue::new_one(divisor.data_type())?.to_scalar()?; + let safe_divisor = zip(¬_condition, &ones, &divisor)?; + let result = div(numerator, &safe_divisor)?; + Ok(nullif(&result, ¬_condition)?) +} + /// Create a CASE expression pub fn case( expr: Option>, @@ -2319,6 +2441,65 @@ mod tests { Ok(()) } + #[test] + fn test_divide_by_zero_protection_specialization() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE NULL END + let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?; + let then = binary( + lit(25.0f64), + Operator::Divide, + cast(col("a", &schema)?, &schema, Float64)?, + &schema, + )?; + + let expr = CaseExpr::try_new(None, vec![(when, then)], None)?; + + assert!( + matches!(expr.eval_method, EvalMethod::DivideByZeroProtection), + "Expected DivideByZeroProtection, got {:?}", + expr.eval_method + ); + + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_float64_array(&result)?; + + let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_divide_by_zero_protection_specialization_not_applied() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE WHEN a > 0 THEN b / c ELSE NULL END + // Divisor (c) != checked operand (a), should NOT use specialization + let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?; + let then = binary( + col("b", &schema)?, + Operator::Divide, + col("c", &schema)?, + &schema, + )?; + + let expr = CaseExpr::try_new(None, vec![(when, then)], None)?; + + assert!( + !matches!(expr.eval_method, EvalMethod::DivideByZeroProtection), + "Should NOT use DivideByZeroProtection when divisor doesn't match" + ); + + Ok(()) + } + fn make_col(name: &str, index: usize) -> Arc { Arc::new(Column::new(name, index)) }