@@ -196,14 +196,17 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder {
196
196
#[ cfg( test) ]
197
197
mod tests {
198
198
use crate :: {
199
- physical_plan:: { hash_join:: PartitionMode , Statistics } ,
199
+ physical_plan:: {
200
+ displayable, hash_join:: PartitionMode , ColumnStatistics , Statistics ,
201
+ } ,
200
202
test:: exec:: StatisticsExec ,
201
203
} ;
202
204
203
205
use super :: * ;
204
206
use std:: sync:: Arc ;
205
207
206
208
use arrow:: datatypes:: { DataType , Field , Schema } ;
209
+ use datafusion_common:: ScalarValue ;
207
210
208
211
fn create_big_and_small ( ) -> ( Arc < dyn ExecutionPlan > , Arc < dyn ExecutionPlan > ) {
209
212
let big = Arc :: new ( StatisticsExec :: new (
@@ -226,6 +229,75 @@ mod tests {
226
229
( big, small)
227
230
}
228
231
232
+ /// Create a column statistics vector for a single column
233
+ /// that has the given min/max/distinct_count properties.
234
+ ///
235
+ /// Given min/max will be mapped to a [`ScalarValue`] if
236
+ /// they are not `None`.
237
+ fn create_column_stats (
238
+ min : Option < u64 > ,
239
+ max : Option < u64 > ,
240
+ distinct_count : Option < usize > ,
241
+ ) -> Option < Vec < ColumnStatistics > > {
242
+ Some ( vec ! [ ColumnStatistics {
243
+ distinct_count,
244
+ min_value: min. map( |size| ScalarValue :: UInt64 ( Some ( size) ) ) ,
245
+ max_value: max. map( |size| ScalarValue :: UInt64 ( Some ( size) ) ) ,
246
+ ..Default :: default ( )
247
+ } ] )
248
+ }
249
+
250
+ /// Returns three plans with statistics of (min, max, distinct_count)
251
+ /// * big 100K rows @ (0, 50k, 50k)
252
+ /// * medium 10K rows @ (1k, 5k, 1k)
253
+ /// * small 1K rows @ (0, 100k, 1k)
254
+ fn create_nested_with_min_max ( ) -> (
255
+ Arc < dyn ExecutionPlan > ,
256
+ Arc < dyn ExecutionPlan > ,
257
+ Arc < dyn ExecutionPlan > ,
258
+ ) {
259
+ let big = Arc :: new ( StatisticsExec :: new (
260
+ Statistics {
261
+ num_rows : Some ( 100_000 ) ,
262
+ column_statistics : create_column_stats (
263
+ Some ( 0 ) ,
264
+ Some ( 50_000 ) ,
265
+ Some ( 50_000 ) ,
266
+ ) ,
267
+ ..Default :: default ( )
268
+ } ,
269
+ Schema :: new ( vec ! [ Field :: new( "big_col" , DataType :: Int32 , false ) ] ) ,
270
+ ) ) ;
271
+
272
+ let medium = Arc :: new ( StatisticsExec :: new (
273
+ Statistics {
274
+ num_rows : Some ( 10_000 ) ,
275
+ column_statistics : create_column_stats (
276
+ Some ( 1000 ) ,
277
+ Some ( 5000 ) ,
278
+ Some ( 1000 ) ,
279
+ ) ,
280
+ ..Default :: default ( )
281
+ } ,
282
+ Schema :: new ( vec ! [ Field :: new( "medium_col" , DataType :: Int32 , false ) ] ) ,
283
+ ) ) ;
284
+
285
+ let small = Arc :: new ( StatisticsExec :: new (
286
+ Statistics {
287
+ num_rows : Some ( 1000 ) ,
288
+ column_statistics : create_column_stats (
289
+ Some ( 0 ) ,
290
+ Some ( 100_000 ) ,
291
+ Some ( 1000 ) ,
292
+ ) ,
293
+ ..Default :: default ( )
294
+ } ,
295
+ Schema :: new ( vec ! [ Field :: new( "small_col" , DataType :: Int32 , false ) ] ) ,
296
+ ) ) ;
297
+
298
+ ( big, medium, small)
299
+ }
300
+
229
301
#[ tokio:: test]
230
302
async fn test_join_with_swap ( ) {
231
303
let ( big, small) = create_big_and_small ( ) ;
@@ -274,6 +346,82 @@ mod tests {
274
346
) ;
275
347
}
276
348
349
+ /// Compare the input plan with the plan after running the probe order optimizer.
350
+ macro_rules! assert_optimized {
351
+ ( $EXPECTED_LINES: expr, $PLAN: expr) => {
352
+ let expected_lines =
353
+ $EXPECTED_LINES. iter( ) . map( |s| * s) . collect:: <Vec <& str >>( ) ;
354
+
355
+ let optimized = HashBuildProbeOrder :: new( )
356
+ . optimize( Arc :: new( $PLAN) , & SessionConfig :: new( ) )
357
+ . unwrap( ) ;
358
+
359
+ let plan = displayable( optimized. as_ref( ) ) . indent( ) . to_string( ) ;
360
+ let actual_lines = plan. split( "\n " ) . collect:: <Vec <& str >>( ) ;
361
+
362
+ assert_eq!(
363
+ & expected_lines, & actual_lines,
364
+ "\n \n expected:\n \n {:#?}\n actual:\n \n {:#?}\n \n " ,
365
+ expected_lines, actual_lines
366
+ ) ;
367
+ } ;
368
+ }
369
+
370
+ #[ tokio:: test]
371
+ async fn test_nested_join_swap ( ) {
372
+ let ( big, medium, small) = create_nested_with_min_max ( ) ;
373
+
374
+ // Form the inner join: big JOIN small
375
+ let child_join = HashJoinExec :: try_new (
376
+ Arc :: clone ( & big) ,
377
+ Arc :: clone ( & small) ,
378
+ vec ! [ (
379
+ Column :: new_with_schema( "big_col" , & big. schema( ) ) . unwrap( ) ,
380
+ Column :: new_with_schema( "small_col" , & small. schema( ) ) . unwrap( ) ,
381
+ ) ] ,
382
+ None ,
383
+ & JoinType :: Inner ,
384
+ PartitionMode :: CollectLeft ,
385
+ & false ,
386
+ )
387
+ . unwrap ( ) ;
388
+ let child_schema = child_join. schema ( ) ;
389
+
390
+ // Form join tree `medium LEFT JOIN (big JOIN small)`
391
+ let join = HashJoinExec :: try_new (
392
+ Arc :: clone ( & medium) ,
393
+ Arc :: new ( child_join) ,
394
+ vec ! [ (
395
+ Column :: new_with_schema( "medium_col" , & medium. schema( ) ) . unwrap( ) ,
396
+ Column :: new_with_schema( "small_col" , & child_schema) . unwrap( ) ,
397
+ ) ] ,
398
+ None ,
399
+ & JoinType :: Left ,
400
+ PartitionMode :: CollectLeft ,
401
+ & false ,
402
+ )
403
+ . unwrap ( ) ;
404
+
405
+ // Hash join uses the left side to build the hash table, and right side to probe it. We want
406
+ // to keep left as small as possible, so if we can estimate (with a reasonable margin of error)
407
+ // that the left side is smaller than the right side, we should swap the sides.
408
+ //
409
+ // The first hash join's left is 'small' table (with 1000 rows), and the second hash join's
410
+ // left is the F(small IJ big) which has an estimated cardinality of 2000 rows (vs medium which
411
+ // has an exact cardinality of 10_000 rows).
412
+ let expected = [
413
+ "ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col]" ,
414
+ " HashJoinExec: mode=CollectLeft, join_type=Right, on=[(Column { name: \" small_col\" , index: 1 }, Column { name: \" medium_col\" , index: 0 })]" ,
415
+ " ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]" ,
416
+ " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \" small_col\" , index: 0 }, Column { name: \" big_col\" , index: 0 })]" ,
417
+ " StatisticsExec: col_count=1, row_count=Some(1000)" ,
418
+ " StatisticsExec: col_count=1, row_count=Some(100000)" ,
419
+ " StatisticsExec: col_count=1, row_count=Some(10000)" ,
420
+ ""
421
+ ] ;
422
+ assert_optimized ! ( expected, join) ;
423
+ }
424
+
277
425
#[ tokio:: test]
278
426
async fn test_join_no_swap ( ) {
279
427
let ( big, small) = create_big_and_small ( ) ;
0 commit comments