@@ -2,6 +2,7 @@ use super::masks::{ToBitMask, ToBitMaskArray};
2
2
use crate :: simd:: {
3
3
cmp:: SimdPartialOrd ,
4
4
intrinsics,
5
+ prelude:: SimdPartialEq ,
5
6
ptr:: { SimdConstPtr , SimdMutPtr } ,
6
7
LaneCount , Mask , MaskElement , SupportedLaneCount , Swizzle ,
7
8
} ;
@@ -314,48 +315,95 @@ where
314
315
315
316
#[ must_use]
316
317
#[ inline]
317
- pub fn masked_load_or ( slice : & [ T ] , or : Self ) -> Self
318
+ pub fn load_or_default ( slice : & [ T ] ) -> Self
318
319
where
319
320
Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
321
+ T : Default ,
322
+ <T as SimdElement >:: Mask : Default
323
+ + core:: convert:: From < i8 >
324
+ + core:: ops:: Add < <T as SimdElement >:: Mask , Output = <T as SimdElement >:: Mask > ,
325
+ Simd < <T as SimdElement >:: Mask , N > : SimdPartialOrd ,
326
+ Mask < <T as SimdElement >:: Mask , N > : core:: ops:: BitAnd < Output = Mask < <T as SimdElement >:: Mask , N > >
327
+ + core:: convert:: From < <Simd < <T as SimdElement >:: Mask , N > as SimdPartialEq >:: Mask > ,
320
328
{
321
- Self :: masked_load_select ( slice, Mask :: splat ( true ) , or )
329
+ Self :: load_or ( slice, Default :: default ( ) )
322
330
}
323
331
324
332
#[ must_use]
325
333
#[ inline]
326
- pub fn masked_load_select (
327
- slice : & [ T ] ,
328
- mut enable : Mask < <T as SimdElement >:: Mask , N > ,
329
- or : Self ,
330
- ) -> Self
334
+ pub fn load_or ( slice : & [ T ] , or : Self ) -> Self
331
335
where
332
336
Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
337
+ <T as SimdElement >:: Mask : Default
338
+ + core:: convert:: From < i8 >
339
+ + core:: ops:: Add < <T as SimdElement >:: Mask , Output = <T as SimdElement >:: Mask > ,
340
+ Simd < <T as SimdElement >:: Mask , N > : SimdPartialOrd ,
341
+ Mask < <T as SimdElement >:: Mask , N > : core:: ops:: BitAnd < Output = Mask < <T as SimdElement >:: Mask , N > >
342
+ + core:: convert:: From < <Simd < <T as SimdElement >:: Mask , N > as SimdPartialEq >:: Mask > ,
333
343
{
334
- enable &= {
344
+ Self :: load_select ( slice, Mask :: splat ( true ) , or)
345
+ }
346
+
347
+ #[ must_use]
348
+ #[ inline]
349
+ pub fn load_select_or_default ( slice : & [ T ] , enable : Mask < <T as SimdElement >:: Mask , N > ) -> Self
350
+ where
351
+ Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
352
+ T : Default ,
353
+ <T as SimdElement >:: Mask : Default
354
+ + core:: convert:: From < i8 >
355
+ + core:: ops:: Add < <T as SimdElement >:: Mask , Output = <T as SimdElement >:: Mask > ,
356
+ Simd < <T as SimdElement >:: Mask , N > : SimdPartialOrd ,
357
+ Mask < <T as SimdElement >:: Mask , N > : core:: ops:: BitAnd < Output = Mask < <T as SimdElement >:: Mask , N > >
358
+ + core:: convert:: From < <Simd < <T as SimdElement >:: Mask , N > as SimdPartialEq >:: Mask > ,
359
+ {
360
+ Self :: load_select ( slice, enable, Default :: default ( ) )
361
+ }
362
+
363
+ #[ must_use]
364
+ #[ inline]
365
+ pub fn load_select ( slice : & [ T ] , mut enable : Mask < <T as SimdElement >:: Mask , N > , or : Self ) -> Self
366
+ where
367
+ Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
368
+ <T as SimdElement >:: Mask : Default
369
+ + core:: convert:: From < i8 >
370
+ + core:: ops:: Add < <T as SimdElement >:: Mask , Output = <T as SimdElement >:: Mask > ,
371
+ Simd < <T as SimdElement >:: Mask , N > : SimdPartialOrd ,
372
+ Mask < <T as SimdElement >:: Mask , N > : core:: ops:: BitAnd < Output = Mask < <T as SimdElement >:: Mask , N > >
373
+ + core:: convert:: From < <Simd < <T as SimdElement >:: Mask , N > as SimdPartialEq >:: Mask > ,
374
+ {
375
+ if USE_BRANCH {
376
+ if core:: intrinsics:: likely ( enable. all ( ) && slice. len ( ) > N ) {
377
+ return Self :: from_slice ( slice) ;
378
+ }
379
+ }
380
+ enable &= if USE_BITMASK {
335
381
let mask = bzhi_u64 ( u64:: MAX , core:: cmp:: min ( N , slice. len ( ) ) as u32 ) ;
336
382
let mask_bytes: [ u8 ; 8 ] = unsafe { core:: mem:: transmute ( mask) } ;
337
383
let mut in_bounds_arr = Mask :: splat ( true ) . to_bitmask_array ( ) ;
338
384
let len = in_bounds_arr. as_ref ( ) . len ( ) ;
339
385
in_bounds_arr. as_mut ( ) . copy_from_slice ( & mask_bytes[ ..len] ) ;
340
386
Mask :: from_bitmask_array ( in_bounds_arr)
387
+ } else {
388
+ mask_up_to ( enable, slice. len ( ) )
341
389
} ;
342
- unsafe { Self :: masked_load_select_ptr ( slice. as_ptr ( ) , enable, or) }
390
+ unsafe { Self :: load_select_ptr ( slice. as_ptr ( ) , enable, or) }
343
391
}
344
392
345
393
#[ must_use]
346
394
#[ inline]
347
- pub unsafe fn masked_load_select_unchecked (
395
+ pub unsafe fn load_select_unchecked (
348
396
slice : & [ T ] ,
349
397
enable : Mask < <T as SimdElement >:: Mask , N > ,
350
398
or : Self ,
351
399
) -> Self {
352
400
let ptr = slice. as_ptr ( ) ;
353
- unsafe { Self :: masked_load_select_ptr ( ptr, enable, or) }
401
+ unsafe { Self :: load_select_ptr ( ptr, enable, or) }
354
402
}
355
403
356
404
#[ must_use]
357
405
#[ inline]
358
- pub unsafe fn masked_load_select_ptr (
406
+ pub unsafe fn load_select_ptr (
359
407
ptr : * const T ,
360
408
enable : Mask < <T as SimdElement >:: Mask , N > ,
361
409
or : Self ,
@@ -545,14 +593,28 @@ where
545
593
pub fn masked_store ( self , slice : & mut [ T ] , mut enable : Mask < <T as SimdElement >:: Mask , N > )
546
594
where
547
595
Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
596
+ Mask < <T as SimdElement >:: Mask , N > : ToBitMask + ToBitMaskArray ,
597
+ <T as SimdElement >:: Mask : Default
598
+ + core:: convert:: From < i8 >
599
+ + core:: ops:: Add < <T as SimdElement >:: Mask , Output = <T as SimdElement >:: Mask > ,
600
+ Simd < <T as SimdElement >:: Mask , N > : SimdPartialOrd ,
601
+ Mask < <T as SimdElement >:: Mask , N > : core:: ops:: BitAnd < Output = Mask < <T as SimdElement >:: Mask , N > >
602
+ + core:: convert:: From < <Simd < <T as SimdElement >:: Mask , N > as SimdPartialEq >:: Mask > ,
548
603
{
549
- enable &= {
604
+ if USE_BRANCH {
605
+ if core:: intrinsics:: likely ( enable. all ( ) && slice. len ( ) > N ) {
606
+ return self . copy_to_slice ( slice) ;
607
+ }
608
+ }
609
+ enable &= if USE_BITMASK {
550
610
let mask = bzhi_u64 ( u64:: MAX , core:: cmp:: min ( N , slice. len ( ) ) as u32 ) ;
551
611
let mask_bytes: [ u8 ; 8 ] = unsafe { core:: mem:: transmute ( mask) } ;
552
612
let mut in_bounds_arr = Mask :: splat ( true ) . to_bitmask_array ( ) ;
553
613
let len = in_bounds_arr. as_ref ( ) . len ( ) ;
554
614
in_bounds_arr. as_mut ( ) . copy_from_slice ( & mask_bytes[ ..len] ) ;
555
615
Mask :: from_bitmask_array ( in_bounds_arr)
616
+ } else {
617
+ mask_up_to ( enable, slice. len ( ) )
556
618
} ;
557
619
unsafe { self . masked_store_ptr ( slice. as_mut_ptr ( ) , enable) }
558
620
}
@@ -1058,9 +1120,43 @@ where
1058
1120
type Mask = isize ;
1059
1121
}
1060
1122
1123
+ const USE_BRANCH : bool = false ;
1124
+ const USE_BITMASK : bool = false ;
1125
+
1126
+ #[ inline]
1127
+ fn index < T , const N : usize > ( ) -> Simd < T , N >
1128
+ where
1129
+ T : MaskElement + Default + core:: convert:: From < i8 > + core:: ops:: Add < T , Output = T > ,
1130
+ LaneCount < N > : SupportedLaneCount ,
1131
+ {
1132
+ let mut index = [ T :: default ( ) ; N ] ;
1133
+ for i in 1 ..N {
1134
+ index[ i] = index[ i - 1 ] + T :: from ( 1 ) ;
1135
+ }
1136
+ Simd :: from_array ( index)
1137
+ }
1138
+
1139
+ #[ inline]
1140
+ fn mask_up_to < M , const N : usize > ( enable : Mask < M , N > , len : usize ) -> Mask < M , N >
1141
+ where
1142
+ LaneCount < N > : SupportedLaneCount ,
1143
+ M : MaskElement + Default + core:: convert:: From < i8 > + core:: ops:: Add < M , Output = M > ,
1144
+ Simd < M , N > : SimdPartialOrd ,
1145
+ // <Simd<M, N> as SimdPartialEq>::Mask: Mask<M, N>,
1146
+ Mask < M , N > : core:: ops:: BitAnd < Output = Mask < M , N > >
1147
+ + core:: convert:: From < <Simd < M , N > as SimdPartialEq >:: Mask > ,
1148
+ {
1149
+ let index = index :: < M , N > ( ) ;
1150
+ enable
1151
+ & Mask :: < M , N > :: from (
1152
+ index. simd_lt ( Simd :: splat ( M :: from ( i8:: try_from ( len) . unwrap_or ( i8:: MAX ) ) ) ) ,
1153
+ )
1154
+ }
1155
+
1061
1156
// This function matches the semantics of the `bzhi` instruction on x86 BMI2
1062
1157
// TODO: optimize it further if possible
1063
1158
// https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform
1159
+ #[ inline( always) ]
1064
1160
fn bzhi_u64 ( a : u64 , ix : u32 ) -> u64 {
1065
1161
if ix > 63 {
1066
1162
a
0 commit comments