Skip to content

Commit 146f16a

Browse files
authored
Move filtered SMJ Left Anti filtered join out of join_partial phase (#13111)
* Move filtered SMJ Left Anti filtered join out of `join_partial` phase
1 parent 62b063c commit 146f16a

File tree

3 files changed

+414
-220
lines changed

3 files changed

+414
-220
lines changed

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{
4141
};
4242
use datafusion::physical_plan::memory::MemoryExec;
4343

44+
use crate::fuzz_cases::join_fuzz::JoinTestType::NljHj;
4445
use datafusion::prelude::{SessionConfig, SessionContext};
4546
use test_utils::stagger_batch_with_seed;
4647

@@ -223,17 +224,14 @@ async fn test_anti_join_1k() {
223224
}
224225

225226
#[tokio::test]
226-
// flaky for HjSmj case, giving 1 rows difference sometimes
227-
// https://github.com/apache/datafusion/issues/11555
228-
#[ignore]
229227
async fn test_anti_join_1k_filtered() {
230228
JoinFuzzTestCase::new(
231229
make_staggered_batches(1000),
232230
make_staggered_batches(1000),
233231
JoinType::LeftAnti,
234232
Some(Box::new(col_lt_col_filter)),
235233
)
236-
.run_test(&[JoinTestType::NljHj], false)
234+
.run_test(&[JoinTestType::HjSmj, NljHj], false)
237235
.await
238236
}
239237

datafusion/physical-plan/src/joins/sort_merge_join.rs

Lines changed: 227 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,32 @@ fn get_corrected_filter_mask(
802802

803803
Some(corrected_mask.finish())
804804
}
805+
JoinType::LeftAnti => {
806+
for i in 0..row_indices_length {
807+
let last_index =
808+
last_index_for_row(i, row_indices, batch_ids, row_indices_length);
809+
810+
if filter_mask.value(i) {
811+
seen_true = true;
812+
}
813+
814+
if last_index {
815+
if !seen_true {
816+
corrected_mask.append_value(true);
817+
} else {
818+
corrected_mask.append_null();
819+
}
820+
821+
seen_true = false;
822+
} else {
823+
corrected_mask.append_null();
824+
}
825+
}
826+
827+
let null_matched = expected_size - corrected_mask.len();
828+
corrected_mask.extend(vec![Some(true); null_matched]);
829+
Some(corrected_mask.finish())
830+
}
805831
// Only outer joins needs to keep track of processed rows and apply corrected filter mask
806832
_ => None,
807833
}
@@ -835,15 +861,18 @@ impl Stream for SMJStream {
835861
JoinType::Left
836862
| JoinType::LeftSemi
837863
| JoinType::Right
864+
| JoinType::LeftAnti
838865
)
839866
{
840867
self.freeze_all()?;
841868

842869
if !self.output_record_batches.batches.is_empty()
843-
&& self.buffered_data.scanning_finished()
844870
{
845-
let out_batch = self.filter_joined_batch()?;
846-
return Poll::Ready(Some(Ok(out_batch)));
871+
let out_filtered_batch =
872+
self.filter_joined_batch()?;
873+
return Poll::Ready(Some(Ok(
874+
out_filtered_batch,
875+
)));
847876
}
848877
}
849878

@@ -907,15 +936,17 @@ impl Stream for SMJStream {
907936
// because target output batch size can be hit in the middle of
908937
// filtering causing the filtering to be incomplete and causing
909938
// correctness issues
910-
let record_batch = if !(self.filter.is_some()
939+
if self.filter.is_some()
911940
&& matches!(
912941
self.join_type,
913-
JoinType::Left | JoinType::LeftSemi | JoinType::Right
914-
)) {
915-
record_batch
916-
} else {
942+
JoinType::Left
943+
| JoinType::LeftSemi
944+
| JoinType::Right
945+
| JoinType::LeftAnti
946+
)
947+
{
917948
continue;
918-
};
949+
}
919950

