Skip to content

Commit 34d9d3a

Browse files
authored
feat: basic support for executing prepared statements (#13242)
* feat: basic support for executing prepared statements * Improve execute_prepared * Fix tests * Update doc * Add test * Add issue test * Respect allow_statements option
1 parent ba094e7 commit 34d9d3a

File tree

6 files changed

+304
-59
lines changed

6 files changed

+304
-59
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ use crate::{
4242
logical_expr::{
4343
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
4444
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
45-
DropView, LogicalPlan, LogicalPlanBuilder, SetVariable, TableType, UNNAMED_TABLE,
45+
DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable,
46+
TableType, UNNAMED_TABLE,
4647
},
4748
physical_expr::PhysicalExpr,
4849
physical_plan::ExecutionPlan,
@@ -54,9 +55,9 @@ use arrow::record_batch::RecordBatch;
5455
use arrow_schema::Schema;
5556
use datafusion_common::{
5657
config::{ConfigExtension, TableOptions},
57-
exec_err, not_impl_err, plan_datafusion_err, plan_err,
58+
exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err,
5859
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
59-
DFSchema, SchemaReference, TableReference,
60+
DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference,
6061
};
6162
use datafusion_execution::registry::SerializerRegistry;
6263
use datafusion_expr::{
@@ -687,7 +688,31 @@ impl SessionContext {
687688
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
688689
self.set_variable(stmt).await
689690
}
690-
691+
LogicalPlan::Prepare(Prepare {
692+
name,
693+
input,
694+
data_types,
695+
}) => {
696+
// The number of parameters must match the specified data types length.
697+
if !data_types.is_empty() {
698+
let param_names = input.get_parameter_names()?;
699+
if param_names.len() != data_types.len() {
700+
return plan_err!(
701+
"Prepare specifies {} data types but query has {} parameters",
702+
data_types.len(),
703+
param_names.len()
704+
);
705+
}
706+
}
707+
// Store the unoptimized plan into the session state. Although storing the
708+
// optimized plan or the physical plan would be more efficient, doing so is
709+
// not currently feasible. This is because `now()` would be optimized to a
710+
// constant value, causing each EXECUTE to yield the same result, which is
711+
// incorrect behavior.
712+
self.state.write().store_prepared(name, data_types, input)?;
713+
self.return_empty_dataframe()
714+
}
715+
LogicalPlan::Execute(execute) => self.execute_prepared(execute),
691716
plan => Ok(DataFrame::new(self.state(), plan)),
692717
}
693718
}
@@ -1088,6 +1113,49 @@ impl SessionContext {
10881113
}
10891114
}
10901115

1116+
fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
1117+
let Execute {
1118+
name, parameters, ..
1119+
} = execute;
1120+
let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
1121+
exec_datafusion_err!("Prepared statement '{}' does not exist", name)
1122+
})?;
1123+
1124+
// Only allow literals as parameters for now.
1125+
let mut params: Vec<ScalarValue> = parameters
1126+
.into_iter()
1127+
.map(|e| match e {
1128+
Expr::Literal(scalar) => Ok(scalar),
1129+
_ => not_impl_err!("Unsupported parameter type: {}", e),
1130+
})
1131+
.collect::<Result<_>>()?;
1132+
1133+
// If the prepared statement provides data types, cast the params to those types.
1134+
if !prepared.data_types.is_empty() {
1135+
if params.len() != prepared.data_types.len() {
1136+
return exec_err!(
1137+
"Prepared statement '{}' expects {} parameters, but {} provided",
1138+
name,
1139+
prepared.data_types.len(),
1140+
params.len()
1141+
);
1142+
}
1143+
params = params
1144+
.into_iter()
1145+
.zip(prepared.data_types.iter())
1146+
.map(|(e, dt)| e.cast_to(dt))
1147+
.collect::<Result<_>>()?;
1148+
}
1149+
1150+
let params = ParamValues::List(params);
1151+
let plan = prepared
1152+
.plan
1153+
.as_ref()
1154+
.clone()
1155+
.replace_params_with_values(&params)?;
1156+
Ok(DataFrame::new(self.state(), plan))
1157+
}
1158+
10911159
/// Registers a variable provider within this context.
10921160
pub fn register_variable(
10931161
&self,
@@ -1705,6 +1773,14 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {
17051773
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
17061774
plan_err!("Statement not supported: {}", stmt.name())
17071775
}
1776+
// TODO: Implement PREPARE as a LogicalPlan::Statement
1777+
LogicalPlan::Prepare(_) if !self.options.allow_statements => {
1778+
plan_err!("Statement not supported: PREPARE")
1779+
}
1780+
// TODO: Implement EXECUTE as a LogicalPlan::Statement
1781+
LogicalPlan::Execute(_) if !self.options.allow_statements => {
1782+
plan_err!("Statement not supported: EXECUTE")
1783+
}
17081784
_ => Ok(TreeNodeRecursion::Continue),
17091785
}
17101786
}

