@@ -802,6 +802,32 @@ fn get_corrected_filter_mask(
802
802
803
803
Some ( corrected_mask. finish ( ) )
804
804
}
805
+ JoinType :: LeftAnti => {
806
+ for i in 0 ..row_indices_length {
807
+ let last_index =
808
+ last_index_for_row ( i, row_indices, batch_ids, row_indices_length) ;
809
+
810
+ if filter_mask. value ( i) {
811
+ seen_true = true ;
812
+ }
813
+
814
+ if last_index {
815
+ if !seen_true {
816
+ corrected_mask. append_value ( true ) ;
817
+ } else {
818
+ corrected_mask. append_null ( ) ;
819
+ }
820
+
821
+ seen_true = false ;
822
+ } else {
823
+ corrected_mask. append_null ( ) ;
824
+ }
825
+ }
826
+
827
+ let null_matched = expected_size - corrected_mask. len ( ) ;
828
+ corrected_mask. extend ( vec ! [ Some ( true ) ; null_matched] ) ;
829
+ Some ( corrected_mask. finish ( ) )
830
+ }
805
831
// Only outer joins needs to keep track of processed rows and apply corrected filter mask
806
832
_ => None ,
807
833
}
@@ -835,15 +861,18 @@ impl Stream for SMJStream {
835
861
JoinType :: Left
836
862
| JoinType :: LeftSemi
837
863
| JoinType :: Right
864
+ | JoinType :: LeftAnti
838
865
)
839
866
{
840
867
self . freeze_all ( ) ?;
841
868
842
869
if !self . output_record_batches . batches . is_empty ( )
843
- && self . buffered_data . scanning_finished ( )
844
870
{
845
- let out_batch = self . filter_joined_batch ( ) ?;
846
- return Poll :: Ready ( Some ( Ok ( out_batch) ) ) ;
871
+ let out_filtered_batch =
872
+ self . filter_joined_batch ( ) ?;
873
+ return Poll :: Ready ( Some ( Ok (
874
+ out_filtered_batch,
875
+ ) ) ) ;
847
876
}
848
877
}
849
878
@@ -907,15 +936,17 @@ impl Stream for SMJStream {
907
936
// because target output batch size can be hit in the middle of
908
937
// filtering causing the filtering to be incomplete and causing
909
938
// correctness issues
910
- let record_batch = if ! ( self . filter . is_some ( )
939
+ if self . filter . is_some ( )
911
940
&& matches ! (
912
941
self . join_type,
913
- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
914
- ) ) {
915
- record_batch
916
- } else {
942
+ JoinType :: Left
943
+ | JoinType :: LeftSemi
944
+ | JoinType :: Right
945
+ | JoinType :: LeftAnti
946
+ )
947
+ {
917
948
continue ;
918
- } ;
949
+ }
919
950
920
951
return Poll :: Ready ( Some ( Ok ( record_batch) ) ) ;
921
952
}
@@ -929,7 +960,10 @@ impl Stream for SMJStream {
929
960
if self . filter . is_some ( )
930
961
&& matches ! (
931
962
self . join_type,
932
- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
963
+ JoinType :: Left
964
+ | JoinType :: LeftSemi
965
+ | JoinType :: Right
966
+ | JoinType :: LeftAnti
933
967
)
934
968
{
935
969
let out = self . filter_joined_batch ( ) ?;
@@ -1273,11 +1307,7 @@ impl SMJStream {
1273
1307
} ;
1274
1308
1275
1309
if matches ! ( self . join_type, JoinType :: LeftAnti ) && self . filter . is_some ( ) {
1276
- join_streamed = !self
1277
- . streamed_batch
1278
- . join_filter_matched_idxs
1279
- . contains ( & ( self . streamed_batch . idx as u64 ) )
1280
- && !self . streamed_joined ;
1310
+ join_streamed = !self . streamed_joined ;
1281
1311
join_buffered = join_streamed;
1282
1312
}
1283
1313
}
@@ -1519,7 +1549,10 @@ impl SMJStream {
1519
1549
// Push the filtered batch which contains rows passing join filter to the output
1520
1550
if matches ! (
1521
1551
self . join_type,
1522
- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
1552
+ JoinType :: Left
1553
+ | JoinType :: LeftSemi
1554
+ | JoinType :: Right
1555
+ | JoinType :: LeftAnti
1523
1556
) {
1524
1557
self . output_record_batches
1525
1558
. batches
@@ -1654,7 +1687,10 @@ impl SMJStream {
1654
1687
if !( self . filter . is_some ( )
1655
1688
&& matches ! (
1656
1689
self . join_type,
1657
- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
1690
+ JoinType :: Left
1691
+ | JoinType :: LeftSemi
1692
+ | JoinType :: Right
1693
+ | JoinType :: LeftAnti
1658
1694
) )
1659
1695
{
1660
1696
self . output_record_batches . batches . clear ( ) ;
@@ -1727,7 +1763,7 @@ impl SMJStream {
1727
1763
& self . schema ,
1728
1764
& [ filtered_record_batch, null_joined_streamed_batch] ,
1729
1765
) ?;
1730
- } else if matches ! ( self . join_type, JoinType :: LeftSemi ) {
1766
+ } else if matches ! ( self . join_type, JoinType :: LeftSemi | JoinType :: LeftAnti ) {
1731
1767
let output_column_indices = ( 0 ..streamed_columns_length) . collect :: < Vec < _ > > ( ) ;
1732
1768
filtered_record_batch =
1733
1769
filtered_record_batch. project ( & output_column_indices) ?;
@@ -3349,6 +3385,7 @@ mod tests {
3349
3385
batch_ids : vec ! [ ] ,
3350
3386
} ;
3351
3387
3388
+ // Insert already prejoined non-filtered rows
3352
3389
batches. batches . push ( RecordBatch :: try_new (
3353
3390
Arc :: clone ( & schema) ,
3354
3391
vec ! [
@@ -3835,6 +3872,178 @@ mod tests {
3835
3872
Ok ( ( ) )
3836
3873
}
3837
3874
3875
+ #[ tokio:: test]
3876
+ async fn test_left_anti_join_filtered_mask ( ) -> Result < ( ) > {
3877
+ let mut joined_batches = build_joined_record_batches ( ) ?;
3878
+ let schema = joined_batches. batches . first ( ) . unwrap ( ) . schema ( ) ;
3879
+
3880
+ let output = concat_batches ( & schema, & joined_batches. batches ) ?;
3881
+ let out_mask = joined_batches. filter_mask . finish ( ) ;
3882
+ let out_indices = joined_batches. row_indices . finish ( ) ;
3883
+
3884
+ assert_eq ! (
3885
+ get_corrected_filter_mask(
3886
+ LeftAnti ,
3887
+ & UInt64Array :: from( vec![ 0 ] ) ,
3888
+ & [ 0usize ] ,
3889
+ & BooleanArray :: from( vec![ true ] ) ,
3890
+ 1
3891
+ )
3892
+ . unwrap( ) ,
3893
+ BooleanArray :: from( vec![ None ] )
3894
+ ) ;
3895
+
3896
+ assert_eq ! (
3897
+ get_corrected_filter_mask(
3898
+ LeftAnti ,
3899
+ & UInt64Array :: from( vec![ 0 ] ) ,
3900
+ & [ 0usize ] ,
3901
+ & BooleanArray :: from( vec![ false ] ) ,
3902
+ 1
3903
+ )
3904
+ . unwrap( ) ,
3905
+ BooleanArray :: from( vec![ Some ( true ) ] )
3906
+ ) ;
3907
+
3908
+ assert_eq ! (
3909
+ get_corrected_filter_mask(
3910
+ LeftAnti ,
3911
+ & UInt64Array :: from( vec![ 0 , 0 ] ) ,
3912
+ & [ 0usize ; 2 ] ,
3913
+ & BooleanArray :: from( vec![ true , true ] ) ,
3914
+ 2
3915
+ )
3916
+ . unwrap( ) ,
3917
+ BooleanArray :: from( vec![ None , None ] )
3918
+ ) ;
3919
+
3920
+ assert_eq ! (
3921
+ get_corrected_filter_mask(
3922
+ LeftAnti ,
3923
+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3924
+ & [ 0usize ; 3 ] ,
3925
+ & BooleanArray :: from( vec![ true , true , true ] ) ,
3926
+ 3
3927
+ )
3928
+ . unwrap( ) ,
3929
+ BooleanArray :: from( vec![ None , None , None ] )
3930
+ ) ;
3931
+
3932
+ assert_eq ! (
3933
+ get_corrected_filter_mask(
3934
+ LeftAnti ,
3935
+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3936
+ & [ 0usize ; 3 ] ,
3937
+ & BooleanArray :: from( vec![ true , false , true ] ) ,
3938
+ 3
3939
+ )
3940
+ . unwrap( ) ,
3941
+ BooleanArray :: from( vec![ None , None , None ] )
3942
+ ) ;
3943
+
3944
+ assert_eq ! (
3945
+ get_corrected_filter_mask(
3946
+ LeftAnti ,
3947
+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3948
+ & [ 0usize ; 3 ] ,
3949
+ & BooleanArray :: from( vec![ false , false , true ] ) ,
3950
+ 3
3951
+ )
3952
+ . unwrap( ) ,
3953
+ BooleanArray :: from( vec![ None , None , None ] )
3954
+ ) ;
3955
+
3956
+ assert_eq ! (
3957
+ get_corrected_filter_mask(
3958
+ LeftAnti ,
3959
+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3960
+ & [ 0usize ; 3 ] ,
3961
+ & BooleanArray :: from( vec![ false , true , true ] ) ,
3962
+ 3
3963
+ )
3964
+ . unwrap( ) ,
3965
+ BooleanArray :: from( vec![ None , None , None ] )
3966
+ ) ;
3967
+
3968
+ assert_eq ! (
3969
+ get_corrected_filter_mask(
3970
+ LeftAnti ,
3971
+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3972
+ & [ 0usize ; 3 ] ,
3973
+ & BooleanArray :: from( vec![ false , false , false ] ) ,
3974
+ 3
3975
+ )
3976
+ . unwrap( ) ,
3977
+ BooleanArray :: from( vec![ None , None , Some ( true ) ] )
3978
+ ) ;
3979
+
3980
+ let corrected_mask = get_corrected_filter_mask (
3981
+ LeftAnti ,
3982
+ & out_indices,
3983
+ & joined_batches. batch_ids ,
3984
+ & out_mask,
3985
+ output. num_rows ( ) ,
3986
+ )
3987
+ . unwrap ( ) ;
3988
+
3989
+ assert_eq ! (
3990
+ corrected_mask,
3991
+ BooleanArray :: from( vec![
3992
+ None ,
3993
+ None ,
3994
+ None ,
3995
+ None ,
3996
+ None ,
3997
+ Some ( true ) ,
3998
+ None ,
3999
+ Some ( true )
4000
+ ] )
4001
+ ) ;
4002
+
4003
+ let filtered_rb = filter_record_batch ( & output, & corrected_mask) ?;
4004
+
4005
+ assert_batches_eq ! (
4006
+ & [
4007
+ "+---+----+---+----+" ,
4008
+ "| a | b | x | y |" ,
4009
+ "+---+----+---+----+" ,
4010
+ "| 1 | 13 | 1 | 12 |" ,
4011
+ "| 1 | 14 | 1 | 11 |" ,
4012
+ "+---+----+---+----+" ,
4013
+ ] ,
4014
+ & [ filtered_rb]
4015
+ ) ;
4016
+
4017
+ // output null rows
4018
+ let null_mask = arrow:: compute:: not ( & corrected_mask) ?;
4019
+ assert_eq ! (
4020
+ null_mask,
4021
+ BooleanArray :: from( vec![
4022
+ None ,
4023
+ None ,
4024
+ None ,
4025
+ None ,
4026
+ None ,
4027
+ Some ( false ) ,
4028
+ None ,
4029
+ Some ( false ) ,
4030
+ ] )
4031
+ ) ;
4032
+
4033
+ let null_joined_batch = filter_record_batch ( & output, & null_mask) ?;
4034
+
4035
+ assert_batches_eq ! (
4036
+ & [
4037
+ "+---+---+---+---+" ,
4038
+ "| a | b | x | y |" ,
4039
+ "+---+---+---+---+" ,
4040
+ "+---+---+---+---+" ,
4041
+ ] ,
4042
+ & [ null_joined_batch]
4043
+ ) ;
4044
+ Ok ( ( ) )
4045
+ }
4046
+
3838
4047
/// Returns the column names on the schema
3839
4048
fn columns ( schema : & Schema ) -> Vec < String > {
3840
4049
schema. fields ( ) . iter ( ) . map ( |f| f. name ( ) . clone ( ) ) . collect ( )
0 commit comments