Skip to content

Fix array::IntoIter::fold to use the optimized Range::fold #95602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion library/core/src/array/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> {
Fold: FnMut(Acc, Self::Item) -> Acc,
{
let data = &mut self.data;
self.alive.by_ref().fold(init, |acc, idx| {
iter::ByRefSized(&mut self.alive).fold(init, |acc, idx| {
// SAFETY: idx is obtained by folding over the `alive` range, which implies the
// value is currently considered alive but as the range is being consumed each value
// we read here will only be read once and then considered dead.
Expand Down Expand Up @@ -323,6 +323,20 @@ impl<T, const N: usize> DoubleEndedIterator for IntoIter<T, N> {
})
}

#[inline]
fn rfold<Acc, Fold>(mut self, init: Acc, mut rfold: Fold) -> Acc
where
Fold: FnMut(Acc, Self::Item) -> Acc,
{
let data = &mut self.data;
iter::ByRefSized(&mut self.alive).rfold(init, |acc, idx| {
// SAFETY: idx is obtained by folding over the `alive` range, which implies the
// value is currently considered alive but as the range is being consumed each value
// we read here will only be read once and then considered dead.
rfold(acc, unsafe { data.get_unchecked(idx).assume_init_read() })
})
}

fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let len = self.len();

Expand Down
40 changes: 40 additions & 0 deletions library/core/src/iter/adapters/by_ref_sized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,35 @@ pub(crate) struct ByRefSized<'a, I>(pub &'a mut I);
impl<I: Iterator> Iterator for ByRefSized<'_, I> {
type Item = I::Item;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}

#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
self.0.advance_by(n)
}

#[inline]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.0.nth(n)
}

#[inline]
fn fold<B, F>(self, init: B, f: F) -> B
where
F: FnMut(B, Self::Item) -> B,
{
self.0.fold(init, f)
}

#[inline]
fn try_fold<B, F, R>(&mut self, init: B, f: F) -> R
where
F: FnMut(B, Self::Item) -> R,
Expand All @@ -40,3 +46,37 @@ impl<I: Iterator> Iterator for ByRefSized<'_, I> {
self.0.try_fold(init, f)
}
}

impl<I: DoubleEndedIterator> DoubleEndedIterator for ByRefSized<'_, I> {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
self.0.next_back()
}

#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
self.0.advance_back_by(n)
}

#[inline]
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
self.0.nth_back(n)
}

#[inline]
fn rfold<B, F>(self, init: B, f: F) -> B
where
F: FnMut(B, Self::Item) -> B,
{
self.0.rfold(init, f)
}

#[inline]
fn try_rfold<B, F, R>(&mut self, init: B, f: F) -> R
where
F: FnMut(B, Self::Item) -> R,
R: Try<Output = B>,
{
self.0.try_rfold(init, f)
}
}
32 changes: 32 additions & 0 deletions library/core/tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,35 @@ fn array_mixed_equality_nans() {
assert!(!(mut3 == array3));
assert!(mut3 != array3);
}

#[test]
fn array_into_iter_fold() {
// Strings to help MIRI catch if we double-free or something
let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
let mut s = "s".to_string();
a.into_iter().for_each(|b| s += &b);
assert_eq!(s, "sAaBbCc");

let a = [1, 2, 3, 4, 5, 6];
let mut it = a.into_iter();
it.advance_by(1).unwrap();
it.advance_back_by(2).unwrap();
let s = it.fold(10, |a, b| 10 * a + b);
assert_eq!(s, 10234);
}

#[test]
fn array_into_iter_rfold() {
// Strings to help MIRI catch if we double-free or something
let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
let mut s = "s".to_string();
a.into_iter().rev().for_each(|b| s += &b);
assert_eq!(s, "sCcBbAa");

let a = [1, 2, 3, 4, 5, 6];
let mut it = a.into_iter();
it.advance_by(1).unwrap();
it.advance_back_by(2).unwrap();
let s = it.rfold(10, |a, b| 10 * a + b);
assert_eq!(s, 10432);
}
54 changes: 54 additions & 0 deletions src/test/codegen/simd-wide-sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// compile-flags: -C opt-level=3 --edition=2021
// only-x86_64
// ignore-debug: the debug assertions get in the way

#![crate_type = "lib"]
#![feature(portable_simd)]

use std::simd::Simd;
const N: usize = 8;

#[no_mangle]
// CHECK-LABEL: @wider_reduce_simd
pub fn wider_reduce_simd(x: Simd<u8, N>) -> u16 {
// CHECK: zext <8 x i8>
// CHECK-SAME: to <8 x i16>
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
let x: Simd<u16, N> = x.cast();
x.reduce_sum()
}

#[no_mangle]
// CHECK-LABEL: @wider_reduce_loop
pub fn wider_reduce_loop(x: Simd<u8, N>) -> u16 {
// CHECK: zext <8 x i8>
// CHECK-SAME: to <8 x i16>
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
let mut sum = 0_u16;
for i in 0..N {
sum += u16::from(x[i]);
}
sum
}

#[no_mangle]
// CHECK-LABEL: @wider_reduce_iter
pub fn wider_reduce_iter(x: Simd<u8, N>) -> u16 {
// CHECK: zext <8 x i8>
// CHECK-SAME: to <8 x i16>
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
x.as_array().iter().copied().map(u16::from).sum()
}

// This iterator one is the most interesting, as it's the one
// which used to not auto-vectorize due to a suboptimality in the
// `<array::IntoIter as Iterator>::fold` implementation.

#[no_mangle]
// CHECK-LABEL: @wider_reduce_into_iter
pub fn wider_reduce_into_iter(x: Simd<u8, N>) -> u16 {
// CHECK: zext <8 x i8>
// CHECK-SAME: to <8 x i16>
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
x.to_array().into_iter().map(u16::from).sum()
}