diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index eb0886a31e8df..d9b1b565721cd 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -20,6 +20,7 @@ use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::test_util::seedable_rng; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, case, col, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -93,6 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) { run_benchmarks(c, &make_batch(8192, 100)); benchmark_lookup_table_case_when(c, 8192); + benchmark_divide_by_zero_protection(c, 8192); } fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { @@ -517,5 +519,112 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { } } +fn benchmark_divide_by_zero_protection(c: &mut Criterion, batch_size: usize) { + let mut group = c.benchmark_group("divide_by_zero_protection"); + + for zero_percentage in [0.0, 0.1, 0.5, 0.9] { + let rng = &mut seedable_rng(); + + let numerator: Int32Array = + (0..batch_size).map(|_| Some(rng.random::())).collect(); + + let divisor_values: Vec> = (0..batch_size) + .map(|_| { + let roll: f32 = rng.random(); + if roll < zero_percentage { + Some(0) + } else { + let mut val = rng.random::(); + while val == 0 { + val = rng.random::(); + } + Some(val) + } + }) + .collect(); + + let divisor: Int32Array = divisor_values.iter().cloned().collect(); + let divisor_copy: Int32Array = divisor_values.iter().cloned().collect(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("numerator", numerator.data_type().clone(), true), + Field::new("divisor", divisor.data_type().clone(), true), + Field::new("divisor_copy", divisor_copy.data_type().clone(), true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(numerator), + Arc::new(divisor), + Arc::new(divisor_copy), + ], + ) + .unwrap(); + + let numerator_col = col("numerator", &batch.schema()).unwrap(); + let divisor_col = col("divisor", &batch.schema()).unwrap(); + let divisor_copy_col = col("divisor_copy", &batch.schema()).unwrap(); + + // DivideByZeroProtection: WHEN condition checks `divisor_col > 0` and division + // uses `divisor_col` as divisor. Since the checked column matches the divisor, + // this triggers the DivideByZeroProtection optimization. + group.bench_function( + format!( + "{} rows, {}% zeros: DivideByZeroProtection", + batch_size, + (zero_percentage * 100.0) as i32 + ), + |b| { + let when = Arc::new(BinaryExpr::new( + Arc::clone(&divisor_col), + Operator::Gt, + lit(0i32), + )); + let then = Arc::new(BinaryExpr::new( + Arc::clone(&numerator_col), + Operator::Divide, + Arc::clone(&divisor_col), + )); + let else_null: Arc = lit(ScalarValue::Int32(None)); + let expr = + Arc::new(case(None, vec![(when, then)], Some(else_null)).unwrap()); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + + // ExpressionOrExpression: WHEN condition checks `divisor_copy_col > 0` but + // division uses `divisor_col` as divisor. Since the checked column does NOT + // match the divisor, this falls back to ExpressionOrExpression evaluation. + group.bench_function( + format!( + "{} rows, {}% zeros: ExpressionOrExpression", + batch_size, + (zero_percentage * 100.0) as i32 + ), + |b| { + let when = Arc::new(BinaryExpr::new( + Arc::clone(&divisor_copy_col), + Operator::Gt, + lit(0i32), + )); + let then = Arc::new(BinaryExpr::new( + Arc::clone(&numerator_col), + Operator::Divide, + Arc::clone(&divisor_col), + )); + let else_null: Arc = lit(ScalarValue::Int32(None)); + let expr = + Arc::new(case(None, vec![(when, then)], Some(else_null)).unwrap()); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + } + + group.finish(); +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches);