Skip to content

Commit 881f850

Browse files
committed
implement simd_reduce_min/max
1 parent 77cd19c commit 881f850

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

src/shims/intrinsics.rs

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
433433
| "simd_reduce_or"
434434
| "simd_reduce_xor"
435435
| "simd_reduce_any"
436-
| "simd_reduce_all" => {
436+
| "simd_reduce_all"
437+
| "simd_reduce_max"
438+
| "simd_reduce_min" => {
437439
use mir::BinOp;
438440

439441
let &[ref op] = check_arg_count(args)?;
@@ -445,19 +447,27 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
445447
enum Op {
446448
MirOp(BinOp),
447449
MirOpBool(BinOp),
450+
Max,
451+
Min,
448452
}
449-
// The initial value is the neutral element.
450-
let (which, init) = match intrinsic_name {
451-
"simd_reduce_and" => (Op::MirOp(BinOp::BitAnd), ImmTy::from_int(-1, dest.layout)),
452-
"simd_reduce_or" => (Op::MirOp(BinOp::BitOr), ImmTy::from_int(0, dest.layout)),
453-
"simd_reduce_xor" => (Op::MirOp(BinOp::BitXor), ImmTy::from_int(0, dest.layout)),
454-
"simd_reduce_any" => (Op::MirOpBool(BinOp::BitOr), imm_from_bool(false)),
455-
"simd_reduce_all" => (Op::MirOpBool(BinOp::BitAnd), imm_from_bool(true)),
453+
let which = match intrinsic_name {
454+
"simd_reduce_and" => Op::MirOp(BinOp::BitAnd),
455+
"simd_reduce_or" => Op::MirOp(BinOp::BitOr),
456+
"simd_reduce_xor" => Op::MirOp(BinOp::BitXor),
457+
"simd_reduce_any" => Op::MirOpBool(BinOp::BitOr),
458+
"simd_reduce_all" => Op::MirOpBool(BinOp::BitAnd),
459+
"simd_reduce_max" => Op::Max,
460+
"simd_reduce_min" => Op::Min,
456461
_ => unreachable!(),
457462
};
458463

459-
let mut res = init;
460-
for i in 0..op_len {
464+
// Initialize with first lane, then proceed with the rest.
465+
let mut res = this.read_immediate(&this.mplace_index(&op, 0)?.into())?;
466+
if matches!(which, Op::MirOpBool(_)) {
467+
// Convert to `bool` scalar.
468+
res = imm_from_bool(simd_element_to_bool(res)?);
469+
}
470+
for i in 1..op_len {
461471
let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
462472
res = match which {
463473
Op::MirOp(mir_op) => {
@@ -467,6 +477,26 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
467477
let op = imm_from_bool(simd_element_to_bool(op)?);
468478
this.binary_op(mir_op, &res, &op)?
469479
}
480+
Op::Max => {
481+
// if `op > res`...
482+
if this.binary_op(BinOp::Gt, &op, &res)?.to_scalar()?.to_bool()? {
483+
// update accumulator
484+
op
485+
} else {
486+
// no change
487+
res
488+
}
489+
}
490+
Op::Min => {
491+
// if `op < res`...
492+
if this.binary_op(BinOp::Lt, &op, &res)?.to_scalar()?.to_bool()? {
493+
// update accumulator
494+
op
495+
} else {
496+
// no change
497+
res
498+
}
499+
}
470500
};
471501
}
472502
this.write_immediate(*res, dest)?;

tests/run-pass/portable-simd.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ fn simd_ops_f32() {
2424
assert_eq!(b.horizontal_sum(), 2.0);
2525
assert_eq!(a.horizontal_product(), 100.0 * 100.0);
2626
assert_eq!(b.horizontal_product(), -24.0);
27+
assert_eq!(a.horizontal_max(), 10.0);
28+
assert_eq!(b.horizontal_max(), 3.0);
29+
assert_eq!(a.horizontal_min(), 10.0);
30+
assert_eq!(b.horizontal_min(), -4.0);
2731
}
2832

2933
fn simd_ops_f64() {
@@ -49,6 +53,10 @@ fn simd_ops_f64() {
4953
assert_eq!(b.horizontal_sum(), 2.0);
5054
assert_eq!(a.horizontal_product(), 100.0 * 100.0);
5155
assert_eq!(b.horizontal_product(), -24.0);
56+
assert_eq!(a.horizontal_max(), 10.0);
57+
assert_eq!(b.horizontal_max(), 3.0);
58+
assert_eq!(a.horizontal_min(), 10.0);
59+
assert_eq!(b.horizontal_min(), -4.0);
5260
}
5361

5462
fn simd_ops_i32() {
@@ -86,6 +94,10 @@ fn simd_ops_i32() {
8694
assert_eq!(b.horizontal_sum(), 2);
8795
assert_eq!(a.horizontal_product(), 100 * 100);
8896
assert_eq!(b.horizontal_product(), -24);
97+
assert_eq!(a.horizontal_max(), 10);
98+
assert_eq!(b.horizontal_max(), 3);
99+
assert_eq!(a.horizontal_min(), 10);
100+
assert_eq!(b.horizontal_min(), -4);
89101
}
90102

91103
fn simd_mask() {

0 commit comments

Comments
 (0)