Skip to content

Commit 8de6063

Browse files
committed
always require a filter in topk
1 parent 4084894 commit 8de6063

File tree

2 files changed

+155
-58
lines changed

2 files changed

+155
-58
lines changed

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use crate::spill::get_record_batch_memory_size;
3838
use crate::spill::in_progress_spill_file::InProgressSpillFile;
3939
use crate::spill::spill_manager::SpillManager;
4040
use crate::stream::RecordBatchStreamAdapter;
41-
use crate::topk::TopK;
41+
use crate::topk::{TopK, TopKDynamicFilters};
4242
use crate::{
4343
DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan,
4444
ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream,
@@ -846,8 +846,10 @@ pub struct SortExec {
846846
common_sort_prefix: Vec<PhysicalSortExpr>,
847847
/// Cache holding plan properties like equivalences, output partitioning etc.
848848
cache: PlanProperties,
849-
/// Filter matching the state of the sort for dynamic filter pushdown
850-
filter: Option<Arc<DynamicFilterPhysicalExpr>>,
849+
/// Filter matching the state of the sort for dynamic filter pushdown.
850+
/// If `fetch` is `Some`, this will also be set and a TopK operator may be used.
851+
/// If `fetch` is `None`, this will be `None`.
852+
filter: Option<TopKDynamicFilters>,
851853
}
852854

853855
impl SortExec {
@@ -920,7 +922,10 @@ impl SortExec {
920922
.iter()
921923
.map(|sort_expr| Arc::clone(&sort_expr.expr))
922924
.collect::<Vec<_>>();
923-
Arc::new(DynamicFilterPhysicalExpr::new(children, lit(true)))
925+
TopKDynamicFilters::new(Arc::new(DynamicFilterPhysicalExpr::new(
926+
children,
927+
lit(true),
928+
)))
924929
})
925930
});
926931
SortExec {
@@ -935,11 +940,6 @@ impl SortExec {
935940
}
936941
}
937942

938-
pub fn with_filter(mut self, filter: Arc<DynamicFilterPhysicalExpr>) -> Self {
939-
self.filter = Some(filter);
940-
self
941-
}
942-
943943
/// Input schema
944944
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
945945
&self.input
@@ -955,11 +955,6 @@ impl SortExec {
955955
self.fetch
956956
}
957957

