Skip to content

Commit 3910073

Browse files
niebayesalamb
andauthored
fix: add an "expr_planners" method to SessionState (#15119)
* add expr_planners to SessionState * minor * fix ci * add test * flatten imports --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 396e0d5 commit 3910073

File tree

2 files changed

+161
-2
lines changed

2 files changed

+161
-2
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1711,7 +1711,7 @@ impl FunctionRegistry for SessionContext {
17111711
}
17121712

17131713
fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
1714-
self.state.read().expr_planners()
1714+
self.state.read().expr_planners().to_vec()
17151715
}
17161716

17171717
fn register_expr_planner(

datafusion/core/src/execution/session_state.rs

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@ impl SessionState {
552552
&self.optimizer
553553
}
554554

555+
/// Returns the [`ExprPlanner`]s for this session
556+
pub fn expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
557+
&self.expr_planners
558+
}
559+
555560
/// Returns the [`QueryPlanner`] for this session
556561
pub fn query_planner(&self) -> &Arc<dyn QueryPlanner + Send + Sync> {
557562
&self.query_planner
@@ -1637,7 +1642,7 @@ struct SessionContextProvider<'a> {
16371642

16381643
impl ContextProvider for SessionContextProvider<'_> {
16391644
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
1640-
&self.state.expr_planners
1645+
self.state.expr_planners()
16411646
}
16421647