datafusion/core/src/execution/session_state.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
4040
use datafusion_common::file_options::file_type::FileType;
4141
use datafusion_common::tree_node::TreeNode;
4242
use datafusion_common::{
43-
config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError,
43+
config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError,
4444
ResolvedTableReference, TableReference,
4545
};
4646
use datafusion_execution::config::SessionConfig;
@@ -171,6 +171,9 @@ pub struct SessionState {
171171
/// It will be invoked on `CREATE FUNCTION` statements.
172172
/// thus, changing dialect o PostgreSql is required
173173
function_factory: Option<Arc<dyn FunctionFactory>>,
174+
/// Cache logical plans of prepared statements for later execution.
175+
/// Key is the prepared statement name.
176+
prepared_plans: HashMap<String, Arc<PreparedPlan>>,
174177
}
175178

176179
impl Debug for SessionState {
@@ -197,6 +200,7 @@ impl Debug for SessionState {
197200
.field("scalar_functions", &self.scalar_functions)
198201
.field("aggregate_functions", &self.aggregate_functions)
199202
.field("window_functions", &self.window_functions)
203+
.field("prepared_plans", &self.prepared_plans)
200204
.finish()
201205
}
202206
}
@@ -906,6 +910,29 @@ impl SessionState {
906910
let udtf = self.table_functions.remove(name);
907911
Ok(udtf.map(|x| x.function().clone()))
908912
}
913+
914+
/// Store the logical plan and the parameter types of a prepared statement.
915+
pub(crate) fn store_prepared(
916+
&mut self,
917+
name: String,
918+
data_types: Vec<DataType>,
919+
plan: Arc<LogicalPlan>,
920+
) -> datafusion_common::Result<()> {
921+
match self.prepared_plans.entry(name) {
922+
Entry::Vacant(e) => {
923+
e.insert(Arc::new(PreparedPlan { data_types, plan }));
924+
Ok(())
925+
}
926+
Entry::Occupied(e) => {
927+
exec_err!("Prepared statement '{}' already exists", e.key())
928+
}
929+
}
930+
}
931+
932+
/// Get the prepared plan with the given name.
933+
pub(crate) fn get_prepared(&self, name: &str) -> Option<Arc<PreparedPlan>> {
934+
self.prepared_plans.get(name).map(Arc::clone)
935+
}
909936
}
910937

911938
/// A builder to be used for building [`SessionState`]'s. Defaults will
@@ -1327,6 +1354,7 @@ impl SessionStateBuilder {
13271354
table_factories: table_factories.unwrap_or_default(),
13281355
runtime_env,
13291356
function_factory,
1357+
prepared_plans: HashMap::new(),
13301358
};
13311359

13321360
if let Some(file_formats) = file_formats {
@@ -1876,6 +1904,14 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {
18761904
}
18771905
}
18781906

