Skip to content

Commit 2a0fc77

Browse files
authored
Add support of sorting dictionary of other primitive arrays (#2701)
* Add support of sorting dictionary of other primitive arrays * Collapse match statements * Add one helper to match primitive types
1 parent 5146663 commit 2a0fc77

File tree

2 files changed

+179
-104
lines changed

2 files changed

+179
-104
lines changed

arrow/src/compute/kernels/sort.rs

Lines changed: 155 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -314,119 +314,32 @@ pub fn sort_to_indices(
314314
}
315315
},
316316
DataType::Dictionary(_, _) => {
317+
let value_null_first = if options.descending {
318+
// When sorting dictionary in descending order, we take inverse of of null ordering
319+
// when sorting the values. Because if `nulls_first` is true, null must be in front
320+
// of non-null value. As we take the sorted order of value array to sort dictionary
321+
// keys, these null values will be treated as smallest ones and be sorted to the end
322+
// of sorted result. So we set `nulls_first` to false when sorting dictionary value
323+
// array to make them as largest ones, then null values will be put at the beginning
324+
// of sorted dictionary result.
325+
!options.nulls_first
326+
} else {
327+
options.nulls_first
328+
};
329+
let value_options = Some(SortOptions {
330+
descending: false,
331+
nulls_first: value_null_first,
332+
});
317333
downcast_dictionary_array!(
318334
values => match values.values().data_type() {
319-
DataType::Int8 => {
320-
let dict_values = values.values();
321-
let value_null_first = if options.descending {
322-
// When sorting dictionary in descending order, we take inverse of of null ordering
323-
// when sorting the values. Because if `nulls_first` is true, null must be in front
324-
// of non-null value. As we take the sorted order of value array to sort dictionary
325-
// keys, these null values will be treated as smallest ones and be sorted to the end
326-
// of sorted result. So we set `nulls_first` to false when sorting dictionary value
327-
// array to make them as largest ones, then null values will be put at the beginning
328-
// of sorted dictionary result.
329-
!options.nulls_first
330-
} else {
331-
options.nulls_first
332-
};
333-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
334-
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
335-
let value_indices_map = prepare_indices_map(&sorted_value_indices);
336-
sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp)
337-
},
338-
DataType::Int16 => {
335+
dt if DataType::is_primitive(dt) => {
339336
let dict_values = values.values();
340-
let value_null_first = if options.descending {
341-
!options.nulls_first
342-
} else {
343-
options.nulls_first
344-
};
345-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
346-
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
347-
let value_indices_map = prepare_indices_map(&sorted_value_indices);
348-
sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp)
349-
},
350-
DataType::Int32 => {
351-
let dict_values = values.values();
352-
let value_null_first = if options.descending {
353-
!options.nulls_first
354-
} else {
355-
options.nulls_first
356-
};
357-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
358-
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
359-
let value_indices_map = prepare_indices_map(&sorted_value_indices);
360-
sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp)
361-
},
362-
DataType::Int64 => {
363-
let dict_values = values.values();
364-
let value_null_first = if options.descending {
365-
!options.nulls_first
366-
} else {
367-
options.nulls_first
368-
};
369-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
370-
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
371-
let value_indices_map = prepare_indices_map(&sorted_value_indices);
372-
sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp)
373-
},
374-
DataType::UInt8 => {
375-
let dict_values = values.values();
376-
let value_null_first = if options.descending {
377-
!options.nulls_first
378-
} else {
379-
options.nulls_first
380-
};
381-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
382-
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
383-
let value_indices_map = prepare_indices_map(&sorted_value_indices);
384-
sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp)
385-
},
386-
DataType::UInt16 => {
387-
let dict_values = values.values();
388-
let value_null_first = if options.descending {
389-
!options.nulls_first
390-
} else {
391-
options.nulls_first
392-
};
393-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
394-
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
395-
let value_indices_map = prepare_indices_map(&sorted_value_indices);
396-
sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp)
397-
},
398-
DataType::UInt32 => {
399-
let dict_values = values.values();
400-
let value_null_first = if options.descending {
401-
!options.nulls_first
402-
} else {
403-
options.nulls_first
404-
};
405-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
406-
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
407-
let value_indices_map = prepare_indices_map(&sorted_value_indices);
408-
sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp)
409-
},
410-
DataType::UInt64 => {
411-
let dict_values = values.values();
412-
let value_null_first = if options.descending {
413-
!options.nulls_first
414-
} else {
415-
options.nulls_first
416-
};
417-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
418337
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
419338
let value_indices_map = prepare_indices_map(&sorted_value_indices);
420339
sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp)
421340
},
422341
DataType::Utf8 => {
423342
let dict_values = values.values();
424-
let value_null_first = if options.descending {
425-
!options.nulls_first
426-
} else {
427-
options.nulls_first
428-
};
429-
let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first });
430343
let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?;
431344
let value_indices_map = prepare_indices_map(&sorted_value_indices);
432345
sort_string_dictionary::<_>(values, &value_indices_map, v, n, &options, limit)
@@ -3552,4 +3465,142 @@ mod tests {
35523465
vec![None, None, None, Some(5), Some(5), Some(3), Some(1)],
35533466
);
35543467
}
3468+
3469+
#[test]
3470+
fn test_sort_f32_dicts() {
3471+
let keys =
3472+
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]);
3473+
let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
3474+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3475+
keys,
3476+
values,
3477+
None,
3478+
None,
3479+
vec![None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)],
3480+
);
3481+
3482+
let keys =
3483+
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]);
3484+
let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
3485+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3486+
keys,
3487+
values,
3488+
Some(SortOptions {
3489+
descending: true,
3490+
nulls_first: false,
3491+
}),
3492+
None,
3493+
vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None],
3494+
);
3495+
3496+
let keys =
3497+
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]);
3498+
let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
3499+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3500+
keys,
3501+
values,
3502+
Some(SortOptions {
3503+
descending: false,
3504+
nulls_first: false,
3505+
}),
3506+
None,
3507+
vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None],
3508+
);
3509+
3510+
let keys =
3511+
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]);
3512+
let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
3513+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3514+
keys,
3515+
values,
3516+
Some(SortOptions {
3517+
descending: true,
3518+
nulls_first: true,
3519+
}),
3520+
Some(3),
3521+
vec![None, None, Some(5.1)],
3522+
);
3523+
3524+
// Values have `None`.
3525+
let keys = Int8Array::from(vec![
3526+
Some(1_i8),
3527+
None,
3528+
Some(3),
3529+
None,
3530+
Some(2),
3531+
Some(3),
3532+
Some(0),
3533+
]);
3534+
let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]);
3535+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3536+
keys,
3537+
values,
3538+
None,
3539+
None,
3540+
vec![None, None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)],
3541+
);
3542+
3543+
let keys = Int8Array::from(vec![
3544+
Some(1_i8),
3545+
None,
3546+
Some(3),
3547+
None,
3548+
Some(2),
3549+
Some(3),
3550+
Some(0),
3551+
]);
3552+
let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]);
3553+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3554+
keys,
3555+
values,
3556+
Some(SortOptions {
3557+
descending: false,
3558+
nulls_first: false,
3559+
}),
3560+
None,
3561+
vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None, None],
3562+
);
3563+
3564+
let keys = Int8Array::from(vec![
3565+
Some(1_i8),
3566+
None,
3567+
Some(3),
3568+
None,
3569+
Some(2),
3570+
Some(3),
3571+
Some(0),
3572+
]);
3573+
let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]);
3574+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3575+
keys,
3576+
values,
3577+
Some(SortOptions {
3578+
descending: true,
3579+
nulls_first: false,
3580+
}),
3581+
None,
3582+
vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None, None],
3583+
);
3584+
3585+
let keys = Int8Array::from(vec![
3586+
Some(1_i8),
3587+
None,
3588+
Some(3),
3589+
None,
3590+
Some(2),
3591+
Some(3),
3592+
Some(0),
3593+
]);
3594+
let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]);
3595+
test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
3596+
keys,
3597+
values,
3598+
Some(SortOptions {
3599+
descending: true,
3600+
nulls_first: true,
3601+
}),
3602+
None,
3603+
vec![None, None, None, Some(5.1), Some(5.1), Some(3.0), Some(1.2)],
3604+
);
3605+
}
35553606
}

arrow/src/datatypes/datatype.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,30 @@ impl DataType {
10701070
)
10711071
}
10721072

1073+
/// Returns true if the type is primitive: (numeric, temporal).
1074+
pub fn is_primitive(t: &DataType) -> bool {
1075+
use DataType::*;
1076+
matches!(
1077+
t,
1078+
Int8 | Int16
1079+
| Int32
1080+
| Int64
1081+
| UInt8
1082+
| UInt16
1083+
| UInt32
1084+
| UInt64
1085+
| Float32
1086+
| Float64
1087+
| Date32
1088+
| Date64
1089+
| Time32(_)
1090+
| Time64(_)
1091+
| Timestamp(_, _)
1092+
| Interval(_)
1093+
| Duration(_)
1094+
)
1095+
}
1096+
10731097
/// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval).
10741098
pub fn is_temporal(t: &DataType) -> bool {
10751099
use DataType::*;

0 commit comments

Comments
 (0)