Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 183 additions & 2 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

mod literal_lookup_table;

use super::{Column, Literal};
use super::{CastExpr, Column, Literal};
use crate::PhysicalExpr;
use crate::expressions::{lit, try_cast};
use crate::expressions::{BinaryExpr, lit, try_cast};
use arrow::array::*;
use arrow::compute::kernels::numeric::div;
use arrow::compute::kernels::zip::zip;
use arrow::compute::{
FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
Expand All @@ -33,6 +34,7 @@ use datafusion_common::{
internal_datafusion_err, internal_err,
};
use datafusion_expr::ColumnarValue;
use datafusion_expr_common::operator::Operator;
use indexmap::{IndexMap, IndexSet};
use std::borrow::Cow;
use std::hash::Hash;
Expand Down Expand Up @@ -86,6 +88,14 @@ enum EvalMethod {
///
/// See [`LiteralLookupTable`] for more details
WithExprScalarLookupTable(LiteralLookupTable),

/// This is a specialization for divide-by-zero protection pattern:
/// CASE WHEN y > 0 THEN x / y ELSE NULL END
/// CASE WHEN y != 0 THEN x / y ELSE NULL END
///
/// Instead of evaluating the full CASE expression, it is preferred to directly perform division
/// that return NULL when the divisor is zero.
DivideByZeroProtection,
}

/// Implementing hash so we can use `derive` on [`EvalMethod`].
Expand Down Expand Up @@ -652,6 +662,20 @@ impl CaseExpr {
return Ok(EvalMethod::WithExpression(body.project()?));
}

// Check for divide-by-zero protection pattern:
// CASE WHEN y > 0 THEN x / y ELSE NULL END
if body.when_then_expr.len() == 1 && body.else_expr.is_none() {
let (when_expr, then_expr) = &body.when_then_expr[0];

if let Some(checked_operand) = Self::extract_non_zero_operand(when_expr)
&& let Some((_numerator, divisor)) =
Self::extract_division_operands(then_expr)
&& divisor.eq(&checked_operand)
{
return Ok(EvalMethod::DivideByZeroProtection);
}
}

Ok(
if body.when_then_expr.len() == 1
&& is_cheap_and_infallible(&(body.when_then_expr[0].1))
Expand Down Expand Up @@ -686,6 +710,67 @@ impl CaseExpr {
pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
self.body.else_expr.as_ref()
}

/// Extract the operand being checked for non-zero from a comparison expression.
/// Return Some(operand) for patterns like `y > 0`, `y != 0`, `0 < y`, `0 != y`.
fn extract_non_zero_operand(
expr: &Arc<dyn PhysicalExpr>,
) -> Option<Arc<dyn PhysicalExpr>> {
let binary = expr.as_any().downcast_ref::<BinaryExpr>()?;

match binary.op() {
// y > 0 or y != 0
Operator::Gt | Operator::NotEq if Self::is_literal_zero(binary.right()) => {
Some(Arc::clone(binary.left()))
}
// 0 < y or 0 != y
Operator::Lt | Operator::NotEq if Self::is_literal_zero(binary.left()) => {
Some(Arc::clone(binary.right()))
}
_ => None,
}
}

/// Extract (numerator, divisor) from a division expression.
fn extract_division_operands(
expr: &Arc<dyn PhysicalExpr>,
) -> Option<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> {
let binary = expr.as_any().downcast_ref::<BinaryExpr>()?;

if binary.op() == &Operator::Divide {
let divisor =
if let Some(cast) = binary.right().as_any().downcast_ref::<CastExpr>() {
Arc::clone(cast.expr())
} else {
Arc::clone(binary.right())
};
Some((Arc::clone(binary.left()), divisor))
} else {
None
}
}

/// Check if an expression is a literal zero value
fn is_literal_zero(expr: &Arc<dyn PhysicalExpr>) -> bool {
if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
match lit.value() {
ScalarValue::Int8(Some(0))
| ScalarValue::Int16(Some(0))
| ScalarValue::Int32(Some(0))
| ScalarValue::Int64(Some(0))
| ScalarValue::UInt8(Some(0))
| ScalarValue::UInt16(Some(0))
| ScalarValue::UInt32(Some(0))
| ScalarValue::UInt64(Some(0)) => true,
ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true,
ScalarValue::Float32(Some(v)) if *v == 0.0 => true,
ScalarValue::Float64(Some(v)) if *v == 0.0 => true,
_ => false,
}
} else {
false
}
}
}

impl CaseBody {
Expand Down Expand Up @@ -1191,6 +1276,30 @@ impl CaseExpr {

Ok(result)
}

fn divide_by_zero_protection(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let (when_expr, then_expr) = &self.body.when_then_expr[0];

let when_value = when_expr.evaluate(batch)?;

let binary = then_expr
.as_any()
.downcast_ref::<BinaryExpr>()
.expect("then expression should be a binary expression");

let numerator = binary.left().evaluate(batch)?;
let divisor = binary.right().evaluate(batch)?;

let num_rows = batch.num_rows();
let num_array = numerator.into_array(num_rows)?;
let div_array = divisor.into_array(num_rows)?;
let condition_array = when_value.into_array(num_rows)?;
let condition = as_boolean_array(&condition_array)?;

let result = safe_divide_with_mask(&num_array, &div_array, condition)?;

Ok(ColumnarValue::Array(result))
}
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -1289,6 +1398,7 @@ impl PhysicalExpr for CaseExpr {
EvalMethod::WithExprScalarLookupTable(lookup_table) => {
self.with_lookup_table(batch, lookup_table)
}
EvalMethod::DivideByZeroProtection => self.divide_by_zero_protection(batch),
}
}

Expand Down Expand Up @@ -1410,6 +1520,18 @@ fn replace_with_null(
Ok(with_null)
}

fn safe_divide_with_mask(
numerator: &ArrayRef,
divisor: &ArrayRef,
condition: &BooleanArray,
) -> Result<ArrayRef> {
let not_condition = not(condition)?;
let ones = ScalarValue::new_one(divisor.data_type())?.to_scalar()?;
let safe_divisor = zip(&not_condition, &ones, &divisor)?;
let result = div(numerator, &safe_divisor)?;
Ok(nullif(&result, &not_condition)?)
}

/// Create a CASE expression
pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
Expand Down Expand Up @@ -2319,6 +2441,65 @@ mod tests {
Ok(())
}

#[test]
fn test_divide_by_zero_protection_specialization() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();

// CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE NULL END
let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?;
let then = binary(
lit(25.0f64),
Operator::Divide,
cast(col("a", &schema)?, &schema, Float64)?,
&schema,
)?;

let expr = CaseExpr::try_new(None, vec![(when, then)], None)?;

assert!(
matches!(expr.eval_method, EvalMethod::DivideByZeroProtection),
"Expected DivideByZeroProtection, got {:?}",
expr.eval_method
);

let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_float64_array(&result)?;

let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
assert_eq!(expected, result);

Ok(())
}

#[test]
fn test_divide_by_zero_protection_specialization_not_applied() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();

// CASE WHEN a > 0 THEN b / c ELSE NULL END
// Divisor (c) != checked operand (a), should NOT use specialization
let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?;
let then = binary(
col("b", &schema)?,
Operator::Divide,
col("c", &schema)?,
&schema,
)?;

let expr = CaseExpr::try_new(None, vec![(when, then)], None)?;

assert!(
!matches!(expr.eval_method, EvalMethod::DivideByZeroProtection),
"Should NOT use DivideByZeroProtection when divisor doesn't match"
);

Ok(())
}

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}
Expand Down