diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index ff3b66986ced..8c70b02a54fe 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,8 +17,9 @@ use std::sync::Arc; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ - AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, + AggregationFuzzerBuilder, DatasetGeneratorConfig, }; use arrow::array::{ @@ -85,6 +86,7 @@ async fn test_min() { .with_aggregate_function("min") // min works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -111,6 +113,7 @@ async fn test_first_val() { .with_table_name("fuzz_table") .with_aggregate_function("first_value") .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -137,6 +140,7 @@ async fn test_last_val() { .with_table_name("fuzz_table") .with_aggregate_function("last_value") .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -156,6 +160,7 @@ async fn test_max() { .with_aggregate_function("max") // max works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -176,6 +181,7 @@ async fn test_sum() { .with_distinct_aggregate_function("sum") // sum only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -196,6 +202,7 @@ async fn test_count() { .with_distinct_aggregate_function("count") // count work for all arguments .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -216,6 +223,7 @@ async fn test_median() { .with_distinct_aggregate_function("median") // median only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index 53e9288ab4af..58688ce7ee8d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -16,15 +16,14 @@ // under the License. use std::sync::Arc; -use std::{collections::HashSet, str::FromStr}; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion_common::{DataFusionError, Result}; use datafusion_common_runtime::JoinSet; -use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ check_equality_of_batches, context_generator::{SessionContextGenerator, SessionContextWithParams}, @@ -69,30 +68,16 @@ impl AggregationFuzzerBuilder { /// - 3 random queries /// - 3 random queries for each group by selected from the sort keys /// - 1 random query with no grouping - pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self { - const NUM_QUERIES: usize = 3; - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - // also add several queries limited to grouping on the group by columns only, if any - // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b` - if let Some(data_gen_config) = &self.data_gen_config { - for sort_keys in &data_gen_config.sort_keys_set { - let group_by_columns = sort_keys.iter().map(|s| s.as_str()); - query_builder = query_builder.set_group_by_columns(group_by_columns); - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - } - } - // also add a query with no grouping - query_builder = query_builder.set_group_by_columns(vec![]); - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); + pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self { + self = self.table_name(query_builder.table_name()); - self.table_name(query_builder.table_name()) + let sqls = query_builder + .generate_queries() + .into_iter() + .map(|sql| Arc::from(sql.as_str())); + self.candidate_sqls.extend(sqls); + + self } pub fn table_name(mut self, table_name: &str) -> Self { @@ -371,217 +356,3 @@ fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display pretty_format_batches(&to_print).unwrap() } - -/// Random aggregate query builder -/// -/// Creates queries like -/// ```sql -/// SELECT AGG(..) FROM table_name GROUP BY -///``` -#[derive(Debug, Default, Clone)] -pub struct QueryBuilder { - /// The name of the table to query - table_name: String, - /// Aggregate functions to be used in the query - /// (function_name, is_distinct) - aggregate_functions: Vec<(String, bool)>, - /// Columns to be used in group by - group_by_columns: Vec, - /// Possible columns for arguments in the aggregate functions - /// - /// Assumes each - arguments: Vec, -} -impl QueryBuilder { - pub fn new() -> Self { - Default::default() - } - - /// return the table name if any - pub fn table_name(&self) -> &str { - &self.table_name - } - - /// Set the table name for the query builder - pub fn with_table_name(mut self, table_name: impl Into) -> Self { - self.table_name = table_name.into(); - self - } - - /// Add a new possible aggregate function to the query builder - pub fn with_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), false)); - self - } - - /// Add a new possible `DISTINCT` aggregate function to the query - /// - /// This is different than `with_aggregate_function` because only certain - /// aggregates support `DISTINCT` - pub fn with_distinct_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), true)); - self - } - - /// Set the columns to be used in the group bys clauses - pub fn set_group_by_columns<'a>( - mut self, - group_by: impl IntoIterator, - ) -> Self { - self.group_by_columns = group_by.into_iter().map(String::from).collect(); - self - } - - /// Add one or more columns to be used as an argument in the aggregate functions - pub fn with_aggregate_arguments<'a>( - mut self, - arguments: impl IntoIterator, - ) -> Self { - let arguments = arguments.into_iter().map(String::from); - self.arguments.extend(arguments); - self - } - - pub fn generate_query(&self) -> String { - let group_by = self.random_group_by(); - let mut query = String::from("SELECT "); - query.push_str(&group_by.join(", ")); - if !group_by.is_empty() { - query.push_str(", "); - } - query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); - query.push_str(" FROM "); - query.push_str(&self.table_name); - if !group_by.is_empty() { - query.push_str(" GROUP BY "); - query.push_str(&group_by.join(", ")); - } - query - } - - /// Generate a some random aggregate function invocations (potentially repeating). - /// - /// Each aggregate function invocation is of the form - /// - /// ```sql - /// function_name( argument) as alias - /// ``` - /// - /// where - /// * `function_names` are randomly selected from [`Self::aggregate_functions`] - /// * ` argument` is randomly selected from [`Self::arguments`] - /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) - fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { - const MAX_NUM_FUNCTIONS: usize = 5; - let mut rng = thread_rng(); - let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); - - let mut alias_gen = 1; - - let mut aggregate_functions = vec![]; - - let mut order_by_black_list: HashSet = - group_by_cols.iter().cloned().collect(); - // remove one random col - if let Some(first) = order_by_black_list.iter().next().cloned() { - order_by_black_list.remove(&first); - } - - while aggregate_functions.len() < num_aggregate_functions { - let idx = rng.gen_range(0..self.aggregate_functions.len()); - let (function_name, is_distinct) = &self.aggregate_functions[idx]; - let argument = self.random_argument(); - let alias = format!("col{}", alias_gen); - let distinct = if *is_distinct { "DISTINCT " } else { "" }; - alias_gen += 1; - - let (order_by, null_opt) = if function_name.eq("first_value") - || function_name.eq("last_value") - { - ( - self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ - self.null_opt(), - ) - } else { - ("".to_string(), "".to_string()) - }; - - let function = format!( - "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" - ); - aggregate_functions.push(function); - } - aggregate_functions - } - - /// Pick a random aggregate function argument - fn random_argument(&self) -> String { - let mut rng = thread_rng(); - let idx = rng.gen_range(0..self.arguments.len()); - self.arguments[idx].clone() - } - - fn order_by(&self, black_list: &HashSet) -> String { - let mut available_columns: Vec = self - .arguments - .iter() - .filter(|col| !black_list.contains(*col)) - .cloned() - .collect(); - - available_columns.shuffle(&mut thread_rng()); - - let num_of_order_by_col = 12; - let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); - - let selected_columns = &available_columns[0..column_count]; - - let mut rng = thread_rng(); - let mut result = String::from_str(" order by ").unwrap(); - for col in selected_columns { - let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; - result.push_str(&format!("{} {},", col, order)); - } - - result.strip_suffix(",").unwrap().to_string() - } - - fn null_opt(&self) -> String { - if thread_rng().gen_bool(0.5) { - "RESPECT NULLS".to_string() - } else { - "IGNORE NULLS".to_string() - } - } - - /// Pick a random number of fields to group by (non-repeating) - /// - /// Limited to 3 group by columns to ensure coverage for large groups. With - /// larger numbers of columns, each group has many fewer values. - fn random_group_by(&self) -> Vec { - let mut rng = thread_rng(); - const MAX_GROUPS: usize = 3; - let max_groups = self.group_by_columns.len().max(MAX_GROUPS); - let num_group_by = rng.gen_range(1..max_groups); - - let mut already_used = HashSet::new(); - let mut group_by = vec![]; - while group_by.len() < num_group_by - && already_used.len() != self.group_by_columns.len() - { - let idx = rng.gen_range(0..self.group_by_columns.len()); - if already_used.insert(idx) { - group_by.push(self.group_by_columns[idx].clone()); - } - } - group_by - } -} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs index bfb3bb096326..04b764e46a96 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -43,6 +43,7 @@ use datafusion_common::error::Result; mod context_generator; mod data_generator; mod fuzzer; +pub mod query_builder; pub use crate::fuzz_cases::record_batch_generator::ColumnDescr; pub use data_generator::DatasetGeneratorConfig; diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs new file mode 100644 index 000000000000..df4730214f1a --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashSet, str::FromStr}; + +use rand::{seq::SliceRandom, thread_rng, Rng}; + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default, Clone)] +pub struct QueryBuilder { + // =================================== + // Table settings + // =================================== + /// The name of the table to query + table_name: String, + + // =================================== + // Grouping settings + // =================================== + /// Columns to be used in randomly generate `groupings` + /// + /// # Example + /// + /// Columns: + /// + /// ```text + /// [a,b,c,d] + /// ``` + /// + /// And randomly generated `groupings` (at least 1 column) + /// can be: + /// + /// ```text + /// [a] + /// [a,b] + /// [a,b,d] + /// ... + /// ``` + /// + /// So the finally generated sqls will be: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY a; + /// SELECT aggr FROM t GROUP BY a,b; + /// SELECT aggr FROM t GROUP BY a,b,d; + /// ... + /// ``` + group_by_columns: Vec, + + /// Max columns num in randomly generated `groupings` + max_group_by_columns: usize, + + /// Min columns num in randomly generated `groupings` + min_group_by_columns: usize, + + /// The sort keys of dataset + /// + /// Due to optimizations will be triggered when all or some + /// grouping columns are the sort keys of dataset. + /// So it is necessary to randomly generate some `groupings` basing on + /// dataset sort keys for test coverage. + /// + /// # Example + /// + /// Dataset including columns [a,b,c], and sorted by [a,b] + /// + /// And we may generate sqls to try covering the sort-optimization cases like: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY b; // no permutation case + /// SELECT aggr FROM t GROUP BY a,c; // partial permutation case + /// SELECT aggr FROM t GROUP BY a,b,c; // full permutation case + /// ... + /// ``` + /// + /// More details can see [`GroupOrdering`]. + /// + /// [`GroupOrdering`]: datafusion_physical_plan::aggregates::order::GroupOrdering + /// + dataset_sort_keys: Vec>, + + /// If we will also test the no grouping case like: + /// + /// ```text + /// SELECT aggr FROM t; + /// ``` + /// + no_grouping: bool, + + // ==================================== + // Aggregation function settings + // ==================================== + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} + +impl QueryBuilder { + pub fn new() -> Self { + Self { + no_grouping: true, + max_group_by_columns: 5, + min_group_by_columns: 1, + ..Default::default() + } + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Set the columns to be used in the group bys clauses + pub fn set_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + self.group_by_columns = group_by.into_iter().map(String::from).collect(); + self + } + + /// Add one or more columns to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + /// Add max columns num in group by(default: 3), for example if it is set to 1, + /// the generated sql will group by at most 1 column + #[allow(dead_code)] + pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { + self.max_group_by_columns = max_group_by_columns; + self + } + + #[allow(dead_code)] + pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { + self.min_group_by_columns = min_group_by_columns; + self + } + + /// Add sort keys of dataset if any, then the builder will generate queries basing on it + /// to cover the sort-optimization cases + pub fn with_dataset_sort_keys(mut self, dataset_sort_keys: Vec>) -> Self { + self.dataset_sort_keys = dataset_sort_keys; + self + } + + /// Add if also test the no grouping aggregation case(default: true) + #[allow(dead_code)] + pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { + self.no_grouping = no_grouping; + self + } + + pub fn generate_queries(mut self) -> Vec { + const NUM_QUERIES: usize = 3; + let mut sqls = Vec::new(); + + // Add several queries group on randomly picked columns + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + + // Also add several queries limited to grouping on the group by + // dataset sorted columns only, if any. + // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b`. + if !self.dataset_sort_keys.is_empty() { + let dataset_sort_keys = self.dataset_sort_keys.clone(); + for sort_keys in dataset_sort_keys { + let group_by_columns = sort_keys.iter().map(|s| s.as_str()); + self = self.set_group_by_columns(group_by_columns); + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + } + } + + // Also add a query with no grouping + if self.no_grouping { + self = self.set_group_by_columns(vec![]); + let sql = self.generate_query(); + sqls.push(sql); + } + + sqls + } + + fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + dbg!(&group_by); + let mut query = String::from("SELECT "); + query.push_str(&group_by.join(", ")); + if !group_by.is_empty() { + query.push_str(", "); + } + query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = thread_rng(); + let num_aggregate_functions = rng.gen_range(1..=MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + + let mut order_by_black_list: HashSet = + group_by_cols.iter().cloned().collect(); + // remove one random col + if let Some(first) = order_by_black_list.iter().next().cloned() { + order_by_black_list.remove(&first); + } + + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.gen_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{}", alias_gen); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + + let (order_by, null_opt) = if function_name.eq("first_value") + || function_name.eq("last_value") + { + ( + self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ + self.null_opt(), + ) + } else { + ("".to_string(), "".to_string()) + }; + + let function = format!( + "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" + ); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = thread_rng(); + let idx = rng.gen_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + fn order_by(&self, black_list: &HashSet) -> String { + let mut available_columns: Vec = self + .arguments + .iter() + .filter(|col| !black_list.contains(*col)) + .cloned() + .collect(); + + available_columns.shuffle(&mut thread_rng()); + + let num_of_order_by_col = 12; + let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); + + let selected_columns = &available_columns[0..column_count]; + + let mut rng = thread_rng(); + let mut result = String::from_str(" order by ").unwrap(); + for col in selected_columns { + let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; + result.push_str(&format!("{} {},", col, order)); + } + + result.strip_suffix(",").unwrap().to_string() + } + + fn null_opt(&self) -> String { + if thread_rng().gen_bool(0.5) { + "RESPECT NULLS".to_string() + } else { + "IGNORE NULLS".to_string() + } + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to `max_group_by_columns` group by columns to ensure coverage for large groups. + /// With larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = thread_rng(); + let min_groups = self.min_group_by_columns; + let max_groups = self.max_group_by_columns; + assert!(min_groups <= max_groups); + let num_group_by = rng.gen_range(min_groups..=max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by + && already_used.len() != self.group_by_columns.len() + { + let idx = rng.gen_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +}