Skip to content

feat: Support FixedSizedList in array_distance function #12381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion datafusion/functions-nested/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion_common::cast::{
as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
as_int64_array,
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::DataFusionError;
use datafusion_common::{exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand All @@ -51,7 +52,7 @@ pub(super) struct ArrayDistance {
impl ArrayDistance {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec!["list_distance".to_string()],
}
}
Expand All @@ -77,6 +78,21 @@ impl ScalarUDFImpl for ArrayDistance {
}
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}
let mut result = Vec::new();
for arg_type in arg_types {
match arg_type {
List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)),
_ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
}
}

Ok(result)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_distance_inner)(args)
}
Expand Down
57 changes: 57 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,38 @@ AS VALUES
(arrow_cast(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), 'FixedSizeList(10, List(Int64))'), [28, 29, 30], [28, 29, 30], 10)
;

statement ok
CREATE TABLE arrays_distance_table
AS VALUES
(make_array(1, 2, 3), make_array(1, 2, 3), make_array(1.1, 2.2, 3.3) , make_array(1.1, NULL, 3.3)),
(make_array(1, 2, 3), make_array(4, 5, 6), make_array(4.4, 5.5, 6.6), make_array(4.4, NULL, 6.6)),
(make_array(1, 2, 3), make_array(7, 8, 9), make_array(7.7, 8.8, 9.9), make_array(7.7, NULL, 9.9)),
(make_array(1, 2, 3), make_array(10, 11, 12), make_array(10.1, 11.2, 12.3), make_array(10.1, NULL, 12.3))
;

statement ok
CREATE TABLE large_arrays_distance_table
AS
SELECT
arrow_cast(column1, 'LargeList(Int64)') AS column1,
arrow_cast(column2, 'LargeList(Int64)') AS column2,
arrow_cast(column3, 'LargeList(Float64)') AS column3,
arrow_cast(column4, 'LargeList(Float64)') AS column4
FROM arrays_distance_table
;

statement ok
CREATE TABLE fixed_size_arrays_distance_table
AS
SELECT
arrow_cast(column1, 'FixedSizeList(3, Int64)') AS column1,
arrow_cast(column2, 'FixedSizeList(3, Int64)') AS column2,
arrow_cast(column3, 'FixedSizeList(3, Float64)') AS column3,
arrow_cast(column4, 'FixedSizeList(3, Float64)') AS column4
FROM arrays_distance_table
;


# Array literal

## boolean coercion is not supported
Expand Down Expand Up @@ -4768,6 +4800,31 @@ select list_distance([1, 2, 3], [1, 2, 3]) AS distance;
----
0

# array_distance with columns
query RRR
select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from arrays_distance_table;
----
0 0.374165738677 NULL
5.196152422707 6.063827174318 NULL
10.392304845413 11.778794505381 NULL
15.58845726812 15.935494971917 NULL

query RRR
select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from large_arrays_distance_table;
----
0 0.374165738677 NULL
5.196152422707 6.063827174318 NULL
10.392304845413 11.778794505381 NULL
15.58845726812 15.935494971917 NULL

query RRR
select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from fixed_size_arrays_distance_table;
----
0 0.374165738677 NULL
5.196152422707 6.063827174318 NULL
10.392304845413 11.778794505381 NULL
15.58845726812 15.935494971917 NULL


## array_dims (aliases: `list_dims`)

Expand Down
Loading