16
16
// under the License.
17
17
18
18
//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19
- use std:: collections:: HashSet ;
20
19
use std:: sync:: Arc ;
21
20
22
21
use crate :: { utils, OptimizerConfig , OptimizerRule } ;
23
22
23
+ use crate :: join_key_set:: JoinKeySet ;
24
24
use datafusion_common:: { plan_err, Result } ;
25
25
use datafusion_expr:: expr:: { BinaryExpr , Expr } ;
26
26
use datafusion_expr:: logical_plan:: {
@@ -55,7 +55,7 @@ impl OptimizerRule for EliminateCrossJoin {
55
55
plan : & LogicalPlan ,
56
56
config : & dyn OptimizerConfig ,
57
57
) -> Result < Option < LogicalPlan > > {
58
- let mut possible_join_keys: Vec < ( Expr , Expr ) > = vec ! [ ] ;
58
+ let mut possible_join_keys = JoinKeySet :: new ( ) ;
59
59
let mut all_inputs: Vec < LogicalPlan > = vec ! [ ] ;
60
60
let parent_predicate = match plan {
61
61
LogicalPlan :: Filter ( filter) => {
@@ -76,7 +76,7 @@ impl OptimizerRule for EliminateCrossJoin {
76
76
extract_possible_join_keys (
77
77
& filter. predicate ,
78
78
& mut possible_join_keys,
79
- ) ? ;
79
+ ) ;
80
80
Some ( & filter. predicate )
81
81
}
82
82
_ => {
@@ -101,7 +101,7 @@ impl OptimizerRule for EliminateCrossJoin {
101
101
} ;
102
102
103
103
// Join keys are handled locally:
104
- let mut all_join_keys = HashSet :: < ( Expr , Expr ) > :: new ( ) ;
104
+ let mut all_join_keys = JoinKeySet :: new ( ) ;
105
105
let mut left = all_inputs. remove ( 0 ) ;
106
106
while !all_inputs. is_empty ( ) {
107
107
left = find_inner_join (
@@ -131,7 +131,7 @@ impl OptimizerRule for EliminateCrossJoin {
131
131
. map ( |f| Some ( LogicalPlan :: Filter ( f) ) )
132
132
} else {
133
133
// Remove join expressions from filter:
134
- match remove_join_expressions ( predicate, & all_join_keys) ? {
134
+ match remove_join_expressions ( predicate. clone ( ) , & all_join_keys) {
135
135
Some ( filter_expr) => Filter :: try_new ( filter_expr, Arc :: new ( left) )
136
136
. map ( |f| Some ( LogicalPlan :: Filter ( f) ) ) ,
137
137
_ => Ok ( Some ( left) ) ,
@@ -150,7 +150,7 @@ impl OptimizerRule for EliminateCrossJoin {
150
150
/// Returns a boolean indicating whether the flattening was successful.
151
151
fn try_flatten_join_inputs (
152
152
plan : & LogicalPlan ,
153
- possible_join_keys : & mut Vec < ( Expr , Expr ) > ,
153
+ possible_join_keys : & mut JoinKeySet ,
154
154
all_inputs : & mut Vec < LogicalPlan > ,
155
155
) -> Result < bool > {
156
156
let children = match plan {
@@ -160,7 +160,7 @@ fn try_flatten_join_inputs(
160
160
// issue: https://github.com/apache/datafusion/issues/4844
161
161
return Ok ( false ) ;
162
162
}
163
- possible_join_keys. extend ( join. on . clone ( ) ) ;
163
+ possible_join_keys. insert_many ( join. on . iter ( ) ) ;
164
164
vec ! [ & join. left, & join. right]
165
165
}
166
166
LogicalPlan :: CrossJoin ( join) => {
@@ -204,8 +204,8 @@ fn try_flatten_join_inputs(
204
204
fn find_inner_join (
205
205
left_input : & LogicalPlan ,
206
206
rights : & mut Vec < LogicalPlan > ,
207
- possible_join_keys : & [ ( Expr , Expr ) ] ,
208
- all_join_keys : & mut HashSet < ( Expr , Expr ) > ,
207
+ possible_join_keys : & JoinKeySet ,
208
+ all_join_keys : & mut JoinKeySet ,
209
209
) -> Result < LogicalPlan > {
210
210
for ( i, right_input) in rights. iter ( ) . enumerate ( ) {
211
211
let mut join_keys = vec ! [ ] ;
@@ -228,7 +228,7 @@ fn find_inner_join(
228
228
229
229
// Found one or more matching join keys
230
230
if !join_keys. is_empty ( ) {
231
- all_join_keys. extend ( join_keys. clone ( ) ) ;
231
+ all_join_keys. insert_many ( join_keys. iter ( ) ) ;
232
232
let right_input = rights. remove ( i) ;
233
233
let join_schema = Arc :: new ( build_join_schema (
234
234
left_input. schema ( ) ,
@@ -265,90 +265,67 @@ fn find_inner_join(
265
265
} ) )
266
266
}
267
267
268
- fn intersect (
269
- accum : & mut Vec < ( Expr , Expr ) > ,
270
- vec1 : & [ ( Expr , Expr ) ] ,
271
- vec2 : & [ ( Expr , Expr ) ] ,
272
- ) {
273
- if !( vec1. is_empty ( ) || vec2. is_empty ( ) ) {
274
- for x1 in vec1. iter ( ) {
275
- for x2 in vec2. iter ( ) {
276
- if x1. 0 == x2. 0 && x1. 1 == x2. 1 || x1. 1 == x2. 0 && x1. 0 == x2. 1 {
277
- accum. push ( ( x1. 0 . clone ( ) , x1. 1 . clone ( ) ) ) ;
278
- }
279
- }
280
- }
281
- }
282
- }
283
-
284
268
/// Extract join keys from a WHERE clause
285
- fn extract_possible_join_keys ( expr : & Expr , accum : & mut Vec < ( Expr , Expr ) > ) -> Result < ( ) > {
269
+ fn extract_possible_join_keys ( expr : & Expr , join_keys : & mut JoinKeySet ) {
286
270
if let Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) = expr {
287
271
match op {
288
272
Operator :: Eq => {
289
- // Ensure that we don't add the same Join keys multiple times
290
- if !( accum. contains ( & ( * left. clone ( ) , * right. clone ( ) ) )
291
- || accum. contains ( & ( * right. clone ( ) , * left. clone ( ) ) ) )
292
- {
293
- accum. push ( ( * left. clone ( ) , * right. clone ( ) ) ) ;
294
- }
273
+ // insert handles ensuring we don't add the same Join keys multiple times
274
+ join_keys. insert ( left, right) ;
295
275
}
296
276
Operator :: And => {
297
- extract_possible_join_keys ( left, accum ) ? ;
298
- extract_possible_join_keys ( right, accum ) ?
277
+ extract_possible_join_keys ( left, join_keys ) ;
278
+ extract_possible_join_keys ( right, join_keys )
299
279
}
300
280
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
301
281
Operator :: Or => {
302
- let mut left_join_keys = vec ! [ ] ;
303
- let mut right_join_keys = vec ! [ ] ;
282
+ let mut left_join_keys = JoinKeySet :: new ( ) ;
283
+ let mut right_join_keys = JoinKeySet :: new ( ) ;
304
284
305
- extract_possible_join_keys ( left, & mut left_join_keys) ? ;
306
- extract_possible_join_keys ( right, & mut right_join_keys) ? ;
285
+ extract_possible_join_keys ( left, & mut left_join_keys) ;
286
+ extract_possible_join_keys ( right, & mut right_join_keys) ;
307
287
308
- intersect ( accum , & left_join_keys, & right_join_keys)
288
+ join_keys . insert_intersection ( left_join_keys, right_join_keys)
309
289
}
310
290
_ => ( ) ,
311
291
} ;
312
292
}
313
- Ok ( ( ) )
314
293
}
315
294
316
295
/// Remove join expressions from a filter expression
317
- /// Returns Some() when there are few remaining predicates in filter_expr
318
- /// Returns None otherwise
319
- fn remove_join_expressions (
320
- expr : & Expr ,
321
- join_keys : & HashSet < ( Expr , Expr ) > ,
322
- ) -> Result < Option < Expr > > {
296
+ ///
297
+ /// # Returns
298
+ /// * `Some()` when there are few remaining predicates in filter_expr
299
+ /// * `None` otherwise
300
+ fn remove_join_expressions ( expr : Expr , join_keys : & JoinKeySet ) -> Option < Expr > {
323
301
match expr {
324
- Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => {
325
- match op {
326
- Operator :: Eq => {
327
- if join_keys. contains ( & ( * left. clone ( ) , * right. clone ( ) ) )
328
- || join_keys. contains ( & ( * right. clone ( ) , * left. clone ( ) ) )
329
- {
330
- Ok ( None )
331
- } else {
332
- Ok ( Some ( expr. clone ( ) ) )
333
- }
334
- }
335
- // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
336
- Operator :: And | Operator :: Or => {
337
- let l = remove_join_expressions ( left, join_keys) ?;
338
- let r = remove_join_expressions ( right, join_keys) ?;
339
- match ( l, r) {
340
- ( Some ( ll) , Some ( rr) ) => Ok ( Some ( Expr :: BinaryExpr (
341
- BinaryExpr :: new ( Box :: new ( ll) , * op, Box :: new ( rr) ) ,
342
- ) ) ) ,
343
- ( Some ( ll) , _) => Ok ( Some ( ll) ) ,
344
- ( _, Some ( rr) ) => Ok ( Some ( rr) ) ,
345
- _ => Ok ( None ) ,
346
- }
347
- }
348
- _ => Ok ( Some ( expr. clone ( ) ) ) ,
302
+ Expr :: BinaryExpr ( BinaryExpr {
303
+ left,
304
+ op : Operator :: Eq ,
305
+ right,
306
+ } ) if join_keys. contains ( & left, & right) => {
307
+ // was a join key, so remove it
308
+ None
309
+ }
310
+ // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
311
+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } )
312
+ if matches ! ( op, Operator :: And | Operator :: Or ) =>
313
+ {
314
+ let l = remove_join_expressions ( * left, join_keys) ;
315
+ let r = remove_join_expressions ( * right, join_keys) ;
316
+ match ( l, r) {
317
+ ( Some ( ll) , Some ( rr) ) => Some ( Expr :: BinaryExpr ( BinaryExpr :: new (
318
+ Box :: new ( ll) ,
319
+ op,
320
+ Box :: new ( rr) ,
321
+ ) ) ) ,
322
+ ( Some ( ll) , _) => Some ( ll) ,
323
+ ( _, Some ( rr) ) => Some ( rr) ,
324
+ _ => None ,
349
325
}
350
326
}
351
- _ => Ok ( Some ( expr. clone ( ) ) ) ,
327
+
328
+ _ => Some ( expr) ,
352
329
}
353
330
}
354
331
0 commit comments