1907+
#[derive(Debug)]
1908+
pub(crate) struct PreparedPlan {
1909+
/// Data types of the parameters
1910+
pub(crate) data_types: Vec<DataType>,
1911+
/// The prepared logical plan
1912+
pub(crate) plan: Arc<LogicalPlan>,
1913+
}
1914+
18791915
#[cfg(test)]
18801916
mod tests {
18811917
use super::{SessionContextProvider, SessionStateBuilder};

datafusion/core/tests/sql/select.rs

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ async fn test_named_query_parameters() -> Result<()> {
5757
let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;
5858

5959
// sql to statement then to logical plan with parameters
60-
// c1 defined as UINT32, c2 defined as UInt64
6160
let results = ctx
6261
.sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo")
6362
.await?
@@ -106,9 +105,9 @@ async fn test_prepare_statement() -> Result<()> {
106105
let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;
107106

108107
// sql to statement then to prepare logical plan with parameters
109-
// c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64
110-
let dataframe =
111-
ctx.sql("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?;
108+
let dataframe = ctx
109+
.sql("SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1")
110+
.await?;
112111

113112
// prepare logical plan to logical plan without parameters
114113
let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))];
@@ -156,7 +155,7 @@ async fn prepared_statement_type_coercion() -> Result<()> {
156155
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
157156
])?;
158157
ctx.register_batch("test", batch)?;
159-
let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3")
158+
let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3")
160159
.await?
161160
.with_param_values(vec![
162161
ScalarValue::from(1_i64),
@@ -176,27 +175,6 @@ async fn prepared_statement_type_coercion() -> Result<()> {
176175
Ok(())
177176
}
178177

179-
#[tokio::test]
180-
async fn prepared_statement_invalid_types() -> Result<()> {
181-
let ctx = SessionContext::new();
182-
let signed_ints: Int32Array = vec![-1, 0, 1].into();
183-
let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
184-
let batch = RecordBatch::try_from_iter(vec![
185-
("signed", Arc::new(signed_ints) as ArrayRef),
186-
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
187-
])?;
188-
ctx.register_batch("test", batch)?;
189-
let results = ctx
190-
.sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = $1")
191-
.await?
192-
.with_param_values(vec![ScalarValue::from("1")]);
193-
assert_eq!(
194-
results.unwrap_err().strip_backtrace(),
195-
"Error during planning: Expected parameter of type Int32, got Utf8 at index 0"
196-
);
197-
Ok(())
198-
}
199-
200178
#[tokio::test]
201179
async fn test_parameter_type_coercion() -> Result<()> {
202180
let ctx = SessionContext::new();

datafusion/core/tests/sql/sql_api.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,30 @@ async fn unsupported_statement_returns_error() {
113113
ctx.sql_with_options(sql, options).await.unwrap();
114114
}
115115

116+
// Disallow PREPARE and EXECUTE statements if `allow_statements` is false
117+
#[tokio::test]
118+
async fn disable_prepare_and_execute_statement() {
119+
let ctx = SessionContext::new();
120+
121+
let prepare_sql = "PREPARE plan(INT) AS SELECT $1";
122+
let execute_sql = "EXECUTE plan(1)";
123+
let options = SQLOptions::new().with_allow_statements(false);
124+
let df = ctx.sql_with_options(prepare_sql, options).await;
125+
assert_eq!(
126+
df.unwrap_err().strip_backtrace(),
127+
"Error during planning: Statement not supported: PREPARE"
128+
);
129+
let df = ctx.sql_with_options(execute_sql, options).await;
130+
assert_eq!(
131+
df.unwrap_err().strip_backtrace(),
132+
"Error during planning: Statement not supported: EXECUTE"
133+
);
134+
135+
let options = options.with_allow_statements(true);
136+
ctx.sql_with_options(prepare_sql, options).await.unwrap();
137+
ctx.sql_with_options(execute_sql, options).await.unwrap();
138+
}
139+
116140
#[tokio::test]
117141
async fn empty_statement_returns_error() {
118142
let ctx = SessionContext::new();

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,22 @@ impl LogicalPlan {
14401440
.map(|res| res.data)
14411441
}
14421442

1443+
/// Walk the logical plan, find any `Placeholder` tokens, and return a set of their names.
1444+
pub fn get_parameter_names(&self) -> Result<HashSet<String>> {
1445+
let mut param_names = HashSet::new();
1446+
self.apply_with_subqueries(|plan| {
1447+
plan.apply_expressions(|expr| {
1448+
expr.apply(|expr| {
1449+
if let Expr::Placeholder(Placeholder { id, .. }) = expr {
1450+
param_names.insert(id.clone());
1451+
}
1452+
Ok(TreeNodeRecursion::Continue)
1453+
})
1454+
})
1455+
})
1456+
.map(|_| param_names)
1457+
}
1458+
14431459
/// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes
14441460
pub fn get_parameter_types(
14451461
&self,

0 commit comments

Comments
 (0)