Skip to content

Commit 34e54b4

Browse files
committed
Rework the masking logic, rename the functions
1 parent 66a5748 commit 34e54b4

File tree

3 files changed

+112
-15
lines changed

3 files changed

+112
-15
lines changed

crates/core_simd/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
const_maybe_uninit_as_mut_ptr,
55
const_mut_refs,
66
convert_float_to_int,
7+
core_intrinsics,
78
decl_macro,
89
inline_const,
910
intra_doc_pointers,

crates/core_simd/src/vector.rs

Lines changed: 109 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::masks::{ToBitMask, ToBitMaskArray};
22
use crate::simd::{
33
cmp::SimdPartialOrd,
44
intrinsics,
5+
prelude::SimdPartialEq,
56
ptr::{SimdConstPtr, SimdMutPtr},
67
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle,
78
};
@@ -314,48 +315,95 @@ where
314315

315316
#[must_use]
316317
#[inline]
317-
pub fn masked_load_or(slice: &[T], or: Self) -> Self
318+
pub fn load_or_default(slice: &[T]) -> Self
318319
where
319320
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
321+
T: Default,
322+
<T as SimdElement>::Mask: Default
323+
+ core::convert::From<i8>
324+
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
325+
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
326+
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
327+
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
320328
{
321-
Self::masked_load_select(slice, Mask::splat(true), or)
329+
Self::load_or(slice, Default::default())
322330
}
323331

324332
#[must_use]
325333
#[inline]
326-
pub fn masked_load_select(
327-
slice: &[T],
328-
mut enable: Mask<<T as SimdElement>::Mask, N>,
329-
or: Self,
330-
) -> Self
334+
pub fn load_or(slice: &[T], or: Self) -> Self
331335
where
332336
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
337+
<T as SimdElement>::Mask: Default
338+
+ core::convert::From<i8>
339+
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
340+
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
341+
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
342+
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
333343
{
334-
enable &= {
344+
Self::load_select(slice, Mask::splat(true), or)
345+
}
346+
347+
#[must_use]
348+
#[inline]
349+
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self
350+
where
351+
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
352+
T: Default,
353+
<T as SimdElement>::Mask: Default
354+
+ core::convert::From<i8>
355+
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
356+
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
357+
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
358+
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
359+
{
360+
Self::load_select(slice, enable, Default::default())
361+
}
362+
363+
#[must_use]
364+
#[inline]
365+
pub fn load_select(slice: &[T], mut enable: Mask<<T as SimdElement>::Mask, N>, or: Self) -> Self
366+
where
367+
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
368+
<T as SimdElement>::Mask: Default
369+
+ core::convert::From<i8>
370+
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
371+
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
372+
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
373+
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
374+
{
375+
if USE_BRANCH {
376+
if core::intrinsics::likely(enable.all() && slice.len() > N) {
377+
return Self::from_slice(slice);
378+
}
379+
}
380+
enable &= if USE_BITMASK {
335381
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
336382
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
337383
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
338384
let len = in_bounds_arr.as_ref().len();
339385
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
340386
Mask::from_bitmask_array(in_bounds_arr)
387+
} else {
388+
mask_up_to(enable, slice.len())
341389
};
342-
unsafe { Self::masked_load_select_ptr(slice.as_ptr(), enable, or) }
390+
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) }
343391
}
344392

345393
#[must_use]
346394
#[inline]
347-
pub unsafe fn masked_load_select_unchecked(
395+
pub unsafe fn load_select_unchecked(
348396
slice: &[T],
349397
enable: Mask<<T as SimdElement>::Mask, N>,
350398
or: Self,
351399
) -> Self {
352400
let ptr = slice.as_ptr();
353-
unsafe { Self::masked_load_select_ptr(ptr, enable, or) }
401+
unsafe { Self::load_select_ptr(ptr, enable, or) }
354402
}
355403

356404
#[must_use]
357405
#[inline]
358-
pub unsafe fn masked_load_select_ptr(
406+
pub unsafe fn load_select_ptr(
359407
ptr: *const T,
360408
enable: Mask<<T as SimdElement>::Mask, N>,
361409
or: Self,
@@ -545,14 +593,28 @@ where
545593
pub fn masked_store(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>)
546594
where
547595
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
596+
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
597+
<T as SimdElement>::Mask: Default
598+
+ core::convert::From<i8>
599+
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
600+
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
601+
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
602+
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
548603
{
549-
enable &= {
604+
if USE_BRANCH {
605+
if core::intrinsics::likely(enable.all() && slice.len() > N) {
606+
return self.copy_to_slice(slice);
607+
}
608+
}
609+
enable &= if USE_BITMASK {
550610
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
551611
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
552612
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
553613
let len = in_bounds_arr.as_ref().len();
554614
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
555615
Mask::from_bitmask_array(in_bounds_arr)
616+
} else {
617+
mask_up_to(enable, slice.len())
556618
};
557619
unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) }
558620
}
@@ -1058,9 +1120,43 @@ where
10581120
type Mask = isize;
10591121
}
10601122

1123+
const USE_BRANCH: bool = false;
1124+
const USE_BITMASK: bool = false;
1125+
1126+
#[inline]
1127+
fn index<T, const N: usize>() -> Simd<T, N>
1128+
where
1129+
T: MaskElement + Default + core::convert::From<i8> + core::ops::Add<T, Output = T>,
1130+
LaneCount<N>: SupportedLaneCount,
1131+
{
1132+
let mut index = [T::default(); N];
1133+
for i in 1..N {
1134+
index[i] = index[i - 1] + T::from(1);
1135+
}
1136+
Simd::from_array(index)
1137+
}
1138+
1139+
#[inline]
1140+
fn mask_up_to<M, const N: usize>(enable: Mask<M, N>, len: usize) -> Mask<M, N>
1141+
where
1142+
LaneCount<N>: SupportedLaneCount,
1143+
M: MaskElement + Default + core::convert::From<i8> + core::ops::Add<M, Output = M>,
1144+
Simd<M, N>: SimdPartialOrd,
1145+
// <Simd<M, N> as SimdPartialEq>::Mask: Mask<M, N>,
1146+
Mask<M, N>: core::ops::BitAnd<Output = Mask<M, N>>
1147+
+ core::convert::From<<Simd<M, N> as SimdPartialEq>::Mask>,
1148+
{
1149+
let index = index::<M, N>();
1150+
enable
1151+
& Mask::<M, N>::from(
1152+
index.simd_lt(Simd::splat(M::from(i8::try_from(len).unwrap_or(i8::MAX)))),
1153+
)
1154+
}
1155+
10611156
// This function matches the semantics of the `bzhi` instruction on x86 BMI2
10621157
// TODO: optimize it further if possible
10631158
// https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform
1159+
#[inline(always)]
10641160
fn bzhi_u64(a: u64, ix: u32) -> u64 {
10651161
if ix > 63 {
10661162
a

crates/core_simd/tests/masked_load_store.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ fn masked_load_store() {
2121

2222
// read from index 8 is OOB and dropped
2323
assert_eq!(
24-
u8x4::masked_load_or(&arr[4..], u8x4::splat(42)),
24+
u8x4::load_or(&arr[4..], u8x4::splat(42)),
2525
u8x4::from_array([3, 255, 0, 42])
2626
);
2727
assert_eq!(
28-
u8x4::masked_load_select(
28+
u8x4::load_select(
2929
&arr[4..],
3030
Mask::from_array([true, false, true, true]),
3131
u8x4::splat(42)

0 commit comments

Comments
 (0)