Skip to content

Commit 1878f59

Browse files
committed
Add hook for per-file filter pushdown rewrites
This enables use cases like data shredding and pre-computed filter expressions
1 parent 9382add commit 1878f59

File tree

8 files changed

+657
-38
lines changed

8 files changed

+657
-38
lines changed
Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray};
22+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
23+
use arrow_schema::Fields;
24+
use async_trait::async_trait;
25+
26+
use datafusion::assert_batches_eq;
27+
use datafusion::catalog::{Session, TableProvider};
28+
use datafusion::common::tree_node::{
29+
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
30+
};
31+
use datafusion::common::{assert_contains, DFSchema, Result};
32+
use datafusion::datasource::file_expr_rewriter::FileExpressionRewriter;
33+
use datafusion::datasource::listing::PartitionedFile;
34+
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetSource};
35+
use datafusion::execution::context::SessionContext;
36+
use datafusion::execution::object_store::ObjectStoreUrl;
37+
use datafusion::logical_expr::utils::conjunction;
38+
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType};
39+
use datafusion::parquet::arrow::ArrowWriter;
40+
use datafusion::parquet::file::properties::WriterProperties;
41+
use datafusion::physical_expr::PhysicalExpr;
42+
use datafusion::physical_expr::{expressions, ScalarFunctionExpr};
43+
use datafusion::physical_plan::ExecutionPlan;
44+
use datafusion::prelude::lit;
45+
use futures::StreamExt;
46+
use object_store::memory::InMemory;
47+
use object_store::path::Path;
48+
use object_store::{ObjectStore, PutPayload};
49+
50+
// Example showing how to implement custom filter rewriting for struct fields.
51+
//
52+
// In this example, we have a table with a struct column like:
53+
// struct_col: {"a": 1, "b": "foo"}
54+
//
55+
// Our custom TableProvider will use a FilterExpressionRewriter to rewrite
56+
// expressions like `struct_col['a'] = 10` to use a flattened column name
57+
// `_struct_col.a` if it exists in the file schema.
58+
#[tokio::main]
59+
async fn main() -> Result<()> {
60+
println!("=== Creating example data with structs and flattened fields ===");
61+
62+
// Create sample data with both struct columns and flattened fields
63+
let (table_schema, batch) = create_sample_data();
64+
65+
let store = InMemory::new();
66+
let buf = {
67+
let mut buf = vec![];
68+
69+
let props = WriterProperties::builder()
70+
.set_max_row_group_size(1)
71+
.build();
72+
73+
let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props))
74+
.expect("creating writer");
75+
76+
writer.write(&batch).expect("Writing batch");
77+
writer.close().unwrap();
78+
buf
79+
};
80+
let path = Path::from("example.parquet");
81+
let payload = PutPayload::from_bytes(buf.into());
82+
store.put(&path, payload).await?;
83+
84+
// Create a custom table provider that rewrites struct field access
85+
let table_provider = Arc::new(ExampleTableProvider::new(table_schema));
86+
87+
// Set up query execution
88+
let ctx = SessionContext::new();
89+
90+
// Register our table
91+
ctx.register_table("structs", table_provider)?;
92+
93+
ctx.runtime_env().register_object_store(
94+
ObjectStoreUrl::parse("memory://")?.as_ref(),
95+
Arc::new(store),
96+
);
97+
98+
println!("\n=== Showing all data ===");
99+
let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?;
100+
arrow::util::pretty::print_batches(&batches)?;
101+
102+
println!("\n=== Running query with struct field access and filter < 30 ===");
103+
println!("Query: SELECT user_info['name'] FROM structs WHERE user_info['age'] < 30");
104+
105+
let batches = ctx
106+
.sql("SELECT user_info['name'] FROM structs WHERE user_info['age'] < 30 ORDER BY user_info['name']")
107+
.await?
108+
.collect()
109+
.await?;
110+
111+
#[rustfmt::skip]
112+
let expected = [
113+
"+-------------------------+",
114+
"| structs.user_info[name] |",
115+
"+-------------------------+",
116+
"| Bob |",
117+
"| Dave |",
118+
"+-------------------------+",
119+
];
120+
arrow::util::pretty::print_batches(&batches)?;
121+
assert_batches_eq!(expected, &batches);
122+
123+
println!("\n=== Running explain analyze to confirm row group pruning ===");
124+
125+
let batches = ctx
126+
.sql("EXPLAIN ANALYZE SELECT user_info['name'] FROM structs WHERE user_info['age'] < 30")
127+
.await?
128+
.collect()
129+
.await?;
130+
let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?);
131+
println!("{plan}");
132+
assert_contains!(&plan, "row_groups_pruned_statistics=2");
133+
134+
Ok(())
135+
}
136+
137+
/// Create the example data with both struct fields and flattened fields
138+
fn create_sample_data() -> (SchemaRef, RecordBatch) {
139+
// Create a schema with a struct column
140+
let user_info_fields = Fields::from(vec![
141+
Field::new("name", DataType::Utf8, false),
142+
Field::new("age", DataType::Int32, false),
143+
]);
144+
145+
let file_schema = Schema::new(vec![
146+
Field::new(
147+
"user_info",
148+
DataType::Struct(user_info_fields.clone()),
149+
false,
150+
),
151+
// Include flattened fields (in real scenarios these might be in some files but not others)
152+
Field::new("_user_info.age", DataType::Int32, true),
153+
]);
154+
155+
let table_schema = Schema::new(vec![Field::new(
156+
"user_info",
157+
DataType::Struct(user_info_fields.clone()),
158+
false,
159+
)]);
160+
161+
// Create struct array for user_info
162+
let names = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]);
163+
let ages = Int32Array::from(vec![30, 25, 35, 22]);
164+
165+
let user_info = StructArray::from(vec![
166+
(
167+
Arc::new(Field::new("name", DataType::Utf8, false)),
168+
Arc::new(names.clone()) as ArrayRef,
169+
),
170+
(
171+
Arc::new(Field::new("age", DataType::Int32, false)),
172+
Arc::new(ages.clone()) as ArrayRef,
173+
),
174+
]);
175+
176+
// Create a record batch with the data
177+
let batch = RecordBatch::try_new(
178+
Arc::new(file_schema.clone()),
179+
vec![
180+
Arc::new(user_info),
181+
Arc::new(ages), // Shredded age field
182+
],
183+
)
184+
.unwrap();
185+
186+
(Arc::new(table_schema), batch)
187+
}
188+
189+
/// Custom TableProvider that uses a StructFieldRewriter
190+
#[derive(Debug)]
191+
struct ExampleTableProvider {
192+
schema: SchemaRef,
193+
}
194+
195+
impl ExampleTableProvider {
196+
fn new(schema: SchemaRef) -> Self {
197+
Self { schema }
198+
}
199+
}
200+
201+
#[async_trait]
202+
impl TableProvider for ExampleTableProvider {
203+
fn as_any(&self) -> &dyn Any {
204+
self
205+
}
206+
207+
fn schema(&self) -> SchemaRef {
208+
self.schema.clone()
209+
}
210+
211+
fn table_type(&self) -> TableType {
212+
TableType::Base
213+
}
214+
215+
fn supports_filters_pushdown(
216+
&self,
217+
filters: &[&Expr],
218+
) -> Result<Vec<TableProviderFilterPushDown>> {
219+
// Implementers can choose to mark these filters as exact or inexact.
220+
// If marked as exact they cannot have false positives and must always be applied.
221+
// If marked as Inexact they can have false positives and at runtime the rewriter
222+
// can decide to not rewrite / ignore some filters since they will be re-evaluated upstream.
223+
// For the purposes of this example we mark them as Exact to demonstrate the rewriter is working and the filtering is not being re-evaluated upstream.
224+
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
225+
}
226+
227+
async fn scan(
228+
&self,
229+
state: &dyn Session,
230+
projection: Option<&Vec<usize>>,
231+
filters: &[Expr],
232+
limit: Option<usize>,
233+
) -> Result<Arc<dyn ExecutionPlan>> {
234+
let schema = self.schema.clone();
235+
let df_schema = DFSchema::try_from(schema.clone())?;
236+
let filter = state.create_physical_expr(
237+
conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)),
238+
&df_schema,
239+
)?;
240+
241+
let parquet_source = ParquetSource::default()
242+
.with_predicate(self.schema.clone(), filter)
243+
.with_pushdown_filters(true)
244+
// if the rewriter needs a reference to the table schema you can bind self.schema() here
245+
.with_filter_expression_rewriter(Arc::new(StructFieldRewriter) as _);
246+
247+
let object_store_url = ObjectStoreUrl::parse("memory://")?;
248+
249+
let store = state.runtime_env().object_store(object_store_url)?;
250+
251+
let mut files = vec![];
252+
let mut listing = store.list(None);
253+
while let Some(file) = listing.next().await {
254+
if let Ok(file) = file {
255+
files.push(file);
256+
}
257+
}
258+
259+
let file_group = files
260+
.iter()
261+
.map(|file| {
262+
PartitionedFile::new(
263+
file.location.clone(),
264+
u64::try_from(file.size).expect("fits in a u64"),
265+
)
266+
})
267+
.collect();
268+
269+
let file_scan_config = FileScanConfig::new(
270+
ObjectStoreUrl::parse("memory://")?,
271+
schema,
272+
Arc::new(parquet_source),
273+
)
274+
.with_projection(projection.cloned())
275+
.with_limit(limit)
276+
.with_file_group(file_group);
277+
278+
Ok(file_scan_config.build())
279+
}
280+
}
281+
282+
/// Rewriter that converts struct field access to flattened column references
283+
#[derive(Debug)]
284+
struct StructFieldRewriter;
285+
286+
impl FileExpressionRewriter for StructFieldRewriter {
287+
fn rewrite(
288+
&self,
289+
file_schema: SchemaRef,
290+
expr: Arc<dyn PhysicalExpr>,
291+
) -> Result<Arc<dyn PhysicalExpr>> {
292+
let mut rewrite = StructFieldRewriterImpl { file_schema };
293+
expr.rewrite(&mut rewrite).data()
294+
}
295+
}
296+
297+
struct StructFieldRewriterImpl {
298+
file_schema: SchemaRef,
299+
}
300+
301+
impl TreeNodeRewriter for StructFieldRewriterImpl {
302+
type Node = Arc<dyn PhysicalExpr>;
303+
304+
fn f_down(
305+
&mut self,
306+
expr: Arc<dyn PhysicalExpr>,
307+
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
308+
if let Some(scalar_function) = expr.as_any().downcast_ref::<ScalarFunctionExpr>()
309+
{
310+
if scalar_function.name() == "get_field" && scalar_function.args().len() == 2
311+
{
312+
// First argument is the column, second argument is the field name
313+
let column = scalar_function.args()[0].clone();
314+
let field_name = scalar_function.args()[1].clone();
315+
if let Some(literal) =
316+
field_name.as_any().downcast_ref::<expressions::Literal>()
317+
{
318+
if let Some(field_name) = literal.value().try_as_str().flatten() {
319+
if let Some(column) =
320+
column.as_any().downcast_ref::<expressions::Column>()
321+
{
322+
let column_name = column.name();
323+
let source_field =
324+
self.file_schema.field_with_name(column_name)?;
325+
let expected_flattened_column_name =
326+
format!("_{}.{}", column_name, field_name);
327+
if let DataType::Struct(struct_fields) =
328+
source_field.data_type()
329+
{
330+
// Check if the flattened column exists in the file schema and has the same type
331+
if let Ok(shredded_field) = self
332+
.file_schema
333+
.field_with_name(&expected_flattened_column_name)
334+
{
335+
if let Some((_, struct_field)) =
336+
struct_fields.find(field_name)
337+
{
338+
if struct_field.data_type()
339+
== shredded_field.data_type()
340+
{
341+
// Rewrite the expression to use the flattened column
342+
let rewritten_expr = expressions::col(
343+
&expected_flattened_column_name,
344+
&self.file_schema,
345+
)?;
346+
return Ok(Transformed::yes(rewritten_expr));
347+
}
348+
}
349+
}
350+
}
351+
// Check if the flattened column exists in the file schema and has the same type
352+
if let Ok(shredded_field) = self
353+
.file_schema
354+
.field_with_name(&expected_flattened_column_name)
355+
{
356+
if source_field.data_type() == shredded_field.data_type()
357+
{
358+
// Rewrite the expression to use the flattened column
359+
let rewritten_expr = expressions::col(
360+
&expected_flattened_column_name,
361+
&self.file_schema,
362+
)?;
363+
return Ok(Transformed::yes(rewritten_expr));
364+
}
365+
}
366+
}
367+
}
368+
}
369+
}
370+
}
371+
372+
Ok(Transformed::no(expr))
373+
}
374+
}

datafusion-testing

Submodule datafusion-testing updated 259 files

datafusion/core/src/datasource/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ mod statistics;
3333
pub mod stream;
3434
pub mod view;
3535

36+
pub use datafusion_datasource::file_expr_rewriter;
3637
pub use datafusion_datasource::schema_adapter;
3738
pub use datafusion_datasource::source;
3839

0 commit comments

Comments
 (0)