Skip to content

Commit 1f237b3

Browse files
committed
implement SIMD float rounding functions
1 parent 2029641 commit 1f237b3

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

src/shims/intrinsics.rs

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,20 +325,37 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
325325
// SIMD operations
326326
#[rustfmt::skip]
327327
| "simd_neg"
328-
| "simd_fabs" => {
328+
| "simd_fabs"
329+
| "simd_ceil"
330+
| "simd_floor"
331+
| "simd_round"
332+
| "simd_trunc" => {
329333
let &[ref op] = check_arg_count(args)?;
330334
let (op, op_len) = this.operand_to_simd(op)?;
331335
let (dest, dest_len) = this.place_to_simd(dest)?;
332336

333337
assert_eq!(dest_len, op_len);
334338

339+
#[derive(Copy, Clone)]
340+
enum HostFloatOp {
341+
Ceil,
342+
Floor,
343+
Round,
344+
Trunc,
345+
}
346+
#[derive(Copy, Clone)]
335347
enum Op {
336348
MirOp(mir::UnOp),
337349
Abs,
350+
HostOp(HostFloatOp),
338351
}
339352
let which = match intrinsic_name {
340353
"simd_neg" => Op::MirOp(mir::UnOp::Neg),
341354
"simd_fabs" => Op::Abs,
355+
"simd_ceil" => Op::HostOp(HostFloatOp::Ceil),
356+
"simd_floor" => Op::HostOp(HostFloatOp::Floor),
357+
"simd_round" => Op::HostOp(HostFloatOp::Round),
358+
"simd_trunc" => Op::HostOp(HostFloatOp::Trunc),
342359
_ => unreachable!(),
343360
};
344361

@@ -350,14 +367,43 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
350367
Op::Abs => {
351368
// Works for f32 and f64.
352369
let ty::Float(float_ty) = op.layout.ty.kind() else {
353-
bug!("simd_fabs operand is not a float")
370+
bug!("{} operand is not a float", intrinsic_name)
354371
};
355372
let op = op.to_scalar()?;
356373
match float_ty {
357374
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
358375
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
359376
}
360377
}
378+
Op::HostOp(host_op) => {
379+
let ty::Float(float_ty) = op.layout.ty.kind() else {
380+
bug!("{} operand is not a float", intrinsic_name)
381+
};
382+
// FIXME using host floats
383+
match float_ty {
384+
FloatTy::F32 => {
385+
let f = f32::from_bits(op.to_scalar()?.to_u32()?);
386+
let res = match host_op {
387+
HostFloatOp::Ceil => f.ceil(),
388+
HostFloatOp::Floor => f.floor(),
389+
HostFloatOp::Round => f.round(),
390+
HostFloatOp::Trunc => f.trunc(),
391+
};
392+
Scalar::from_u32(res.to_bits())
393+
}
394+
FloatTy::F64 => {
395+
let f = f64::from_bits(op.to_scalar()?.to_u64()?);
396+
let res = match host_op {
397+
HostFloatOp::Ceil => f.ceil(),
398+
HostFloatOp::Floor => f.floor(),
399+
HostFloatOp::Round => f.round(),
400+
HostFloatOp::Trunc => f.trunc(),
401+
};
402+
Scalar::from_u64(res.to_bits())
403+
}
404+
}
405+
406+
}
361407
};
362408
this.write_scalar(val, &dest.into())?;
363409
}

tests/run-pass/portable-simd.rs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,39 @@ fn simd_ops_i32() {
106106
assert_eq!(a.min(b * i32x4::splat(4)), i32x4::from_array([4, 8, 10, -16]));
107107

108108
assert_eq!(
109-
i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([1, i8::MIN, i8::MAX, 28])),
109+
i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([
110+
1,
111+
i8::MIN,
112+
i8::MAX,
113+
28
114+
])),
110115
i8x4::from_array([i8::MAX, i8::MIN, i8::MAX, -100])
111116
);
112117
assert_eq!(
113-
i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([1, i8::MAX, i8::MAX, -80])),
118+
i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([
119+
1,
120+
i8::MAX,
121+
i8::MAX,
122+
-80
123+
])),
114124
i8x4::from_array([126, i8::MIN, -100, 122])
115125
);
116126
assert_eq!(
117-
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([1, 1, u8::MAX, 200])),
127+
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([
128+
1,
129+
1,
130+
u8::MAX,
131+
200
132+
])),
118133
u8x4::from_array([u8::MAX, 1, u8::MAX, 242])
119134
);
120135
assert_eq!(
121-
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([1, 1, u8::MAX, 200])),
136+
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([
137+
1,
138+
1,
139+
u8::MAX,
140+
200
141+
])),
122142
u8x4::from_array([254, 0, 0, 0])
123143
);
124144

@@ -259,6 +279,42 @@ fn simd_gather_scatter() {
259279
assert_eq!(vec, vec![124, 11, 12, 82, 14, 15, 16, 17, 18]);
260280
}
261281

282+
fn simd_round() {
283+
assert_eq!(
284+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
285+
f32x4::from_array([1.0, 2.0, 2.0, -4.0])
286+
);
287+
assert_eq!(
288+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
289+
f32x4::from_array([0.0, 1.0, 2.0, -5.0])
290+
);
291+
assert_eq!(
292+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
293+
f32x4::from_array([1.0, 1.0, 2.0, -5.0])
294+
);
295+
assert_eq!(
296+
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
297+
f32x4::from_array([0.0, 1.0, 2.0, -4.0])
298+
);
299+
300+
assert_eq!(
301+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
302+
f64x4::from_array([1.0, 2.0, 2.0, -4.0])
303+
);
304+
assert_eq!(
305+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
306+
f64x4::from_array([0.0, 1.0, 2.0, -5.0])
307+
);
308+
assert_eq!(
309+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
310+
f64x4::from_array([1.0, 1.0, 2.0, -5.0])
311+
);
312+
assert_eq!(
313+
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
314+
f64x4::from_array([0.0, 1.0, 2.0, -4.0])
315+
);
316+
}
317+
262318
fn simd_intrinsics() {
263319
extern "platform-intrinsic" {
264320
fn simd_eq<T, U>(x: T, y: T) -> U;
@@ -299,5 +355,6 @@ fn main() {
299355
simd_cast();
300356
simd_swizzle();
301357
simd_gather_scatter();
358+
simd_round();
302359
simd_intrinsics();
303360
}

0 commit comments

Comments
 (0)