diff --git a/datafusion/functions-nested/src/cosine_distance.rs b/datafusion/functions-nested/src/cosine_distance.rs new file mode 100644 index 0000000000000..ab616ec7942d2 --- /dev/null +++ b/datafusion/functions-nested/src/cosine_distance.rs @@ -0,0 +1,322 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for cosine_distance function. + +use crate::utils::make_scalar_function; +use crate::vector_math::{convert_to_f64_array, dot_product_f64, magnitude_f64}; +use arrow::array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, +}; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_functions::downcast_arg; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::sync::Arc; + +make_udf_expr_and_func!( + CosineDistance, + cosine_distance, + array1 array2, + "returns the cosine distance between two numeric arrays.", + cosine_distance_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros.", + syntax_example = "cosine_distance(array1, array2)", + sql_example = r#"```sql +> select cosine_distance([1.0, 0.0], [0.0, 1.0]); ++-----------------------------------------------+ +| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) | ++-----------------------------------------------+ +| 1.0 | ++-----------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CosineDistance { + signature: Signature, + aliases: Vec, +} + +impl Default for CosineDistance { + fn default() -> Self { + Self::new() + } +} + +impl CosineDistance { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_cosine_distance".to_string()], + } + } +} + +impl ScalarUDFImpl for CosineDistance { + fn name(&self) -> &str { + "cosine_distance" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) + } + }); + + arg_types.try_collect() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(cosine_distance_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn cosine_distance_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("cosine_distance", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_cosine_distance::(args), + (LargeList(_), LargeList(_)) => general_cosine_distance::(args), + (arg_type1, arg_type2) => { + exec_err!( + "cosine_distance does not support types {arg_type1} and {arg_type2}" + ) + } + } +} + +fn general_cosine_distance(arrays: &[ArrayRef]) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let result = list_array1 + .iter() + .zip(list_array2.iter()) + .map(|(arr1, arr2)| compute_cosine_distance(arr1, arr2)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Computes the cosine distance between two arrays: 1 - dot(a,b) / (||a|| * ||b||) +fn compute_cosine_distance( + arr1: Option, + arr2: Option, +) -> Result> { + let value1 = match arr1 { + Some(arr) => arr, + None => return Ok(None), + }; + let value2 = match arr2 { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value1 = value1; + let mut value2 = value2; + + loop { + match value1.data_type() { + List(_) => { + if downcast_arg!(value1, ListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value1, LargeListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, LargeListArray).value(0); + } + _ => break, + } + + match value2.data_type() { + List(_) => { + if downcast_arg!(value2, ListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value2, LargeListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, LargeListArray).value(0); + } + _ => break, + } + } + + if value1.null_count() != 0 || value2.null_count() != 0 { + return Ok(None); + } + + let values1 = convert_to_f64_array(&value1)?; + let values2 = convert_to_f64_array(&value2)?; + + if values1.len() != values2.len() { + return exec_err!("Both arrays must have the same length"); + } + + let dot = dot_product_f64(&values1, &values2); + let mag1 = magnitude_f64(&values1); + let mag2 = magnitude_f64(&values2); + + if mag1 == 0.0 || mag2 == 0.0 { + return Ok(None); + } + + Ok(Some(1.0 - dot / (mag1 * mag2))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + fn make_f64_list_array(values: Vec>>>) -> ArrayRef { + let mut flat: Vec> = Vec::new(); + let mut offsets: Vec = vec![0]; + for v in &values { + match v { + Some(inner) => { + flat.extend(inner); + offsets.push(flat.len() as i32); + } + None => { + offsets.push(flat.len() as i32); + } + } + } + let values_array = Arc::new(Float64Array::from(flat)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from( + values.iter().map(|v| v.is_some()).collect::>(), + ); + Arc::new(ListArray::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) + } + + #[test] + fn test_cosine_distance_orthogonal() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(1.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 1.0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_identical() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.value(0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_opposite() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(-1.0), Some(0.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 2.0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_null_array() { + let arr1 = make_f64_list_array(vec![None]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + } + + #[test] + fn test_cosine_distance_mismatched_lengths() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]); + assert!(result.is_err()); + } + + #[test] + fn test_cosine_distance_zero_magnitude() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + } +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 99b25ec96454b..716e72790b70d 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -41,6 +41,7 @@ pub mod array_has; pub mod arrays_zip; pub mod cardinality; pub mod concat; +pub mod cosine_distance; pub mod dimension; pub mod distance; pub mod empty; @@ -68,6 +69,7 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; +pub mod vector_math; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -85,6 +87,7 @@ pub mod expr_fn { pub use super::concat::array_append; pub use super::concat::array_concat; pub use super::concat::array_prepend; + pub use super::cosine_distance::cosine_distance; pub use super::dimension::array_dims; pub use super::dimension::array_ndims; pub use super::distance::array_distance; @@ -150,6 +153,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_any_udf(), empty::array_empty_udf(), length::array_length_udf(), + cosine_distance::cosine_distance_udf(), distance::array_distance_udf(), flatten::flatten_udf(), min_max::array_max_udf(), diff --git a/datafusion/functions-nested/src/vector_math.rs b/datafusion/functions-nested/src/vector_math.rs new file mode 100644 index 0000000000000..02b8772cab915 --- /dev/null +++ b/datafusion/functions-nested/src/vector_math.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared vector math primitives used by cosine_distance, inner_product, +//! array_normalize, and related functions. + +use arrow::array::{ArrayRef, Float64Array}; +use datafusion_common::cast::{ + as_float32_array, as_float64_array, as_int32_array, as_int64_array, +}; +use datafusion_common::{Result, exec_err}; + +/// Converts an array of any numeric type to a Float64Array. +pub fn convert_to_f64_array(array: &ArrayRef) -> Result { + match array.data_type() { + arrow::datatypes::DataType::Float64 => Ok(as_float64_array(array)?.clone()), + arrow::datatypes::DataType::Float32 => { + let array = as_float32_array(array)?; + Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) + } + arrow::datatypes::DataType::Int64 => { + let array = as_int64_array(array)?; + Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) + } + arrow::datatypes::DataType::Int32 => { + let array = as_int32_array(array)?; + Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) + } + _ => exec_err!("Unsupported array type for conversion to Float64Array"), + } +} + +/// Computes dot product: sum(a\[i\] * b\[i\]) +pub fn dot_product_f64(a: &Float64Array, b: &Float64Array) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(v1, v2)| v1.unwrap_or(0.0) * v2.unwrap_or(0.0)) + .sum() +} + +/// Computes sum of squares: sum(a\[i\]^2) +pub fn sum_of_squares_f64(a: &Float64Array) -> f64 { + a.iter() + .map(|v| { + let val = v.unwrap_or(0.0); + val * val + }) + .sum() +} + +/// Computes magnitude (L2 norm): sqrt(sum(a\[i\]^2)) +pub fn magnitude_f64(a: &Float64Array) -> f64 { + sum_of_squares_f64(a).sqrt() +} diff --git a/datafusion/sqllogictest/test_files/cosine_distance.slt b/datafusion/sqllogictest/test_files/cosine_distance.slt new file mode 100644 index 0000000000000..4daba225250be --- /dev/null +++ b/datafusion/sqllogictest/test_files/cosine_distance.slt @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## cosine_distance + +# Orthogonal vectors: distance = 1.0 +query R +select cosine_distance([1.0, 0.0], [0.0, 1.0]); +---- +1 + +# Identical vectors: distance = 0.0 +query R +select cosine_distance([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]); +---- +0 + +# Opposite vectors: distance = 2.0 +query R +select cosine_distance([1.0, 0.0], [-1.0, 0.0]); +---- +2 + +# 45-degree angle: distance ≈ 0.293 +query R +select round(cosine_distance([1.0, 0.0], [1.0, 1.0]), 3); +---- +0.293 + +# NULL input (bare NULL is not a list type, errors at planning) +query error cosine_distance does not support type +select cosine_distance(NULL, [1.0, 2.0]); + +# NULL in second position +query error cosine_distance does not support type +select cosine_distance([1.0, 2.0], NULL); + +# Zero vector returns NULL (undefined cosine similarity) +query R +select cosine_distance([0.0, 0.0], [1.0, 2.0]); +---- +NULL + +# Mismatched lengths error +query error Both arrays must have the same length +select cosine_distance([1.0, 2.0], [1.0]); + +# LargeList support +query R +select cosine_distance( + arrow_cast([1.0, 0.0], 'LargeList(Float64)'), + arrow_cast([0.0, 1.0], 'LargeList(Float64)') +); +---- +1 + +# Integer arrays (coerced to Float64) +query R +select cosine_distance([1, 0], [0, 1]); +---- +1 + +# Multi-row query +query R +select cosine_distance(column1, column2) from (values + (make_array(1.0, 0.0), make_array(0.0, 1.0)), + (make_array(1.0, 1.0), make_array(1.0, 1.0)), + (make_array(1.0, 0.0), make_array(-1.0, 0.0)) +) as t(column1, column2); +---- +1 +0 +2 + +# list_cosine_distance alias +query R +select list_cosine_distance([1.0, 0.0], [0.0, 1.0]); +---- +1 + +# Empty arrays return NULL (magnitude = 0) +query R +select cosine_distance(arrow_cast(make_array(), 'List(Float64)'), arrow_cast(make_array(), 'List(Float64)')); +---- +NULL + +# No arguments error +query error cosine_distance function requires 2 arguments, got 0 +select cosine_distance(); + +# Return type is Float64 +query RT +select cosine_distance([1.0, 0.0], [0.0, 1.0]), arrow_typeof(cosine_distance([1.0, 0.0], [0.0, 1.0])); +---- +1 Float64 diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d1b80f1f90b8b..84455402571ee 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3279,6 +3279,7 @@ _Alias of [current_date](#current_date)._ - [arrays_overlap](#arrays_overlap) - [arrays_zip](#arrays_zip) - [cardinality](#cardinality) +- [cosine_distance](#cosine_distance) - [empty](#empty) - [flatten](#flatten) - [generate_series](#generate_series) @@ -3287,6 +3288,7 @@ _Alias of [current_date](#current_date)._ - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_contains](#list_contains) +- [list_cosine_distance](#list_cosine_distance) - [list_dims](#list_dims) - [list_distance](#list_distance) - [list_distinct](#list_distinct) @@ -4441,6 +4443,34 @@ cardinality(array) +--------------------------------------+ ``` +### `cosine_distance` + +Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros. + +```sql +cosine_distance(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select cosine_distance([1.0, 0.0], [0.0, 1.0]); ++-----------------------------------------------+ +| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) | ++-----------------------------------------------+ +| 1.0 | ++-----------------------------------------------+ +``` + +#### Aliases + +- list_cosine_distance + ### `empty` Returns 1 for an empty array or 0 for a non-empty array. @@ -4543,6 +4573,10 @@ _Alias of [array_concat](#array_concat)._ _Alias of [array_has](#array_has)._ +### `list_cosine_distance` + +_Alias of [cosine_distance](#cosine_distance)._ + ### `list_dims` _Alias of [array_dims](#array_dims)._