Skip to content

Simply table function example, add some comments #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 41 additions & 58 deletions datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::Int64Array;
use arrow::csv::reader::Format;
use arrow::csv::ReaderBuilder;
use async_trait::async_trait;
Expand All @@ -24,37 +23,38 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
use datafusion::execution::TaskContext;
use datafusion::execution::context::{ExecutionProps, SessionState};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::{collect, ExecutionPlan};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType};
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
use datafusion_expr::{Expr, TableType};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use std::fs::File;
use std::io::Seek;
use std::path::Path;
use std::sync::Arc;

// To define your own table function, you only need to do the following 3 things:
// 1. Implement your own TableProvider
// 2. Implement your own TableFunctionImpl and return your TableProvider
// 3. Register the function using ctx.register_udtf
// 1. Implement your own [`TableProvider`]
// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`]
// 3. Register the function using [`SessionContext::register_udtf`]

/// This example demonstrates how to register a TableFunction
#[tokio::main]
async fn main() -> Result<()> {
// create local execution context
let ctx = SessionContext::new();

// register the table function that will be called in SQL statements by `read_csv`
ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {}));

let testdata = datafusion::test_util::arrow_test_data();
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");

// read csv with at most 2 rows
// Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2)
let df = ctx
.sql(format!("SELECT * FROM read_csv('{csv_file}', 2);").as_str())
.sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str())
.await?;
df.show().await?;

Expand All @@ -67,9 +67,14 @@ async fn main() -> Result<()> {
Ok(())
}

/// Table Function that mimics the [`read_csv`] function in DuckDB.
///
/// Usage: `read_csv(filename, [limit])`
///
/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html
struct LocalCsvTable {
schema: SchemaRef,
exprs: Vec<Expr>,
limit: Option<usize>,
batches: Vec<RecordBatch>,
}

Expand All @@ -89,13 +94,12 @@ impl TableProvider for LocalCsvTable {

async fn scan(
&self,
state: &SessionState,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batches = if !self.exprs.is_empty() {
let max_return_lines = self.interpreter_expr(state).await?;
let batches = if let Some(max_return_lines) = self.limit {
// get max return rows from self.batches
let mut batches = vec![];
let mut lines = 0;
Expand All @@ -121,56 +125,35 @@ impl TableProvider for LocalCsvTable {
)?))
}
}

impl LocalCsvTable {
async fn interpreter_expr(&self, state: &SessionState) -> Result<i64> {
use datafusion::logical_expr::expr_rewriter::normalize_col;
use datafusion::logical_expr::utils::columnize_expr;
let plan = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::new(DFSchema::empty()),
});
let logical_plan = Projection::try_new(
vec![columnize_expr(
normalize_col(self.exprs[0].clone(), &plan)?,
plan.schema(),
)],
Arc::new(plan),
)
.map(LogicalPlan::Projection)?;
let rbs = collect(
state.create_physical_plan(&logical_plan).await?,
Arc::new(TaskContext::from(state)),
)
.await?;
let limit = rbs[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.value(0);
Ok(limit)
}
}

struct LocalCsvTableFunc {}

impl TableFunctionImpl for LocalCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let mut new_exprs = vec![];
let mut filepath = String::new();
for expr in exprs {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(ref path))) => {
filepath = path.clone()
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else {
return plan_err!("read_csv requires at least one string argument");
};

let limit = exprs
.get(1)
.map(|expr| {
// try to simpify the expression, so 1+2 becomes 3, for example
let execution_props = ExecutionProps::new();
let info = SimplifyContext::new(&execution_props);
let expr = ExprSimplifier::new(info).simplify(expr.clone())?;

if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr {
Ok(limit as usize)
} else {
plan_err!("Limit must be an integer")
}
expr => new_exprs.push(expr.clone()),
}
}
let (schema, batches) = read_csv_batches(filepath)?;
})
.transpose()?;

let (schema, batches) = read_csv_batches(path)?;

let table = LocalCsvTable {
schema,
exprs: new_exprs.clone(),
limit,
batches,
};
Ok(Arc::new(table))
Expand Down