Skip to content

Commit 32ecce4

Browse files
committed
Support User Defined Table Function
Signed-off-by: veeupup <[email protected]>
1 parent 393e48f commit 32ecce4

File tree

6 files changed

+381
-20
lines changed

6 files changed

+381
-20
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::csv::reader::Format;
19+
use arrow::csv::ReaderBuilder;
20+
use async_trait::async_trait;
21+
use datafusion::arrow::datatypes::SchemaRef;
22+
use datafusion::arrow::record_batch::RecordBatch;
23+
use datafusion::datasource::function::TableFunctionImpl;
24+
use datafusion::datasource::streaming::StreamingTable;
25+
use datafusion::datasource::TableProvider;
26+
use datafusion::error::Result;
27+
use datafusion::execution::context::SessionState;
28+
use datafusion::execution::TaskContext;
29+
use datafusion::physical_plan::memory::MemoryExec;
30+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
31+
use datafusion::physical_plan::streaming::PartitionStream;
32+
use datafusion::physical_plan::{collect, ExecutionPlan};
33+
use datafusion::prelude::SessionContext;
34+
use datafusion_common::{DFSchema, ScalarValue};
35+
use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType};
36+
use std::fs::File;
37+
use std::io::Seek;
38+
use std::path::Path;
39+
use std::sync::Arc;
40+
41+
// To define your own table function, you only need to do the following 3 things:
42+
// 1. Implement your own TableProvider
43+
// 2. Implement your own TableFunctionImpl and return your TableProvider
44+
// 3. Register the function using ctx.register_udtf
45+
46+
/// This example demonstrates how to register a TableFunction
47+
#[tokio::main]
48+
async fn main() -> Result<()> {
49+
// create local execution context
50+
let ctx = SessionContext::new();
51+
52+
ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {}));
53+
ctx.register_udtf("read_csv_stream", Arc::new(LocalStreamCsvTable {}));
54+
55+
let testdata = datafusion::test_util::arrow_test_data();
56+
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");
57+
58+
// run it with println now()
59+
let df = ctx
60+
.sql(format!("SELECT * FROM read_csv('{csv_file}', now());").as_str())
61+
.await?;
62+
df.show().await?;
63+
64+
// just run
65+
let df = ctx
66+
.sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
67+
.await?;
68+
df.show().await?;
69+
70+
// stream csv table
71+
let df2 = ctx
72+
.sql(format!("SELECT * FROM read_csv_stream('{csv_file}');").as_str())
73+
.await?;
74+
df2.show().await?;
75+
76+
Ok(())
77+
}
78+
79+
// Option1: (full implmentation of a TableProvider)
80+
struct LocalCsvTable {
81+
schema: SchemaRef,
82+
exprs: Vec<Expr>,
83+
batches: Vec<RecordBatch>,
84+
}
85+
86+
#[async_trait]
87+
impl TableProvider for LocalCsvTable {
88+
fn as_any(&self) -> &dyn std::any::Any {
89+
self
90+
}
91+
92+
fn schema(&self) -> SchemaRef {
93+
self.schema.clone()
94+
}
95+
96+
fn table_type(&self) -> TableType {
97+
TableType::Base
98+
}
99+
100+
async fn scan(
101+
&self,
102+
state: &SessionState,
103+
projection: Option<&Vec<usize>>,
104+
_filters: &[Expr],
105+
_limit: Option<usize>,
106+
) -> Result<Arc<dyn ExecutionPlan>> {
107+
if !self.exprs.is_empty() {
108+
self.interpreter_expr(state).await?;
109+
}
110+
Ok(Arc::new(MemoryExec::try_new(
111+
&[self.batches.clone()],
112+
TableProvider::schema(self),
113+
projection.cloned(),
114+
)?))
115+
}
116+
}
117+
118+
impl LocalCsvTable {
119+
// TODO(veeupup): maybe we can make interpreter Expr this more simpler for users
120+
// TODO(veeupup): maybe we can support more type of exprs
121+
async fn interpreter_expr(&self, state: &SessionState) -> Result<()> {
122+
use datafusion::logical_expr::expr_rewriter::normalize_col;
123+
use datafusion::logical_expr::utils::columnize_expr;
124+
let plan = LogicalPlan::EmptyRelation(EmptyRelation {
125+
produce_one_row: true,
126+
schema: Arc::new(DFSchema::empty()),
127+
});
128+
let logical_plan = Projection::try_new(
129+
vec![columnize_expr(
130+
normalize_col(self.exprs[0].clone(), &plan)?,
131+
plan.schema(),
132+
)],
133+
Arc::new(plan),
134+
)
135+
.map(LogicalPlan::Projection)?;
136+
let rbs = collect(
137+
state.create_physical_plan(&logical_plan).await?,
138+
Arc::new(TaskContext::from(state)),
139+
)
140+
.await?;
141+
println!("time now: {:?}", rbs[0].column(0));
142+
Ok(())
143+
}
144+
}
145+
146+
struct LocalCsvTableFunc {}
147+
148+
impl TableFunctionImpl for LocalCsvTableFunc {
149+
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
150+
let mut new_exprs = vec![];
151+
let mut filepath = String::new();
152+
for expr in exprs {
153+
match expr {
154+
Expr::Literal(ScalarValue::Utf8(Some(ref path))) => {
155+
filepath = path.clone()
156+
}
157+
expr => new_exprs.push(expr.clone()),
158+
}
159+
}
160+
let (schema, batches) = read_csv_batches(filepath)?;
161+
let table = LocalCsvTable {
162+
schema,
163+
exprs: new_exprs.clone(),
164+
batches,
165+
};
166+
Ok(Arc::new(table))
167+
}
168+
}
169+
170+
// Option2: (use StreamingTable to make it simpler)
171+
// Implement PartitionStream and Use StreamTable to return streaming table
172+
impl PartitionStream for LocalCsvTable {
173+
fn schema(&self) -> &SchemaRef {
174+
&self.schema
175+
}
176+
177+
fn execute(
178+
&self,
179+
_ctx: Arc<datafusion::execution::TaskContext>,
180+
) -> datafusion::physical_plan::SendableRecordBatchStream {
181+
Box::pin(RecordBatchStreamAdapter::new(
182+
self.schema.clone(),
183+
// You can even read data from network or else anywhere, using async is also ok
184+
// In Fact, you can even implement your own SendableRecordBatchStream
185+
// by implementing Stream<Item = ArrowResult<RecordBatch>> + Send + Sync + 'static
186+
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
187+
))
188+
}
189+
}
190+
191+
struct LocalStreamCsvTable {}
192+
193+
impl TableFunctionImpl for LocalStreamCsvTable {
194+
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
195+
let filepath = match args[0] {
196+
Expr::Literal(ScalarValue::Utf8(Some(ref path))) => path.clone(),
197+
_ => unimplemented!(),
198+
};
199+
let (schema, batches) = read_csv_batches(filepath)?;
200+
let stream = LocalCsvTable {
201+
schema: schema.clone(),
202+
batches,
203+
exprs: vec![],
204+
};
205+
let table = StreamingTable::try_new(schema, vec![Arc::new(stream)])?;
206+
Ok(Arc::new(table))
207+
}
208+
}
209+
210+
fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef, Vec<RecordBatch>)> {
211+
let mut file = File::open(csv_path)?;
212+
let (schema, _) = Format::default().infer_schema(&mut file, None)?;
213+
file.rewind()?;
214+
215+
let reader = ReaderBuilder::new(Arc::new(schema.clone()))
216+
.with_header(true)
217+
.build(file)?;
218+
let mut batches = vec![];
219+
for bacth in reader {
220+
batches.push(bacth?);
221+
}
222+
let schema = Arc::new(schema);
223+
Ok((schema, batches))
224+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! A table that uses a function to generate data
19+
20+
use super::TableProvider;
21+
22+
use datafusion_common::Result;
23+
use datafusion_expr::Expr;
24+
25+
use std::sync::Arc;
26+
27+
/// A trait for table function implementations
28+
pub trait TableFunctionImpl: Sync + Send {
29+
/// Create a table provider
30+
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
31+
}
32+
33+
/// A table that uses a function to generate data
34+
pub struct TableFunction {
35+
/// Name of the table function
36+
name: String,
37+
/// Function implementation
38+
fun: Arc<dyn TableFunctionImpl>,
39+
}
40+
41+
impl TableFunction {
42+
/// Create a new table function
43+
pub fn new(name: String, fun: Arc<dyn TableFunctionImpl>) -> Self {
44+
Self { name, fun }
45+
}
46+
47+
/// Get the name of the table function
48+
pub fn name(&self) -> String {
49+
self.name.clone()
50+
}
51+
52+
/// Get the function implementation and generate a table
53+
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
54+
self.fun.call(args)
55+
}
56+
}

