Skip to content

Commit 70c5ffd

Browse files
committed
TopK dynamic filter pushdown attempt 2
1 parent 55ba4ca commit 70c5ffd

File tree

14 files changed

+1405
-12
lines changed

14 files changed

+1405
-12
lines changed

datafusion/common/src/config.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,13 @@ config_namespace! {
612612
/// during aggregations, if possible
613613
pub enable_topk_aggregation: bool, default = true
614614

615+
/// When set to true attempts to push down dynamic filters generated by operators into the file scan phase.
616+
/// For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer
617+
/// will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans.
618+
/// This means that if we already have 10 timestamps in the year 2025
619+
/// any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan.
620+
pub enable_dynamic_filter_pushdown: bool, default = true
621+
615622
/// When set to true, the optimizer will insert filters before a join between
616623
/// a nullable and non-nullable column to filter out nulls on the nullable side. This
617624
/// filter can add additional overhead when the file format does not fully support

datafusion/core/tests/fuzz_cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod join_fuzz;
2121
mod merge_fuzz;
2222
mod sort_fuzz;
2323
mod sort_query_fuzz;
24+
mod topk_filter_pushdown;
2425

2526
mod aggregation_fuzzer;
2627
mod equivalence;
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::collections::HashMap;
19+
use std::sync::{Arc, LazyLock};
20+
21+
use arrow::array::{Int32Array, StringArray, StringDictionaryBuilder};
22+
use arrow::datatypes::Int32Type;
23+
use arrow::record_batch::RecordBatch;
24+
use arrow::util::pretty::pretty_format_batches;
25+
use arrow_schema::{DataType, Field, Schema};
26+
use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig};
27+
use datafusion::prelude::{SessionConfig, SessionContext};
28+
use datafusion_datasource::ListingTableUrl;
29+
use datafusion_datasource_parquet::ParquetFormat;
30+
use datafusion_execution::object_store::ObjectStoreUrl;
31+
use itertools::Itertools;
32+
use object_store::memory::InMemory;
33+
use object_store::path::Path;
34+
use object_store::{ObjectStore, PutPayload};
35+
use parquet::arrow::ArrowWriter;
36+
use rand::rngs::StdRng;
37+
use rand::{Rng, SeedableRng};
38+
use tokio::sync::Mutex;
39+
use tokio::task::JoinSet;
40+
41+
#[derive(Clone)]
42+
struct TestDataSet {
43+
store: Arc<dyn ObjectStore>,
44+
schema: Arc<Schema>,
45+
}
46+
47+
/// List of in memory parquet files with UTF8 data
48+
// Use a mutex rather than LazyLock to allow for async initialization
49+
static TESTFILES: LazyLock<Mutex<Vec<TestDataSet>>> =
50+
LazyLock::new(|| Mutex::new(vec![]));
51+
52+
async fn test_files() -> Vec<TestDataSet> {
53+
let files_mutex = &TESTFILES;
54+
let mut files = files_mutex.lock().await;
55+
if !files.is_empty() {
56+
return (*files).clone();
57+
}
58+
59+
let mut rng = StdRng::seed_from_u64(0);
60+
61+
for nulls_in_ids in [false, true] {
62+
for nulls_in_names in [false, true] {
63+
for nulls_in_departments in [false, true] {
64+
let store = Arc::new(InMemory::new());
65+
66+
let schema = Arc::new(Schema::new(vec![
67+
Field::new("id", DataType::Int32, nulls_in_ids),
68+
Field::new("name", DataType::Utf8, nulls_in_names),
69+
Field::new(
70+
"department",
71+
DataType::Dictionary(
72+
Box::new(DataType::Int32),
73+
Box::new(DataType::Utf8),
74+
),
75+
nulls_in_departments,
76+
),
77+
]));
78+
79+
let name_choices = if nulls_in_names {
80+
[Some("Alice"), Some("Bob"), None, Some("David"), None]
81+
} else {
82+
[
83+
Some("Alice"),
84+
Some("Bob"),
85+
Some("Charlie"),
86+
Some("David"),
87+
Some("Eve"),
88+
]
89+
};
90+
91+
let department_choices = if nulls_in_departments {
92+
[
93+
Some("Theater"),
94+
Some("Engineering"),
95+
None,
96+
Some("Arts"),
97+
None,
98+
]
99+
} else {
100+
[
101+
Some("Theater"),
102+
Some("Engineering"),
103+
Some("Healthcare"),
104+
Some("Arts"),
105+
Some("Music"),
106+
]
107+
};
108+
109+
// Generate 5 files, some with overlapping or repeated ids some without
110+
for i in 0..5 {
111+
let num_batches = rng.gen_range(1..3);
112+
let mut batches = Vec::with_capacity(num_batches);
113+
for _ in 0..num_batches {
114+
let num_rows = 25;
115+
let ids = Int32Array::from_iter((0..num_rows).map(|file| {
116+
if nulls_in_ids {
117+
if rng.gen_bool(1.0 / 10.0) {
118+
None
119+
} else {
120+
Some(rng.gen_range(file..file + 5))
121+
}
122+
} else {
123+
Some(rng.gen_range(file..file + 5))
124+
}
125+
}));
126+
let names = StringArray::from_iter((0..num_rows).map(|_| {
127+
// randomly select a name
128+
let idx = rng.gen_range(0..name_choices.len());
129+
name_choices[idx].map(|s| s.to_string())
130+
}));
131+
let mut departments = StringDictionaryBuilder::<Int32Type>::new();
132+
for _ in 0..num_rows {
133+
// randomly select a department
134+
let idx = rng.gen_range(0..department_choices.len());
135+
departments.append_option(department_choices[idx].as_ref());
136+
}
137+
let batch = RecordBatch::try_new(
138+
schema.clone(),
139+
vec![
140+
Arc::new(ids),
141+
Arc::new(names),
142+
Arc::new(departments.finish()),
143+
],
144+
)
145+
.unwrap();
146+
batches.push(batch);
147+
}
148+
let mut buf = vec![];
149+
{
150+
let mut writer =
151+
ArrowWriter::try_new(&mut buf, schema.clone(), None).unwrap();
152+
for batch in batches {
153+
writer.write(&batch).unwrap();
154+
writer.flush().unwrap();
155+
}
156+
writer.flush().unwrap();
157+
writer.finish().unwrap();
158+
}
159+
let payload = PutPayload::from(buf);
160+
let path = Path::from(format!("file_{i}.parquet"));
161+
store.put(&path, payload).await.unwrap();
162+
}
163+
files.push(TestDataSet { store, schema });
164+
}
165+
}
166+
}
167+
(*files).clone()
168+
}
169+
170+
async fn run_query_with_config(
171+
query: &str,
172+
config: SessionConfig,
173+
dataset: TestDataSet,
174+
) -> Vec<RecordBatch> {
175+
let store = dataset.store;
176+
let schema = dataset.schema;
177+
let ctx = SessionContext::new_with_config(config);
178+
let url = ObjectStoreUrl::parse("memory://").unwrap();
179+
ctx.register_object_store(url.as_ref(), store.clone());
180+
181+
let format = Arc::new(
182+
ParquetFormat::default()
183+
.with_options(ctx.state().table_options().parquet.clone()),
184+
);
185+
let options = ListingOptions::new(format);
186+
let table_path = ListingTableUrl::parse("memory:///").unwrap();
187+
let config = ListingTableConfig::new(table_path)
188+
.with_listing_options(options)
189+
.with_schema(schema);
190+
let table = Arc::new(ListingTable::try_new(config).unwrap());
191+
192+
ctx.register_table("test_table", table).unwrap();
193+
194+
ctx.sql(query).await.unwrap().collect().await.unwrap()
195+
}
196+
197+
#[derive(Debug)]
198+
struct RunQueryResult {
199+
query: String,
200+
result: Vec<RecordBatch>,
201+
expected: Vec<RecordBatch>,
202+
}
203+
204+
impl RunQueryResult {
205+
fn expected_formated(&self) -> String {
206+
format!("{}", pretty_format_batches(&self.expected).unwrap())
207+
}
208+
209+
fn result_formated(&self) -> String {
210+
format!("{}", pretty_format_batches(&self.result).unwrap())
211+
}
212+
213+
fn is_ok(&self) -> bool {
214+
self.expected_formated() == self.result_formated()
215+
}
216+
}
217+
218+
async fn run_query(
219+
query: String,
220+
cfg: SessionConfig,
221+
dataset: TestDataSet,
222+
) -> RunQueryResult {
223+
let cfg_with_dynamic_filters = cfg
224+
.clone()
225+
.set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true);
226+
let cfg_without_dynamic_filters = cfg
227+
.clone()
228+
.set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", false);
229+
230+
let expected_result =
231+
run_query_with_config(&query, cfg_without_dynamic_filters, dataset.clone()).await;
232+
let result =
233+
run_query_with_config(&query, cfg_with_dynamic_filters, dataset.clone()).await;
234+
235+
RunQueryResult {
236+
query: query.to_string(),
237+
result,
238+
expected: expected_result,
239+
}
240+
}
241+
242+
struct TestCase {
243+
query: String,
244+
cfg: SessionConfig,
245+
dataset: TestDataSet,
246+
}
247+
248+
#[tokio::test(flavor = "multi_thread")]
249+
async fn test_fuzz_topk_filter_pushdown() {
250+
let order_columns = ["id", "name", "department"];
251+
let order_directions = ["ASC", "DESC"];
252+
let null_orders = ["NULLS FIRST", "NULLS LAST"];
253+
254+
let start = datafusion_common::instant::Instant::now();
255+
let mut orders: HashMap<String, Vec<String>> = HashMap::new();
256+
for order_column in &order_columns {
257+
for order_direction in &order_directions {
258+
for null_order in &null_orders {
259+
// if there is a vec for this column insert the order, otherwise create a new vec
260+
let ordering =
261+
format!("{} {} {}", order_column, order_direction, null_order);
262+
match orders.get_mut(*order_column) {
263+
Some(order_vec) => {
264+
order_vec.push(ordering);
265+
}
266+
None => {
267+
orders.insert(order_column.to_string(), vec![ordering]);
268+
}
269+
}
270+
}
271+
}
272+
}
273+
274+
let mut queries = vec![];
275+
276+
for limit in [1, 10] {
277+
for num_order_by_columns in [1, 2, 3] {
278+
for order_columns in ["id", "name", "department"]
279+
.iter()
280+
.combinations(num_order_by_columns)
281+
{
282+
for orderings in order_columns
283+
.iter()
284+
.map(|col| orders.get(**col).unwrap())
285+
.multi_cartesian_product()
286+
{
287+
let query = format!(
288+
"SELECT * FROM test_table ORDER BY {} LIMIT {}",
289+
orderings.into_iter().join(", "),
290+
limit
291+
);
292+
queries.push(query);
293+
}
294+
}
295+
}
296+
}
297+
298+
queries.sort_unstable();
299+
println!(
300+
"Generated {} queries in {:?}",
301+
queries.len(),
302+
start.elapsed()
303+
);
304+
305+
let start = datafusion_common::instant::Instant::now();
306+
let datasets = test_files().await;
307+
println!("Generated test files in {:?}", start.elapsed());
308+
309+
let mut test_cases = vec![];
310+
for enable_filter_pushdown in [true, false] {
311+
for query in &queries {
312+
for dataset in &datasets {
313+
let mut cfg = SessionConfig::new();
314+
cfg = cfg.set_bool(
315+
"datafusion.optimizer.enable_dynamic_filter_pushdown",
316+
enable_filter_pushdown,
317+
);
318+
test_cases.push(TestCase {
319+
query: query.to_string(),
320+
cfg,
321+
dataset: dataset.clone(),
322+
});
323+
}
324+
}
325+
}
326+
327+
let start = datafusion_common::instant::Instant::now();
328+
let mut join_set = JoinSet::new();
329+
for tc in test_cases {
330+
join_set.spawn(run_query(tc.query, tc.cfg, tc.dataset));
331+
}
332+
let mut results = join_set.join_all().await;
333+
results.sort_unstable_by(|a, b| a.query.cmp(&b.query));
334+
println!("Ran {} test cases in {:?}", results.len(), start.elapsed());
335+
336+
let failures = results
337+
.iter()
338+
.filter(|result| !result.is_ok())
339+
.collect::<Vec<_>>();
340+
341+
for failure in &failures {
342+
println!("Failure:");
343+
println!("Query:\n{}", failure.query);
344+
println!("\nExpected:\n{}", failure.expected_formated());
345+
println!("\nResult:\n{}", failure.result_formated());
346+
println!("\n\n");
347+
}
348+
349+
if !failures.is_empty() {
350+
panic!("Some test cases failed");
351+
} else {
352+
println!("All test cases passed");
353+
}
354+
}

0 commit comments

Comments
 (0)