Skip to content

feat: add multi level merge for sorting #15608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
use std::{mem, vec};

use crate::aggregates::group_values::{new_group_values, GroupValues};
use crate::aggregates::order::GroupOrderingFull;
Expand Down Expand Up @@ -1066,13 +1066,11 @@ impl GroupedHashAggregateStream {
sort_batch(&batch, expr.as_ref(), None)
})),
)));
for spill in self.spill_state.spills.drain(..) {
let stream = self.spill_state.spill_manager.read_spill_as_stream(spill)?;
streams.push(stream);
}
self.spill_state.is_stream_merging = true;
self.input = StreamingMergeBuilder::new()
.with_streams(streams)
.with_sorted_spill_files(mem::take(&mut self.spill_state.spills))
.with_spill_manager(self.spill_state.spill_manager.clone())
.with_schema(schema)
.with_expressions(self.spill_state.spill_expr.as_ref())
.with_metrics(self.baseline_metrics.clone())
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-plan/src/sorts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
mod builder;
mod cursor;
mod merge;
pub mod multi_level_sort_preserving_merge_stream;
pub mod partial_sort;
pub mod sort;
pub mod sort_preserving_merge;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Merge that deals with an arbitrary size of spilled files.
//! This is an order-preserving merge.

use crate::metrics::BaselineMetrics;
use crate::{EmptyRecordBatchStream, SpillManager};
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use datafusion_common::{internal_err, Result};
use datafusion_execution::memory_pool::MemoryReservation;
use std::mem;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::spill::in_progress_spill_file::InProgressSpillFile;
use crate::stream::RecordBatchStreamAdapter;
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use futures::{Stream, StreamExt};

enum State {
/// Had an error
Aborted,

/// Stream did not start yet or between passes
Uninitialized,

/// In progress of merging multiple sorted streams
MultiLevel {
stream: SendableRecordBatchStream,
in_progress_file: InProgressSpillFile,
},

/// This is the last level of the merge, just pass through the stream
Passthrough(SendableRecordBatchStream),
}

pub struct MultiLevelSortPreservingMergeStream {
schema: SchemaRef,
spill_manager: SpillManager,
sorted_spill_files: Vec<RefCountedTempFile>,
sorted_streams: Vec<SendableRecordBatchStream>,
expr: Arc<LexOrdering>,
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,
fetch: Option<usize>,
enable_round_robin_tie_breaker: bool,

/// The number of blocking threads to use for merging sorted streams
max_blocking_threads: usize,

/// The current state of the stream
state: State,
}

impl MultiLevelSortPreservingMergeStream {
#[allow(clippy::too_many_arguments)]
pub fn new(
spill_manager: SpillManager,
schema: SchemaRef,
sorted_spill_files: Vec<RefCountedTempFile>,
sorted_streams: Vec<SendableRecordBatchStream>,
expr: LexOrdering,
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,

max_blocking_threads: Option<usize>,

fetch: Option<usize>,
enable_round_robin_tie_breaker: bool,
) -> Result<Self> {
// TODO - add a check to see the actual number of available blocking threads
let max_blocking_threads = max_blocking_threads.unwrap_or(128);
Comment on lines +94 to +95
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently tokio don't expose the value of maximum number of blocking threads, and you have no way to know if you reached the limit as spawn_blocking doesn't return an indication that the current pool is full.


if max_blocking_threads <= 1 {
return internal_err!("max_blocking_threads must be greater than 1");
}

Ok(Self {
spill_manager,
schema,
sorted_spill_files,
sorted_streams,
expr: Arc::new(expr),
metrics,
batch_size,
reservation,
fetch,
state: State::Uninitialized,
enable_round_robin_tie_breaker,
max_blocking_threads,
})
}

fn created_sorted_stream(&mut self) -> Result<SendableRecordBatchStream> {
let mut sorted_streams = mem::take(&mut self.sorted_streams);

match (self.sorted_spill_files.len(), sorted_streams.len()) {
// No data so empty batch
(0, 0) => Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone(
&self.schema,
)))),

// Only in-memory stream
(0, 1) => Ok(sorted_streams.into_iter().next().unwrap()),

// Only single sorted spill file so stream it
(1, 0) => self
.spill_manager
.read_spill_as_stream(self.sorted_spill_files.drain(..).next().unwrap()),
Comment on lines +126 to +132
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add metrics in these cases


