Skip to content

Commit 4bd7c13

Browse files
berkaysynnadamustafasrepoozankabak
authored
CrossJoin Refactor (#9830)
* First iteration * Wrap the logic inside function * Send batches in the size of left batches * Update cross_join.rs * fuzz tests * Update cross_join_fuzz.rs * Update cross_join_fuzz.rs * Test version 2 * Minor changes * Minor changes * Stateful implementation of CJ * Adding comments * Update cross_join_fuzz.rs * Update cross_join.rs * collect until batch size * tmp * revert changes * Preserve the join strategy, clean the algorithm and states * Update cross_join.rs * Review * Update cross_join.rs --------- Co-authored-by: Mustafa Akur <[email protected]> Co-authored-by: Mehmet Ozan Kabak <[email protected]>
1 parent 2f55003 commit 4bd7c13

File tree

1 file changed

+95
-47
lines changed

1 file changed

+95
-47
lines changed

datafusion/physical-plan/src/joins/cross_join.rs

Lines changed: 95 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,23 @@ use std::{any::Any, sync::Arc, task::Poll};
2222

2323
use super::utils::{
2424
adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut,
25+
StatefulStreamResult,
2526
};
2627
use crate::coalesce_batches::concat_batches;
2728
use crate::coalesce_partitions::CoalescePartitionsExec;
2829
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
29-
use crate::ExecutionPlanProperties;
3030
use crate::{
31-
execution_mode_from_children, ColumnStatistics, DisplayAs, DisplayFormatType,
32-
Distribution, ExecutionMode, ExecutionPlan, PlanProperties, RecordBatchStream,
31+
execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs,
32+
DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan,
33+
ExecutionPlanProperties, PlanProperties, RecordBatchStream,
3334
SendableRecordBatchStream, Statistics,
3435
};
3536

3637
use arrow::datatypes::{Fields, Schema, SchemaRef};
3738
use arrow::record_batch::RecordBatch;
3839
use arrow_array::RecordBatchOptions;
3940
use datafusion_common::stats::Precision;
40-
use datafusion_common::{JoinType, Result, ScalarValue};
41+
use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
4142
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
4243
use datafusion_execution::TaskContext;
4344
use datafusion_physical_expr::equivalence::join_equivalence_properties;
@@ -257,9 +258,10 @@ impl ExecutionPlan for CrossJoinExec {
257258
schema: self.schema.clone(),
258259
left_fut,
259260
right: stream,
260-
right_batch: Arc::new(parking_lot::Mutex::new(None)),
261261
left_index: 0,
262262
join_metrics,
263+
state: CrossJoinStreamState::WaitBuildSide,
264+
left_data: RecordBatch::new_empty(self.left().schema()),
263265
}))
264266
}
265267

