Skip to content

Commit bd50698

Browse files
authored
Add array_distance function (#12211)
* Add `distance` aggregation function Signed-off-by: Austin Liu <[email protected]> Add `distance` aggregation function Signed-off-by: Austin Liu <[email protected]> * Add sql logic test for `distance` Signed-off-by: Austin Liu <[email protected]> * Simplify diff calculation Signed-off-by: Austin Liu <[email protected]> * Add `array_distance`/`list_distance` as list function in functions-nested Signed-off-by: Austin Liu <[email protected]> * Remove aggregate function `distance` Signed-off-by: Austin Liu <[email protected]> * format Signed-off-by: Austin Liu <[email protected]> * clean up error handling Signed-off-by: Austin Liu <[email protected]> * Add `array_distance` in scalar array functions docs Signed-off-by: Austin Liu <[email protected]> * Update bulletin Signed-off-by: Austin Liu <[email protected]> * Prettify example Signed-off-by: Austin Liu <[email protected]> --------- Signed-off-by: Austin Liu <[email protected]>
1 parent 1fce2a9 commit bd50698

File tree

4 files changed

+308
-0
lines changed

4 files changed

+308
-0
lines changed
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [ScalarUDFImpl] definitions for array_distance function.
19+
20+
use crate::utils::{downcast_arg, make_scalar_function};
21+
use arrow_array::{
22+
Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
23+
};
24+
use arrow_schema::DataType;
25+
use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List};
26+
use core::any::type_name;
27+
use datafusion_common::cast::{
28+
as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
29+
as_int64_array,
30+
};
31+
use datafusion_common::DataFusionError;
32+
use datafusion_common::{exec_err, Result};
33+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
34+
use std::any::Any;
35+
use std::sync::Arc;
36+
37+
make_udf_expr_and_func!(
38+
ArrayDistance,
39+
array_distance,
40+
array,
41+
"returns the Euclidean distance between two numeric arrays.",
42+
array_distance_udf
43+
);
44+
45+
#[derive(Debug)]
46+
pub(super) struct ArrayDistance {
47+
signature: Signature,
48+
aliases: Vec<String>,
49+
}
50+
51+
impl ArrayDistance {
52+
pub fn new() -> Self {
53+
Self {
54+
signature: Signature::variadic_any(Volatility::Immutable),
55+
aliases: vec!["list_distance".to_string()],
56+
}
57+
}
58+
}
59+
60+
impl ScalarUDFImpl for ArrayDistance {
61+
fn as_any(&self) -> &dyn Any {
62+
self
63+
}
64+
65+
fn name(&self) -> &str {
66+
"array_distance"
67+
}
68+
69+
fn signature(&self) -> &Signature {
70+
&self.signature
71+
}
72+
73+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
74+
match arg_types[0] {
75+
List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64),
76+
_ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
77+
}
78+
}
79+
80+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
81+
make_scalar_function(array_distance_inner)(args)
82+
}
83+
84+
fn aliases(&self) -> &[String] {
85+
&self.aliases
86+
}
87+
}
88+
89+
pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
90+
if args.len() != 2 {
91+
return exec_err!("array_distance expects exactly two arguments");
92+
}
93+
94+
match (&args[0].data_type(), &args[1].data_type()) {
95+
(List(_), List(_)) => general_array_distance::<i32>(args),
96+
(LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
97+
(array_type1, array_type2) => {
98+
exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'")
99+
}
100+
}
101+
}
102+
103+
fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
104+
let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
105+
let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
106+
107+
let result = list_array1
108+
.iter()
109+
.zip(list_array2.iter())
110+
.map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
111+
.collect::<Result<Float64Array>>()?;
112+
113+
Ok(Arc::new(result) as ArrayRef)
114+
}
115+
116+
/// Computes the Euclidean distance between two arrays
117+
fn compute_array_distance(
118+
arr1: Option<ArrayRef>,
119+
arr2: Option<ArrayRef>,
120+
) -> Result<Option<f64>> {
121+
let value1 = match arr1 {
122+
Some(arr) => arr,
123+
None => return Ok(None),
124+
};
125+
let value2 = match arr2 {
126+
Some(arr) => arr,
127+
None => return Ok(None),
128+
};
129+
130+
let mut value1 = value1;
131+
let mut value2 = value2;
132+
133+
loop {
134+
match value1.data_type() {
135+
List(_) => {
136+
if downcast_arg!(value1, ListArray).null_count() > 0 {
137+
return Ok(None);
138+
}
139+
value1 = downcast_arg!(value1, ListArray).value(0);
140+
}
141+
LargeList(_) => {
142+
if downcast_arg!(value1, LargeListArray).null_count() > 0 {
143+
return Ok(None);
144+
}
145+
value1 = downcast_arg!(value1, LargeListArray).value(0);
146+
}
147+
_ => break,
148+
}
149+
150+
match value2.data_type() {
151+
List(_) => {
152+
if downcast_arg!(value2, ListArray).null_count() > 0 {
153+
return Ok(None);
154+
}
155+
value2 = downcast_arg!(value2, ListArray).value(0);
156+
}
157+
LargeList(_) => {
158+
if downcast_arg!(value2, LargeListArray).null_count() > 0 {
159+
return Ok(None);
160+
}
161+
value2 = downcast_arg!(value2, LargeListArray).value(0);
162+
}
163+
_ => break,
164+
}
165+
}
166+
167+
// Check for NULL values inside the arrays
168+
if value1.null_count() != 0 || value2.null_count() != 0 {
169+
return Ok(None);
170+
}
171+
172+
let values1 = convert_to_f64_array(&value1)?;
173+
let values2 = convert_to_f64_array(&value2)?;
174+
175+
if values1.len() != values2.len() {
176+
return exec_err!("Both arrays must have the same length");
177+
}
178+
179+
let sum_squares: f64 = values1
180+
.iter()
181+
.zip(values2.iter())
182+
.map(|(v1, v2)| {
183+
let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
184+
diff * diff
185+
})
186+
.sum();
187+
188+
Ok(Some(sum_squares.sqrt()))
189+
}
190+
191+
/// Converts an array of any numeric type to a Float64Array.
192+
fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
193+
match array.data_type() {
194+
DataType::Float64 => Ok(as_float64_array(array)?.clone()),
195+
DataType::Float32 => {
196+
let array = as_float32_array(array)?;
197+
let converted: Float64Array =
198+
array.iter().map(|v| v.map(|v| v as f64)).collect();
199+
Ok(converted)
200+
}
201+
DataType::Int64 => {
202+
let array = as_int64_array(array)?;
203+
let converted: Float64Array =
204+
array.iter().map(|v| v.map(|v| v as f64)).collect();
205+
Ok(converted)
206+
}
207+
DataType::Int32 => {
208+
let array = as_int32_array(array)?;
209+
let converted: Float64Array =
210+
array.iter().map(|v| v.map(|v| v as f64)).collect();
211+
Ok(converted)
212+
}
213+
_ => exec_err!("Unsupported array type for conversion to Float64Array"),
214+
}
215+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub mod array_has;
3434
pub mod cardinality;
3535
pub mod concat;
3636
pub mod dimension;
37+
pub mod distance;
3738
pub mod empty;
3839
pub mod except;
3940
pub mod expr_ext;
@@ -73,6 +74,7 @@ pub mod expr_fn {
7374
pub use super::concat::array_prepend;
7475
pub use super::dimension::array_dims;
7576
pub use super::dimension::array_ndims;
77+
pub use super::distance::array_distance;
7678
pub use super::empty::array_empty;
7779
pub use super::except::array_except;
7880
pub use super::extract::array_element;
@@ -128,6 +130,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
128130
array_has::array_has_any_udf(),
129131
empty::array_empty_udf(),
130132
length::array_length_udf(),
133+
distance::array_distance_udf(),
131134
flatten::flatten_udf(),
132135
sort::array_sort_udf(),
133136
repeat::array_repeat_udf(),

datafusion/sqllogictest/test_files/array.slt

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4715,6 +4715,60 @@ NULL 10
47154715
NULL 10
47164716
NULL 10
47174717

4718+
query RRR
4719+
select array_distance([2], [3]), list_distance([1], [2]), list_distance([1], [-2]);
4720+
----
4721+
1 1 3
4722+
4723+
query error
4724+
select list_distance([1], [1, 2]);
4725+
4726+
query R
4727+
select array_distance([[1, 1]], [1, 2]);
4728+
----
4729+
1
4730+
4731+
query R
4732+
select array_distance([[1, 1]], [[1, 2]]);
4733+
----
4734+
1
4735+
4736+
query R
4737+
select array_distance([[1, 1]], [[1, 2]]);
4738+
----
4739+
1
4740+
4741+
query RR
4742+
select array_distance([1, 1, 0, 0], [2, 2, 1, 1]), list_distance([1, 2, 3], [1, 2, 3]);
4743+
----
4744+
2 0
4745+
4746+
query RR
4747+
select array_distance([1.0, 1, 0, 0], [2, 2.0, 1, 1]), list_distance([1, 2.0, 3], [1, 2, 3]);
4748+
----
4749+
2 0
4750+
4751+
query R
4752+
select list_distance([1, 1, NULL, 0], [2, 2, NULL, NULL]);
4753+
----
4754+
NULL
4755+
4756+
query R
4757+
select list_distance([NULL, NULL], [NULL, NULL]);
4758+
----
4759+
NULL
4760+
4761+
query R
4762+
select list_distance([1.0, 2.0, 3.0], [1.0, 2.0, 3.5]) AS distance;
4763+
----
4764+
0.5
4765+
4766+
query R
4767+
select list_distance([1, 2, 3], [1, 2, 3]) AS distance;
4768+
----
4769+
0
4770+
4771+
47184772
## array_dims (aliases: `list_dims`)
47194773

47204774
# array dims error

docs/source/user-guide/sql/scalar_functions.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,6 +2093,7 @@ to_unixtime(expression[, ..., format_n])
20932093
- [array_concat](#array_concat)
20942094
- [array_contains](#array_contains)
20952095
- [array_dims](#array_dims)
2096+
- [array_distance](#array_distance)
20962097
- [array_distinct](#array_distinct)
20972098
- [array_has](#array_has)
20982099
- [array_has_all](#array_has_all)
@@ -2135,6 +2136,7 @@ to_unixtime(expression[, ..., format_n])
21352136
- [list_cat](#list_cat)
21362137
- [list_concat](#list_concat)
21372138
- [list_dims](#list_dims)
2139+
- [list_distance](#list_distance)
21382140
- [list_distinct](#list_distinct)
21392141
- [list_element](#list_element)
21402142
- [list_except](#list_except)
@@ -2388,6 +2390,36 @@ array_dims(array)
23882390

23892391
- list_dims
23902392

2393+
### `array_distance`
2394+
2395+
Returns the Euclidean distance between two input arrays of equal length.
2396+
2397+
```
2398+
array_distance(array1, array2)
2399+
```
2400+
2401+
#### Arguments
2402+
2403+
- **array1**: Array expression.
2404+
Can be a constant, column, or function, and any combination of array operators.
2405+
- **array2**: Array expression.
2406+
Can be a constant, column, or function, and any combination of array operators.
2407+
2408+
#### Example
2409+
2410+
```
2411+
> select array_distance([1, 2], [1, 4]);
2412+
+------------------------------------+
2413+
| array_distance(List([1,2], [1,4])) |
2414+
+------------------------------------+
2415+
| 2.0 |
2416+
+------------------------------------+
2417+
```
2418+
2419+
#### Aliases
2420+
2421+
- list_distance
2422+
23912423
### `array_distinct`
23922424

23932425
Returns distinct values from the array after removing duplicates.
@@ -3224,6 +3256,10 @@ _Alias of [array_concat](#array_concat)._
32243256

32253257
_Alias of [array_dims](#array_dims)._
32263258

3259+
### `list_distance`
3260+
3261+
_Alias of [array_distance](#array_distance)._
3262+
32273263
### `list_distinct`
32283264

32293265
_Alias of [array_dims](#array_distinct)._

0 commit comments

Comments
 (0)