Skip to content

Commit 1360acd

Browse files
committed
Implement mask_up_to properly for vectors 128 elements and larger
1 parent 9a84478 commit 1360acd

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

crates/core_simd/src/masks.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod sealed {
3434
fn eq(self, other: Self) -> bool;
3535

3636
fn to_usize(self) -> usize;
37+
fn max_unsigned() -> u64;
3738

3839
type Unsigned: SimdElement;
3940

@@ -78,6 +79,11 @@ macro_rules! impl_element {
7879
self as usize
7980
}
8081

82+
#[inline]
83+
fn max_unsigned() -> u64 {
84+
<$unsigned>::MAX as u64
85+
}
86+
8187
type Unsigned = $unsigned;
8288

8389
const TRUE: Self = -1;

crates/core_simd/src/vector.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::simd::{
22
cmp::SimdPartialOrd,
33
ptr::{SimdConstPtr, SimdMutPtr},
4+
num::SimdUint,
45
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle,
56
};
67

@@ -1190,14 +1191,13 @@ where
11901191
}
11911192

11921193
#[inline]
1193-
fn lane_indices<T, const N: usize>() -> Simd<T, N>
1194+
fn lane_indices<const N: usize>() -> Simd<usize, N>
11941195
where
1195-
T: MaskElement + Default + core::convert::From<i8> + core::ops::Add<T, Output = T>,
11961196
LaneCount<N>: SupportedLaneCount,
11971197
{
1198-
let mut index = [T::default(); N];
1199-
for i in 1..N {
1200-
index[i] = index[i - 1] + T::from(1);
1198+
let mut index = [0; N];
1199+
for i in 0..N {
1200+
index[i] = i;
12011201
}
12021202
Simd::from_array(index)
12031203
}
@@ -1208,7 +1208,18 @@ where
12081208
LaneCount<N>: SupportedLaneCount,
12091209
M: MaskElement,
12101210
{
1211-
let index = lane_indices::<i8, N>();
1212-
let lt = index.simd_lt(Simd::splat(i8::try_from(len).unwrap_or(i8::MAX)));
1213-
lt.cast()
1211+
let index = lane_indices::<N>();
1212+
let max_value: u64 = M::max_unsigned();
1213+
macro_rules! case {
1214+
($ty:ty) => {
1215+
if N < <$ty>::MAX as usize && max_value as $ty as u64 == max_value {
1216+
return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
1217+
}
1218+
};
1219+
}
1220+
case!(u8);
1221+
case!(u16);
1222+
case!(u32);
1223+
case!(u64);
1224+
index.simd_lt(Simd::splat(len)).cast()
12141225
}

0 commit comments

Comments
 (0)