Skip to content

Commit dd242b9

Browse files
authored
refactor: change some hashbrown RawTable uses to HashTable (round 2) (#13524)
* refactor: migrate `GroupValuesRows` to `HashTable` For #13433. * refactor: migrate `GroupValuesPrimitive` to `HashTable` For #13433. * refactor: migrate `GroupValuesColumn` to `HashTable` For #13433.
1 parent 55e56c4 commit dd242b9

File tree

3 files changed

+89
-96
lines changed

3 files changed

+89
-96
lines changed

datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ use arrow_array::{Array, ArrayRef};
4242
use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit};
4343
use datafusion_common::hash_utils::create_hashes;
4444
use datafusion_common::{not_impl_err, DataFusionError, Result};
45-
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
45+
use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt};
4646
use datafusion_expr::EmitTo;
4747
use datafusion_physical_expr::binary_map::OutputType;
4848

49-
use hashbrown::raw::RawTable;
49+
use hashbrown::hash_table::HashTable;
5050

5151
const NON_INLINED_FLAG: u64 = 0x8000000000000000;
5252
const VALUE_MASK: u64 = 0x7FFFFFFFFFFFFFFF;
@@ -180,7 +180,7 @@ pub struct GroupValuesColumn<const STREAMING: bool> {
180180
/// And we use [`GroupIndexView`] to represent such `group indices` in table.
181181
///
182182
///
183-
map: RawTable<(u64, GroupIndexView)>,
183+
map: HashTable<(u64, GroupIndexView)>,
184184

185185
/// The size of `map` in bytes
186186
map_size: usize,
@@ -261,7 +261,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
261261

262262
/// Create a new instance of GroupValuesColumn if supported for the specified schema
263263
pub fn try_new(schema: SchemaRef) -> Result<Self> {
264-
let map = RawTable::with_capacity(0);
264+
let map = HashTable::with_capacity(0);
265265
Ok(Self {
266266
schema,
267267
map,
@@ -338,7 +338,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
338338
for (row, &target_hash) in batch_hashes.iter().enumerate() {
339339
let entry = self
340340
.map
341-
.get_mut(target_hash, |(exist_hash, group_idx_view)| {
341+
.find_mut(target_hash, |(exist_hash, group_idx_view)| {
342342
// It is ensured to be inlined in `scalarized_intern`
343343
debug_assert!(!group_idx_view.is_non_inlined());
344344

@@ -506,7 +506,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
506506
for (row, &target_hash) in batch_hashes.iter().enumerate() {
507507
let entry = self
508508
.map
509-
.get(target_hash, |(exist_hash, _)| target_hash == *exist_hash);
509+
.find(target_hash, |(exist_hash, _)| target_hash == *exist_hash);
510510

511511
let Some((_, group_index_view)) = entry else {
512512
// 1. Bucket not found case
@@ -733,7 +733,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
733733

734734
for &row in &self.vectorized_operation_buffers.remaining_row_indices {
735735
let target_hash = batch_hashes[row];
736-
let entry = map.get_mut(target_hash, |(exist_hash, _)| {
736+
let entry = map.find_mut(target_hash, |(exist_hash, _)| {
737737
// Somewhat surprisingly, this closure can be called even if the
738738
// hash doesn't match, so check the hash first with an integer
739739
// comparison first avoid the more expensive comparison with
@@ -852,7 +852,7 @@ impl<const STREAMING: bool> GroupValuesColumn<STREAMING> {
852852
/// Return group indices of the hash, also if its `group_index_view` is non-inlined
853853
#[cfg(test)]
854854
fn get_indices_by_hash(&self, hash: u64) -> Option<(Vec<usize>, GroupIndexView)> {
855-
let entry = self.map.get(hash, |(exist_hash, _)| hash == *exist_hash);
855+
let entry = self.map.find(hash, |(exist_hash, _)| hash == *exist_hash);
856856

857857
match entry {
858858
Some((_, group_index_view)) => {
@@ -1091,67 +1091,63 @@ impl<const STREAMING: bool> GroupValues for GroupValuesColumn<STREAMING> {
10911091
.collect::<Vec<_>>();
10921092
let mut next_new_list_offset = 0;
10931093

1094-
// SAFETY: self.map outlives iterator and is not modified concurrently
1095-
unsafe {
1096-
for bucket in self.map.iter() {
1097-
// In non-streaming case, we need to check if the `group index view`
1098-
// is `inlined` or `non-inlined`
1099-
if !STREAMING && bucket.as_ref().1.is_non_inlined() {
1100-
// Non-inlined case
1101-
// We take `group_index_list` from `old_group_index_lists`
1102-
1103-
// list_offset is incrementally
1104-
self.emit_group_index_list_buffer.clear();
1105-
let list_offset = bucket.as_ref().1.value() as usize;
1106-
for group_index in self.group_index_lists[list_offset].iter()
1107-
{
1108-
if let Some(remaining) = group_index.checked_sub(n) {
1109-
self.emit_group_index_list_buffer.push(remaining);
1110-
}
1094+
self.map.retain(|(_exist_hash, group_idx_view)| {
1095+
// In non-streaming case, we need to check if the `group index view`
1096+
// is `inlined` or `non-inlined`
1097+
if !STREAMING && group_idx_view.is_non_inlined() {
1098+
// Non-inlined case
1099+
// We take `group_index_list` from `old_group_index_lists`
1100+
1101+
// list_offset is incrementally
1102+
self.emit_group_index_list_buffer.clear();
1103+
let list_offset = group_idx_view.value() as usize;
1104+
for group_index in self.group_index_lists[list_offset].iter() {
1105+
if let Some(remaining) = group_index.checked_sub(n) {
1106+
self.emit_group_index_list_buffer.push(remaining);
11111107
}
1112-
1113-
// The possible results:
1114-
// - `new_group_index_list` is empty, we should erase this bucket
1115-
// - only one value in `new_group_index_list`, switch the `view` to `inlined`
1116-
// - still multiple values in `new_group_index_list`, build and set the new `unlined view`
1117-
if self.emit_group_index_list_buffer.is_empty() {
1118-
self.map.erase(bucket);
1119-
} else if self.emit_group_index_list_buffer.len() == 1 {
1120-
let group_index =
1121-
self.emit_group_index_list_buffer.first().unwrap();
1122-
bucket.as_mut().1 =
1123-
GroupIndexView::new_inlined(*group_index as u64);
1124-
} else {
1125-
let group_index_list =
1126-
&mut self.group_index_lists[next_new_list_offset];
1127-
group_index_list.clear();
1128-
group_index_list
1129-
.extend(self.emit_group_index_list_buffer.iter());
1130-
bucket.as_mut().1 = GroupIndexView::new_non_inlined(
1131-
next_new_list_offset as u64,
1132-
);
1133-
next_new_list_offset += 1;
1134-
}
1135-
1136-
continue;
11371108
}
11381109

1110+
// The possible results:
1111+
// - `new_group_index_list` is empty, we should erase this bucket
1112+
// - only one value in `new_group_index_list`, switch the `view` to `inlined`
1113+
// - still multiple values in `new_group_index_list`, build and set the new `unlined view`
1114+
if self.emit_group_index_list_buffer.is_empty() {
1115+
false
1116+
} else if self.emit_group_index_list_buffer.len() == 1 {
1117+
let group_index =
1118+
self.emit_group_index_list_buffer.first().unwrap();
1119+
*group_idx_view =
1120+
GroupIndexView::new_inlined(*group_index as u64);
1121+
true
1122+
} else {
1123+
let group_index_list =
1124+
&mut self.group_index_lists[next_new_list_offset];
1125+
group_index_list.clear();
1126+
group_index_list
1127+
.extend(self.emit_group_index_list_buffer.iter());
1128+
*group_idx_view = GroupIndexView::new_non_inlined(
1129+
next_new_list_offset as u64,
1130+
);
1131+
next_new_list_offset += 1;
1132+
true
1133+
}
1134+
} else {
11391135
// In `streaming case`, the `group index view` is ensured to be `inlined`
1140-
debug_assert!(!bucket.as_ref().1.is_non_inlined());
1136+
debug_assert!(!group_idx_view.is_non_inlined());
11411137

11421138
// Inlined case, we just decrement group index by n)
1143-
let group_index = bucket.as_ref().1.value() as usize;
1139+
let group_index = group_idx_view.value() as usize;
11441140
match group_index.checked_sub(n) {
11451141
// Group index was >= n, shift value down
11461142
Some(sub) => {
1147-
bucket.as_mut().1 =
1148-
GroupIndexView::new_inlined(sub as u64)
1143+
*group_idx_view = GroupIndexView::new_inlined(sub as u64);
1144+
true
11491145
}
11501146
// Group index was < n, so remove from table
1151-
None => self.map.erase(bucket),
1147+
None => false,
11521148
}
11531149
}
1154-
}
1150+
});
11551151

11561152
if !STREAMING {
11571153
self.group_index_lists.truncate(next_new_list_offset);
@@ -1243,7 +1239,7 @@ mod tests {
12431239
use arrow::{compute::concat_batches, util::pretty::pretty_format_batches};
12441240
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StringViewArray};
12451241
use arrow_schema::{DataType, Field, Schema, SchemaRef};
1246-
use datafusion_common::utils::proxy::RawTableAllocExt;
1242+
use datafusion_common::utils::proxy::HashTableAllocExt;
12471243
use datafusion_expr::EmitTo;
12481244

12491245
use crate::aggregates::group_values::{

datafusion/physical-plan/src/aggregates/group_values/row.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ use arrow_array::{Array, ArrayRef, ListArray, StructArray};
2424
use arrow_schema::{DataType, SchemaRef};
2525
use datafusion_common::hash_utils::create_hashes;
2626
use datafusion_common::Result;
27-
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
27+
use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt};
2828
use datafusion_expr::EmitTo;
29-
use hashbrown::raw::RawTable;
29+
use hashbrown::hash_table::HashTable;
3030
use log::debug;
3131
use std::mem::size_of;
3232
use std::sync::Arc;
@@ -54,7 +54,7 @@ pub struct GroupValuesRows {
5454
///
5555
/// keys: u64 hashes of the GroupValue
5656
/// values: (hash, group_index)
57-
map: RawTable<(u64, usize)>,
57+
map: HashTable<(u64, usize)>,
5858

5959
/// The size of `map` in bytes
6060
map_size: usize,
@@ -92,7 +92,7 @@ impl GroupValuesRows {
9292
.collect(),
9393
)?;
9494

95-
let map = RawTable::with_capacity(0);
95+
let map = HashTable::with_capacity(0);
9696

9797
let starting_rows_capacity = 1000;
9898

@@ -135,7 +135,7 @@ impl GroupValues for GroupValuesRows {
135135
create_hashes(cols, &self.random_state, batch_hashes)?;
136136

137137
for (row, &target_hash) in batch_hashes.iter().enumerate() {
138-
let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| {
138+
let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| {
139139
// Somewhat surprisingly, this closure can be called even if the
140140
// hash doesn't match, so check the hash first with an integer
141141
// comparison first avoid the more expensive comparison with
@@ -216,18 +216,18 @@ impl GroupValues for GroupValuesRows {
216216
}
217217
std::mem::swap(&mut new_group_values, &mut group_values);
218218

219-
// SAFETY: self.map outlives iterator and is not modified concurrently
220-
unsafe {
221-
for bucket in self.map.iter() {
222-
// Decrement group index by n
223-
match bucket.as_ref().1.checked_sub(n) {
224-
// Group index was >= n, shift value down
225-
Some(sub) => bucket.as_mut().1 = sub,
226-
// Group index was < n, so remove from table
227-
None => self.map.erase(bucket),
219+
self.map.retain(|(_exists_hash, group_idx)| {
220+
// Decrement group index by n
221+
match group_idx.checked_sub(n) {
222+
// Group index was >= n, shift value down
223+
Some(sub) => {
224+
*group_idx = sub;
225+
true
228226
}
227+
// Group index was < n, so remove from table
228+
None => false,
229229
}
230-
}
230+
});
231231
output
232232
}
233233
};

datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use datafusion_common::Result;
2929
use datafusion_execution::memory_pool::proxy::VecAllocExt;
3030
use datafusion_expr::EmitTo;
3131
use half::f16;
32-
use hashbrown::raw::RawTable;
32+
use hashbrown::hash_table::HashTable;
3333
use std::mem::size_of;
3434
use std::sync::Arc;
3535

@@ -86,7 +86,7 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
8686
///
8787
/// We don't store the hashes as hashing fixed width primitives
8888
/// is fast enough for this not to benefit performance
89-
map: RawTable<usize>,
89+
map: HashTable<usize>,
9090
/// The group index of the null value if any
9191
null_group: Option<usize>,
9292
/// The values for each group index
@@ -100,7 +100,7 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
100100
assert!(PrimitiveArray::<T>::is_compatible(&data_type));
101101
Self {
102102
data_type,
103-
map: RawTable::with_capacity(128),
103+
map: HashTable::with_capacity(128),
104104
values: Vec::with_capacity(128),
105105
null_group: None,
106106
random_state: Default::default(),
@@ -126,22 +126,19 @@ where
126126
Some(key) => {
127127
let state = &self.random_state;
128128
let hash = key.hash(state);
129-
let insert = self.map.find_or_find_insert_slot(
129+
let insert = self.map.entry(
130130
hash,
131131
|g| unsafe { self.values.get_unchecked(*g).is_eq(key) },
132132
|g| unsafe { self.values.get_unchecked(*g).hash(state) },
133133
);
134134

135-
// SAFETY: No mutation occurred since find_or_find_insert_slot
136-
unsafe {
137-
match insert {
138-
Ok(v) => *v.as_ref(),
139-
Err(slot) => {
140-
let g = self.values.len();
141-
self.map.insert_in_slot(hash, slot, g);
142-
self.values.push(key);
143-
g
144-
}
135+
match insert {
136+
hashbrown::hash_table::Entry::Occupied(o) => *o.get(),
137+
hashbrown::hash_table::Entry::Vacant(v) => {
138+
let g = self.values.len();
139+
v.insert(g);
140+
self.values.push(key);
141+
g
145142
}
146143
}
147144
}
@@ -183,18 +180,18 @@ where
183180
build_primitive(std::mem::take(&mut self.values), self.null_group.take())
184181
}
185182
EmitTo::First(n) => {
186-
// SAFETY: self.map outlives iterator and is not modified concurrently
187-
unsafe {
188-
for bucket in self.map.iter() {
189-
// Decrement group index by n
190-
match bucket.as_ref().checked_sub(n) {
191-
// Group index was >= n, shift value down
192-
Some(sub) => *bucket.as_mut() = sub,
193-
// Group index was < n, so remove from table
194-
None => self.map.erase(bucket),
183+
self.map.retain(|group_idx| {
184+
// Decrement group index by n
185+
match group_idx.checked_sub(n) {
186+
// Group index was >= n, shift value down
187+
Some(sub) => {
188+
*group_idx = sub;
189+
true
195190
}
191+
// Group index was < n, so remove from table
192+
None => false,
196193
}
197-
}
194+
});
198195
let null_group = match &mut self.null_group {
199196
Some(v) if *v >= n => {
200197
*v -= n;

0 commit comments

Comments
 (0)