Skip to content

Commit 26ddf6a

Browse files
committed
Support User Defined Table Function
Signed-off-by: veeupup <[email protected]>
1 parent 393e48f commit 26ddf6a

File tree

6 files changed

+300
-19
lines changed

6 files changed

+300
-19
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::csv::reader::Format;
19+
use arrow::csv::ReaderBuilder;
20+
use async_trait::async_trait;
21+
use datafusion::arrow::datatypes::SchemaRef;
22+
use datafusion::arrow::record_batch::RecordBatch;
23+
use datafusion::datasource::streaming::StreamingTable;
24+
use datafusion::datasource::TableProvider;
25+
use datafusion::error::Result;
26+
use datafusion::execution::context::SessionState;
27+
use datafusion::physical_plan::memory::MemoryExec;
28+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
29+
use datafusion::physical_plan::streaming::PartitionStream;
30+
use datafusion::physical_plan::ExecutionPlan;
31+
use datafusion::prelude::SessionContext;
32+
use datafusion_expr::{Expr, TableType};
33+
use std::fs::File;
34+
use std::io::Seek;
35+
use std::path::Path;
36+
use std::sync::Arc;
37+
38+
// To define your own table function, you only need to do the following 3 things:
39+
// 1. Define a function that returns a Result<Arc<dyn TableProvider>>
40+
// maybe you can just implement your own TableProvider
41+
// and you can even implement your own PartitionStream and return it with inner StreamTable
42+
// 2. Register the function using ctx.register_udtf
43+
44+
/// This example demonstrates how to register a TableFunction
45+
#[tokio::main]
46+
async fn main() -> Result<()> {
47+
// create local execution context
48+
let ctx = SessionContext::new();
49+
50+
ctx.register_udtf("read_csv", Arc::new(read_csv));
51+
ctx.register_udtf("read_csv_stream", Arc::new(read_csv_stream));
52+
53+
let testdata = datafusion::test_util::arrow_test_data();
54+
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");
55+
56+
let df = ctx
57+
.sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
58+
.await?;
59+
df.show().await?;
60+
61+
let df2 = ctx
62+
.sql(format!("SELECT * FROM read_csv_stream('{csv_file}');").as_str())
63+
.await?;
64+
df2.show().await?;
65+
66+
Ok(())
67+
}
68+
69+
// Option1: (full implmentation of a TableProvider)
70+
// Define your own TableProvider and make a function return Arc<dyn Provider>
71+
struct LocalCsvTable {
72+
schema: SchemaRef,
73+
batches: Vec<RecordBatch>,
74+
}
75+
76+
#[async_trait]
77+
impl TableProvider for LocalCsvTable {
78+
fn as_any(&self) -> &dyn std::any::Any {
79+
self
80+
}
81+
82+
fn schema(&self) -> SchemaRef {
83+
self.schema.clone()
84+
}
85+
86+
fn table_type(&self) -> TableType {
87+
TableType::Base
88+
}
89+
90+
async fn scan(
91+
&self,
92+
_state: &SessionState,
93+
projection: Option<&Vec<usize>>,
94+
_filters: &[Expr],
95+
_limit: Option<usize>,
96+
) -> Result<Arc<dyn ExecutionPlan>> {
97+
Ok(Arc::new(MemoryExec::try_new(
98+
&[self.batches.clone()],
99+
TableProvider::schema(self),
100+
projection.cloned(),
101+
)?))
102+
}
103+
}
104+
105+
fn read_csv(args: &[String]) -> Result<Arc<dyn TableProvider>> {
106+
let (schema, batches) = read_csv_batches(args[0].clone())?;
107+
let table = LocalCsvTable { schema, batches };
108+
Ok(Arc::new(table))
109+
}
110+
111+
// Option2: (use StreamingTable to make it simpler)
112+
// Implement PartitionStream and make a function return it with StreamTable
113+
impl PartitionStream for LocalCsvTable {
114+
fn schema(&self) -> &SchemaRef {
115+
&self.schema
116+
}
117+
118+
fn execute(
119+
&self,
120+
_ctx: Arc<datafusion::execution::TaskContext>,
121+
) -> datafusion::physical_plan::SendableRecordBatchStream {
122+
Box::pin(RecordBatchStreamAdapter::new(
123+
self.schema.clone(),
124+
// You can even read data from network or else anywhere, using async is also ok
125+
// In Fact, you can even implement your own SendableRecordBatchStream
126+
// by implementing Stream<Item = ArrowResult<RecordBatch>> + Send + Sync + 'static
127+
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
128+
))
129+
}
130+
}
131+
132+
fn read_csv_stream(args: &[String]) -> Result<Arc<dyn TableProvider>> {
133+
let (schema, batches) = read_csv_batches(args[0].clone())?;
134+
let stream = LocalCsvTable {
135+
schema: schema.clone(),
136+
batches,
137+
};
138+
let table = StreamingTable::try_new(schema, vec![Arc::new(stream)])?;
139+
Ok(Arc::new(table))
140+
}
141+
142+
fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef, Vec<RecordBatch>)> {
143+
let mut file = File::open(csv_path)?;
144+
let (schema, _) = Format::default().infer_schema(&mut file, None)?;
145+
file.rewind()?;
146+
147+
let reader = ReaderBuilder::new(Arc::new(schema.clone()))
148+
.with_header(true)
149+
.build(file)?;
150+
let mut batches = vec![];
151+
for bacth in reader {
152+
batches.push(bacth?);
153+
}
154+
let schema = Arc::new(schema);
155+
Ok((schema, batches))
156+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! A table that uses a function to generate data
19+
20+
use super::TableProvider;
21+
use datafusion_common::Result;
22+
use std::sync::Arc;
23+
24+
/// Table Function implementation
25+
pub type TableFunctionImplementation =
26+
Arc<dyn Fn(&[String]) -> Result<Arc<dyn TableProvider>> + Send + Sync>;
27+
28+
/// A table that uses a function to generate data
29+
pub struct TableFunction {
30+
/// Name of the table function
31+
name: String,
32+
/// Function implementation
33+
fun: TableFunctionImplementation,
34+
}
35+
36+
impl TableFunction {
37+
/// Create a new table function
38+
pub fn new(name: String, fun: TableFunctionImplementation) -> Self {
39+
Self { name, fun }
40+
}
41+
42+
/// Get the name of the table function
43+
pub fn name(&self) -> String {
44+
self.name.clone()
45+
}
46+
47+
/// Get the function implementation and generate a table
48+
pub fn create_table_provider(
49+
&self,
50+
args: &[String],
51+
) -> Result<Arc<dyn TableProvider>> {
52+
let table = (self.fun)(args)?;
53+
Ok(table)
54+
}
55+
}

datafusion/core/src/datasource/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub mod avro_to_arrow;
2323
pub mod default_table_source;
2424
pub mod empty;
2525
pub mod file_format;
26+
pub mod function;
2627
pub mod listing;
2728
pub mod listing_table_factory;
2829
pub mod memory;

datafusion/core/src/execution/context/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod parquet;
2626
use crate::{
2727
catalog::{CatalogList, MemoryCatalogList},
2828
datasource::{
29+
function::{TableFunction, TableFunctionImplementation},
2930
listing::{ListingOptions, ListingTable},
3031
provider::TableProviderFactory,
3132
},
@@ -795,6 +796,14 @@ impl SessionContext {
795796
.add_var_provider(variable_type, provider);
796797
}
797798

799+
/// Register a table UDF with this context
800+
pub fn register_udtf(&self, name: &str, fun: TableFunctionImplementation) {
801+
self.state.write().table_functions.insert(
802+
name.to_owned(),
803+
Arc::new(TableFunction::new(name.to_owned(), fun)),
804+
);
805+
}
806+
798807
/// Registers a scalar UDF within this context.
799808
///
800809
/// Note in SQL queries, function names are looked up using
@@ -1224,6 +1233,8 @@ pub struct SessionState {
12241233
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
12251234
/// Collection of catalogs containing schemas and ultimately TableProviders
12261235
catalog_list: Arc<dyn CatalogList>,
1236+
/// Table Functions
1237+
table_functions: HashMap<String, Arc<TableFunction>>,
12271238
/// Scalar functions that are registered with the context
12281239
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
12291240
/// Aggregate functions registered in the context
@@ -1322,6 +1333,7 @@ impl SessionState {
13221333
physical_optimizers: PhysicalOptimizer::new(),
13231334
query_planner: Arc::new(DefaultQueryPlanner {}),
13241335
catalog_list,
1336+
table_functions: HashMap::new(),
13251337
scalar_functions: HashMap::new(),
13261338
aggregate_functions: HashMap::new(),
13271339
window_functions: HashMap::new(),
@@ -1860,6 +1872,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
18601872
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
18611873
}
18621874

1875+
fn get_table_function_source(
1876+
&self,
1877+
name: &str,
1878+
args: Vec<String>,
1879+
) -> Result<Arc<dyn TableSource>> {
1880+
let tbl_func = self
1881+
.state
1882+
.table_functions
1883+
.get(name)
1884+
.cloned()
1885+
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
1886+
let provider = tbl_func.create_table_provider(&args)?;
1887+
1888+
Ok(provider_as_source(provider))
1889+
}
1890+
18631891
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
18641892
self.state.scalar_functions().get(name).cloned()
18651893
}

datafusion/sql/src/planner.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ pub trait ContextProvider {
5151
}
5252
/// Getter for a datasource
5353
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>>;
54+
/// Getter for a table function
55+
fn get_table_function_source(
56+
&self,
57+
_name: &str,
58+
_args: Vec<String>,
59+
) -> Result<Arc<dyn TableSource>> {
60+
unimplemented!()
61+
}
62+
5463
/// Getter for a UDF description
5564
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
5665
/// Getter for a UDAF description

datafusion/sql/src/relation/mod.rs

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
// under the License.
1717

1818
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
19-
use datafusion_common::{not_impl_err, DataFusionError, Result};
19+
use datafusion_common::{not_impl_err, DataFusionError, Result, TableReference};
2020
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
21-
use sqlparser::ast::TableFactor;
21+
use sqlparser::ast::{Expr, FunctionArgExpr, TableFactor, Value};
2222

2323
mod join;
2424

@@ -30,24 +30,56 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
3030
planner_context: &mut PlannerContext,
3131
) -> Result<LogicalPlan> {
3232
let (plan, alias) = match relation {
33-
TableFactor::Table { name, alias, .. } => {
34-
// normalize name and alias
35-
let table_ref = self.object_name_to_table_reference(name)?;
36-
let table_name = table_ref.to_string();
37-
let cte = planner_context.get_cte(&table_name);
38-
(
39-
match (
40-
cte,
41-
self.context_provider.get_table_source(table_ref.clone()),
42-
) {
43-
(Some(cte_plan), _) => Ok(cte_plan.clone()),
44-
(_, Ok(provider)) => {
45-
LogicalPlanBuilder::scan(table_ref, provider, None)?.build()
33+
TableFactor::Table {
34+
name, alias, args, ..
35+
} => {
36+
if let Some(func_args) = args {
37+
let tbl_func_name = name.0.get(0).unwrap().value.to_string();
38+
let mut args = vec![];
39+
for arg in func_args {
40+
if let sqlparser::ast::FunctionArg::Unnamed(
41+
FunctionArgExpr::Expr(Expr::Value(
42+
Value::SingleQuotedString(val),
43+
)),
44+
) = arg
45+
{
46+
args.push(val);
47+
} else {
48+
unimplemented!("Unsupported function argument type")
4649
}
47-
(None, Err(e)) => Err(e),
48-
}?,
49-
alias,
50-
)
50+
}
51+
let provider = self
52+
.context_provider
53+
.get_table_function_source(&tbl_func_name, args)?;
54+
let plan = LogicalPlanBuilder::scan(
55+
TableReference::Bare {
56+
table: std::borrow::Cow::Borrowed("tmp_table"),
57+
},
58+
provider,
59+
None,
60+
)?
61+
.build()?;
62+
(plan, alias)
63+
} else {
64+
// normalize name and alias
65+
let table_ref = self.object_name_to_table_reference(name)?;
66+
let table_name = table_ref.to_string();
67+
let cte = planner_context.get_cte(&table_name);
68+
(
69+
match (
70+
cte,
71+
self.context_provider.get_table_source(table_ref.clone()),
72+
) {
73+
(Some(cte_plan), _) => Ok(cte_plan.clone()),
74+
(_, Ok(provider)) => {
75+
LogicalPlanBuilder::scan(table_ref, provider, None)?
76+
.build()
77+
}
78+
(None, Err(e)) => Err(e),
79+
}?,
80+
alias,
81+
)
82+
}
5183
}
5284
TableFactor::Derived {
5385
subquery, alias, ..

0 commit comments

Comments
 (0)