diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index c32711a96b97..1101616a4106 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -23,28 +23,161 @@ pub(crate) mod spill_manager; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::ptr::NonNull; +use std::sync::Arc; +use std::task::{Context, Poll}; use arrow::array::ArrayData; use arrow::datatypes::{Schema, SchemaRef}; use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; use arrow::record_batch::RecordBatch; -use tokio::sync::mpsc::Sender; - -use datafusion_common::{exec_datafusion_err, HashSet, Result}; - -fn read_spill(sender: Sender>, path: &Path) -> Result<()> { - let file = BufReader::new(File::open(path)?); - // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications - // with validated schemas and buffers. Skip redundant validation during read - // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written. - let reader = unsafe { StreamReader::try_new(file, None)?.with_skip_validation(true) }; - for batch in reader { - sender - .blocking_send(batch.map_err(Into::into)) - .map_err(|e| exec_datafusion_err!("{e}"))?; + +use datafusion_common::{exec_datafusion_err, DataFusionError, HashSet, Result}; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::RecordBatchStream; +use futures::{FutureExt as _, Stream}; + +/// Stream that reads spill files from disk where each batch is read in a spawned blocking task +/// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`] +/// +/// A simpler solution would be spawning a long-running blocking task for each +/// file read (instead of each batch). This approach does not work because when +/// the number of concurrent reads exceeds the Tokio thread pool limit, +/// deadlocks can occur and block progress. +struct SpillReaderStream { + schema: SchemaRef, + state: SpillReaderStreamState, +} + +/// When we poll for the next batch, we will get back both the batch and the reader, +/// so we can call `next` again. +type NextRecordBatchResult = Result<(StreamReader>, Option)>; + +enum SpillReaderStreamState { + /// Initial state: the stream was not initialized yet + /// and the file was not opened + Uninitialized(RefCountedTempFile), + + /// A read is in progress in a spawned blocking task for which we hold the handle. + ReadInProgress(SpawnedTask), + + /// A read has finished and we wait for being polled again in order to start reading the next batch. + Waiting(StreamReader>), + + /// The stream has finished, successfully or not. + Done, +} + +impl SpillReaderStream { + fn new(schema: SchemaRef, spill_file: RefCountedTempFile) -> Self { + Self { + schema, + state: SpillReaderStreamState::Uninitialized(spill_file), + } + } + + fn poll_next_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + match &mut self.state { + SpillReaderStreamState::Uninitialized(_) => { + // Temporarily replace with `Done` to be able to pass the file to the task. + let SpillReaderStreamState::Uninitialized(spill_file) = + std::mem::replace(&mut self.state, SpillReaderStreamState::Done) + else { + unreachable!() + }; + + let task = SpawnedTask::spawn_blocking(move || { + let file = BufReader::new(File::open(spill_file.path())?); + // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications + // with validated schemas and buffers. Skip redundant validation during read + // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written. + let mut reader = unsafe { + StreamReader::try_new(file, None)?.with_skip_validation(true) + }; + + let next_batch = reader.next().transpose()?; + + Ok((reader, next_batch)) + }); + + self.state = SpillReaderStreamState::ReadInProgress(task); + + // Poll again immediately so the inner task is polled and the waker is + // registered. + self.poll_next_inner(cx) + } + + SpillReaderStreamState::ReadInProgress(task) => { + let result = futures::ready!(task.poll_unpin(cx)) + .unwrap_or_else(|err| Err(DataFusionError::External(Box::new(err)))); + + match result { + Ok((reader, batch)) => { + match batch { + Some(batch) => { + self.state = SpillReaderStreamState::Waiting(reader); + + Poll::Ready(Some(Ok(batch))) + } + None => { + // Stream is done + self.state = SpillReaderStreamState::Done; + + Poll::Ready(None) + } + } + } + Err(err) => { + self.state = SpillReaderStreamState::Done; + + Poll::Ready(Some(Err(err))) + } + } + } + + SpillReaderStreamState::Waiting(_) => { + // Temporarily replace with `Done` to be able to pass the file to the task. + let SpillReaderStreamState::Waiting(mut reader) = + std::mem::replace(&mut self.state, SpillReaderStreamState::Done) + else { + unreachable!() + }; + + let task = SpawnedTask::spawn_blocking(move || { + let next_batch = reader.next().transpose()?; + + Ok((reader, next_batch)) + }); + + self.state = SpillReaderStreamState::ReadInProgress(task); + + // Poll again immediately so the inner task is polled and the waker is + // registered. + self.poll_next_inner(cx) + } + + SpillReaderStreamState::Done => Poll::Ready(None), + } + } +} + +impl Stream for SpillReaderStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().poll_next_inner(cx) + } +} + +impl RecordBatchStream for SpillReaderStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) } - Ok(()) } /// Spill the `RecordBatch` to disk as smaller batches @@ -205,6 +338,7 @@ mod tests { use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::runtime_env::RuntimeEnv; + use futures::StreamExt as _; use std::sync::Arc; @@ -604,4 +738,42 @@ mod tests { Ok(()) } + + #[test] + fn test_reading_more_spills_than_tokio_blocking_threads() -> Result<()> { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .max_blocking_threads(1) + .build() + .unwrap() + .block_on(async { + let batch = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + + let schema = batch.schema(); + + // Construct SpillManager + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let batches: [_; 10] = std::array::from_fn(|_| batch.clone()); + + let spill_file_1 = spill_manager + .spill_record_batch_and_finish(&batches, "Test1")? + .unwrap(); + let spill_file_2 = spill_manager + .spill_record_batch_and_finish(&batches, "Test2")? + .unwrap(); + + let mut stream_1 = spill_manager.read_spill_as_stream(spill_file_1)?; + let mut stream_2 = spill_manager.read_spill_as_stream(spill_file_2)?; + stream_1.next().await; + stream_2.next().await; + + Ok(()) + }) + } } diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index f2c6090f4bb0..78cd47a8bad0 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -27,10 +27,9 @@ use datafusion_common::Result; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::SendableRecordBatchStream; -use crate::metrics::SpillMetrics; -use crate::stream::RecordBatchReceiverStream; +use crate::{common::spawn_buffered, metrics::SpillMetrics}; -use super::{in_progress_spill_file::InProgressSpillFile, read_spill}; +use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -126,14 +125,11 @@ impl SpillManager { &self, spill_file_path: RefCountedTempFile, ) -> Result { - let mut builder = RecordBatchReceiverStream::builder( + let stream = Box::pin(SpillReaderStream::new( Arc::clone(&self.schema), - self.batch_read_buffer_capacity, - ); - let sender = builder.tx(); + spill_file_path, + )); - builder.spawn_blocking(move || read_spill(sender, spill_file_path.path())); - - Ok(builder.build()) + Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } }