@@ -22,22 +22,23 @@ use std::{any::Any, sync::Arc, task::Poll};
22
22
23
23
use super :: utils:: {
24
24
adjust_right_output_partitioning, BuildProbeJoinMetrics , OnceAsync , OnceFut ,
25
+ StatefulStreamResult ,
25
26
} ;
26
27
use crate :: coalesce_batches:: concat_batches;
27
28
use crate :: coalesce_partitions:: CoalescePartitionsExec ;
28
29
use crate :: metrics:: { ExecutionPlanMetricsSet , MetricsSet } ;
29
- use crate :: ExecutionPlanProperties ;
30
30
use crate :: {
31
- execution_mode_from_children, ColumnStatistics , DisplayAs , DisplayFormatType ,
32
- Distribution , ExecutionMode , ExecutionPlan , PlanProperties , RecordBatchStream ,
31
+ execution_mode_from_children, handle_state, ColumnStatistics , DisplayAs ,
32
+ DisplayFormatType , Distribution , ExecutionMode , ExecutionPlan ,
33
+ ExecutionPlanProperties , PlanProperties , RecordBatchStream ,
33
34
SendableRecordBatchStream , Statistics ,
34
35
} ;
35
36
36
37
use arrow:: datatypes:: { Fields , Schema , SchemaRef } ;
37
38
use arrow:: record_batch:: RecordBatch ;
38
39
use arrow_array:: RecordBatchOptions ;
39
40
use datafusion_common:: stats:: Precision ;
40
- use datafusion_common:: { JoinType , Result , ScalarValue } ;
41
+ use datafusion_common:: { internal_err , JoinType , Result , ScalarValue } ;
41
42
use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
42
43
use datafusion_execution:: TaskContext ;
43
44
use datafusion_physical_expr:: equivalence:: join_equivalence_properties;
@@ -257,9 +258,10 @@ impl ExecutionPlan for CrossJoinExec {
257
258
schema : self . schema . clone ( ) ,
258
259
left_fut,
259
260
right : stream,
260
- right_batch : Arc :: new ( parking_lot:: Mutex :: new ( None ) ) ,
261
261
left_index : 0 ,
262
262
join_metrics,
263
+ state : CrossJoinStreamState :: WaitBuildSide ,
264
+ left_data : RecordBatch :: new_empty ( self . left ( ) . schema ( ) ) ,
263
265
} ) )
264
266
}
265
267
@@ -319,16 +321,18 @@ fn stats_cartesian_product(
319
321
struct CrossJoinStream {
320
322
/// Input schema
321
323
schema : Arc < Schema > ,
322
- /// future for data from left side
324
+ /// Future for data from left side
323
325
left_fut : OnceFut < JoinLeftData > ,
324
- /// right
326
+ /// Right side stream
325
327
right : SendableRecordBatchStream ,
326
328
/// Current value on the left
327
329
left_index : usize ,
328
- /// Current batch being processed from the right side
329
- right_batch : Arc < parking_lot:: Mutex < Option < RecordBatch > > > ,
330
- /// join execution metrics
330
+ /// Join execution metrics
331
331
join_metrics : BuildProbeJoinMetrics ,
332
+ /// State of the stream
333
+ state : CrossJoinStreamState ,
334
+ /// Left data
335
+ left_data : RecordBatch ,
332
336
}
333
337
334
338
impl RecordBatchStream for CrossJoinStream {
@@ -337,6 +341,25 @@ impl RecordBatchStream for CrossJoinStream {
337
341
}
338
342
}
339
343
344
+ /// Represents states of CrossJoinStream
345
+ enum CrossJoinStreamState {
346
+ WaitBuildSide ,
347
+ FetchProbeBatch ,
348
+ /// Holds the currently processed right side batch
349
+ BuildBatches ( RecordBatch ) ,
350
+ }
351
+
352
+ impl CrossJoinStreamState {
353
+ /// Tries to extract RecordBatch from CrossJoinStreamState enum.
354
+ /// Returns an error if state is not BuildBatches state.
355
+ fn try_as_record_batch ( & mut self ) -> Result < & RecordBatch > {
356
+ match self {
357
+ CrossJoinStreamState :: BuildBatches ( rb) => Ok ( rb) ,
358
+ _ => internal_err ! ( "Expected RecordBatch in BuildBatches state" ) ,
359
+ }
360
+ }
361
+ }
362
+
340
363
fn build_batch (
341
364
left_index : usize ,
342
365
batch : & RecordBatch ,
@@ -384,58 +407,83 @@ impl CrossJoinStream {
384
407
& mut self ,
385
408
cx : & mut std:: task:: Context < ' _ > ,
386
409
) -> std:: task:: Poll < Option < Result < RecordBatch > > > {
410
+ loop {
411
+ return match self . state {
412
+ CrossJoinStreamState :: WaitBuildSide => {
413
+ handle_state ! ( ready!( self . collect_build_side( cx) ) )
414
+ }
415
+ CrossJoinStreamState :: FetchProbeBatch => {
416
+ handle_state ! ( ready!( self . fetch_probe_batch( cx) ) )
417
+ }
418
+ CrossJoinStreamState :: BuildBatches ( _) => {
419
+ handle_state ! ( self . build_batches( ) )
420
+ }
421
+ } ;
422
+ }
423
+ }
424
+
425
+ /// Collects build (left) side of the join into the state. In case of an empty build batch,
426
+ /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
427
+ fn collect_build_side (
428
+ & mut self ,
429
+ cx : & mut std:: task:: Context < ' _ > ,
430
+ ) -> Poll < Result < StatefulStreamResult < Option < RecordBatch > > > > {
387
431
let build_timer = self . join_metrics . build_time . timer ( ) ;
388
432
let ( left_data, _) = match ready ! ( self . left_fut. get( cx) ) {
389
433
Ok ( left_data) => left_data,
390
- Err ( e) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
434
+ Err ( e) => return Poll :: Ready ( Err ( e) ) ,
391
435
} ;
392
436
build_timer. done ( ) ;
393
437
394
- if left_data. num_rows ( ) == 0 {
395
- return Poll :: Ready ( None ) ;
396
- }
438
+ let result = if left_data. num_rows ( ) == 0 {
439
+ StatefulStreamResult :: Ready ( None )
440
+ } else {
441
+ self . left_data = left_data. clone ( ) ;
442
+ self . state = CrossJoinStreamState :: FetchProbeBatch ;
443
+ StatefulStreamResult :: Continue
444
+ } ;
445
+ Poll :: Ready ( Ok ( result) )
446
+ }
447
+
448
+ /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
449
+ /// Then, the state is updated to build result batches.
450
+ fn fetch_probe_batch (
451
+ & mut self ,
452
+ cx : & mut std:: task:: Context < ' _ > ,
453
+ ) -> Poll < Result < StatefulStreamResult < Option < RecordBatch > > > > {
454
+ self . left_index = 0 ;
455
+ let right_data = match ready ! ( self . right. poll_next_unpin( cx) ) {
456
+ Some ( Ok ( right_data) ) => right_data,
457
+ Some ( Err ( e) ) => return Poll :: Ready ( Err ( e) ) ,
458
+ None => return Poll :: Ready ( Ok ( StatefulStreamResult :: Ready ( None ) ) ) ,
459
+ } ;
460
+ self . join_metrics . input_batches . add ( 1 ) ;
461
+ self . join_metrics . input_rows . add ( right_data. num_rows ( ) ) ;
462
+
463
+ self . state = CrossJoinStreamState :: BuildBatches ( right_data) ;
464
+ Poll :: Ready ( Ok ( StatefulStreamResult :: Continue ) )
465
+ }
397
466
398
- if self . left_index > 0 && self . left_index < left_data. num_rows ( ) {
467
+ /// Joins the the indexed row of left data with the current probe batch.
468
+ /// If all the results are produced, the state is set to fetch new probe batch.
469
+ fn build_batches ( & mut self ) -> Result < StatefulStreamResult < Option < RecordBatch > > > {
470
+ let right_batch = self . state . try_as_record_batch ( ) ?;
471
+ if self . left_index < self . left_data . num_rows ( ) {
399
472
let join_timer = self . join_metrics . join_time . timer ( ) ;
400
- let right_batch = {
401
- let right_batch = self . right_batch . lock ( ) ;
402
- right_batch. clone ( ) . unwrap ( )
403
- } ;
404
473
let result =
405
- build_batch ( self . left_index , & right_batch, left_data, & self . schema ) ;
406
- self . join_metrics . input_rows . add ( right_batch. num_rows ( ) ) ;
474
+ build_batch ( self . left_index , right_batch, & self . left_data , & self . schema ) ;
475
+ join_timer. done ( ) ;
476
+
407
477
if let Ok ( ref batch) = result {
408
- join_timer. done ( ) ;
409
478
self . join_metrics . output_batches . add ( 1 ) ;
410
479
self . join_metrics . output_rows . add ( batch. num_rows ( ) ) ;
411
480
}
412
481
self . left_index += 1 ;
413
- return Poll :: Ready ( Some ( result) ) ;
482
+ result. map ( |r| StatefulStreamResult :: Ready ( Some ( r) ) )
483
+ } else {
484
+ self . state = CrossJoinStreamState :: FetchProbeBatch ;
485
+ Ok ( StatefulStreamResult :: Continue )
414
486
}
415
- self . left_index = 0 ;
416
- self . right
417
- . poll_next_unpin ( cx)
418
- . map ( |maybe_batch| match maybe_batch {
419
- Some ( Ok ( batch) ) => {
420
- let join_timer = self . join_metrics . join_time . timer ( ) ;
421
- let result =
422
- build_batch ( self . left_index , & batch, left_data, & self . schema ) ;
423
- self . join_metrics . input_batches . add ( 1 ) ;
424
- self . join_metrics . input_rows . add ( batch. num_rows ( ) ) ;
425
- if let Ok ( ref batch) = result {
426
- join_timer. done ( ) ;
427
- self . join_metrics . output_batches . add ( 1 ) ;
428
- self . join_metrics . output_rows . add ( batch. num_rows ( ) ) ;
429
- }
430
- self . left_index = 1 ;
431
-
432
- let mut right_batch = self . right_batch . lock ( ) ;
433
- * right_batch = Some ( batch) ;
434
-
435
- Some ( result)
436
- }
437
- other => other,
438
- } )
439
487
}
440
488
}
441
489
0 commit comments