16431648
fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
@@ -1959,8 +1964,17 @@ pub(crate) struct PreparedPlan {
19591964
#[cfg(test)]
19601965
mod tests {
19611966
use super::{SessionContextProvider, SessionStateBuilder};
1967+
use crate::common::assert_contains;
1968+
use crate::config::ConfigOptions;
1969+
use crate::datasource::empty::EmptyTable;
1970+
use crate::datasource::provider_as_source;
19621971
use crate::datasource::MemTable;
19631972
use crate::execution::context::SessionState;
1973+
use crate::logical_expr::planner::ExprPlanner;
1974+
use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
1975+
use crate::physical_plan::ExecutionPlan;
1976+
use crate::sql::planner::ContextProvider;
1977+
use crate::sql::{ResolvedTableReference, TableReference};
19641978
use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
19651979
use arrow::datatypes::{DataType, Field, Schema};
19661980
use datafusion_catalog::MemoryCatalogProviderList;
@@ -1970,6 +1984,7 @@ mod tests {
19701984
use datafusion_expr::Expr;
19711985
use datafusion_optimizer::optimizer::OptimizerRule;
19721986
use datafusion_optimizer::Optimizer;
1987+
use datafusion_physical_plan::display::DisplayableExecutionPlan;
19731988
use datafusion_sql::planner::{PlannerContext, SqlToRel};
19741989
use std::collections::HashMap;
19751990
use std::sync::Arc;
@@ -2127,4 +2142,148 @@ mod tests {
21272142

21282143
Ok(())
21292144
}
2145+
2146+
/// This test demonstrates why it's more convenient and somewhat necessary to provide
2147+
/// an `expr_planners` method for `SessionState`.
2148+
#[tokio::test]
2149+
async fn test_with_expr_planners() -> Result<()> {
2150+
// A helper method for planning count wildcard with or without expr planners.
2151+
async fn plan_count_wildcard(
2152+
with_expr_planners: bool,
2153+
) -> Result<Arc<dyn ExecutionPlan>> {
2154+
let mut context_provider = MyContextProvider::new().with_table(
2155+
"t",
2156+
provider_as_source(Arc::new(EmptyTable::new(Schema::empty().into()))),
2157+
);
2158+
if with_expr_planners {
2159+
context_provider = context_provider.with_expr_planners();
2160+
}
2161+
2162+
let state = &context_provider.state;
2163+
let statement = state.sql_to_statement("select count(*) from t", "mysql")?;
2164+
let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?;
2165+
state.create_physical_plan(&plan).await
2166+
}
2167+
2168+
// Planning count wildcard without expr planners should fail.
2169+
let got = plan_count_wildcard(false).await;
2170+
assert_contains!(
2171+
got.unwrap_err().to_string(),
2172+
"Physical plan does not support logical expression Wildcard"
2173+
);
2174+
2175+
// Planning count wildcard with expr planners should succeed.
2176+
let got = plan_count_wildcard(true).await?;
2177+
let displayable = DisplayableExecutionPlan::new(got.as_ref());
2178+
assert_eq!(
2179+
displayable.indent(false).to_string(),
2180+
"ProjectionExec: expr=[0 as count(*)]\n PlaceholderRowExec\n"
2181+
);
2182+
2183+
Ok(())
2184+
}
2185+
2186+
/// A `ContextProvider` based on `SessionState`.
2187+
///
2188+
/// Almost all planning context are retrieved from the `SessionState`.
2189+
struct MyContextProvider {
2190+
/// The session state.
2191+
state: SessionState,
2192+
/// Registered tables.
2193+
tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
2194+
/// Controls whether to return expression planners when called `ContextProvider::expr_planners`.
2195+
return_expr_planners: bool,
2196+
}
2197+
2198+
impl MyContextProvider {
2199+
/// Creates a new `SessionContextProvider`.
2200+
pub fn new() -> Self {
2201+
Self {
2202+
state: SessionStateBuilder::default()
2203+
.with_default_features()
2204+
.build(),
2205+
tables: HashMap::new(),
2206+
return_expr_planners: false,
2207+
}
2208+
}
2209+
2210+
/// Registers a table.
2211+
///
2212+
/// The catalog and schema are provided by default.
2213+
pub fn with_table(mut self, table: &str, source: Arc<dyn TableSource>) -> Self {
2214+
self.tables.insert(
2215+
ResolvedTableReference {
2216+
catalog: "default".to_string().into(),
2217+
schema: "public".to_string().into(),
2218+
table: table.to_string().into(),
2219+
},
2220+
source,
2221+
);
2222+
self
2223+
}
2224+
2225+
/// Sets the `return_expr_planners` flag to true.
2226+
pub fn with_expr_planners(self) -> Self {
2227+
Self {
2228+
return_expr_planners: true,
2229+
..self
2230+
}
2231+
}
2232+
}
2233+
2234+
impl ContextProvider for MyContextProvider {
2235+
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
2236+
let resolved_table_ref = ResolvedTableReference {
2237+
catalog: "default".to_string().into(),
2238+
schema: "public".to_string().into(),
2239+
table: name.table().to_string().into(),
2240+
};
2241+
let source = self.tables.get(&resolved_table_ref).cloned().unwrap();
2242+
Ok(source)
2243+
}
2244+
2245+
/// We use a `return_expr_planners` flag to demonstrate why it's necessary to
2246+
/// return the expression planners in the `SessionState`.
2247+
///
2248+
/// Note, the default implementation returns an empty slice.
2249+
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
2250+
if self.return_expr_planners {
2251+
self.state.expr_planners()
2252+
} else {
2253+
&[]
2254+
}
2255+
}
2256+
2257+
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
2258+
self.state.scalar_functions().get(name).cloned()
2259+
}
2260+
2261+
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
2262+
self.state.aggregate_functions().get(name).cloned()
2263+
}
2264+
2265+
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
2266+
self.state.window_functions().get(name).cloned()
2267+
}
2268+
2269+
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
2270+
None
2271+
}
2272+
2273+
fn options(&self) -> &ConfigOptions {
2274+
self.state.config_options()
2275+
}
2276+
2277+
fn udf_names(&self) -> Vec<String> {
2278+
self.state.scalar_functions().keys().cloned().collect()
2279+
}
2280+
2281+
fn udaf_names(&self) -> Vec<String> {
2282+
self.state.aggregate_functions().keys().cloned().collect()
2283+
}
2284+
2285+
fn udwf_names(&self) -> Vec<String> {
2286+
self.state.window_functions().keys().cloned().collect()
2287+
}
2288+
}
21302289
}

0 commit comments

Comments
 (0)