Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions datafusion/physical-expr/benches/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::util::test_util::seedable_rng;
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, case, col, lit};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
Expand Down Expand Up @@ -93,6 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) {
run_benchmarks(c, &make_batch(8192, 100));

benchmark_lookup_table_case_when(c, 8192);
benchmark_divide_by_zero_protection(c, 8192);
}

fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) {
Expand Down Expand Up @@ -517,5 +519,112 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) {
}
}

fn benchmark_divide_by_zero_protection(c: &mut Criterion, batch_size: usize) {
let mut group = c.benchmark_group("divide_by_zero_protection");

for zero_percentage in [0.0, 0.1, 0.5, 0.9] {
let rng = &mut seedable_rng();

let numerator: Int32Array =
(0..batch_size).map(|_| Some(rng.random::<i32>())).collect();

let divisor_values: Vec<Option<i32>> = (0..batch_size)
.map(|_| {
let roll: f32 = rng.random();
if roll < zero_percentage {
Some(0)
} else {
let mut val = rng.random::<i32>();
while val == 0 {
val = rng.random::<i32>();
}
Some(val)
}
})
.collect();

let divisor: Int32Array = divisor_values.iter().cloned().collect();
let divisor_copy: Int32Array = divisor_values.iter().cloned().collect();

let schema = Arc::new(Schema::new(vec![
Field::new("numerator", numerator.data_type().clone(), true),
Field::new("divisor", divisor.data_type().clone(), true),
Field::new("divisor_copy", divisor_copy.data_type().clone(), true),
]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(numerator),
Arc::new(divisor),
Arc::new(divisor_copy),
],
)
.unwrap();

let numerator_col = col("numerator", &batch.schema()).unwrap();
let divisor_col = col("divisor", &batch.schema()).unwrap();
let divisor_copy_col = col("divisor_copy", &batch.schema()).unwrap();

// DivideByZeroProtection: WHEN condition checks `divisor_col > 0` and division
// uses `divisor_col` as divisor. Since the checked column matches the divisor,
// this triggers the DivideByZeroProtection optimization.
group.bench_function(
format!(
"{} rows, {}% zeros: DivideByZeroProtection",
batch_size,
(zero_percentage * 100.0) as i32
),
|b| {
let when = Arc::new(BinaryExpr::new(
Arc::clone(&divisor_col),
Operator::Gt,
lit(0i32),
));
let then = Arc::new(BinaryExpr::new(
Arc::clone(&numerator_col),
Operator::Divide,
Arc::clone(&divisor_col),
));
let else_null: Arc<dyn PhysicalExpr> = lit(ScalarValue::Int32(None));
let expr =
Arc::new(case(None, vec![(when, then)], Some(else_null)).unwrap());

b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
},
);

// ExpressionOrExpression: WHEN condition checks `divisor_copy_col > 0` but
// division uses `divisor_col` as divisor. Since the checked column does NOT
// match the divisor, this falls back to ExpressionOrExpression evaluation.
group.bench_function(
format!(
"{} rows, {}% zeros: ExpressionOrExpression",
batch_size,
(zero_percentage * 100.0) as i32
),
|b| {
let when = Arc::new(BinaryExpr::new(
Arc::clone(&divisor_copy_col),
Operator::Gt,
lit(0i32),
));
let then = Arc::new(BinaryExpr::new(
Arc::clone(&numerator_col),
Operator::Divide,
Arc::clone(&divisor_col),
));
let else_null: Arc<dyn PhysicalExpr> = lit(ScalarValue::Int32(None));
let expr =
Arc::new(case(None, vec![(when, then)], Some(else_null)).unwrap());

b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
},
);
}

group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);