Skip to content

Commit eccbf0f

Browse files
committed
Add support for masked loads & stores
1 parent fbc9efa commit eccbf0f

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

crates/core_simd/src/vector.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,142 @@ where
314314
unsafe { self.store(slice.as_mut_ptr().cast()) }
315315
}
316316

317+
/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
318+
/// the `slice`. Otherwise, the default value for the element type is returned.
319+
///
320+
/// # Examples
321+
/// ```
322+
/// # #![feature(portable_simd)]
323+
/// # use core::simd::{Simd, Mask};
324+
/// let vec: Vec<i32> = vec![10, 11];
325+
///
326+
/// let result = Simd::<i32, 4>::load_or_default(&vec);
327+
/// assert_eq!(result, Simd::from_array([10, 11, 0, 0]));
328+
/// ```
329+
#[must_use]
330+
#[inline]
331+
pub fn load_or_default(slice: &[T]) -> Self
332+
where
333+
T: Default,
334+
<T as SimdElement>::Mask: Default
335+
+ core::convert::From<i8>
336+
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
337+
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
338+
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
339+
+ core::convert::From<Mask<i8, N>>,
340+
{
341+
Self::load_or(slice, Default::default())
342+
}
343+
344+
/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
345+
/// the `slice`. Otherwise, the corresponding value from `or` is passed through.
346+
///
347+
/// # Examples
348+
/// ```
349+
/// # #![feature(portable_simd)]
350+
/// # use core::simd::{Simd, Mask};
351+
/// let vec: Vec<i32> = vec![10, 11];
352+
/// let or = Simd::from_array([-5, -4, -3, -2]);
353+
///
354+
/// let result = Simd::load_or(&vec, or);
355+
/// assert_eq!(result, Simd::from_array([10, 11, -3, -2]));
356+
/// ```
357+
#[must_use]
358+
#[inline]
359+
pub fn load_or(slice: &[T], or: Self) -> Self
360+
where
361+
<T as SimdElement>::Mask: Default
362+
+ core::convert::From<i8>
363+
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
364+
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
365+
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
366+
+ core::convert::From<Mask<i8, N>>,
367+
{
368+
Self::load_select(slice, Mask::splat(true), or)
369+
}
370+
371+
/// Reads contiguous elements from `slice`. Each lane is read from memory if its
372+
/// corresponding lane in `enable` is `true`.
373+
///
374+
/// When the lane is disabled or out of bounds for the slice, that memory location
375+
/// is not accessed and the corresponding value from `or` is passed through.
376+
///
377+
/// # Examples
378+
/// ```
379+
/// # #![feature(portable_simd)]
380+
/// # use core::simd::{Simd, Mask};
381+
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
382+
/// let enable = Mask::from_array([true, true, false, true]);
383+
/// let or = Simd::from_array([-5, -4, -3, -2]);
384+
///
385+
/// let result = Simd::load_select(&vec, enable, or);
386+
/// assert_eq!(result, Simd::from_array([10, 11, -3, 14]));
387+
/// ```
388+
#[must_use]
389+
#[inline]
390+
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self
391+
where
392+
T: Default,
393+
{
394+
Self::load_select(slice, enable, Default::default())
395+
}
396+
397+
/// Reads contiguous elements from `slice`. Each lane is read from memory if its
398+
/// corresponding lane in `enable` is `true`.
399+
///
400+
/// When the lane is disabled or out of bounds for the slice, that memory location
401+
/// is not accessed and the corresponding value from `or` is passed through.
402+
///
403+
/// # Examples
404+
/// ```
405+
/// # #![feature(portable_simd)]
406+
/// # use core::simd::{Simd, Mask};
407+
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
408+
/// let enable = Mask::from_array([true, true, false, true]);
409+
/// let or = Simd::from_array([-5, -4, -3, -2]);
410+
///
411+
/// let result = Simd::load_select(&vec, enable, or);
412+
/// assert_eq!(result, Simd::from_array([10, 11, -3, 14]));
413+
/// ```
414+
#[must_use]
415+
#[inline]
416+
pub fn load_select(slice: &[T], mut enable: Mask<<T as SimdElement>::Mask, N>, or: Self) -> Self
417+
{
418+
enable &= mask_up_to(enable, slice.len());
419+
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) }
420+
}
421+
422+
/// Reads contiguous elements from `slice`. Each lane is read from memory if its
423+
/// corresponding lane in `enable` is `true`.
424+
///
425+
/// When the lane is disabled, that memory location is not accessed and the corresponding
426+
/// value from `or` is passed through.
427+
#[must_use]
428+
#[inline]
429+
pub unsafe fn load_select_unchecked(
430+
slice: &[T],
431+
enable: Mask<<T as SimdElement>::Mask, N>,
432+
or: Self,
433+
) -> Self {
434+
let ptr = slice.as_ptr();
435+
unsafe { Self::load_select_ptr(ptr, enable, or) }
436+
}
437+
438+
/// Reads contiguous elements starting at `ptr`. Each lane is read from memory if its
439+
/// corresponding lane in `enable` is `true`.
440+
///
441+
/// When the lane is disabled, that memory location is not accessed and the corresponding
442+
/// value from `or` is passed through.
443+
#[must_use]
444+
#[inline]
445+
pub unsafe fn load_select_ptr(
446+
ptr: *const T,
447+
enable: Mask<<T as SimdElement>::Mask, N>,
448+
or: Self,
449+
) -> Self {
450+
unsafe { core::intrinsics::simd::simd_masked_load(enable.to_int(), ptr, or) }
451+
}
452+
317453
/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
318454
/// If an index is out-of-bounds, the element is instead selected from the `or` vector.
319455
///
@@ -492,6 +628,28 @@ where
492628
unsafe { core::intrinsics::simd::simd_gather(or, source, enable.to_int()) }
493629
}
494630

