diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 3487c7e70968..f26eefb8d6e6 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -29,7 +29,7 @@ use arrow::array::*; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::kernels::cmp::eq as arrow_eq; -use arrow::compute::{SortOptions, take}; +use arrow::compute::{SortOptions, cast, take}; use arrow::datatypes::*; use arrow::util::bit_iterator::BitIndexIterator; use datafusion_common::hash_utils::with_hashes; @@ -43,11 +43,21 @@ use datafusion_common::HashMap; use datafusion_common::hash_utils::RandomState; use hashbrown::hash_map::RawEntryMut; -/// Trait for InList static filters +/// Trait for InList static filters. +/// +/// Static filters store a pre-computed set of values (the haystack) and check +/// whether needle values are contained in that set. The haystack is always +/// represented in its non-dictionary (value) type. Dictionary haystacks are +/// flattened via `cast()` before construction. +/// +/// Dictionary-encoded needles are unwrapped inside `contains()` and +/// evaluated against the dictionary's values. trait StaticFilter { fn null_count(&self) -> usize; - /// Checks if values in `v` are contained in the filter + /// Checks if values in `v` (needle) are contained in this filter's + /// haystack. `v` may be dictionary-encoded, in which case the + /// implementation unwraps the dictionary and operates on its values. fn contains(&self, v: &dyn Array, negated: bool) -> Result; } @@ -164,6 +174,13 @@ fn supports_arrow_eq(dt: &DataType) -> bool { fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { + // Flatten dictionary-encoded haystacks to their value type so that + // specialized filters (e.g. Int32StaticFilter) are used instead of + // falling through to the generic ArrayStaticFilter. + let in_array = match in_array.data_type() { + DataType::Dictionary(_, value_type) => cast(&in_array, value_type.as_ref())?, + _ => in_array, + }; match in_array.data_type() { // Integer primitive types DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), @@ -642,20 +659,34 @@ impl InListExpr { /// Create a new InList expression directly from an array, bypassing expression evaluation. /// - /// This is more efficient than `in_list()` when you already have the list as an array, - /// as it avoids the conversion: `ArrayRef -> Vec -> ArrayRef -> StaticFilter`. - /// Instead it goes directly: `ArrayRef -> StaticFilter`. + /// This is more efficient than [`InListExpr::try_new`] when you already have the list + /// as an array, as it builds the static filter directly from the array instead of + /// reconstructing an intermediate array from literal expressions. + /// + /// The `list` field is populated with literal expressions extracted from + /// the array, and the array is used to build a static filter for + /// efficient set membership evaluation. /// - /// The `list` field will be empty when using this constructor, as the array is stored - /// directly in the static filter. + /// The `array` may be dictionary-encoded — it will be flattened to its + /// value type such that specialized filters are used. /// - /// This does not make the expression any more performant at runtime, but it does make it slightly - /// cheaper to build. + /// Returns an error if the expression's data type and the array's data type + /// are not logically equal. Null arrays are always accepted. pub fn try_new_from_array( expr: Arc, array: ArrayRef, negated: bool, + schema: &Schema, ) -> Result { + let expr_data_type = expr.data_type(schema)?; + let array_data_type = array.data_type(); + if *array_data_type != DataType::Null { + assert_or_internal_err!( + DFSchema::datatype_is_logically_equal(&expr_data_type, array_data_type), + "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {array_data_type}" + ); + } + let list = (0..array.len()) .map(|i| { let scalar = ScalarValue::try_from_array(array.as_ref(), i)?; @@ -2318,6 +2349,7 @@ mod tests { Arc::clone(&col_a), array, false, + &schema, )?) as Arc; // Create test data: [1, 2, 3, 4, null] @@ -2447,6 +2479,7 @@ mod tests { Arc::clone(&col_a), null_array, false, + &schema, )?) as Arc; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; @@ -2475,6 +2508,7 @@ mod tests { Arc::clone(&col_a), null_array, false, + &schema, )?) as Arc; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; @@ -3911,8 +3945,9 @@ mod tests { let schema = Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]); let col_a = col("a", &schema)?; - let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?) - as Arc; + let expr = Arc::new(InListExpr::try_new_from_array( + col_a, in_array, false, &schema, + )?) as Arc; let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; Ok(as_boolean_array(&result).clone()) @@ -4045,43 +4080,182 @@ mod tests { Ok(()) } + fn make_int32_dict_array(values: Vec>) -> ArrayRef { + let mut builder = PrimitiveDictionaryBuilder::::new(); + for v in values { + match v { + Some(val) => builder.append_value(val), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) + } + + fn make_f64_dict_array(values: Vec>) -> ArrayRef { + let mut builder = PrimitiveDictionaryBuilder::::new(); + for v in values { + match v { + Some(val) => builder.append_value(val), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) + } + + #[test] + fn test_try_new_from_array_dict_haystack_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let needle = Int32Array::from(vec![1, 2, 3, 4]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; + + let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]); + + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), None, Some(true), None]) + ); + + Ok(()) + } + #[test] fn test_in_list_from_array_type_mismatch_errors() -> Result<()> { - // Utf8 needle, Dict(Utf8) in_array - let err = eval_in_list_from_array( - Arc::new(StringArray::from(vec!["a", "d", "b"])), - wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), - ) - .unwrap_err() - .to_string(); - assert!( - err.contains("Can't compare arrays of different types"), - "{err}" + // Utf8 needle, Dict(Utf8) in_array: now works with dict haystack support + assert_eq!( + BooleanArray::from(vec![Some(true), Some(false), Some(true)]), + eval_in_list_from_array( + Arc::new(StringArray::from(vec!["a", "d", "b"])), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + )? ); - // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter - // rejects the Utf8 dictionary values at construction time + // Dict(Utf8) needle, Int64 in_array: type validation rejects at construction let err = eval_in_list_from_array( wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))), Arc::new(Int64Array::from(vec![1, 2, 3])), ) .unwrap_err() .to_string(); - assert!(err.contains("Failed to downcast"), "{err}"); + assert!(err.contains("The data type inlist should be same"), "{err}"); // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different - // value types, make_comparator rejects the comparison + // value types, type validation rejects at construction let err = eval_in_list_from_array( wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))), wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), ) .unwrap_err() .to_string(); - assert!( - err.contains("Can't compare arrays of different types"), - "{err}" + assert!(err.contains("The data type inlist should be same"), "{err}"); + + Ok(()) + } + + #[test] + fn test_try_new_from_array_dict_haystack_negated() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let needle = Int32Array::from(vec![1, 2, 3, 4]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; + + let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]); + + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, true, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(false), None, Some(false), None]) + ); + + Ok(()) + } + + #[test] + fn test_try_new_from_array_dict_haystack_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let needle = StringArray::from(vec!["a", "b", "c"]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; + + let dict_builder = StringDictionaryBuilder::::new(); + let mut builder = dict_builder; + builder.append_value("a"); + builder.append_value("c"); + let haystack: ArrayRef = Arc::new(builder.finish()); + + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true)]) + ); + + Ok(()) + } + + #[test] + fn test_try_new_from_array_dict_needle_and_plain_haystack() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), + false, + )]); + + let needle = make_int32_dict_array(vec![Some(1), Some(2), Some(3), Some(4)]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::clone(&needle)])?; + + let haystack: ArrayRef = Arc::new(Int32Array::from(vec![1, 3])); + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]) + ); + + Ok(()) + } + + #[test] + fn test_try_new_from_array_dict_haystack_float64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let needle = Float64Array::from(vec![1.0, 2.0, 3.0]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; + + let haystack = make_f64_dict_array(vec![Some(1.0), Some(3.0)]); + + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true)]) ); Ok(()) } + + #[test] + fn test_try_new_from_array_type_mismatch_rejects() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let col_a = col("a", &schema)?; + let haystack: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0])); + + let result = InListExpr::try_new_from_array(col_a, haystack, false, &schema); + assert!(result.is_err()); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index f32dc7fa8026..33df5136afda 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -125,6 +125,7 @@ fn create_membership_predicate( expr, in_list_array, false, + schema, )?))) } // Use hash table lookup for large build sides