diff --git a/benchmarks/queries/clickbench/queries.sql b/benchmarks/queries/clickbench/queries.sql index 52e72e02e1e0..6b871afa8956 100644 --- a/benchmarks/queries/clickbench/queries.sql +++ b/benchmarks/queries/clickbench/queries.sql @@ -1,43 +1,2 @@ -SELECT COUNT(*) FROM hits; -SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; -SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; -SELECT AVG("UserID") FROM hits; -SELECT COUNT(DISTINCT "UserID") FROM hits; -SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; -SELECT MIN("EventDate"::INT::DATE), MAX("EventDate"::INT::DATE) FROM hits; -SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; -SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; -SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; -SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; -SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; -SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; -SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; -SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime"), "SearchPhrase" LIMIT 10; -SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; -SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; -SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; -SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; -SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "URLHash", "EventDate"::INT::DATE, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate"::INT::DATE ORDER BY PageViews DESC LIMIT 10 OFFSET 100; -SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; -SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; +SELECT "UserID", concat("SearchPhrase", repeat('hello', 20)) as s, COUNT(*) FROM hits GROUP BY "UserID", s ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 09b90a56d2aa..dc793f2e393d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -184,6 +184,10 @@ name = "math_query_sql" harness = false name = "filter_query_sql" +[[bench]] +harness = false +name = "reuse_hash" + [[bench]] harness = false name = "window_query_sql" diff --git a/datafusion/core/benches/reuse_hash.rs b/datafusion/core/benches/reuse_hash.rs new file mode 100644 index 000000000000..4757ebe77a07 --- /dev/null +++ b/datafusion/core/benches/reuse_hash.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, +}; +use arrow_array::{Int64Array, StringArray}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::prelude::SessionContext; +use datafusion::{datasource::MemTable, error::Result}; +use futures::executor::block_on; +use std::sync::Arc; +use tokio::runtime::Runtime; + +async fn query(ctx: &mut SessionContext, sql: &str) { + let rt = Runtime::new().unwrap(); + + // execute the query + let df = rt.block_on(ctx.sql(sql)).unwrap(); + criterion::black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context(array_len: usize, batch_size: usize) -> Result { + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + ])); + + // define data. + let batches = (0..array_len / batch_size) + .map(|_i| { + let data1 = (0..batch_size) + .into_iter() + .map(|x| x as i64) + .collect::>(); + let data2 = (0..batch_size) + .into_iter() + .map(|j| format!("a{j}")) + .collect::>(); + + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(data1)), + Arc::new(StringArray::from(data2)), + ], + ) + .unwrap() + }) + .collect::>(); + + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![batches])?; + ctx.register_table("t", Arc::new(provider))?; + + Ok(ctx) +} + +fn criterion_benchmark(c: &mut Criterion) { + let array_len = 2000000; // 2M rows + let batch_size = array_len; + + c.bench_function("benchmark", |b| { + let mut ctx = create_context(array_len, batch_size).unwrap(); + b.iter(|| block_on(query(&mut ctx, "select a, b, count(*) from t group by a, b order by count(*) desc limit 10"))) + }); +} + +criterion_group! { + name = benches; + // This can be any expression that returns a `Criterion` object. + config = Criterion::default().sample_size(10); + targets = criterion_benchmark +} +criterion_main!(benches); + +// reuse-hash +// benchmark time: [2.5999 s 6.3132 s 11.062 s] +// Found 1 outliers among 10 measurements (10.00%) + +// main +// benchmark time: [4.1404 s 8.4601 s 13.226 s] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index d83a47ceb069..34b2c6af9fe5 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -29,7 +29,7 @@ use arrow::{ }, record_batch::RecordBatch, }; -use arrow_array::{Array, Float32Array, Float64Array, UnionArray}; +use arrow_array::{Array, Float32Array, Float64Array, Int64Array, UnionArray}; use arrow_buffer::ScalarBuffer; use arrow_schema::{ArrowError, UnionFields, UnionMode}; use datafusion_functions_aggregate::count::count_udaf; diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index edf608a2054f..18855e321321 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -66,10 +66,11 @@ impl ArrowBytesSet { /// Inserts each value from `values` into the set pub fn insert(&mut self, values: &ArrayRef) { - fn make_payload_fn(_value: Option<&[u8]>) {} + fn make_payload_fn(_value: Option<&[u8]>, _hash: u64) {} fn observe_payload_fn(_payload: ()) {} + self.0 - .insert_if_new(values, make_payload_fn, observe_payload_fn); + .insert_if_new(&[], values, make_payload_fn, observe_payload_fn); } /// Converts this set into a `StringArray`/`LargeStringArray` or @@ -291,11 +292,12 @@ where /// with valid values from `values`, not for the `NULL` value. pub fn insert_if_new( &mut self, + batch_hashes: &[u64], values: &ArrayRef, make_payload_fn: MP, observe_payload_fn: OP, ) where - MP: FnMut(Option<&[u8]>) -> V, + MP: FnMut(Option<&[u8]>, u64) -> V, OP: FnMut(V), { // Sanity array type @@ -306,6 +308,7 @@ where DataType::Binary | DataType::LargeBinary )); self.insert_if_new_inner::>( + batch_hashes, values, make_payload_fn, observe_payload_fn, @@ -317,6 +320,7 @@ where DataType::Utf8 | DataType::LargeUtf8 )); self.insert_if_new_inner::>( + batch_hashes, values, make_payload_fn, observe_payload_fn, @@ -336,22 +340,28 @@ where /// See comments on `insert_if_new` for more details fn insert_if_new_inner( &mut self, + batch_hashes: &[u64], values: &ArrayRef, mut make_payload_fn: MP, mut observe_payload_fn: OP, ) where - MP: FnMut(Option<&[u8]>) -> V, + MP: FnMut(Option<&[u8]>, u64) -> V, OP: FnMut(V), B: ByteArrayType, { // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); + let batch_hashes = if batch_hashes.is_empty() { + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + batch_hashes + } else { + batch_hashes + }; // step 2: insert each value into the set, if not already present let values = values.as_bytes::(); @@ -365,7 +375,7 @@ where let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload } else { - let payload = make_payload_fn(None); + let payload = make_payload_fn(None, hash); let null_index = self.offsets.len() - 1; // nulls need a zero length in the offset buffer let offset = self.buffer.len(); @@ -406,7 +416,7 @@ where // comparison self.buffer.append_slice(value); self.offsets.push(O::usize_as(self.buffer.len())); - let payload = make_payload_fn(Some(value)); + let payload = make_payload_fn(Some(value), hash); let new_header = Entry { hash, len: value_len, @@ -448,7 +458,7 @@ where self.buffer.append_slice(value); self.offsets.push(O::usize_as(self.buffer.len())); - let payload = make_payload_fn(Some(value)); + let payload = make_payload_fn(Some(value), hash); let new_header = Entry { hash, len: value_len, @@ -954,8 +964,9 @@ mod tests { let mut seen_new_strings = vec![]; let mut seen_indexes = vec![]; self.map.insert_if_new( + &[], &arr, - |s| { + |s, _hash| { let value = s .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string")); let index = next_index; diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 18bc6801aa60..05317875e6a0 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -42,10 +42,10 @@ impl ArrowBytesViewSet { /// Inserts each value from `values` into the set pub fn insert(&mut self, values: &ArrayRef) { - fn make_payload_fn(_value: Option<&[u8]>) {} + fn make_payload_fn(_value: Option<&[u8]>, _hash: u64) {} fn observe_payload_fn(_payload: ()) {} self.0 - .insert_if_new(values, make_payload_fn, observe_payload_fn); + .insert_if_new(&[], values, make_payload_fn, observe_payload_fn); } /// Return the contents of this map and replace it with a new empty map with @@ -192,11 +192,12 @@ where /// with valid values from `values`, not for the `NULL` value. pub fn insert_if_new( &mut self, + batch_hashes: &[u64], values: &ArrayRef, make_payload_fn: MP, observe_payload_fn: OP, ) where - MP: FnMut(Option<&[u8]>) -> V, + MP: FnMut(Option<&[u8]>, u64) -> V, OP: FnMut(V), { // Sanity check array type @@ -204,6 +205,7 @@ where OutputType::BinaryView => { assert!(matches!(values.data_type(), DataType::BinaryView)); self.insert_if_new_inner::( + batch_hashes, values, make_payload_fn, observe_payload_fn, @@ -212,6 +214,7 @@ where OutputType::Utf8View => { assert!(matches!(values.data_type(), DataType::Utf8View)); self.insert_if_new_inner::( + batch_hashes, values, make_payload_fn, observe_payload_fn, @@ -231,22 +234,28 @@ where /// See comments on `insert_if_new` for more details fn insert_if_new_inner( &mut self, + batch_hashes: &[u64], values: &ArrayRef, mut make_payload_fn: MP, mut observe_payload_fn: OP, ) where - MP: FnMut(Option<&[u8]>) -> V, + MP: FnMut(Option<&[u8]>, u64) -> V, OP: FnMut(V), B: ByteViewType, { // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); + let batch_hashes = if batch_hashes.is_empty() { + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + batch_hashes + } else { + batch_hashes + }; // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); @@ -260,7 +269,7 @@ where let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload } else { - let payload = make_payload_fn(None); + let payload = make_payload_fn(None, hash); let null_index = self.builder.len(); self.builder.append_null(); self.null = Some((payload, null_index)); @@ -287,7 +296,7 @@ where entry.payload } else { // no existing value, make a new one. - let payload = make_payload_fn(Some(value)); + let payload = make_payload_fn(Some(value), hash); let inner_view_idx = self.builder.len(); let new_header = Entry { @@ -632,8 +641,9 @@ mod tests { let mut seen_new_strings = vec![]; let mut seen_indexes = vec![]; self.map.insert_if_new( + &[], &arr, - |s| { + |s, _hash| { let value = s .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string")); let index = next_index; diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index f789af8b8a02..13aac62f9e1e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -15,8 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::GroupValues; -use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; +use std::sync::Arc; + +use crate::aggregates::{group_values::GroupValues, AggregateMode}; +use ahash::RandomState; +use arrow::{array::AsArray, datatypes::UInt64Type}; +use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch, UInt64Array}; +use datafusion_common::hash_utils::create_hashes; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; @@ -29,6 +34,11 @@ pub struct GroupValuesByes { map: ArrowBytesMap, /// The total number of groups so far (used to assign group_index) num_groups: usize, + /// random state used to generate hashes + random_state: RandomState, + /// buffer that stores hash values (reused across batches to save allocations) + hashes_buffer: Vec, + group_hashes: Option>, } impl GroupValuesByes { @@ -36,6 +46,9 @@ impl GroupValuesByes { Self { map: ArrowBytesMap::new(output_type), num_groups: 0, + random_state: RandomState::with_seeds(0, 0, 0, 0), + hashes_buffer: Default::default(), + group_hashes: Default::default(), } } } @@ -45,6 +58,7 @@ impl GroupValues for GroupValuesByes { &mut self, cols: &[ArrayRef], groups: &mut Vec, + hash_values: Option<&ArrayRef>, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -52,13 +66,36 @@ impl GroupValues for GroupValuesByes { let arr = &cols[0]; groups.clear(); + + let mut store_gp_hashes = match self.group_hashes.take() { + Some(group_hashes) => group_hashes, + None => vec![], + }; + + let batch_hashes = if let Some(hash_values) = hash_values { + let hash_array = hash_values.as_primitive::(); + hash_array.values().as_ref() + } else { + // step 1: compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(arr.len(), 0); + create_hashes(&[arr.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + batch_hashes + }; + self.map.insert_if_new( + batch_hashes, arr, // called for each new group - |_value| { + |_value, hash| { // assign new group index on each insert let group_idx = self.num_groups; self.num_groups += 1; + store_gp_hashes.push(hash); group_idx }, // called for each group @@ -67,6 +104,8 @@ impl GroupValues for GroupValuesByes { }, ); + self.group_hashes = Some(store_gp_hashes); + // ensure we assigned a group to for each row assert_eq!(groups.len(), arr.len()); Ok(()) @@ -84,10 +123,19 @@ impl GroupValues for GroupValuesByes { self.num_groups } - fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + fn emit( + &mut self, + emit_to: EmitTo, + mode: AggregateMode, + ) -> datafusion_common::Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); + let mut group_hashes = self + .group_hashes + .take() + .expect("Can not emit from empty rows for hashes"); + let group_values = match emit_to { EmitTo::All => { self.num_groups -= map_contents.len(); @@ -106,9 +154,18 @@ impl GroupValues for GroupValuesByes { let remaining_group_values = map_contents.slice(n, map_contents.len() - n); + let remaining_group_hashes = group_hashes.split_off(n); + let hash_array = + Arc::new(UInt64Array::from(remaining_group_hashes)) as ArrayRef; + self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + + self.intern( + &[remaining_group_values], + &mut group_indexes, + Some(&hash_array), + )?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); @@ -117,7 +174,13 @@ impl GroupValues for GroupValuesByes { } }; - Ok(vec![group_values]) + let mut output = vec![group_values]; + if mode == AggregateMode::Partial { + let arr = Arc::new(UInt64Array::from(group_hashes)) as ArrayRef; + output.push(arr); + } + + Ok(output) } fn clear_shrink(&mut self, _batch: &RecordBatch) { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs index 1a0cb90a16d4..995ab34e4588 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -15,8 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::GroupValues; -use arrow_array::{Array, ArrayRef, RecordBatch}; +use std::sync::Arc; + +use crate::aggregates::{group_values::GroupValues, AggregateMode}; +use ahash::RandomState; +use arrow::{array::AsArray, datatypes::UInt64Type}; +use arrow_array::{Array, ArrayRef, RecordBatch, UInt64Array}; +use datafusion_common::hash_utils::create_hashes; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; @@ -30,6 +35,11 @@ pub struct GroupValuesBytesView { map: ArrowBytesViewMap, /// The total number of groups so far (used to assign group_index) num_groups: usize, + /// random state used to generate hashes + random_state: RandomState, + /// buffer that stores hash values (reused across batches to save allocations) + hashes_buffer: Vec, + group_hashes: Option>, } impl GroupValuesBytesView { @@ -37,6 +47,9 @@ impl GroupValuesBytesView { Self { map: ArrowBytesViewMap::new(output_type), num_groups: 0, + random_state: RandomState::with_seeds(0, 0, 0, 0), + hashes_buffer: Default::default(), + group_hashes: Default::default(), } } } @@ -46,6 +59,7 @@ impl GroupValues for GroupValuesBytesView { &mut self, cols: &[ArrayRef], groups: &mut Vec, + hash_values: Option<&ArrayRef>, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -53,13 +67,36 @@ impl GroupValues for GroupValuesBytesView { let arr = &cols[0]; groups.clear(); + + let mut store_gp_hashes = match self.group_hashes.take() { + Some(group_hashes) => group_hashes, + None => vec![], + }; + + let batch_hashes = if let Some(hash_values) = hash_values { + let hash_array = hash_values.as_primitive::(); + hash_array.values().as_ref() + } else { + // step 1: compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(arr.len(), 0); + create_hashes(&[arr.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + batch_hashes + }; + self.map.insert_if_new( + batch_hashes, arr, // called for each new group - |_value| { + |_value, hash| { // assign new group index on each insert let group_idx = self.num_groups; self.num_groups += 1; + store_gp_hashes.push(hash); group_idx }, // called for each group @@ -68,6 +105,8 @@ impl GroupValues for GroupValuesBytesView { }, ); + self.group_hashes = Some(store_gp_hashes); + // ensure we assigned a group to for each row assert_eq!(groups.len(), arr.len()); Ok(()) @@ -85,10 +124,19 @@ impl GroupValues for GroupValuesBytesView { self.num_groups } - fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + fn emit( + &mut self, + emit_to: EmitTo, + mode: AggregateMode, + ) -> datafusion_common::Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); + let mut group_hashes = self + .group_hashes + .take() + .expect("Can not emit from empty rows for hashes"); + let group_values = match emit_to { EmitTo::All => { self.num_groups -= map_contents.len(); @@ -107,9 +155,17 @@ impl GroupValues for GroupValuesBytesView { let remaining_group_values = map_contents.slice(n, map_contents.len() - n); + let remaining_group_hashes = group_hashes.split_off(n); + let hash_array = + Arc::new(UInt64Array::from(remaining_group_hashes)) as ArrayRef; + self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + self.intern( + &[remaining_group_values], + &mut group_indexes, + Some(&hash_array), + )?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); @@ -118,7 +174,13 @@ impl GroupValues for GroupValuesBytesView { } }; - Ok(vec![group_values]) + let mut output = vec![group_values]; + if mode == AggregateMode::Partial { + let arr = Arc::new(UInt64Array::from(group_hashes)) as ArrayRef; + output.push(arr); + } + + Ok(output) } fn clear_shrink(&mut self, _batch: &RecordBatch) { diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index be7ac934d7bc..e471821e680a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -33,10 +33,17 @@ mod bytes_view; use bytes::GroupValuesByes; use datafusion_physical_expr::binary_map::OutputType; +use super::AggregateMode; + /// An interning store for group keys pub trait GroupValues: Send { /// Calculates the `groups` for each input row of `cols` - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hash_values: Option<&ArrayRef>, + ) -> Result<()>; /// Returns the number of bytes used by this [`GroupValues`] fn size(&self) -> usize; @@ -48,7 +55,7 @@ pub trait GroupValues: Send { fn len(&self) -> usize; /// Emits the group values - fn emit(&mut self, emit_to: EmitTo) -> Result>; + fn emit(&mut self, emit_to: EmitTo, mode: AggregateMode) -> Result>; /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, batch: &RecordBatch); diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index d5b7f1b11ac5..a05d349233b8 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -16,13 +16,16 @@ // under the License. use crate::aggregates::group_values::GroupValues; +use crate::aggregates::AggregateMode; use ahash::RandomState; use arrow::array::BooleanBufferBuilder; use arrow::buffer::NullBuffer; use arrow::datatypes::i256; use arrow::record_batch::RecordBatch; use arrow_array::cast::AsArray; -use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray}; +use arrow_array::{ + ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, UInt64Array, +}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_schema::DataType; use datafusion_common::Result; @@ -92,6 +95,8 @@ pub struct GroupValuesPrimitive { values: Vec, /// The random state used to generate hashes random_state: RandomState, + + group_hashes: Option>, } impl GroupValuesPrimitive { @@ -102,7 +107,8 @@ impl GroupValuesPrimitive { map: RawTable::with_capacity(128), values: Vec::with_capacity(128), null_group: None, - random_state: Default::default(), + random_state: ahash::RandomState::with_seeds(0, 0, 0, 0), + group_hashes: Default::default(), } } } @@ -111,15 +117,27 @@ impl GroupValues for GroupValuesPrimitive where T::Native: HashValue, { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _hash_values: Option<&ArrayRef>, + ) -> Result<()> { assert_eq!(cols.len(), 1); groups.clear(); + let mut store_gp_hashes = match self.group_hashes.take() { + Some(group_hashes) => group_hashes, + None => vec![], + }; + for v in cols[0].as_primitive::() { let group_id = match v { None => *self.null_group.get_or_insert_with(|| { let group_id = self.values.len(); self.values.push(Default::default()); + // hash value 0 for null + store_gp_hashes.push(0); group_id }), Some(key) => { @@ -139,6 +157,7 @@ where let g = self.values.len(); self.map.insert_in_slot(hash, slot, g); self.values.push(key); + store_gp_hashes.push(hash); g } } @@ -147,6 +166,9 @@ where }; groups.push(group_id) } + + self.group_hashes = Some(store_gp_hashes); + Ok(()) } @@ -162,7 +184,7 @@ where self.values.len() } - fn emit(&mut self, emit_to: EmitTo) -> Result> { + fn emit(&mut self, emit_to: EmitTo, mode: AggregateMode) -> Result> { fn build_primitive( values: Vec, null_idx: Option, @@ -176,12 +198,21 @@ where PrimitiveArray::::new(values.into(), nulls) } + let mut group_hashes = self + .group_hashes + .take() + .expect("Can not emit from empty rows"); + + let mut remaining_group_hashes = vec![]; + let array: PrimitiveArray = match emit_to { EmitTo::All => { self.map.clear(); build_primitive(std::mem::take(&mut self.values), self.null_group.take()) } EmitTo::First(n) => { + remaining_group_hashes = group_hashes.split_off(n); + // SAFETY: self.map outlives iterator and is not modified concurrently unsafe { for bucket in self.map.iter() { @@ -207,7 +238,18 @@ where build_primitive(split, null_group) } }; - Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + + let mut output = + vec![Arc::new(array.with_data_type(self.data_type.clone())) as ArrayRef]; + + if mode == AggregateMode::Partial { + let arr = Arc::new(UInt64Array::from(group_hashes)) as ArrayRef; + output.push(arr); + } + + self.group_hashes = Some(remaining_group_hashes); + + Ok(output) } fn clear_shrink(&mut self, batch: &RecordBatch) { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 8c2a4ba5c497..d65e100042e6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::aggregates::group_values::GroupValues; +use crate::aggregates::AggregateMode; use ahash::RandomState; +use arrow::array::AsArray; use arrow::compute::cast; +use arrow::datatypes::UInt64Type; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; +use arrow_array::{Array, ArrayRef, UInt64Array}; use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::{DataFusionError, Result}; @@ -58,6 +63,7 @@ pub struct GroupValuesRows { /// /// [`Row`]: arrow::row::Row group_values: Option, + group_hashes: Option>, /// reused buffer to store hashes hashes_buffer: Vec, @@ -91,15 +97,21 @@ impl GroupValuesRows { map, map_size: 0, group_values: None, + group_hashes: Default::default(), hashes_buffer: Default::default(), rows_buffer, - random_state: Default::default(), + random_state: RandomState::with_seeds(0, 0, 0, 0), }) } } impl GroupValues for GroupValuesRows { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hash_values: Option<&ArrayRef>, + ) -> Result<()> { // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; group_rows.clear(); @@ -111,45 +123,86 @@ impl GroupValues for GroupValuesRows { None => self.row_converter.empty_rows(0, 0), }; + let mut store_gp_hashes = match self.group_hashes.take() { + Some(group_hashes) => group_hashes, + None => vec![], + }; + // tracks to which group each of the input rows belongs groups.clear(); - // 1.1 Calculate the group keys for the group values - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, batch_hashes)?; - - for (row, &hash) in batch_hashes.iter().enumerate() { - let entry = self.map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - group_rows.row(row) == group_values.row(*group_idx) - }); - - let group_idx = match entry { - // Existing group_index for this group value - Some((_hash, group_idx)) => *group_idx, - // 1.2 Need to create new entry for the group - None => { - // Add new entry to aggr_state and save newly created index - let group_idx = group_values.num_rows(); - group_values.push(group_rows.row(row)); - - // for hasher function, use precomputed hash value - self.map.insert_accounted( - (hash, group_idx), - |(hash, _group_index)| *hash, - &mut self.map_size, - ); - group_idx - } - }; - groups.push(group_idx); + if let Some(hash_values) = hash_values { + let hash_array = hash_values.as_primitive::(); + for (row, hash) in hash_array.iter().enumerate() { + let hash = hash.unwrap(); + let entry = self.map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + group_rows.row(row) == group_values.row(*group_idx) + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + let group_idx = group_values.num_rows(); + group_values.push(group_rows.row(row)); + store_gp_hashes.push(hash); + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } + } else { + // 1.1 Calculate the group keys for the group values + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, batch_hashes)?; + + for (row, &hash) in batch_hashes.iter().enumerate() { + let entry = self.map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + group_rows.row(row) == group_values.row(*group_idx) + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + let group_idx = group_values.num_rows(); + group_values.push(group_rows.row(row)); + store_gp_hashes.push(hash); + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } } self.group_values = Some(group_values); + self.group_hashes = Some(store_gp_hashes); Ok(()) } @@ -174,21 +227,40 @@ impl GroupValues for GroupValuesRows { .unwrap_or(0) } - fn emit(&mut self, emit_to: EmitTo) -> Result> { + fn emit(&mut self, emit_to: EmitTo, mode: AggregateMode) -> Result> { let mut group_values = self .group_values .take() - .expect("Can not emit from empty rows"); + .expect("Can not emit from empty rows for values"); + + let mut group_hashes = self + .group_hashes + .take() + .expect("Can not emit from empty rows for hashes"); let mut output = match emit_to { EmitTo::All => { - let output = self.row_converter.convert_rows(&group_values)?; + let mut output = self.row_converter.convert_rows(&group_values)?; + if mode == AggregateMode::Partial { + let gh = std::mem::take(&mut group_hashes); + let arr = Arc::new(UInt64Array::from(gh)) as ArrayRef; + output.push(arr); + } group_values.clear(); + output } EmitTo::First(n) => { let groups_rows = group_values.iter().take(n); - let output = self.row_converter.convert_rows(groups_rows)?; + let mut output = self.row_converter.convert_rows(groups_rows)?; + + if mode == AggregateMode::Partial { + let remain_group_hashes = group_hashes.split_off(n); + let arr = Arc::new(UInt64Array::from(group_hashes)) as ArrayRef; + group_hashes = remain_group_hashes; + output.push(arr); + } + // Clear out first n group keys by copying them to a new Rows. // TODO file some ticket in arrow-rs to make this more efficient? let mut new_group_values = self.row_converter.empty_rows(0, 0); @@ -228,6 +300,7 @@ impl GroupValues for GroupValuesRows { } self.group_values = Some(group_values); + self.group_hashes = Some(group_hashes); Ok(output) } @@ -237,6 +310,10 @@ impl GroupValues for GroupValuesRows { rows.clear(); rows }); + self.group_hashes = self.group_hashes.take().map(|mut gp_hashes| { + gp_hashes.clear(); + gp_hashes + }); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d1152038eb2a..7ecc7edf441f 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -25,6 +25,7 @@ use crate::aggregates::{ no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; +use crate::common::GROUP_HASH_VALUE_COLUMN_NAME; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::get_ordered_partition_by_indices; use crate::{ @@ -35,6 +36,7 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; @@ -253,6 +255,7 @@ pub struct AggregateExec { pub input: Arc, /// Schema after the aggregate is applied schema: SchemaRef, + plain_schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the /// same as input.schema() but for the final aggregate it will be the same as the input /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`. @@ -265,6 +268,7 @@ pub struct AggregateExec { /// Describes how the input is ordered relative to the group by columns input_order_mode: InputOrderMode, cache: PlanProperties, + plain_cache: PlanProperties, } impl AggregateExec { @@ -279,12 +283,14 @@ impl AggregateExec { metrics: ExecutionPlanMetricsSet::new(), input_order_mode: self.input_order_mode.clone(), cache: self.cache.clone(), + plain_cache: self.plain_cache.clone(), mode: self.mode, group_by: self.group_by.clone(), filter_expr: self.filter_expr.clone(), limit: self.limit, input: Arc::clone(&self.input), schema: Arc::clone(&self.schema), + plain_schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), } } @@ -308,9 +314,20 @@ impl AggregateExec { &aggr_expr, group_by.contains_null(), mode, + false, + )?; + // HACK for GroupStreamTopK + let plain_schema = create_schema( + &input.schema(), + &group_by.expr, + &aggr_expr, + group_by.contains_null(), + mode, + true, )?; let schema = Arc::new(schema); + let plain_schema = Arc::new(plain_schema); AggregateExec::try_new_with_schema( mode, group_by, @@ -319,6 +336,7 @@ impl AggregateExec { input, input_schema, schema, + plain_schema, ) } @@ -339,6 +357,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, schema: SchemaRef, + plain_schema: SchemaRef, ) -> Result { // Make sure arguments are consistent in size if aggr_expr.len() != filter_expr.len() { @@ -404,6 +423,13 @@ impl AggregateExec { &mode, &input_order_mode, ); + let plain_cache = Self::compute_properties( + &input, + Arc::clone(&plain_schema), + &projection_mapping, + &mode, + &input_order_mode, + ); Ok(AggregateExec { mode, @@ -412,12 +438,14 @@ impl AggregateExec { filter_expr, input, schema, + plain_schema, input_schema, metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, limit: None, input_order_mode, cache, + plain_cache, }) } @@ -429,6 +457,17 @@ impl AggregateExec { /// Set the `limit` of this AggExec pub fn with_limit(mut self, limit: Option) -> Self { self.limit = limit; + + // HACK: remove hash_value in schema, since we can't identify whether is has limit while creating schema + std::mem::swap(&mut self.schema, &mut self.plain_schema); + std::mem::swap(&mut self.cache, &mut self.plain_cache); + + // revert back + if self.is_unordered_unfiltered_group_by_distinct() { + std::mem::swap(&mut self.schema, &mut self.plain_schema); + std::mem::swap(&mut self.cache, &mut self.plain_cache); + } + self } /// Grouping expressions @@ -709,6 +748,7 @@ impl ExecutionPlan for AggregateExec { Arc::clone(&children[0]), Arc::clone(&self.input_schema), Arc::clone(&self.schema), + Arc::clone(&self.plain_schema), )?; me.limit = self.limit; @@ -782,7 +822,10 @@ fn create_schema( aggr_expr: &[Arc], contains_null_expr: bool, mode: AggregateMode, + no_hash_value: bool, ) -> Result { + // let group_schema = group_schema(&input_schema, group_expr.len()); + let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); for (expr, name) in group_expr { fields.push(Field::new( @@ -797,6 +840,14 @@ fn create_schema( match mode { AggregateMode::Partial => { + if !group_expr.is_empty() && !no_hash_value { + fields.push(Field::new( + GROUP_HASH_VALUE_COLUMN_NAME, + DataType::UInt64, + true, + )); + } + // in partial mode, the fields of the accumulator's state for expr in aggr_expr { fields.extend(expr.state_fields()?.iter().cloned()) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 167ca7240750..6a15190dcbbb 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -27,7 +27,7 @@ use crate::aggregates::{ evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, PhysicalGroupBy, }; -use crate::common::IPCWriter; +use crate::common::{IPCWriter, GROUP_HASH_VALUE_COLUMN_NAME}; use crate::metrics::{BaselineMetrics, RecordOutput}; use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge; @@ -302,13 +302,13 @@ impl GroupedHashAggregateStream { let aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &agg.mode, - agg_group_by.expr.len(), + agg_group_by.expr.len() + 1, // +1 for hash values )?; // arguments for aggregating spilled data is the same as the one for final aggregation let merging_aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &AggregateMode::Final, - agg_group_by.expr.len(), + agg_group_by.expr.len() + 1, )?; let filter_expressions = match agg.mode { @@ -537,6 +537,8 @@ impl GroupedHashAggregateStream { evaluate_many(&self.aggregate_arguments, &batch)? }; + let hash_values = batch.column_by_name(GROUP_HASH_VALUE_COLUMN_NAME); + // Evaluate the filter expressions, if any, against the inputs let filter_values = if self.spill_state.is_stream_merging { let filter_expressions = vec![None; self.accumulators.len()]; @@ -548,8 +550,12 @@ impl GroupedHashAggregateStream { for group_values in &group_by_values { // calculate the group indices for each input row let starting_num_groups = self.group_values.len(); - self.group_values - .intern(group_values, &mut self.current_group_indices)?; + + self.group_values.intern( + group_values, + &mut self.current_group_indices, + hash_values, + )?; let group_indices = &self.current_group_indices; // Update ordering information if necessary @@ -634,7 +640,7 @@ impl GroupedHashAggregateStream { return Ok(RecordBatch::new_empty(schema)); } - let mut output = self.group_values.emit(emit_to)?; + let mut output = self.group_values.emit(emit_to, self.mode)?; if let EmitTo::First(n) = emit_to { self.group_ordering.remove_groups(n); } @@ -658,6 +664,7 @@ impl GroupedHashAggregateStream { // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is // over the target memory size after emission, we can emit again rather than returning Err. let _ = self.update_memory_reservation(); + let batch = RecordBatch::try_new(schema, output)?; Ok(batch) } diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 4b5eea6b760d..f4d661b0495b 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -40,6 +40,8 @@ use parking_lot::Mutex; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; +pub(crate) const GROUP_HASH_VALUE_COLUMN_NAME: &str = "group_hash_value"; + /// Create a vector of record batches from a stream pub async fn collect(stream: SendableRecordBatchStream) -> Result> { stream.try_collect::>().await diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index f09324c4019c..7b4e095314d3 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -29,7 +29,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use crate::hash_utils::create_hashes; +use crate::common::GROUP_HASH_VALUE_COLUMN_NAME; use crate::metrics::BaselineMetrics; use crate::repartition::distributor_channels::{ channels, partition_aware_channels, DistributionReceiver, DistributionSender, @@ -38,9 +38,10 @@ use crate::sorts::streaming_merge; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; -use arrow::array::{ArrayRef, UInt64Builder}; -use arrow::datatypes::SchemaRef; +use arrow::array::{ArrayRef, AsArray, UInt64Builder}; +use arrow::datatypes::{SchemaRef, UInt64Type}; use arrow::record_batch::RecordBatch; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::transpose; use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; @@ -264,23 +265,37 @@ impl BatchPartitioner { // Tracking time required for distributing indexes across output partitions let timer = self.timer.timer(); - let arrays = exprs - .iter() - .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows())) - .collect::>>()?; - - hash_buffer.clear(); - hash_buffer.resize(batch.num_rows(), 0); - - create_hashes(&arrays, random_state, hash_buffer)?; - let mut indices: Vec<_> = (0..*partitions) .map(|_| UInt64Builder::with_capacity(batch.num_rows())) .collect(); - for (index, hash) in hash_buffer.iter().enumerate() { - indices[(*hash % *partitions as u64) as usize] - .append_value(index as u64); + if let Some(hash_values) = + batch.column_by_name(GROUP_HASH_VALUE_COLUMN_NAME) + { + let hash_array = hash_values.as_primitive::(); + for (index, hash) in hash_array.iter().enumerate() { + let hash = hash.unwrap(); + indices[(hash % *partitions as u64) as usize] + .append_value(index as u64); + } + } else { + // Some queries do repartition first + let arrays = exprs + .iter() + .map(|expr| { + expr.evaluate(&batch)?.into_array(batch.num_rows()) + }) + .collect::>>()?; + + hash_buffer.clear(); + hash_buffer.resize(batch.num_rows(), 0); + + create_hashes(&arrays, random_state, hash_buffer)?; + + for (index, &hash) in hash_buffer.iter().enumerate() { + indices[(hash % *partitions as u64) as usize] + .append_value(index as u64); + } } // Finished building index-arrays for output partitions