920951
return Poll::Ready(Some(Ok(record_batch)));
921952
}
@@ -929,7 +960,10 @@ impl Stream for SMJStream {
929960
if self.filter.is_some()
930961
&& matches!(
931962
self.join_type,
932-
JoinType::Left | JoinType::LeftSemi | JoinType::Right
963+
JoinType::Left
964+
| JoinType::LeftSemi
965+
| JoinType::Right
966+
| JoinType::LeftAnti
933967
)
934968
{
935969
let out = self.filter_joined_batch()?;
@@ -1273,11 +1307,7 @@ impl SMJStream {
12731307
};
12741308

12751309
if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() {
1276-
join_streamed = !self
1277-
.streamed_batch
1278-
.join_filter_matched_idxs
1279-
.contains(&(self.streamed_batch.idx as u64))
1280-
&& !self.streamed_joined;
1310+
join_streamed = !self.streamed_joined;
12811311
join_buffered = join_streamed;
12821312
}
12831313
}
@@ -1519,7 +1549,10 @@ impl SMJStream {
15191549
// Push the filtered batch which contains rows passing join filter to the output
15201550
if matches!(
15211551
self.join_type,
1522-
JoinType::Left | JoinType::LeftSemi | JoinType::Right
1552+
JoinType::Left
1553+
| JoinType::LeftSemi
1554+
| JoinType::Right
1555+
| JoinType::LeftAnti
15231556
) {
15241557
self.output_record_batches
15251558
.batches
@@ -1654,7 +1687,10 @@ impl SMJStream {
16541687
if !(self.filter.is_some()
16551688
&& matches!(
16561689
self.join_type,
1657-
JoinType::Left | JoinType::LeftSemi | JoinType::Right
1690+
JoinType::Left
1691+
| JoinType::LeftSemi
1692+
| JoinType::Right
1693+
| JoinType::LeftAnti
16581694
))
16591695
{
16601696
self.output_record_batches.batches.clear();
@@ -1727,7 +1763,7 @@ impl SMJStream {
17271763
&self.schema,
17281764
&[filtered_record_batch, null_joined_streamed_batch],
17291765
)?;
1730-
} else if matches!(self.join_type, JoinType::LeftSemi) {
1766+
} else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
17311767
let output_column_indices = (0..streamed_columns_length).collect::<Vec<_>>();
17321768
filtered_record_batch =
17331769
filtered_record_batch.project(&output_column_indices)?;
@@ -3349,6 +3385,7 @@ mod tests {
33493385
batch_ids: vec![],
33503386
};
33513387

3388+
// Insert already prejoined non-filtered rows
33523389
batches.batches.push(RecordBatch::try_new(
33533390
Arc::clone(&schema),
33543391
vec![
@@ -3835,6 +3872,178 @@ mod tests {
38353872
Ok(())
38363873
}
38373874

3875+
#[tokio::test]
3876+
async fn test_left_anti_join_filtered_mask() -> Result<()> {
3877+
let mut joined_batches = build_joined_record_batches()?;
3878+
let schema = joined_batches.batches.first().unwrap().schema();
3879+
3880+
let output = concat_batches(&schema, &joined_batches.batches)?;
3881+
let out_mask = joined_batches.filter_mask.finish();
3882+
let out_indices = joined_batches.row_indices.finish();
3883+
3884+
assert_eq!(
3885+
get_corrected_filter_mask(
3886+
LeftAnti,
3887+
&UInt64Array::from(vec![0]),
3888+
&[0usize],
3889+
&BooleanArray::from(vec![true]),
3890+
1
3891+
)
3892+
.unwrap(),
3893+
BooleanArray::from(vec![None])
3894+
);
3895+
3896+
assert_eq!(
3897+
get_corrected_filter_mask(
3898+
LeftAnti,
3899+
&UInt64Array::from(vec![0]),
3900+
&[0usize],
3901+
&BooleanArray::from(vec![false]),
3902+
1
3903+
)
3904+
.unwrap(),
3905+
BooleanArray::from(vec![Some(true)])
3906+
);
3907+
3908+
assert_eq!(
3909+
get_corrected_filter_mask(
3910+
LeftAnti,
3911+
&UInt64Array::from(vec![0, 0]),
3912+
&[0usize; 2],
3913+
&BooleanArray::from(vec![true, true]),
3914+
2
3915+
)
3916+
.unwrap(),
3917+
BooleanArray::from(vec![None, None])
3918+
);
3919+
3920+
assert_eq!(
3921+
get_corrected_filter_mask(
3922+
LeftAnti,
3923+
&UInt64Array::from(vec![0, 0, 0]),
3924+
&[0usize; 3],
3925+
&BooleanArray::from(vec![true, true, true]),
3926+
3
3927+
)
3928+
.unwrap(),
3929+
BooleanArray::from(vec![None, None, None])
3930+
);
3931+
3932+
assert_eq!(
3933+
get_corrected_filter_mask(
3934+
LeftAnti,
3935+
&UInt64Array::from(vec![0, 0, 0]),
3936+
&[0usize; 3],
3937+
&BooleanArray::from(vec![true, false, true]),
3938+
3
3939+
)
3940+
.unwrap(),
3941+
BooleanArray::from(vec![None, None, None])
3942+
);
3943+
3944+
assert_eq!(
3945+
get_corrected_filter_mask(
3946+
LeftAnti,
3947+
&UInt64Array::from(vec![0, 0, 0]),
3948+
&[0usize; 3],
3949+
&BooleanArray::from(vec![false, false, true]),
3950+
3
3951+
)
3952+
.unwrap(),
3953+
BooleanArray::from(vec![None, None, None])
3954+
);
3955+
3956+
assert_eq!(
3957+
get_corrected_filter_mask(
3958+
LeftAnti,
3959+
&UInt64Array::from(vec![0, 0, 0]),
3960+
&[0usize; 3],
3961+
&BooleanArray::from(vec![false, true, true]),
3962+
3
3963+
)
3964+
.unwrap(),
3965+
BooleanArray::from(vec![None, None, None])
3966+
);
3967+
3968+
assert_eq!(
3969+
get_corrected_filter_mask(
3970+
LeftAnti,
3971+
&UInt64Array::from(vec![0, 0, 0]),
3972+
&[0usize; 3],
3973+
&BooleanArray::from(vec![false, false, false]),
3974+
3
3975+
)
3976+
.unwrap(),
3977+
BooleanArray::from(vec![None, None, Some(true)])
3978+
);
3979+
3980+
let corrected_mask = get_corrected_filter_mask(
3981+
LeftAnti,
3982+
&out_indices,
3983+
&joined_batches.batch_ids,
3984+
&out_mask,
3985+
output.num_rows(),
3986+
)
3987+
.unwrap();
3988+
3989+
assert_eq!(
3990+
corrected_mask,
3991+
BooleanArray::from(vec![
3992+
None,
3993+
None,
3994+
None,
3995+
None,
3996+
None,
3997+
Some(true),
3998+
None,
3999+
Some(true)
4000+
])
4001+
);
4002+
4003+
let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
4004+
4005+
assert_batches_eq!(
4006+
&[
4007+
"+---+----+---+----+",
4008+
"| a | b | x | y |",
4009+
"+---+----+---+----+",
4010+
"| 1 | 13 | 1 | 12 |",
4011+
"| 1 | 14 | 1 | 11 |",
4012+
"+---+----+---+----+",
4013+
],
4014+
&[filtered_rb]
4015+
);
4016+
4017+
// output null rows
4018+
let null_mask = arrow::compute::not(&corrected_mask)?;
4019+
assert_eq!(
4020+
null_mask,
4021+
BooleanArray::from(vec![
4022+
None,
4023+
None,
4024+
None,
4025+
None,
4026+
None,
4027+
Some(false),
4028+
None,
4029+
Some(false),
4030+
])
4031+
);
4032+
4033+
let null_joined_batch = filter_record_batch(&output, &null_mask)?;
4034+
4035+
assert_batches_eq!(
4036+
&[
4037+
"+---+---+---+---+",
4038+
"| a | b | x | y |",
4039+
"+---+---+---+---+",
4040+
"+---+---+---+---+",
4041+
],
4042+
&[null_joined_batch]
4043+
);
4044+
Ok(())
4045+
}
4046+
38384047
/// Returns the column names on the schema
38394048
fn columns(schema: &Schema) -> Vec<String> {
38404049
schema.fields().iter().map(|f| f.name().clone()).collect()

0 commit comments

Comments
 (0)