Skip to content

Commit 75ef06a

Browse files
committed
Remove some Expr clones in EliminateCrossJoin
1 parent 55323bf commit 75ef06a

File tree

3 files changed

+291
-73
lines changed

3 files changed

+291
-73
lines changed

datafusion/optimizer/src/eliminate_cross_join.rs

Lines changed: 50 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
// under the License.
1717

1818
//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19-
use std::collections::HashSet;
2019
use std::sync::Arc;
2120

2221
use crate::{utils, OptimizerConfig, OptimizerRule};
2322

23+
use crate::join_key_set::JoinKeySet;
2424
use datafusion_common::{plan_err, Result};
2525
use datafusion_expr::expr::{BinaryExpr, Expr};
2626
use datafusion_expr::logical_plan::{
@@ -55,7 +55,7 @@ impl OptimizerRule for EliminateCrossJoin {
5555
plan: &LogicalPlan,
5656
config: &dyn OptimizerConfig,
5757
) -> Result<Option<LogicalPlan>> {
58-
let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
58+
let mut possible_join_keys = JoinKeySet::new();
5959
let mut all_inputs: Vec<LogicalPlan> = vec![];
6060
let parent_predicate = match plan {
6161
LogicalPlan::Filter(filter) => {
@@ -76,7 +76,7 @@ impl OptimizerRule for EliminateCrossJoin {
7676
extract_possible_join_keys(
7777
&filter.predicate,
7878
&mut possible_join_keys,
79-
)?;
79+
);
8080
Some(&filter.predicate)
8181
}
8282
_ => {
@@ -101,7 +101,7 @@ impl OptimizerRule for EliminateCrossJoin {
101101
};
102102

103103
// Join keys are handled locally:
104-
let mut all_join_keys = HashSet::<(Expr, Expr)>::new();
104+
let mut all_join_keys = JoinKeySet::new();
105105
let mut left = all_inputs.remove(0);
106106
while !all_inputs.is_empty() {
107107
left = find_inner_join(
@@ -131,7 +131,7 @@ impl OptimizerRule for EliminateCrossJoin {
131131
.map(|f| Some(LogicalPlan::Filter(f)))
132132
} else {
133133
// Remove join expressions from filter:
134-
match remove_join_expressions(predicate, &all_join_keys)? {
134+
match remove_join_expressions(predicate.clone(), &all_join_keys) {
135135
Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
136136
.map(|f| Some(LogicalPlan::Filter(f))),
137137
_ => Ok(Some(left)),
@@ -150,7 +150,7 @@ impl OptimizerRule for EliminateCrossJoin {
150150
/// Returns a boolean indicating whether the flattening was successful.
151151
fn try_flatten_join_inputs(
152152
plan: &LogicalPlan,
153-
possible_join_keys: &mut Vec<(Expr, Expr)>,
153+
possible_join_keys: &mut JoinKeySet,
154154
all_inputs: &mut Vec<LogicalPlan>,
155155
) -> Result<bool> {
156156
let children = match plan {
@@ -160,7 +160,7 @@ fn try_flatten_join_inputs(
160160
// issue: https://github.com/apache/datafusion/issues/4844
161161
return Ok(false);
162162
}
163-
possible_join_keys.extend(join.on.clone());
163+
possible_join_keys.insert_many(join.on.iter());
164164
vec![&join.left, &join.right]
165165
}
166166
LogicalPlan::CrossJoin(join) => {
@@ -204,8 +204,8 @@ fn try_flatten_join_inputs(
204204
fn find_inner_join(
205205
left_input: &LogicalPlan,
206206
rights: &mut Vec<LogicalPlan>,
207-
possible_join_keys: &[(Expr, Expr)],
208-
all_join_keys: &mut HashSet<(Expr, Expr)>,
207+
possible_join_keys: &JoinKeySet,
208+
all_join_keys: &mut JoinKeySet,
209209
) -> Result<LogicalPlan> {
210210
for (i, right_input) in rights.iter().enumerate() {
211211
let mut join_keys = vec![];
@@ -228,7 +228,7 @@ fn find_inner_join(
228228

229229
// Found one or more matching join keys
230230
if !join_keys.is_empty() {
231-
all_join_keys.extend(join_keys.clone());
231+
all_join_keys.insert_many(join_keys.iter());
232232
let right_input = rights.remove(i);
233233
let join_schema = Arc::new(build_join_schema(
234234
left_input.schema(),
@@ -265,90 +265,67 @@ fn find_inner_join(
265265
}))
266266
}
267267

268-
fn intersect(
269-
accum: &mut Vec<(Expr, Expr)>,
270-
vec1: &[(Expr, Expr)],
271-
vec2: &[(Expr, Expr)],
272-
) {
273-
if !(vec1.is_empty() || vec2.is_empty()) {
274-
for x1 in vec1.iter() {
275-
for x2 in vec2.iter() {
276-
if x1.0 == x2.0 && x1.1 == x2.1 || x1.1 == x2.0 && x1.0 == x2.1 {
277-
accum.push((x1.0.clone(), x1.1.clone()));
278-
}
279-
}
280-
}
281-
}
282-
}
283-
284268
/// Extract join keys from a WHERE clause
285-
fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> {
269+
fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
286270
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
287271
match op {
288272
Operator::Eq => {
289-
// Ensure that we don't add the same Join keys multiple times
290-
if !(accum.contains(&(*left.clone(), *right.clone()))
291-
|| accum.contains(&(*right.clone(), *left.clone())))
292-
{
293-
accum.push((*left.clone(), *right.clone()));
294-
}
273+
// insert handles ensuring we don't add the same Join keys multiple times
274+
join_keys.insert(left, right);
295275
}
296276
Operator::And => {
297-
extract_possible_join_keys(left, accum)?;
298-
extract_possible_join_keys(right, accum)?
277+
extract_possible_join_keys(left, join_keys);
278+
extract_possible_join_keys(right, join_keys)
299279
}
300280
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
301281
Operator::Or => {
302-
let mut left_join_keys = vec![];
303-
let mut right_join_keys = vec![];
282+
let mut left_join_keys = JoinKeySet::new();
283+
let mut right_join_keys = JoinKeySet::new();
304284

305-
extract_possible_join_keys(left, &mut left_join_keys)?;
306-
extract_possible_join_keys(right, &mut right_join_keys)?;
285+
extract_possible_join_keys(left, &mut left_join_keys);
286+
extract_possible_join_keys(right, &mut right_join_keys);
307287

308-
intersect(accum, &left_join_keys, &right_join_keys)
288+
join_keys.insert_intersection(left_join_keys, right_join_keys)
309289
}
310290
_ => (),
311291
};
312292
}
313-
Ok(())
314293
}
315294

316295
/// Remove join expressions from a filter expression
317-
/// Returns Some() when there are few remaining predicates in filter_expr
318-
/// Returns None otherwise
319-
fn remove_join_expressions(
320-
expr: &Expr,
321-
join_keys: &HashSet<(Expr, Expr)>,
322-
) -> Result<Option<Expr>> {
296+
///
297+
/// # Returns
298+
/// * `Some()` when there are few remaining predicates in filter_expr
299+
/// * `None` otherwise
300+
fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
323301
match expr {
324-
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
325-
match op {
326-
Operator::Eq => {
327-
if join_keys.contains(&(*left.clone(), *right.clone()))
328-
|| join_keys.contains(&(*right.clone(), *left.clone()))
329-
{
330-
Ok(None)
331-
} else {
332-
Ok(Some(expr.clone()))
333-
}
334-
}
335-
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
336-
Operator::And | Operator::Or => {
337-
let l = remove_join_expressions(left, join_keys)?;
338-
let r = remove_join_expressions(right, join_keys)?;
339-
match (l, r) {
340-
(Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr(
341-
BinaryExpr::new(Box::new(ll), *op, Box::new(rr)),
342-
))),
343-
(Some(ll), _) => Ok(Some(ll)),
344-
(_, Some(rr)) => Ok(Some(rr)),
345-
_ => Ok(None),
346-
}
347-
}
348-
_ => Ok(Some(expr.clone())),
302+
Expr::BinaryExpr(BinaryExpr {
303+
left,
304+
op: Operator::Eq,
305+
right,
306+
}) if join_keys.contains(&left, &right) => {
307+
// was a join key, so remove it
308+
None
309+
}
310+
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
311+
Expr::BinaryExpr(BinaryExpr { left, op, right })
312+
if matches!(op, Operator::And | Operator::Or) =>
313+
{
314+
let l = remove_join_expressions(*left, join_keys);
315+
let r = remove_join_expressions(*right, join_keys);
316+
match (l, r) {
317+
(Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
318+
Box::new(ll),
319+
op,
320+
Box::new(rr),
321+
))),
322+
(Some(ll), _) => Some(ll),
323+
(_, Some(rr)) => Some(rr),
324+
_ => None,
349325
}
350326
}
351-
_ => Ok(Some(expr.clone())),
327+
328+
_ => Some(expr),
352329
}
353330
}
354331

0 commit comments

Comments
 (0)