@@ -319,16 +321,18 @@ fn stats_cartesian_product(
319321
struct CrossJoinStream {
320322
/// Input schema
321323
schema: Arc<Schema>,
322-
/// future for data from left side
324+
/// Future for data from left side
323325
left_fut: OnceFut<JoinLeftData>,
324-
/// right
326+
/// Right side stream
325327
right: SendableRecordBatchStream,
326328
/// Current value on the left
327329
left_index: usize,
328-
/// Current batch being processed from the right side
329-
right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
330-
/// join execution metrics
330+
/// Join execution metrics
331331
join_metrics: BuildProbeJoinMetrics,
332+
/// State of the stream
333+
state: CrossJoinStreamState,
334+
/// Left data
335+
left_data: RecordBatch,
332336
}
333337

334338
impl RecordBatchStream for CrossJoinStream {
@@ -337,6 +341,25 @@ impl RecordBatchStream for CrossJoinStream {
337341
}
338342
}
339343

344+
/// Represents states of CrossJoinStream
345+
enum CrossJoinStreamState {
346+
WaitBuildSide,
347+
FetchProbeBatch,
348+
/// Holds the currently processed right side batch
349+
BuildBatches(RecordBatch),
350+
}
351+
352+
impl CrossJoinStreamState {
353+
/// Tries to extract RecordBatch from CrossJoinStreamState enum.
354+
/// Returns an error if state is not BuildBatches state.
355+
fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
356+
match self {
357+
CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
358+
_ => internal_err!("Expected RecordBatch in BuildBatches state"),
359+
}
360+
}
361+
}
362+
340363
fn build_batch(
341364
left_index: usize,
342365
batch: &RecordBatch,
@@ -384,58 +407,83 @@ impl CrossJoinStream {
384407
&mut self,
385408
cx: &mut std::task::Context<'_>,
386409
) -> std::task::Poll<Option<Result<RecordBatch>>> {
410+
loop {
411+
return match self.state {
412+
CrossJoinStreamState::WaitBuildSide => {
413+
handle_state!(ready!(self.collect_build_side(cx)))
414+
}
415+
CrossJoinStreamState::FetchProbeBatch => {
416+
handle_state!(ready!(self.fetch_probe_batch(cx)))
417+
}
418+
CrossJoinStreamState::BuildBatches(_) => {
419+
handle_state!(self.build_batches())
420+
}
421+
};
422+
}
423+
}
424+
425+
/// Collects build (left) side of the join into the state. In case of an empty build batch,
426+
/// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
427+
fn collect_build_side(
428+
&mut self,
429+
cx: &mut std::task::Context<'_>,
430+
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
387431
let build_timer = self.join_metrics.build_time.timer();
388432
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
389433
Ok(left_data) => left_data,
390-
Err(e) => return Poll::Ready(Some(Err(e))),
434+
Err(e) => return Poll::Ready(Err(e)),
391435
};
392436
build_timer.done();
393437

394-
if left_data.num_rows() == 0 {
395-
return Poll::Ready(None);
396-
}
438+
let result = if left_data.num_rows() == 0 {
439+
StatefulStreamResult::Ready(None)
440+
} else {
441+
self.left_data = left_data.clone();
442+
self.state = CrossJoinStreamState::FetchProbeBatch;
443+
StatefulStreamResult::Continue
444+
};
445+
Poll::Ready(Ok(result))
446+
}
447+
448+
/// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
449+
/// Then, the state is updated to build result batches.
450+
fn fetch_probe_batch(
451+
&mut self,
452+
cx: &mut std::task::Context<'_>,
453+
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
454+
self.left_index = 0;
455+
let right_data = match ready!(self.right.poll_next_unpin(cx)) {
456+
Some(Ok(right_data)) => right_data,
457+
Some(Err(e)) => return Poll::Ready(Err(e)),
458+
None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
459+
};
460+
self.join_metrics.input_batches.add(1);
461+
self.join_metrics.input_rows.add(right_data.num_rows());
462+
463+
self.state = CrossJoinStreamState::BuildBatches(right_data);
464+
Poll::Ready(Ok(StatefulStreamResult::Continue))
465+
}
397466

398-
if self.left_index > 0 && self.left_index < left_data.num_rows() {
467+
/// Joins the the indexed row of left data with the current probe batch.
468+
/// If all the results are produced, the state is set to fetch new probe batch.
469+
fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
470+
let right_batch = self.state.try_as_record_batch()?;
471+
if self.left_index < self.left_data.num_rows() {
399472
let join_timer = self.join_metrics.join_time.timer();
400-
let right_batch = {
401-
let right_batch = self.right_batch.lock();
402-
right_batch.clone().unwrap()
403-
};
404473
let result =
405-
build_batch(self.left_index, &right_batch, left_data, &self.schema);
406-
self.join_metrics.input_rows.add(right_batch.num_rows());
474+
build_batch(self.left_index, right_batch, &self.left_data, &self.schema);
475+
join_timer.done();
476+
407477
if let Ok(ref batch) = result {
408-
join_timer.done();
409478
self.join_metrics.output_batches.add(1);
410479
self.join_metrics.output_rows.add(batch.num_rows());
411480
}
412481
self.left_index += 1;
413-
return Poll::Ready(Some(result));
482+
result.map(|r| StatefulStreamResult::Ready(Some(r)))
483+
} else {
484+
self.state = CrossJoinStreamState::FetchProbeBatch;
485+
Ok(StatefulStreamResult::Continue)
414486
}
415-
self.left_index = 0;
416-
self.right
417-
.poll_next_unpin(cx)
418-
.map(|maybe_batch| match maybe_batch {
419-
Some(Ok(batch)) => {
420-
let join_timer = self.join_metrics.join_time.timer();
421-
let result =
422-
build_batch(self.left_index, &batch, left_data, &self.schema);
423-
self.join_metrics.input_batches.add(1);
424-
self.join_metrics.input_rows.add(batch.num_rows());
425-
if let Ok(ref batch) = result {
426-
join_timer.done();
427-
self.join_metrics.output_batches.add(1);
428-
self.join_metrics.output_rows.add(batch.num_rows());
429-
}
430-
self.left_index = 1;
431-
432-
let mut right_batch = self.right_batch.lock();
433-
*right_batch = Some(batch);
434-
435-
Some(result)
436-
}
437-
other => other,
438-
})
439487
}
440488
}
441489

0 commit comments

Comments
 (0)