diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 28312fee79a7..5fc8dbcfdfb3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -387,7 +387,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -757,9 +757,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "blake2" @@ -875,9 +875,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.99" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" +checksum = "2755ff20a1d93490d26ba33a6f092a38a508398a5320df5d4b3014fcccce9410" dependencies = [ "jobserver", "libc", @@ -981,7 +981,7 @@ version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ - "strum 0.26.2", + "strum 0.26.3", "strum_macros 0.26.4", "unicode-width", ] @@ -1099,7 +1099,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -1262,7 +1262,7 @@ dependencies = [ "paste", "serde_json", "sqlparser", - "strum 0.26.2", + "strum 0.26.3", "strum_macros 0.26.4", ] @@ -1426,7 +1426,7 @@ dependencies = [ "log", "regex", "sqlparser", - "strum 0.26.2", + "strum 0.26.3", ] [[package]] @@ -1504,9 +1504,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "endian-type" @@ -1685,7 +1685,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -2253,9 +2253,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libmimalloc-sys" -version = "0.1.38" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6" +checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" dependencies = [ "cc", "libc", @@ -2267,7 +2267,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] @@ -2289,9 +2289,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lz4_flex" @@ -2331,9 +2331,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mimalloc" -version = "0.1.42" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176" +checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" dependencies = [ "libmimalloc-sys", ] @@ -2406,9 +2406,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2482,9 +2482,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.0" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" dependencies = [ "memchr", ] @@ -2698,7 +2698,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -2912,7 +2912,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", ] [[package]] @@ -3095,7 +3095,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -3264,7 +3264,7 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "core-foundation", "core-foundation-sys", "libc", @@ -3310,14 +3310,14 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.119" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "e8eddb61f0697cc3989c5d64b452f5488e2b8a60fd7d5076a3045076ffef8cb0" dependencies = [ "itoa", "ryu", @@ -3445,7 +3445,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3474,9 +3474,9 @@ checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" [[package]] name = "strum" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ "strum_macros 0.26.4", ] @@ -3491,7 +3491,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3504,14 +3504,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] name = "subtle" -version = "2.6.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d0208408ba0c3df17ed26eb06992cb1a1268d41b2c0e12e65203fbe3972cee5" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -3526,9 +3526,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.67" +version = "2.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff8655ed1d86f3af4ee3fd3263786bc14245ad17c4c7e85ba7187fb3ae028c90" +checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" dependencies = [ "proc-macro2", "quote", @@ -3591,7 +3591,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3646,9 +3646,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" dependencies = [ "tinyvec_macros", ] @@ -3686,7 +3686,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3783,7 +3783,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3828,7 +3828,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3907,9 +3907,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.8.0" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" dependencies = [ "getrandom", "serde", @@ -3982,7 +3982,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", "wasm-bindgen-shared", ] @@ -4016,7 +4016,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4281,7 +4281,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 0b880ddbf81b..ac94ee61fcb2 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -605,7 +605,7 @@ impl SessionState { } } - let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); + let query = self.build_sql_query_planner(&provider); query.statement_to_plan(statement) } @@ -658,8 +658,7 @@ impl SessionState { tables: HashMap::new(), }; - let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); - + let query = self.build_sql_query_planner(&provider); query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) } @@ -943,6 +942,31 @@ impl SessionState { let udtf = self.table_functions.remove(name); Ok(udtf.map(|x| x.function().clone())) } + + fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S> + where + S: ContextProvider, + { + let query = SqlToRel::new_with_options(provider, self.get_parser_options()); + + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + { + let array_planner = + Arc::new(functions_array::planner::ArrayFunctionPlanner::default()) as _; + + let field_access_planner = + Arc::new(functions_array::planner::FieldAccessPlanner::default()) as _; + + query + .with_user_defined_planner(array_planner) + .with_user_defined_planner(field_access_planner) + } + #[cfg(not(feature = "array_expressions"))] + { + query + } + } } struct SessionContextProvider<'a> { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 89ee94f9f845..5f1d3c9d5c6b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -48,6 +48,7 @@ pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod planner; pub mod registry; pub mod simplify; pub mod sort_properties; @@ -81,6 +82,7 @@ pub use partition_evaluator::PartitionEvaluator; pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; +pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs new file mode 100644 index 000000000000..1febfbec7ef0 --- /dev/null +++ b/datafusion/expr/src/planner.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ContextProvider`] and [`UserDefinedSQLPlanner`] APIs to customize SQL query planning + +use std::sync::Arc; + +use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_common::{ + config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, + Result, TableReference, +}; + +use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; + +/// Provides the `SQL` query planner meta-data about tables and +/// functions referenced in SQL statements, without a direct dependency on other +/// DataFusion structures +pub trait ContextProvider { + /// Getter for a datasource + fn get_table_source(&self, name: TableReference) -> Result>; + + fn get_file_type(&self, _ext: &str) -> Result> { + not_impl_err!("Registered file types are not supported") + } + + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + + /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) + /// We don't directly implement this in the logical plan's ['SqlToRel`] + /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency + /// of the sql crate (namely, the `CteWorktable`). + /// The [`ContextProvider`] provides a way to "hide" this dependency. + fn create_cte_work_table( + &self, + _name: &str, + _schema: SchemaRef, + ) -> Result> { + not_impl_err!("Recursive CTE is not implemented") + } + + /// Getter for a UDF description + fn get_function_meta(&self, name: &str) -> Option>; + /// Getter for a UDAF description + fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; + /// Getter for system/user-defined variable type + fn get_variable_type(&self, variable_names: &[String]) -> Option; + + /// Get configuration options + fn options(&self) -> &ConfigOptions; + + /// Get all user defined scalar function names + fn udf_names(&self) -> Vec; + + /// Get all user defined aggregate function names + fn udaf_names(&self) -> Vec; + + /// Get all user defined window function names + fn udwf_names(&self) -> Vec; +} + +/// This trait allows users to customize the behavior of the SQL planner +pub trait UserDefinedSQLPlanner { + /// Plan the binary operation between two expressions, returns OriginalBinaryExpr if not possible + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + /// Plan the field access expression, returns OriginalFieldAccessExpr if not possible + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + // Plan the array literal, returns OriginalArray if not possible + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Original(exprs)) + } +} + +/// An operator with two arguments to plan +/// +/// Note `left` and `right` are DataFusion [`Expr`]s but the `op` is the SQL AST +/// operator. +/// +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawBinaryExpr { + pub op: sqlparser::ast::BinaryOperator, + pub left: Expr, + pub right: Expr, +} + +/// An expression with GetFieldAccess to plan +/// +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawFieldAccessExpr { + pub field_access: GetFieldAccess, + pub expr: Expr, +} + +/// Result of planning a raw expr with [`UserDefinedSQLPlanner`] +#[derive(Debug, Clone)] +pub enum PlannerResult { + /// The raw expression was successfully planned as a new [`Expr`] + Planned(Expr), + /// The raw expression could not be planned, and is returned unmodified + Original(T), +} diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 543b7a60277e..814127be806b 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -39,6 +39,7 @@ pub mod extract; pub mod flatten; pub mod length; pub mod make_array; +pub mod planner; pub mod position; pub mod range; pub mod remove; @@ -50,7 +51,6 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; - use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::ScalarUDF; diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs new file mode 100644 index 000000000000..f33ee56582cf --- /dev/null +++ b/datafusion/functions-array/src/planner.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`ArrayFunctionPlanner`] and [`FieldAccessPlanner`] + +use datafusion_common::{utils::list_ndims, DFSchema, Result}; +use datafusion_expr::{ + planner::{PlannerResult, RawBinaryExpr, RawFieldAccessExpr, UserDefinedSQLPlanner}, + sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess, +}; +use datafusion_functions::expr_fn::get_field; + +use crate::{ + array_has::array_has_all, + expr_fn::{array_append, array_concat, array_prepend}, + extract::{array_element, array_slice}, + make_array::make_array, +}; + +#[derive(Default)] +pub struct ArrayFunctionPlanner {} + +impl UserDefinedSQLPlanner for ArrayFunctionPlanner { + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + schema: &DFSchema, + ) -> Result> { + let RawBinaryExpr { op, left, right } = expr; + + if op == sqlparser::ast::BinaryOperator::StringConcat { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. + // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. + if left_list_ndims + right_list_ndims == 0 { + // TODO: concat function ignore null, but string concat takes null into consideration + // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` + } else if left_list_ndims == right_list_ndims { + return Ok(PlannerResult::Planned(array_concat(vec![left, right]))); + } else if left_list_ndims > right_list_ndims { + return Ok(PlannerResult::Planned(array_append(left, right))); + } else if left_list_ndims < right_list_ndims { + return Ok(PlannerResult::Planned(array_prepend(left, right))); + } + } else if matches!( + op, + sqlparser::ast::BinaryOperator::AtArrow + | sqlparser::ast::BinaryOperator::ArrowAt + ) { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + // if both are list + if left_list_ndims > 0 && right_list_ndims > 0 { + if op == sqlparser::ast::BinaryOperator::AtArrow { + // array1 @> array2 -> array_has_all(array1, array2) + return Ok(PlannerResult::Planned(array_has_all(left, right))); + } else { + // array1 <@ array2 -> array_has_all(array2, array1) + return Ok(PlannerResult::Planned(array_has_all(right, left))); + } + } + } + + Ok(PlannerResult::Original(RawBinaryExpr { op, left, right })) + } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Planned(make_array(exprs))) + } +} + +#[derive(Default)] +pub struct FieldAccessPlanner {} + +impl UserDefinedSQLPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, + _schema: &DFSchema, + ) -> Result> { + let RawFieldAccessExpr { expr, field_access } = expr; + + match field_access { + // expr["field"] => get_field(expr, "field") + GetFieldAccess::NamedStructField { name } => { + Ok(PlannerResult::Planned(get_field(expr, name))) + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) + Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { + Ok(PlannerResult::Planned(Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new( + AggregateFunction::NthValue, + agg_func + .args + .into_iter() + .chain(std::iter::once(*index)) + .collect(), + agg_func.distinct, + agg_func.filter, + agg_func.order_by, + agg_func.null_treatment, + ), + ))) + } + _ => Ok(PlannerResult::Planned(array_element(expr, *index))), + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => Ok(PlannerResult::Planned(array_slice( + expr, + *start, + *stop, + Some(*stride), + ))), + } + } +} + +fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { + agg_func.func_def + == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( + AggregateFunction::ArrayAgg, + ) +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b1182b35ec95..786ea288fa0e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,7 +17,8 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; -use datafusion_common::utils::list_ndims; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::RawFieldAccessExpr; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use datafusion_common::{ @@ -27,8 +28,8 @@ use datafusion_common::{ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, Like, Literal, Operator, TryCast, + lit, Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, + Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -52,7 +53,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { enum StackEntry { SQLExpr(Box), - Operator(Operator), + Operator(sqlparser::ast::BinaryOperator), } // Virtual stack machine to convert SQLExpr to Expr @@ -69,7 +70,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::BinaryOp { left, op, right } => { // Note the order that we push the entries to the stack // is important. We want to visit the left node first. - let op = self.parse_sql_binary_op(op)?; stack.push(StackEntry::Operator(op)); stack.push(StackEntry::SQLExpr(right)); stack.push(StackEntry::SQLExpr(left)); @@ -100,91 +100,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn build_logical_expr( &self, - op: Operator, + op: sqlparser::ast::BinaryOperator, left: Expr, right: Expr, schema: &DFSchema, ) -> Result { - // Rewrite string concat operator to function based on types - // if we get list || list then we rewrite it to array_concat() - // if we get list || non-list then we rewrite it to array_append() - // if we get non-list || list then we rewrite it to array_prepend() - // if we get string || string then we rewrite it to concat() - if op == Operator::StringConcat { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - let left_list_ndims = list_ndims(&left_type); - let right_list_ndims = list_ndims(&right_type); - - // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. - // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. - if left_list_ndims + right_list_ndims == 0 { - // TODO: concat function ignore null, but string concat takes null into consideration - // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` - } else if left_list_ndims == right_list_ndims { - if let Some(udf) = self.context_provider.get_function_meta("array_concat") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_concat not found"); + // try extension planers + let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right }; + for planner in self.planners.iter() { + match planner.plan_binary_op(binary_expr, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); } - } else if left_list_ndims > right_list_ndims { - if let Some(udf) = self.context_provider.get_function_meta("array_append") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_append not found"); - } - } else if left_list_ndims < right_list_ndims { - if let Some(udf) = - self.context_provider.get_function_meta("array_prepend") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_prepend not found"); - } - } - } else if matches!(op, Operator::AtArrow | Operator::ArrowAt) { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - let left_list_ndims = list_ndims(&left_type); - let right_list_ndims = list_ndims(&right_type); - // if both are list - if left_list_ndims > 0 && right_list_ndims > 0 { - if let Some(udf) = - self.context_provider.get_function_meta("array_has_all") - { - // array1 @> array2 -> array_has_all(array1, array2) - if op == Operator::AtArrow { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - // array1 <@ array2 -> array_has_all(array2, array1) - } else { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![right, left], - ))); - } - } else { - return internal_err!("array_has_all not found"); + PlannerResult::Original(expr) => { + binary_expr = expr; } } } + let datafusion_expr::planner::RawBinaryExpr { op, left, right } = binary_expr; Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), - op, + self.parse_sql_binary_op(op)?, Box::new(right), ))) } @@ -272,7 +209,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let expr = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let get_field_access = match *subscript { + let field_access = match *subscript { Subscript::Index { index } => { // index can be a name, in which case it is a named field access match index { @@ -341,7 +278,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - self.plan_field_access(expr, get_field_access) + let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; + for planner in self.planners.iter() { + match planner.plan_field_access(field_access_expr, schema)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(expr) => { + field_access_expr = expr; + } + } + } + + not_impl_err!("GetFieldAccess not supported by UserDefinedExtensionPlanners: {field_access_expr:?}") } SQLExpr::CompoundIdentifier(ids) => { @@ -676,36 +623,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - /// Simplifies an expression like `ARRAY_AGG(expr)[index]` to `NTH_VALUE(expr, index)` - /// - /// returns Some(Expr) if the expression was simplified, otherwise None - /// TODO: this should likely be done in ArrayAgg::simplify when it is moved to a UDAF - fn simplify_array_index_expr(expr: &Expr, index: &Expr) -> Option { - fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) - } - match expr { - Expr::AggregateFunction(agg_func) if is_array_agg(agg_func) => { - let mut new_args = agg_func.args.clone(); - new_args.push(index.clone()); - Some(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new( - AggregateFunction::NthValue, - new_args, - agg_func.distinct, - agg_func.filter.clone(), - agg_func.order_by.clone(), - agg_func.null_treatment, - ), - )) - } - _ => None, - } - } - /// Parses a struct(..) expression fn parse_struct( &self, @@ -991,58 +908,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = vec![fullstr, substr]; Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } - - /// Given an expression and the field to access, creates a new expression for accessing that field - fn plan_field_access( - &self, - expr: Expr, - get_field_access: GetFieldAccess, - ) -> Result { - match get_field_access { - GetFieldAccess::NamedStructField { name } => { - if let Some(udf) = self.context_provider.get_function_meta("get_field") { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, lit(name)], - ))) - } else { - internal_err!("get_field not found") - } - } - // expr[idx] ==> array_element(expr, idx) - GetFieldAccess::ListIndex { key } => { - // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - if let Some(simplified) = Self::simplify_array_index_expr(&expr, &key) { - Ok(simplified) - } else if let Some(udf) = - self.context_provider.get_function_meta("array_element") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, *key], - ))) - } else { - internal_err!("get_field not found") - } - } - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - GetFieldAccess::ListRange { - start, - stop, - stride, - } => { - if let Some(udf) = self.context_provider.get_function_meta("array_slice") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, *start, *stop, *stride], - ))) - } else { - internal_err!("array_slice not found") - } - } - } - } } #[cfg(test)] diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index fa95fc2e051d..5cd6ffc68788 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -20,9 +20,10 @@ use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, + internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{BinaryExpr, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; @@ -130,6 +131,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + // IMPORTANT: Keep sql_array_literal's function body small to prevent stack overflow + // This function is recursively called, potentially leading to deep call stacks. pub(super) fn sql_array_literal( &self, elements: Vec, @@ -142,13 +145,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - if let Some(udf) = self.context_provider.get_function_meta("make_array") { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(udf, values))) - } else { - not_impl_err!( - "array_expression featrue is disable, So should implement make_array UDF by yourself" - ) + self.try_plan_array_literal(values, schema) + } + + fn try_plan_array_literal( + &self, + values: Vec, + schema: &DFSchema, + ) -> Result { + let mut exprs = values; + for planner in self.planners.iter() { + match planner.plan_array_literal(exprs, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(values) => exprs = values, + } } + + internal_err!("Expected a simplified result, but none was found") } /// Convert a SQL interval expression to a DataFusion logical plan diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 00f221200624..443cd64a940c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,17 +21,15 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; -use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::WindowUDF; +use datafusion_expr::planner::UserDefinedSQLPlanner; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; -use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_common::{ not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError, @@ -39,64 +37,11 @@ use datafusion_common::{ }; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; -use datafusion_expr::TableSource; -use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF}; +use datafusion_expr::{col, Expr}; use crate::utils::make_decimal_type; -/// The ContextProvider trait allows the query planner to obtain meta-data about tables and -/// functions referenced in SQL statements -pub trait ContextProvider { - /// Getter for a datasource - fn get_table_source(&self, name: TableReference) -> Result>; - - fn get_file_type(&self, _ext: &str) -> Result> { - not_impl_err!("Registered file types are not supported") - } - - /// Getter for a table function - fn get_table_function_source( - &self, - _name: &str, - _args: Vec, - ) -> Result> { - not_impl_err!("Table Functions are not supported") - } - - /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) - /// We don't directly implement this in the logical plan's ['SqlToRel`] - /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency - /// of the sql crate (namely, the `CteWorktable`). - /// The [`ContextProvider`] provides a way to "hide" this dependency. - fn create_cte_work_table( - &self, - _name: &str, - _schema: SchemaRef, - ) -> Result> { - not_impl_err!("Recursive CTE is not implemented") - } - - /// Getter for a UDF description - fn get_function_meta(&self, name: &str) -> Option>; - /// Getter for a UDAF description - fn get_aggregate_meta(&self, name: &str) -> Option>; - /// Getter for a UDWF - fn get_window_meta(&self, name: &str) -> Option>; - /// Getter for system/user-defined variable type - fn get_variable_type(&self, variable_names: &[String]) -> Option; - - /// Get configuration options - fn options(&self) -> &ConfigOptions; - - /// Get all user defined scalar function names - fn udf_names(&self) -> Vec; - - /// Get all user defined aggregate function names - fn udaf_names(&self) -> Vec; - - /// Get all user defined window function names - fn udwf_names(&self) -> Vec; -} +pub use datafusion_expr::planner::ContextProvider; /// SQL parser options #[derive(Debug)] @@ -241,6 +186,8 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, + /// user defined planner extensions + pub(crate) planners: Vec>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -249,13 +196,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Self::new_with_options(context_provider, ParserOptions::default()) } + /// add an user defined planner + pub fn with_user_defined_planner( + mut self, + planner: Arc, + ) -> Self { + self.planners.push(planner); + self + } + /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; + SqlToRel { context_provider, options, normalizer: IdentNormalizer::new(normalize), + planners: vec![], } }