Skip to content

Commit 821e14c

Browse files
authored
Reduce code duplication in PrimitiveGroupValueBuilder with const generics (#12703)
* Reduce code duplication in `PrimitiveGroupValueBuilder` with const generics * Fix docs
1 parent c45fc41 commit 821e14c

File tree

2 files changed

+40
-74
lines changed

2 files changed

+40
-74
lines changed

datafusion/physical-plan/src/aggregates/group_values/column.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
// under the License.
1717

1818
use crate::aggregates::group_values::group_column::{
19-
ByteGroupValueBuilder, GroupColumn, NonNullPrimitiveGroupValueBuilder,
20-
PrimitiveGroupValueBuilder,
19+
ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder,
2120
};
2221
use crate::aggregates::group_values::GroupValues;
2322
use ahash::RandomState;
@@ -124,8 +123,7 @@ impl GroupValuesColumn {
124123
}
125124
}
126125

127-
/// instantiates a [`PrimitiveGroupValueBuilder`] or
128-
/// [`NonNullPrimitiveGroupValueBuilder`] and pushes it into $v
126+
/// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v
129127
///
130128
/// Arguments:
131129
/// `$v`: the vector to push the new builder into
@@ -135,10 +133,10 @@ impl GroupValuesColumn {
135133
macro_rules! instantiate_primitive {
136134
($v:expr, $nullable:expr, $t:ty) => {
137135
if $nullable {
138-
let b = PrimitiveGroupValueBuilder::<$t>::new();
136+
let b = PrimitiveGroupValueBuilder::<$t, true>::new();
139137
$v.push(Box::new(b) as _)
140138
} else {
141-
let b = NonNullPrimitiveGroupValueBuilder::<$t>::new();
139+
let b = PrimitiveGroupValueBuilder::<$t, false>::new();
142140
$v.push(Box::new(b) as _)
143141
}
144142
};

datafusion/physical-plan/src/aggregates/group_values/group_column.rs

Lines changed: 36 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -60,75 +60,25 @@ pub trait GroupColumn: Send + Sync {
6060
fn take_n(&mut self, n: usize) -> ArrayRef;
6161
}
6262

63-
/// An implementation of [`GroupColumn`] for primitive values which are known to have no nulls
64-
#[derive(Debug)]
65-
pub struct NonNullPrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
66-
group_values: Vec<T::Native>,
67-
}
68-
69-
impl<T> NonNullPrimitiveGroupValueBuilder<T>
70-
where
71-
T: ArrowPrimitiveType,
72-
{
73-
pub fn new() -> Self {
74-
Self {
75-
group_values: vec![],
76-
}
77-
}
78-
}
79-
80-
impl<T: ArrowPrimitiveType> GroupColumn for NonNullPrimitiveGroupValueBuilder<T> {
81-
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
82-
// know input has no nulls
83-
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
84-
}
85-
86-
fn append_val(&mut self, array: &ArrayRef, row: usize) {
87-
// input can't possibly have nulls, so don't worry about them
88-
self.group_values.push(array.as_primitive::<T>().value(row))
89-
}
90-
91-
fn len(&self) -> usize {
92-
self.group_values.len()
93-
}
94-
95-
fn size(&self) -> usize {
96-
self.group_values.allocated_size()
97-
}
98-
99-
fn build(self: Box<Self>) -> ArrayRef {
100-
let Self { group_values } = *self;
101-
102-
let nulls = None;
103-
104-
Arc::new(PrimitiveArray::<T>::new(
105-
ScalarBuffer::from(group_values),
106-
nulls,
107-
))
108-
}
109-
110-
fn take_n(&mut self, n: usize) -> ArrayRef {
111-
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
112-
let first_n_nulls = None;
113-
114-
Arc::new(PrimitiveArray::<T>::new(
115-
ScalarBuffer::from(first_n),
116-
first_n_nulls,
117-
))
118-
}
119-
}
120-
121-
/// An implementation of [`GroupColumn`] for primitive values which may have nulls
63+
/// An implementation of [`GroupColumn`] for primitive values
64+
///
65+
/// Optimized to skip null buffer construction if the input is known to be non nullable
66+
///
67+
/// # Template parameters
68+
///
69+
/// `T`: the native Rust type that stores the data
70+
/// `NULLABLE`: if the data can contain any nulls
12271
#[derive(Debug)]
123-
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
72+
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType, const NULLABLE: bool> {
12473
group_values: Vec<T::Native>,
12574
nulls: MaybeNullBufferBuilder,
12675
}
12776

128-
impl<T> PrimitiveGroupValueBuilder<T>
77+
impl<T, const NULLABLE: bool> PrimitiveGroupValueBuilder<T, NULLABLE>
12978
where
13079
T: ArrowPrimitiveType,
13180
{
81+
/// Create a new `PrimitiveGroupValueBuilder`
13282
pub fn new() -> Self {
13383
Self {
13484
group_values: vec![],
@@ -137,18 +87,32 @@ where
13787
}
13888
}
13989

140-
impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
90+
impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
91+
for PrimitiveGroupValueBuilder<T, NULLABLE>
92+
{
14193
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
142-
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
94+
// Perf: skip null check (by short circuit) if input is not ullable
95+
let null_match = if NULLABLE {
96+
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
97+
} else {
98+
true
99+
};
100+
101+
null_match
143102
&& self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
144103
}
145104

146105
fn append_val(&mut self, array: &ArrayRef, row: usize) {
147-
if array.is_null(row) {
148-
self.nulls.append(true);
149-
self.group_values.push(T::default_value());
106+
// Perf: skip null check if input can't have nulls
107+
if NULLABLE {
108+
if array.is_null(row) {
109+
self.nulls.append(true);
110+
self.group_values.push(T::default_value());
111+
} else {
112+
self.nulls.append(false);
113+
self.group_values.push(array.as_primitive::<T>().value(row));
114+
}
150115
} else {
151-
self.nulls.append(false);
152116
self.group_values.push(array.as_primitive::<T>().value(row));
153117
}
154118
}
@@ -168,6 +132,9 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
168132
} = *self;
169133

170134
let nulls = nulls.build();
135+
if !NULLABLE {
136+
assert!(nulls.is_none(), "unexpected nulls in non nullable input");
137+
}
171138

172139
Arc::new(PrimitiveArray::<T>::new(
173140
ScalarBuffer::from(group_values),
@@ -177,7 +144,8 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
177144

178145
fn take_n(&mut self, n: usize) -> ArrayRef {
179146
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
180-
let first_n_nulls = self.nulls.take_n(n);
147+
148+
let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None };
181149

182150
Arc::new(PrimitiveArray::<T>::new(
183151
ScalarBuffer::from(first_n),

0 commit comments

Comments
 (0)