958-
/// If `Some(filter)`, returns the filter expression that matches the state of the sort.
959-
pub fn filter(&self) -> Option<Arc<DynamicFilterPhysicalExpr>> {
960-
self.filter.clone()
961-
}
962-
963958
fn output_partitioning_helper(
964959
input: &Arc<dyn ExecutionPlan>,
965960
preserve_partitioning: bool,
@@ -1038,7 +1033,7 @@ impl DisplayAs for SortExec {
10381033
Some(fetch) => {
10391034
write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr)?;
10401035
if let Some(filter) = &self.filter {
1041-
if let Ok(current) = filter.current() {
1036+
if let Ok(current) = filter.expr().current() {
10421037
if !current.eq(&lit(true)) {
10431038
write!(f, ", filter=[{current}]")?;
10441039
}
@@ -1158,7 +1153,10 @@ impl ExecutionPlan for SortExec {
11581153
context.session_config().batch_size(),
11591154
context.runtime_env(),
11601155
&self.metrics_set,
1161-
self.filter.clone(),
1156+
self.filter
1157+
.as_ref()
1158+
.expect("Filter should be set when fetch is Some")
1159+
.clone(),
11621160
)?;
11631161
Ok(Box::pin(RecordBatchStreamAdapter::new(
11641162
self.schema(),
@@ -1278,10 +1276,9 @@ impl ExecutionPlan for SortExec {
12781276
}
12791277
if let Some(filter) = &self.filter {
12801278
if config.optimizer.enable_dynamic_filter_pushdown {
1281-
let filter = Arc::clone(filter) as Arc<dyn PhysicalExpr>;
12821279
return Ok(FilterDescription::new_with_child_count(1)
12831280
.all_parent_filters_supported(parent_filters)
1284-
.with_self_filter(filter));
1281+
.with_self_filter(filter.expr()));
12851282
}
12861283
}
12871284
Ok(FilterDescription::new_with_child_count(1)

datafusion/physical-plan/src/topk/mod.rs

Lines changed: 140 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::{
2323
row::{RowConverter, Rows, SortField},
2424
};
2525
use datafusion_expr::{ColumnarValue, Operator};
26+
use parking_lot::RwLock;
2627
use std::mem::size_of;
2728
use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
2829

@@ -121,13 +122,36 @@ pub struct TopK {
121122
/// Common sort prefix between the input and the sort expressions to allow early exit optimization
122123
common_sort_prefix: Arc<[PhysicalSortExpr]>,
123124
/// Filter matching the state of the `TopK` heap used for dynamic filter pushdown
124-
filter: Option<Arc<DynamicFilterPhysicalExpr>>,
125+
filter: TopKDynamicFilters,
125126
/// If true, indicates that all rows of subsequent batches are guaranteed
126127
/// to be greater (by byte order, after row conversion) than the top K,
127128
/// which means the top K won't change and the computation can be finished early.
128129
pub(crate) finished: bool,
129130
}
130131

132+
#[derive(Debug, Clone)]
133+
pub struct TopKDynamicFilters {
134+
/// The current *global* threshold for the dynamic filter.
135+
/// This is shared across all partitions and is updated by any of them.
136+
thresholds: Arc<RwLock<Option<Vec<ScalarValue>>>>,
137+
/// The expression used to evaluate the dynamic filter
138+
expr: Arc<DynamicFilterPhysicalExpr>,
139+
}
140+
141+
impl TopKDynamicFilters {
142+
/// Create a new `TopKDynamicFilters` with the given expression
143+
pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
144+
Self {
145+
thresholds: Arc::new(RwLock::new(None)),
146+
expr,
147+
}
148+
}
149+
150+
pub fn expr(&self) -> Arc<DynamicFilterPhysicalExpr> {
151+
Arc::clone(&self.expr)
152+
}
153+
}
154+
131155
// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter
132156
const ESTIMATED_BYTES_PER_ROW: usize = 20;
133157

@@ -160,7 +184,7 @@ impl TopK {
160184
batch_size: usize,
161185
runtime: Arc<RuntimeEnv>,
162186
metrics: &ExecutionPlanMetricsSet,
163-
filter: Option<Arc<DynamicFilterPhysicalExpr>>,
187+
filter: TopKDynamicFilters,
164188
) -> Result<Self> {
165189
let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
166190
.register(&runtime.memory_pool);
@@ -214,41 +238,39 @@ impl TopK {
214238

215239
let mut selected_rows = None;
216240

217-
if let Some(filter) = self.filter.as_ref() {
218-
// If a filter is provided, update it with the new rows
219-
let filter = filter.current()?;
220-
let filtered = filter.evaluate(&batch)?;
221-
let num_rows = batch.num_rows();
222-
let array = filtered.into_array(num_rows)?;
223-
let mut filter = array.as_boolean().clone();
224-
let true_count = filter.true_count();
225-
if true_count == 0 {
226-
// nothing to filter, so no need to update
227-
return Ok(());
241+
// If a filter is provided, update it with the new rows
242+
let filter = self.filter.expr.current()?;
243+
let filtered = filter.evaluate(&batch)?;
244+
let num_rows = batch.num_rows();
245+
let array = filtered.into_array(num_rows)?;
246+
let mut filter = array.as_boolean().clone();
247+
let true_count = filter.true_count();
248+
if true_count == 0 {
249+
// nothing to filter, so no need to update
250+
return Ok(());
251+
}
252+
// only update the keys / rows if the filter does not match all rows
253+
if true_count < num_rows {
254+
// Indices in `set_indices` should be correct if filter contains nulls
255+
// So we prepare the filter here. Note this is also done in the `FilterBuilder`
256+
// so there is no overhead to do this here.
257+
if filter.nulls().is_some() {
258+
filter = prep_null_mask_filter(&filter);
228259
}
229-
// only update the keys / rows if the filter does not match all rows
230-
if true_count < num_rows {
231-
// Indices in `set_indices` should be correct if filter contains nulls
232-
// So we prepare the filter here. Note this is also done in the `FilterBuilder`
233-
// so there is no overhead to do this here.
234-
if filter.nulls().is_some() {
235-
filter = prep_null_mask_filter(&filter);
236-
}
237260

238-
let filter_predicate = FilterBuilder::new(&filter);
239-
let filter_predicate = if sort_keys.len() > 1 {
240-
// Optimize filter when it has multiple sort keys
241-
filter_predicate.optimize().build()
242-
} else {
243-
filter_predicate.build()
244-
};
245-
selected_rows = Some(filter);
246-
sort_keys = sort_keys
247-
.iter()
248-
.map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
249-
.collect::<Result<Vec<_>>>()?;
250-
}
251-
};
261+
let filter_predicate = FilterBuilder::new(&filter);
262+
let filter_predicate = if sort_keys.len() > 1 {
263+
// Optimize filter when it has multiple sort keys
264+
filter_predicate.optimize().build()
265+
} else {
266+
filter_predicate.build()
267+
};
268+
selected_rows = Some(filter);
269+
sort_keys = sort_keys
270+
.iter()
271+
.map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
272+
.collect::<Result<Vec<_>>>()?;
273+
}
252274
// reuse existing `Rows` to avoid reallocations
253275
let rows = &mut self.scratch_rows;
254276
rows.clear();
@@ -319,13 +341,88 @@ impl TopK {
319341
/// (a > 2 OR (a = 2 AND b < 3))
320342
/// ```
321343
fn update_filter(&mut self) -> Result<()> {
322-
let Some(filter) = &self.filter else {
323-
return Ok(());
324-
};
325344
let Some(thresholds) = self.heap.get_threshold_values(&self.expr)? else {
326345
return Ok(());
327346
};
328347

348+
// Are the new thresholds more selective than our existing ones?
349+
let should_update = {
350+
if let Some(current) = self.filter.thresholds.write().as_mut() {
351+
assert!(current.len() == thresholds.len());
352+
// Check if new thresholds are more selective than current ones
353+
let mut more_selective = false;
354+
for ((current_value, new_value), sort_expr) in
355+
current.iter().zip(thresholds.iter()).zip(self.expr.iter())
356+
{
357+
// Handle null cases
358+
let (current_is_null, new_is_null) =
359+
(current_value.is_null(), new_value.is_null());
360+
361+
match (current_is_null, new_is_null) {
362+
(true, true) => {
363+
// Both null, continue checking next values
364+
}
365+
(true, false) => {
366+
// Current is null, new is not null
367+
// For nulls_first: null < non-null, so new value is less selective
368+
// For nulls_last: null > non-null, so new value is more selective
369+
more_selective = !sort_expr.options.nulls_first;
370+
break;
371+
}
372+
(false, true) => {
373+
// Current is not null, new is null
374+
// For nulls_first: non-null > null, so new value is more selective
375+
// For nulls_last: non-null < null, so new value is less selective
376+
more_selective = sort_expr.options.nulls_first;
377+
break;
378+
}
379+
(false, false) => {
380+
// Neither is null, compare values
381+
match current_value.partial_cmp(new_value) {
382+
Some(ordering) => {
383+
match ordering {
384+
Ordering::Equal => {
385+
// Continue checking next values
386+
}
387+
Ordering::Less => {
388+
// For descending sort: new > current means more selective
389+
// For ascending sort: new > current means less selective
390+
more_selective = sort_expr.options.descending;
391+
break;
392+
}
393+
Ordering::Greater => {
394+
// For descending sort: new < current means less selective
395+
// For ascending sort: new < current means more selective
396+
more_selective =
397+
!sort_expr.options.descending;
398+
break;
399+
}
400+
}
401+
}
402+
None => {
403+
// If values can't be compared, don't update
404+
more_selective = false;
405+
break;
406+
}
407+
}
408+
}
409+
}
410+
}
411+
// If the new thresholds are more selective, update the current ones
412+
if more_selective {
413+
*current = thresholds.clone();
414+
}
415+
more_selective
416+
} else {
417+
// No current thresholds, so update with the new ones
418+
true
419+
}
420+
};
421+
422+
if !should_update {
423+
return Ok(());
424+
}
425+
329426
// Create filter expressions for each threshold
330427
let mut filters: Vec<Arc<dyn PhysicalExpr>> =
331428
Vec::with_capacity(thresholds.len());
@@ -405,7 +502,7 @@ impl TopK {
405502

406503
if let Some(predicate) = dynamic_predicate {
407504
if !predicate.eq(&lit(true)) {
408-
filter.update(predicate)?;
505+
self.filter.expr.update(predicate)?;
409506
}
410507
}
411508

@@ -1053,7 +1150,10 @@ mod tests {
10531150
2,
10541151
runtime,
10551152
&metrics,
1056-
None,
1153+
TopKDynamicFilters::new(Arc::new(DynamicFilterPhysicalExpr::new(
1154+
vec![],
1155+
lit(true),
1156+
))),
10571157
)?;
10581158

10591159
// Create the first batch with two columns:

0 commit comments

Comments
 (0)