diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5884e424c781..9d20c242bbef 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -272,7 +272,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.6", + "indexmap 2.3.0", "lexical-core", "num", "serde", @@ -375,7 +375,7 @@ dependencies = [ "tokio", "xz2", "zstd 0.13.2", - "zstd-safe 7.2.0", + "zstd-safe 7.2.1", ] [[package]] @@ -837,9 +837,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.1" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "bytes-utils" @@ -874,9 +874,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" +checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" dependencies = [ "jobserver", "libc", @@ -1161,7 +1161,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.5", - "indexmap 2.2.6", + "indexmap 2.3.0", "itertools 0.12.1", "log", "num-traits", @@ -1357,7 +1357,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.5", - "indexmap 2.2.6", + "indexmap 2.3.0", "itertools 0.12.1", "log", "paste", @@ -1384,7 +1384,7 @@ dependencies = [ "half", "hashbrown 0.14.5", "hex", - "indexmap 2.2.6", + "indexmap 2.3.0", "itertools 0.12.1", "log", "paste", @@ -1436,7 +1436,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.2.6", + "indexmap 2.3.0", "itertools 0.12.1", "log", "once_cell", @@ -1629,9 +1629,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" dependencies = [ "crc32fast", "miniz_oxide", @@ -1801,7 +1801,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.2.6", + "indexmap 2.3.0", "slab", "tokio", "tokio-util", @@ -1820,7 +1820,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.2.6", + "indexmap 2.3.0", "slab", "tokio", "tokio-util", @@ -2112,9 +2112,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -2552,7 +2552,7 @@ dependencies = [ "rand", "reqwest", "ring 0.17.8", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "serde", "serde_json", "snafu", @@ -2682,7 +2682,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.3.0", ] [[package]] @@ -2769,9 +2769,12 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "predicates" @@ -2854,9 +2857,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad" +checksum = "b22d8e7369034b9a7132bc2008cac12f2013c8132b45e0554e6e20e2617f2156" dependencies = [ "bytes", "pin-project-lite", @@ -2864,6 +2867,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls 0.23.12", + "socket2", "thiserror", "tokio", "tracing", @@ -2871,9 +2875,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.3" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" +checksum = "ba92fb39ec7ad06ca2582c0ca834dfeadcaf06ddfc8e635c80aa7e1c05315fdd" dependencies = [ "bytes", "rand", @@ -2895,6 +2899,7 @@ dependencies = [ "libc", "once_cell", "socket2", + "tracing", "windows-sys 0.52.0", ] @@ -2969,9 +2974,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -3029,7 +3034,7 @@ dependencies = [ "quinn", "rustls 0.23.12", "rustls-native-certs 0.7.1", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "rustls-pki-types", "serde", "serde_json", @@ -3117,9 +3122,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc-hash" -version = "1.1.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" [[package]] name = "rustc_version" @@ -3188,7 +3193,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "rustls-pki-types", "schannel", "security-framework", @@ -3205,9 +3210,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -3356,9 +3361,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.121" +version = "1.0.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" +checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" dependencies = [ "itoa", "memchr", @@ -3585,12 +3590,13 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "b8fcd239983515c23a32fb82099f97d0b11b8c72f654ed659363a95c3dad7a53" dependencies = [ "cfg-if", "fastrand 2.1.0", + "once_cell", "rustix", "windows-sys 0.52.0", ] @@ -4119,11 +4125,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4159,6 +4165,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -4311,6 +4326,7 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] @@ -4346,7 +4362,7 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ - "zstd-safe 7.2.0", + "zstd-safe 7.2.1", ] [[package]] @@ -4361,9 +4377,9 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.0" +version = "7.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" dependencies = [ "zstd-sys", ] diff --git a/datafusion/common/src/utils/expr.rs b/datafusion/common/src/utils/expr.rs new file mode 100644 index 000000000000..0fe4546b8538 --- /dev/null +++ b/datafusion/common/src/utils/expr.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression utilities + +use crate::ScalarValue; + +/// The value to which `COUNT(*)` is expanded to in +/// `COUNT()` expressions +pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1)); diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 8b025255f5df..58dc8f40b577 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -17,6 +17,7 @@ //! This module provides the bisect function, which implements binary search. +pub mod expr; pub mod memory; pub mod proxy; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs deleted file mode 100644 index a0f6f6a65b1f..000000000000 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ /dev/null @@ -1,657 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Utilizing exact statistics from sources to avoid scanning data -use std::sync::Arc; - -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_plan::aggregates::AggregateExec; -use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; -use crate::scalar::ScalarValue; - -use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::udaf::AggregateFunctionExpr; - -/// Optimizer that uses available statistics for aggregate functions -#[derive(Default)] -pub struct AggregateStatistics {} - -impl AggregateStatistics { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for AggregateStatistics { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> Result> { - if let Some(partial_agg_exec) = take_optimizable(&*plan) { - let partial_agg_exec = partial_agg_exec - .as_any() - .downcast_ref::() - .expect("take_optimizable() ensures that this is a AggregateExec"); - let stats = partial_agg_exec.input().statistics()?; - let mut projections = vec![]; - for expr in partial_agg_exec.aggr_expr() { - if let Some((non_null_rows, name)) = - take_optimizable_column_and_table_count(&**expr, &stats) - { - projections.push((expressions::lit(non_null_rows), name.to_owned())); - } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { - projections.push((expressions::lit(min), name.to_owned())); - } else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) { - projections.push((expressions::lit(max), name.to_owned())); - } else { - // TODO: we need all aggr_expr to be resolved (cf TODO fullres) - break; - } - } - - // TODO fullres: use statistics even if not all aggr_expr could be resolved - if projections.len() == partial_agg_exec.aggr_expr().len() { - // input can be entirely removed - Ok(Arc::new(ProjectionExec::try_new( - projections, - Arc::new(PlaceholderRowExec::new(plan.schema())), - )?)) - } else { - plan.map_children(|child| { - self.optimize(child, _config).map(Transformed::yes) - }) - .data() - } - } else { - plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes)) - .data() - } - } - - fn name(&self) -> &str { - "aggregate_statistics" - } - - /// This rule will change the nullable properties of the schema, disable the schema check. - fn schema_check(&self) -> bool { - false - } -} - -/// assert if the node passed as argument is a final `AggregateExec` node that can be optimized: -/// - its child (with possible intermediate layers) is a partial `AggregateExec` node -/// - they both have no grouping expression -/// -/// If this is the case, return a ref to the partial `AggregateExec`, else `None`. -/// We would have preferred to return a casted ref to AggregateExec but the recursion requires -/// the `ExecutionPlan.children()` method that returns an owned reference. -fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { - if let Some(final_agg_exec) = node.as_any().downcast_ref::() { - if !final_agg_exec.mode().is_first_stage() - && final_agg_exec.group_expr().is_empty() - { - let mut child = Arc::clone(final_agg_exec.input()); - loop { - if let Some(partial_agg_exec) = - child.as_any().downcast_ref::() - { - if partial_agg_exec.mode().is_first_stage() - && partial_agg_exec.group_expr().is_empty() - && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) - { - return Some(child); - } - } - if let [childrens_child] = child.children().as_slice() { - child = Arc::clone(childrens_child); - } else { - break; - } - } - } - } - None -} - -/// If this agg_expr is a count that can be exactly derived from the statistics, return it. -fn take_optimizable_column_and_table_count( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if is_non_distinct_count(agg_expr) { - if let Precision::Exact(num_rows) = stats.num_rows { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - agg_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = - exprs[0].as_any().downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - agg_expr.name().to_string(), - )); - } - } - } - } - } - None -} - -/// If this agg_expr is a min that is exactly defined in the statistics, return it. -fn take_optimizable_min( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_min(agg_expr) { - if let Ok(min_data_type) = - ScalarValue::try_from(agg_expr.field().unwrap().data_type()) - { - return Some((min_data_type, agg_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_min(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].min_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } - } - _ => {} - } - } - None -} - -/// If this agg_expr is a max that is exactly defined in the statistics, return it. -fn take_optimizable_max( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_max(agg_expr) { - if let Ok(max_data_type) = - ScalarValue::try_from(agg_expr.field().unwrap().data_type()) - { - return Some((max_data_type, agg_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_max(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].max_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } - } - _ => {} - } - } - None -} - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool { - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { - return true; - } - } - false -} - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_min(agg_expr: &dyn AggregateExpr) -> bool { - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name().to_lowercase() == "min" { - return true; - } - } - false -} - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_max(agg_expr: &dyn AggregateExpr) -> bool { - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name().to_lowercase() == "max" { - return true; - } - } - false -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - - use crate::logical_expr::Operator; - use crate::physical_plan::aggregates::PhysicalGroupBy; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::common; - use crate::physical_plan::filter::FilterExec; - use crate::physical_plan::memory::MemoryExec; - use crate::prelude::SessionContext; - - use arrow::array::Int32Array; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::cast::as_int64_array; - use datafusion_functions_aggregate::count::count_udaf; - use datafusion_physical_expr::expressions::cast; - use datafusion_physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; - use datafusion_physical_plan::aggregates::AggregateMode; - - /// Mock data using a MemoryExec which has an exact count statistic - fn mock_data() -> Result> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), - Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), - ], - )?; - - Ok(Arc::new(MemoryExec::try_new( - &[vec![batch]], - Arc::clone(&schema), - None, - )?)) - } - - /// Checks that the count optimization was applied and we still get the right result - async fn assert_count_optim_success( - plan: AggregateExec, - agg: TestAggregate, - ) -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let plan: Arc = Arc::new(plan); - - let optimized = AggregateStatistics::new() - .optimize(Arc::clone(&plan), state.config_options())?; - - // A ProjectionExec is a sign that the count optimization was applied - assert!(optimized.as_any().is::()); - - // run both the optimized and nonoptimized plan - let optimized_result = - common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; - let nonoptimized_result = - common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; - assert_eq!(optimized_result.len(), nonoptimized_result.len()); - - // and validate the results are the same and expected - assert_eq!(optimized_result.len(), 1); - check_batch(optimized_result.into_iter().next().unwrap(), &agg); - // check the non optimized one too to ensure types and names remain the same - assert_eq!(nonoptimized_result.len(), 1); - check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); - - Ok(()) - } - - fn check_batch(batch: RecordBatch, agg: &TestAggregate) { - let schema = batch.schema(); - let fields = schema.fields(); - assert_eq!(fields.len(), 1); - - let field = &fields[0]; - assert_eq!(field.name(), agg.column_name()); - assert_eq!(field.data_type(), &DataType::Int64); - // note that nullabiolity differs - - assert_eq!( - as_int64_array(batch.column(0)).unwrap().values(), - &[agg.expected_count()] - ); - } - - /// Describe the type of aggregate being tested - pub(crate) enum TestAggregate { - /// Testing COUNT(*) type aggregates - CountStar, - - /// Testing for COUNT(column) aggregate - ColumnA(Arc), - } - - impl TestAggregate { - pub(crate) fn new_count_star() -> Self { - Self::CountStar - } - - fn new_count_column(schema: &Arc) -> Self { - Self::ColumnA(schema.clone()) - } - - // Return appropriate expr depending if COUNT is for col or table (*) - pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { - AggregateExprBuilder::new(count_udaf(), vec![self.column()]) - .schema(Arc::new(schema.clone())) - .name(self.column_name()) - .build() - .unwrap() - } - - /// what argument would this aggregate need in the plan? - fn column(&self) -> Arc { - match self { - Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), - Self::ColumnA(s) => expressions::col("a", s).unwrap(), - } - } - - /// What name would this aggregate produce in a plan? - fn column_name(&self) -> &'static str { - match self { - Self::CountStar => "COUNT(*)", - Self::ColumnA(_) => "COUNT(a)", - } - } - - /// What is the expected count? - fn expected_count(&self) -> i64 { - match self { - TestAggregate::CountStar => 3, - TestAggregate::ColumnA(_) => 2, - } - } - } - - #[tokio::test] - async fn test_count_partial_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) - } - - #[tokio::test] - async fn test_count_partial_with_nulls_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) - } - - #[tokio::test] - async fn test_count_partial_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) - } - - #[tokio::test] - async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) - } - - #[tokio::test] - async fn test_count_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = - AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) - } - - #[tokio::test] - async fn test_count_with_nulls_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = - AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index b5d3f432d84d..b181ad9051ed 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -193,7 +193,6 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { mod tests { use super::*; - use crate::physical_optimizer::aggregate_statistics::tests::TestAggregate; use crate::physical_optimizer::enforce_distribution::tests::{ parquet_exec_with_sort, schema, trim_plan_display, }; @@ -201,6 +200,7 @@ mod tests { use crate::physical_plan::collect; use crate::physical_plan::memory::MemoryExec; use crate::prelude::SessionContext; + use crate::test_util::TestAggregate; use arrow::array::Int32Array; use arrow::compute::SortOptions; diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 01ddab3ec97d..9291d0b84865 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -21,7 +21,6 @@ //! "Repartition" or "Sortedness" //! //! [`ExecutionPlan`]: crate::physical_plan::ExecutionPlan -pub mod aggregate_statistics; pub mod coalesce_batches; pub mod combine_partial_final_agg; pub mod enforce_distribution; diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 042febf32fd1..6eb82dece31c 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -45,11 +45,16 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::TableReference; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{CreateExternalTable, Expr, TableType}; -use datafusion_physical_expr::EquivalenceProperties; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::{ + expressions, AggregateExpr, EquivalenceProperties, PhysicalExpr, +}; use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use futures::Stream; use tempfile::TempDir; // backwards compatibility @@ -402,3 +407,57 @@ pub fn bounded_stream(batch: RecordBatch, limit: usize) -> SendableRecordBatchSt batch, }) } + +/// Describe the type of aggregate being tested +pub enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), +} + +impl TestAggregate { + /// Create a new COUNT(*) aggregate + pub fn new_count_star() -> Self { + Self::CountStar + } + + /// Create a new COUNT(column) aggregate + pub fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(schema.clone()) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + pub fn count_expr(&self, schema: &Schema) -> Arc { + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .name(self.column_name()) + .build() + .unwrap() + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), + } + } + + /// What name would this aggregate produce in a plan? + pub fn column_name(&self) -> &'static str { + match self { + Self::CountStar => "COUNT(*)", + Self::ColumnA(_) => "COUNT(a)", + } + } + + /// What is the expected count? + pub fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, + } + } +} diff --git a/datafusion/core/tests/physical_optimizer_integration.rs b/datafusion/core/tests/physical_optimizer_integration.rs new file mode 100644 index 000000000000..bbf4dcd2b799 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer_integration.rs @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the physical optimizer + +use datafusion_common::config::ConfigOptions; +use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +use datafusion::error::Result; +use datafusion::logical_expr::Operator; +use datafusion::prelude::SessionContext; +use datafusion::test_util::TestAggregate; +use datafusion_physical_plan::aggregates::PhysicalGroupBy; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::common; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::memory::MemoryExec; + +use arrow::array::Int32Array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::cast::as_int64_array; +use datafusion_physical_expr::expressions::{self, cast}; +use datafusion_physical_plan::aggregates::AggregateMode; + +/// Mock data using a MemoryExec which has an exact count statistic +fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), + Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) +} + +/// Checks that the count optimization was applied and we still get the right result +async fn assert_count_optim_success( + plan: AggregateExec, + agg: TestAggregate, +) -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let plan: Arc = Arc::new(plan); + + let optimized = + AggregateStatistics::new().optimize(Arc::clone(&plan), state.config_options())?; + + // A ProjectionExec is a sign that the count optimization was applied + assert!(optimized.as_any().is::()); + + // run both the optimized and nonoptimized plan + let optimized_result = + common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; + let nonoptimized_result = + common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; + assert_eq!(optimized_result.len(), nonoptimized_result.len()); + + // and validate the results are the same and expected + assert_eq!(optimized_result.len(), 1); + check_batch(optimized_result.into_iter().next().unwrap(), &agg); + // check the non optimized one too to ensure types and names remain the same + assert_eq!(nonoptimized_result.len(), 1); + check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); + + Ok(()) +} + +fn check_batch(batch: RecordBatch, agg: &TestAggregate) { + let schema = batch.schema(); + let fields = schema.fields(); + assert_eq!(fields.len(), 1); + + let field = &fields[0]; + assert_eq!(field.name(), agg.column_name()); + assert_eq!(field.data_type(), &DataType::Int64); + // note that nullabiolity differs + + assert_eq!( + as_int64_array(batch.column(0)).unwrap().values(), + &[agg.expected_count()] + ); +} + +#[tokio::test] +async fn test_count_partial_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_count_partial_with_nulls_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_count_partial_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_count_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) +} + +#[tokio::test] +async fn test_count_with_nulls_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 683a8e170ed4..65a70b673266 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -35,14 +35,14 @@ use datafusion_common::tree_node::{ use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, Result, - ScalarValue, TableReference, + TableReference, }; use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions -pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1)); +pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; /// Recursively walk a list of expression trees, collecting the unique set of columns /// referenced in the expression diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs new file mode 100644 index 000000000000..0ce92df393aa --- /dev/null +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -0,0 +1,298 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilizing exact statistics from sources to avoid scanning data +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::scalar::ScalarValue; +use datafusion_common::Result; +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; + +use crate::PhysicalOptimizerRule; +use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::udaf::AggregateFunctionExpr; + +/// Optimizer that uses available statistics for aggregate functions +#[derive(Default)] +pub struct AggregateStatistics {} + +impl AggregateStatistics { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for AggregateStatistics { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + if let Some(partial_agg_exec) = take_optimizable(&*plan) { + let partial_agg_exec = partial_agg_exec + .as_any() + .downcast_ref::() + .expect("take_optimizable() ensures that this is a AggregateExec"); + let stats = partial_agg_exec.input().statistics()?; + let mut projections = vec![]; + for expr in partial_agg_exec.aggr_expr() { + if let Some((non_null_rows, name)) = + take_optimizable_column_and_table_count(&**expr, &stats) + { + projections.push((expressions::lit(non_null_rows), name.to_owned())); + } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { + projections.push((expressions::lit(min), name.to_owned())); + } else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) { + projections.push((expressions::lit(max), name.to_owned())); + } else { + // TODO: we need all aggr_expr to be resolved (cf TODO fullres) + break; + } + } + + // TODO fullres: use statistics even if not all aggr_expr could be resolved + if projections.len() == partial_agg_exec.aggr_expr().len() { + // input can be entirely removed + Ok(Arc::new(ProjectionExec::try_new( + projections, + Arc::new(PlaceholderRowExec::new(plan.schema())), + )?)) + } else { + plan.map_children(|child| { + self.optimize(child, _config).map(Transformed::yes) + }) + .data() + } + } else { + plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes)) + .data() + } + } + + fn name(&self) -> &str { + "aggregate_statistics" + } + + /// This rule will change the nullable properties of the schema, disable the schema check. + fn schema_check(&self) -> bool { + false + } +} + +/// assert if the node passed as argument is a final `AggregateExec` node that can be optimized: +/// - its child (with possible intermediate layers) is a partial `AggregateExec` node +/// - they both have no grouping expression +/// +/// If this is the case, return a ref to the partial `AggregateExec`, else `None`. +/// We would have preferred to return a casted ref to AggregateExec but the recursion requires +/// the `ExecutionPlan.children()` method that returns an owned reference. +fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { + if let Some(final_agg_exec) = node.as_any().downcast_ref::() { + if !final_agg_exec.mode().is_first_stage() + && final_agg_exec.group_expr().is_empty() + { + let mut child = Arc::clone(final_agg_exec.input()); + loop { + if let Some(partial_agg_exec) = + child.as_any().downcast_ref::() + { + if partial_agg_exec.mode().is_first_stage() + && partial_agg_exec.group_expr().is_empty() + && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) + { + return Some(child); + } + } + if let [childrens_child] = child.children().as_slice() { + child = Arc::clone(childrens_child); + } else { + break; + } + } + } + } + None +} + +/// If this agg_expr is a count that can be exactly derived from the statistics, return it. +fn take_optimizable_column_and_table_count( + agg_expr: &dyn AggregateExpr, + stats: &Statistics, +) -> Option<(ScalarValue, String)> { + let col_stats = &stats.column_statistics; + if is_non_distinct_count(agg_expr) { + if let Precision::Exact(num_rows) = stats.num_rows { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() + { + let current_val = &col_stats[col_expr.index()].null_count; + if let &Precision::Exact(val) = current_val { + return Some(( + ScalarValue::Int64(Some((num_rows - val) as i64)), + agg_expr.name().to_string(), + )); + } + } else if let Some(lit_expr) = + exprs[0].as_any().downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(( + ScalarValue::Int64(Some(num_rows as i64)), + agg_expr.name().to_string(), + )); + } + } + } + } + } + None +} + +/// If this agg_expr is a min that is exactly defined in the statistics, return it. +fn take_optimizable_min( + agg_expr: &dyn AggregateExpr, + stats: &Statistics, +) -> Option<(ScalarValue, String)> { + if let Precision::Exact(num_rows) = &stats.num_rows { + match *num_rows { + 0 => { + // MIN/MAX with 0 rows is always null + if is_min(agg_expr) { + if let Ok(min_data_type) = + ScalarValue::try_from(agg_expr.field().unwrap().data_type()) + { + return Some((min_data_type, agg_expr.name().to_string())); + } + } + } + value if value > 0 => { + let col_stats = &stats.column_statistics; + if is_min(agg_expr) { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() + { + if let Precision::Exact(val) = + &col_stats[col_expr.index()].min_value + { + if !val.is_null() { + return Some(( + val.clone(), + agg_expr.name().to_string(), + )); + } + } + } + } + } + } + _ => {} + } + } + None +} + +/// If this agg_expr is a max that is exactly defined in the statistics, return it. +fn take_optimizable_max( + agg_expr: &dyn AggregateExpr, + stats: &Statistics, +) -> Option<(ScalarValue, String)> { + if let Precision::Exact(num_rows) = &stats.num_rows { + match *num_rows { + 0 => { + // MIN/MAX with 0 rows is always null + if is_max(agg_expr) { + if let Ok(max_data_type) = + ScalarValue::try_from(agg_expr.field().unwrap().data_type()) + { + return Some((max_data_type, agg_expr.name().to_string())); + } + } + } + value if value > 0 => { + let col_stats = &stats.column_statistics; + if is_max(agg_expr) { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() + { + if let Precision::Exact(val) = + &col_stats[col_expr.index()].max_value + { + if !val.is_null() { + return Some(( + val.clone(), + agg_expr.name().to_string(), + )); + } + } + } + } + } + } + _ => {} + } + } + None +} + +// TODO: Move this check into AggregateUDFImpl +// https://github.com/apache/datafusion/issues/11153 +fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool { + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { + return true; + } + } + false +} + +// TODO: Move this check into AggregateUDFImpl +// https://github.com/apache/datafusion/issues/11153 +fn is_min(agg_expr: &dyn AggregateExpr) -> bool { + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name().to_lowercase() == "min" { + return true; + } + } + false +} + +// TODO: Move this check into AggregateUDFImpl +// https://github.com/apache/datafusion/issues/11153 +fn is_max(agg_expr: &dyn AggregateExpr) -> bool { + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name().to_lowercase() == "max" { + return true; + } + } + false +} diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 6b9df7cad5c8..8108493a0d3b 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -17,6 +17,7 @@ // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] +pub mod aggregate_statistics; mod optimizer; pub mod output_requirements;