// Need to merge multiple streams
(_, _) => {
let sorted_spill_files_to_read =
self.sorted_spill_files.len().min(self.max_blocking_threads);

for spill in self.sorted_spill_files.drain(..sorted_spill_files_to_read) {
let stream = self.spill_manager.read_spill_as_stream(spill)?;
sorted_streams.push(Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.schema),
stream,
)));
}

StreamingMergeBuilder::new()
.with_schema(Arc::clone(&self.schema))
.with_expressions(self.expr.deref())
.with_batch_size(self.batch_size)
.with_fetch(self.fetch)
.with_metrics(if self.sorted_spill_files.is_empty() {
// Only add the metrics to the last run
self.metrics.clone()
} else {
self.metrics.intermediate()
})
.with_round_robin_tie_breaker(self.enable_round_robin_tie_breaker)
.with_streams(sorted_streams)
.with_reservation(self.reservation.new_empty())
.build()
}
}
}

fn poll_next_inner(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
match &mut self.state {
State::Aborted => return Poll::Ready(None),
State::Uninitialized => {
let stream = self.created_sorted_stream()?;

if self.sorted_spill_files.is_empty() {
self.state = State::Passthrough(stream);
} else {
let in_progress_file =
self.spill_manager.create_in_progress_file("spill")?;

self.state = State::MultiLevel {
stream,
in_progress_file,
};
}
}
State::MultiLevel {
stream,
in_progress_file,
} => {
'write_sorted_run: loop {
match futures::ready!(stream.poll_next_unpin(cx)) {
// This stream is finished.
None => {
// finish the file and add it to the sorted spill files
if let Some(sorted_spill_file) =
in_progress_file.finish()?
{
self.sorted_spill_files.push(sorted_spill_file);
}

// Reset the state to create a stream from the current sorted spill files
self.state = State::Uninitialized;

break 'write_sorted_run;
}
Some(Err(e)) => {
self.state = State::Aborted;

// Abort
return Poll::Ready(Some(Err(e)));
}
Some(Ok(batch)) => {
// Got a batch, write it to file
in_progress_file.append_batch(&batch)?;
}
}
}
}

// Last
State::Passthrough(s) => return s.poll_next_unpin(cx),
}
}
}
}

impl Stream for MultiLevelSortPreservingMergeStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.poll_next_inner(cx)
}
}

impl RecordBatchStream for MultiLevelSortPreservingMergeStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
19 changes: 5 additions & 14 deletions datafusion/physical-plan/src/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
//! but spills to disk if needed.

use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::{fmt, mem};

use crate::common::spawn_buffered;
use crate::execution_plan::{Boundedness, CardinalityEffect, EmissionType};
Expand Down Expand Up @@ -355,27 +355,18 @@ impl ExternalSorter {
self.merge_reservation.free();

if self.spilled_before() {
let mut streams = vec![];

// Sort `in_mem_batches` and spill it first. If there are many
// `in_mem_batches` and the memory limit is almost reached, merging
// them with the spilled files at the same time might cause OOM.
if !self.in_mem_batches.is_empty() {
self.sort_and_spill_in_mem_batches().await?;
}

for spill in self.finished_spill_files.drain(..) {
if !spill.path().exists() {
return internal_err!("Spill file {:?} does not exist", spill.path());
}
let stream = self.spill_manager.read_spill_as_stream(spill)?;
streams.push(stream);
}

let expressions: LexOrdering = self.expr.iter().cloned().collect();

StreamingMergeBuilder::new()
.with_streams(streams)
.with_spill_manager(self.spill_manager.clone())
.with_sorted_spill_files(self.finished_spill_files.drain(..).collect())
.with_schema(Arc::clone(&self.schema))
.with_expressions(expressions.as_ref())
.with_metrics(self.metrics.baseline.clone())
Expand Down Expand Up @@ -428,7 +419,7 @@ impl ExternalSorter {

debug!("Spilling sort data of ExternalSorter to disk whilst inserting");

let batches_to_spill = std::mem::take(globally_sorted_batches);
let batches_to_spill = mem::take(globally_sorted_batches);
self.reservation.free();

let in_progress_file = self.in_progress_spill_file.as_mut().ok_or_else(|| {
Expand Down Expand Up @@ -683,7 +674,7 @@ impl ExternalSorter {
return self.sort_batch_stream(batch, metrics, reservation);
}

let streams = std::mem::take(&mut self.in_mem_batches)
let streams = mem::take(&mut self.in_mem_batches)
.into_iter()
.map(|batch| {
let metrics = self.metrics.baseline.intermediate();
Expand Down
Loading