datafusion/core/src/datasource/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub mod avro_to_arrow;
2323
pub mod default_table_source;
2424
pub mod empty;
2525
pub mod file_format;
26+
pub mod function;
2627
pub mod listing;
2728
pub mod listing_table_factory;
2829
pub mod memory;

datafusion/core/src/execution/context/mod.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod parquet;
2626
use crate::{
2727
catalog::{CatalogList, MemoryCatalogList},
2828
datasource::{
29+
function::{TableFunction, TableFunctionImpl},
2930
listing::{ListingOptions, ListingTable},
3031
provider::TableProviderFactory,
3132
},
@@ -42,7 +43,7 @@ use datafusion_common::{
4243
use datafusion_execution::registry::SerializerRegistry;
4344
use datafusion_expr::{
4445
logical_plan::{DdlStatement, Statement},
45-
StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
46+
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
4647
};
4748
pub use datafusion_physical_expr::execution_props::ExecutionProps;
4849
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -795,6 +796,14 @@ impl SessionContext {
795796
.add_var_provider(variable_type, provider);
796797
}
797798

799+
/// Register a table UDF with this context
800+
pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
801+
self.state.write().table_functions.insert(
802+
name.to_owned(),
803+
Arc::new(TableFunction::new(name.to_owned(), fun)),
804+
);
805+
}
806+
798807
/// Registers a scalar UDF within this context.
799808
///
800809
/// Note in SQL queries, function names are looked up using
@@ -1224,6 +1233,8 @@ pub struct SessionState {
12241233
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
12251234
/// Collection of catalogs containing schemas and ultimately TableProviders
12261235
catalog_list: Arc<dyn CatalogList>,
1236+
/// Table Functions
1237+
table_functions: HashMap<String, Arc<TableFunction>>,
12271238
/// Scalar functions that are registered with the context
12281239
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
12291240
/// Aggregate functions registered in the context
@@ -1322,6 +1333,7 @@ impl SessionState {
13221333
physical_optimizers: PhysicalOptimizer::new(),
13231334
query_planner: Arc::new(DefaultQueryPlanner {}),
13241335
catalog_list,
1336+
table_functions: HashMap::new(),
13251337
scalar_functions: HashMap::new(),
13261338
aggregate_functions: HashMap::new(),
13271339
window_functions: HashMap::new(),
@@ -1860,6 +1872,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
18601872
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
18611873
}
18621874

1875+
fn get_table_function_source(
1876+
&self,
1877+
name: &str,
1878+
args: Vec<Expr>,
1879+
) -> Result<Arc<dyn TableSource>> {
1880+
let tbl_func = self
1881+
.state
1882+
.table_functions
1883+
.get(name)
1884+
.cloned()
1885+
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
1886+
let provider = tbl_func.create_table_provider(&args)?;
1887+
1888+
Ok(provider_as_source(provider))
1889+
}
1890+
18631891
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
18641892
self.state.scalar_functions().get(name).cloned()
18651893
}

datafusion/sql/src/planner.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ pub trait ContextProvider {
5151
}
5252
/// Getter for a datasource
5353
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>>;
54+
/// Getter for a table function
55+
fn get_table_function_source(
56+
&self,
57+
_name: &str,
58+
_args: Vec<Expr>,
59+
) -> Result<Arc<dyn TableSource>> {
60+
unimplemented!()
61+
}
62+
5463
/// Getter for a UDF description
5564
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
5665
/// Getter for a UDAF description

0 commit comments

Comments
 (0)