Skip to content

Commit 5dcd397

Browse files
Finish refactoring ints in ops.rs
This should perform a SIMD check for whether or not we can div/rem, so that we can panic several times faster!
1 parent 049e8ca commit 5dcd397

File tree

1 file changed

+147
-124
lines changed

1 file changed

+147
-124
lines changed

crates/core_simd/src/ops.rs

Lines changed: 147 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
use crate::simd::intrinsics;
2-
use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
1+
use crate::simd::{LaneCount, Mask, Simd, SimdElement, SupportedLaneCount};
32
use core::ops::{Add, Mul};
43
use core::ops::{BitAnd, BitOr, BitXor};
54
use core::ops::{Div, Rem, Sub};
@@ -284,145 +283,169 @@ float_arith! {
284283
}
285284
}
286285

287-
/// Automatically implements operators over references in addition to the provided operator.
288-
macro_rules! impl_ref_ops {
289-
// binary op
290-
{
291-
impl<const $lanes:ident: usize> core::ops::$trait:ident<$rhs:ty> for $type:ty
292-
where
293-
LaneCount<$lanes2:ident>: SupportedLaneCount,
294-
{
295-
type Output = $output:ty;
296-
297-
$(#[$attrs:meta])*
298-
fn $fn:ident($self_tok:ident, $rhs_arg:ident: $rhs_arg_ty:ty) -> Self::Output $body:tt
286+
// Division by zero is poison, according to LLVM.
287+
// So is dividing the MIN value of a signed integer by -1,
288+
// since that would return MAX + 1.
289+
// FIXME: Rust allows <SInt>::MIN / -1,
290+
// so we should probably figure out how to make that safe.
291+
macro_rules! int_divrem_guard {
292+
($(impl<const LANES: usize> $op:ident for Simd<$sint:ty, LANES> {
293+
const PANIC_ZERO: &'static str = $zero:literal;
294+
const PANIC_OVERFLOW: &'static str = $overflow:literal;
295+
fn $call:ident {
296+
unsafe { $simd_call:ident }
299297
}
300-
} => {
301-
impl<const $lanes: usize> core::ops::$trait<$rhs> for $type
298+
})*) => {
299+
$(impl<const LANES: usize> $op for Simd<$sint, LANES>
302300
where
303-
LaneCount<$lanes2>: SupportedLaneCount,
301+
$sint: SimdElement,
302+
LaneCount<LANES>: SupportedLaneCount,
304303
{
305-
type Output = $output;
306-
307-
$(#[$attrs])*
308-
fn $fn($self_tok, $rhs_arg: $rhs_arg_ty) -> Self::Output $body
309-
}
304+
type Output = Self;
305+
#[inline]
306+
#[must_use = "operator returns a new vector without mutating the inputs"]
307+
fn $call(self, rhs: Self) -> Self::Output {
308+
if rhs.lanes_eq(Simd::splat(0)).any() {
309+
panic!("attempt to calculate the remainder with a divisor of zero");
310+
} else if <$sint>::MIN != 0 && self.lanes_eq(Simd::splat(<$sint>::MIN)) & rhs.lanes_eq(Simd::splat(-1 as _))
311+
!= Mask::splat(false)
312+
{
313+
panic!("attempt to calculate the remainder with overflow");
314+
} else {
315+
unsafe { $crate::intrinsics::$simd_call(self, rhs) }
316+
}
317+
}
318+
})*
310319
};
311320
}
312321

313-
/// Automatically implements operators over vectors and scalars for a particular vector.
314-
macro_rules! impl_op {
315-
{ impl Add for $scalar:ty } => {
316-
impl_op! { @binary $scalar, Add::add, simd_add }
317-
};
318-
{ impl Sub for $scalar:ty } => {
319-
impl_op! { @binary $scalar, Sub::sub, simd_sub }
320-
};
321-
{ impl Mul for $scalar:ty } => {
322-
impl_op! { @binary $scalar, Mul::mul, simd_mul }
323-
};
324-
{ impl Div for $scalar:ty } => {
325-
impl_op! { @binary $scalar, Div::div, simd_div }
326-
};
327-
{ impl Rem for $scalar:ty } => {
328-
impl_op! { @binary $scalar, Rem::rem, simd_rem }
329-
};
322+
macro_rules! int_arith {
323+
($(impl<const LANES: usize> IntArith for Simd<$sint:ty, LANES> {
324+
fn add(self, rhs: Self) -> Self::Output;
325+
fn mul(self, rhs: Self) -> Self::Output;
326+
fn sub(self, rhs: Self) -> Self::Output;
327+
fn div(self, rhs: Self) -> Self::Output;
328+
fn rem(self, rhs: Self) -> Self::Output;
329+
})*) => {
330+
$(
331+
unsafe_base_op!{
332+
impl<const LANES: usize> Add for Simd<$sint, LANES> {
333+
fn add(self, rhs: Self) -> Self::Output {
334+
unsafe { simd_add }
335+
}
336+
}
330337

331-
// generic binary op with assignment when output is `Self`
332-
{ @binary $scalar:ty, $trait:ident :: $trait_fn:ident, $intrinsic:ident } => {
333-
impl_ref_ops! {
334-
impl<const LANES: usize> core::ops::$trait<Self> for Simd<$scalar, LANES>
335-
where
336-
LaneCount<LANES>: SupportedLaneCount,
337-
{
338-
type Output = Self;
338+
impl<const LANES: usize> Mul for Simd<$sint, LANES> {
339+
fn mul(self, rhs: Self) -> Self::Output {
340+
unsafe { simd_mul }
341+
}
342+
}
339343

340-
#[inline]
341-
fn $trait_fn(self, rhs: Self) -> Self::Output {
342-
unsafe {
343-
intrinsics::$intrinsic(self, rhs)
344-
}
344+
impl<const LANES: usize> Sub for Simd<$sint, LANES> {
345+
fn sub(self, rhs: Self) -> Self::Output {
346+
unsafe { simd_sub }
345347
}
346348
}
347349
}
348-
};
349-
}
350350

351-
/// Implements unsigned integer operators for the provided types.
352-
macro_rules! impl_unsigned_int_ops {
353-
{ $($scalar:ty),* } => {
354-
$(
355-
impl_op! { impl Add for $scalar }
356-
impl_op! { impl Sub for $scalar }
357-
impl_op! { impl Mul for $scalar }
358-
359-
// Integers panic on divide by 0
360-
impl_ref_ops! {
361-
impl<const LANES: usize> core::ops::Div<Self> for Simd<$scalar, LANES>
362-
where
363-
LaneCount<LANES>: SupportedLaneCount,
364-
{
365-
type Output = Self;
366-
367-
#[inline]
368-
fn div(self, rhs: Self) -> Self::Output {
369-
if rhs.as_array()
370-
.iter()
371-
.any(|x| *x == 0)
372-
{
373-
panic!("attempt to divide by zero");
374-
}
375-
376-
// Guards for div(MIN, -1),
377-
// this check only applies to signed ints
378-
if <$scalar>::MIN != 0 && self.as_array().iter()
379-
.zip(rhs.as_array().iter())
380-
.any(|(x,y)| *x == <$scalar>::MIN && *y == -1 as _) {
381-
panic!("attempt to divide with overflow");
382-
}
383-
unsafe { intrinsics::simd_div(self, rhs) }
384-
}
351+
int_divrem_guard!{
352+
impl<const LANES: usize> Div for Simd<$sint, LANES> {
353+
const PANIC_ZERO: &'static str = "attempt to divide by zero";
354+
const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow";
355+
fn div {
356+
unsafe { simd_div }
385357
}
386358
}
387359

388-
// remainder panics on zero divisor
389-
impl_ref_ops! {
390-
impl<const LANES: usize> core::ops::Rem<Self> for Simd<$scalar, LANES>
391-
where
392-
LaneCount<LANES>: SupportedLaneCount,
393-
{
394-
type Output = Self;
395-
396-
#[inline]
397-
fn rem(self, rhs: Self) -> Self::Output {
398-
if rhs.as_array()
399-
.iter()
400-
.any(|x| *x == 0)
401-
{
402-
panic!("attempt to calculate the remainder with a divisor of zero");
403-
}
404-
405-
// Guards for rem(MIN, -1)
406-
// this branch applies the check only to signed ints
407-
if <$scalar>::MIN != 0 && self.as_array().iter()
408-
.zip(rhs.as_array().iter())
409-
.any(|(x,y)| *x == <$scalar>::MIN && *y == -1 as _) {
410-
panic!("attempt to calculate the remainder with overflow");
411-
}
412-
unsafe { intrinsics::simd_rem(self, rhs) }
413-
}
360+
impl<const LANES: usize> Rem for Simd<$sint, LANES> {
361+
const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero";
362+
const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow";
363+
fn rem {
364+
unsafe { simd_rem }
414365
}
415366
}
416-
)*
417-
};
367+
})*
368+
}
418369
}
419370

420-
/// Implements unsigned integer operators for the provided types.
421-
macro_rules! impl_signed_int_ops {
422-
{ $($scalar:ty),* } => {
423-
impl_unsigned_int_ops! { $($scalar),* }
424-
};
425-
}
371+
int_arith! {
372+
impl<const LANES: usize> IntArith for Simd<i8, LANES> {
373+
fn add(self, rhs: Self) -> Self::Output;
374+
fn mul(self, rhs: Self) -> Self::Output;
375+
fn sub(self, rhs: Self) -> Self::Output;
376+
fn div(self, rhs: Self) -> Self::Output;
377+
fn rem(self, rhs: Self) -> Self::Output;
378+
}
426379

427-
impl_unsigned_int_ops! { u8, u16, u32, u64, usize }
428-
impl_signed_int_ops! { i8, i16, i32, i64, isize }
380+
impl<const LANES: usize> IntArith for Simd<i16, LANES> {
381+
fn add(self, rhs: Self) -> Self::Output;
382+
fn mul(self, rhs: Self) -> Self::Output;
383+
fn sub(self, rhs: Self) -> Self::Output;
384+
fn div(self, rhs: Self) -> Self::Output;
385+
fn rem(self, rhs: Self) -> Self::Output;
386+
}
387+
388+
impl<const LANES: usize> IntArith for Simd<i32, LANES> {
389+
fn add(self, rhs: Self) -> Self::Output;
390+
fn mul(self, rhs: Self) -> Self::Output;
391+
fn sub(self, rhs: Self) -> Self::Output;
392+
fn div(self, rhs: Self) -> Self::Output;
393+
fn rem(self, rhs: Self) -> Self::Output;
394+
}
395+
396+
impl<const LANES: usize> IntArith for Simd<i64, LANES> {
397+
fn add(self, rhs: Self) -> Self::Output;
398+
fn mul(self, rhs: Self) -> Self::Output;
399+
fn sub(self, rhs: Self) -> Self::Output;
400+
fn div(self, rhs: Self) -> Self::Output;
401+
fn rem(self, rhs: Self) -> Self::Output;
402+
}
403+
404+
impl<const LANES: usize> IntArith for Simd<isize, LANES> {
405+
fn add(self, rhs: Self) -> Self::Output;
406+
fn mul(self, rhs: Self) -> Self::Output;
407+
fn sub(self, rhs: Self) -> Self::Output;
408+
fn div(self, rhs: Self) -> Self::Output;
409+
fn rem(self, rhs: Self) -> Self::Output;
410+
}
411+
412+
impl<const LANES: usize> IntArith for Simd<u8, LANES> {
413+
fn add(self, rhs: Self) -> Self::Output;
414+
fn mul(self, rhs: Self) -> Self::Output;
415+
fn sub(self, rhs: Self) -> Self::Output;
416+
fn div(self, rhs: Self) -> Self::Output;
417+
fn rem(self, rhs: Self) -> Self::Output;
418+
}
419+
420+
impl<const LANES: usize> IntArith for Simd<u16, LANES> {
421+
fn add(self, rhs: Self) -> Self::Output;
422+
fn mul(self, rhs: Self) -> Self::Output;
423+
fn sub(self, rhs: Self) -> Self::Output;
424+
fn div(self, rhs: Self) -> Self::Output;
425+
fn rem(self, rhs: Self) -> Self::Output;
426+
}
427+
428+
impl<const LANES: usize> IntArith for Simd<u32, LANES> {
429+
fn add(self, rhs: Self) -> Self::Output;
430+
fn mul(self, rhs: Self) -> Self::Output;
431+
fn sub(self, rhs: Self) -> Self::Output;
432+
fn div(self, rhs: Self) -> Self::Output;
433+
fn rem(self, rhs: Self) -> Self::Output;
434+
}
435+
436+
impl<const LANES: usize> IntArith for Simd<u64, LANES> {
437+
fn add(self, rhs: Self) -> Self::Output;
438+
fn mul(self, rhs: Self) -> Self::Output;
439+
fn sub(self, rhs: Self) -> Self::Output;
440+
fn div(self, rhs: Self) -> Self::Output;
441+
fn rem(self, rhs: Self) -> Self::Output;
442+
}
443+
444+
impl<const LANES: usize> IntArith for Simd<usize, LANES> {
445+
fn add(self, rhs: Self) -> Self::Output;
446+
fn mul(self, rhs: Self) -> Self::Output;
447+
fn sub(self, rhs: Self) -> Self::Output;
448+
fn div(self, rhs: Self) -> Self::Output;
449+
fn rem(self, rhs: Self) -> Self::Output;
450+
}
451+
}

0 commit comments

Comments
 (0)