631+
#[inline]
632+
pub fn masked_store(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>)
633+
{
634+
enable &= mask_up_to(enable, slice.len());
635+
unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) }
636+
}
637+
638+
#[inline]
639+
pub unsafe fn masked_store_unchecked(
640+
self,
641+
slice: &mut [T],
642+
enable: Mask<<T as SimdElement>::Mask, N>,
643+
) {
644+
let ptr = slice.as_mut_ptr();
645+
unsafe { self.masked_store_ptr(ptr, enable) }
646+
}
647+
648+
#[inline]
649+
pub unsafe fn masked_store_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
650+
unsafe { core::intrinsics::simd::simd_masked_store(enable.to_int(), ptr, self) }
651+
}
652+
495653
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.
496654
/// If an index is out-of-bounds, the write is suppressed without panicking.
497655
/// If two elements in the scattered vector would write to the same index
@@ -979,3 +1137,27 @@ where
9791137
{
9801138
type Mask = isize;
9811139
}
1140+
1141+
#[inline]
1142+
fn lane_indices<T, const N: usize>() -> Simd<T, N>
1143+
where
1144+
T: MaskElement + Default + core::convert::From<i8> + core::ops::Add<T, Output = T>,
1145+
LaneCount<N>: SupportedLaneCount,
1146+
{
1147+
let mut index = [T::default(); N];
1148+
for i in 1..N {
1149+
index[i] = index[i - 1] + T::from(1);
1150+
}
1151+
Simd::from_array(index)
1152+
}
1153+
1154+
#[inline]
1155+
fn mask_up_to<M, const N: usize>(enable: Mask<M, N>, len: usize) -> Mask<M, N>
1156+
where
1157+
LaneCount<N>: SupportedLaneCount,
1158+
M: MaskElement,
1159+
{
1160+
let index = lane_indices::<i8, N>();
1161+
let lt = index.simd_lt(Simd::splat(i8::try_from(len).unwrap_or(i8::MAX)));
1162+
enable & Mask::<M, N>::from_bitmask_vector(lt.to_bitmask_vector())
1163+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#![feature(portable_simd)]
2+
use core_simd::simd::prelude::*;
3+
4+
#[cfg(target_arch = "wasm32")]
5+
use wasm_bindgen_test::*;
6+
7+
#[cfg(target_arch = "wasm32")]
8+
wasm_bindgen_test_configure!(run_in_browser);
9+
10+
#[test]
11+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
12+
fn masked_load_store() {
13+
let mut arr = [u8::MAX; 7];
14+
15+
u8x4::splat(0).masked_store(&mut arr[5..], Mask::from_array([false, true, false, true]));
16+
// write to index 8 is OOB and dropped
17+
assert_eq!(arr, [255u8, 255, 255, 255, 255, 255, 0]);
18+
19+
u8x4::from_array([0, 1, 2, 3]).masked_store(&mut arr[1..], Mask::splat(true));
20+
assert_eq!(arr, [255u8, 0, 1, 2, 3, 255, 0]);
21+
22+
// read from index 8 is OOB and dropped
23+
assert_eq!(
24+
u8x4::load_or(&arr[4..], u8x4::splat(42)),
25+
u8x4::from_array([3, 255, 0, 42])
26+
);
27+
assert_eq!(
28+
u8x4::load_select(
29+
&arr[4..],
30+
Mask::from_array([true, false, true, true]),
31+
u8x4::splat(42)
32+
),
33+
u8x4::from_array([3, 42, 0, 42])
34+
);
35+
}

0 commit comments

Comments
 (0)