From 72b01d5cca0604e1ea5818d90b9feefae92a0093 Mon Sep 17 00:00:00 2001 From: The 8472 Date: Fri, 25 Aug 2023 23:29:18 +0200 Subject: [PATCH 1/3] Optimize Take::{fold, for_each} when wrapping TrustedRandomAccess iterators --- library/core/src/iter/adapters/take.rs | 102 ++++++++++++++++---- tests/codegen/lib-optimizations/iter-sum.rs | 14 +++ 2 files changed, 97 insertions(+), 19 deletions(-) create mode 100644 tests/codegen/lib-optimizations/iter-sum.rs diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs index ce18bffe7146f..70252e075b9f6 100644 --- a/library/core/src/iter/adapters/take.rs +++ b/library/core/src/iter/adapters/take.rs @@ -1,5 +1,7 @@ use crate::cmp; -use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen}; +use crate::iter::{ + adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen, TrustedRandomAccess, +}; use crate::num::NonZeroUsize; use crate::ops::{ControlFlow, Try}; @@ -98,26 +100,18 @@ where } } - impl_fold_via_try_fold! { fold -> try_fold } - #[inline] - fn for_each(mut self, f: F) { - // The default implementation would use a unit accumulator, so we can - // avoid a stateful closure by folding over the remaining number - // of items we wish to return instead. - fn check<'a, Item>( - mut action: impl FnMut(Item) + 'a, - ) -> impl FnMut(usize, Item) -> Option + 'a { - move |more, x| { - action(x); - more.checked_sub(1) - } - } + fn fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + Self::spec_fold(self, init, f) + } - let remaining = self.n; - if remaining > 0 { - self.iter.try_fold(remaining - 1, check(f)); - } + #[inline] + fn for_each(self, f: F) { + Self::spec_for_each(self, f) } #[inline] @@ -249,3 +243,73 @@ impl FusedIterator for Take where I: FusedIterator {} #[unstable(feature = "trusted_len", issue = "37572")] unsafe impl TrustedLen for Take {} + +trait SpecTake: Iterator { + fn spec_fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B; + + fn spec_for_each(self, f: F); +} + +impl SpecTake for Take { + #[inline] + default fn spec_fold(mut self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + use crate::ops::NeverShortCircuit; + self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0 + } + + #[inline] + default fn spec_for_each(mut self, f: F) { + // The default implementation would use a unit accumulator, so we can + // avoid a stateful closure by folding over the remaining number + // of items we wish to return instead. + fn check<'a, Item>( + mut action: impl FnMut(Item) + 'a, + ) -> impl FnMut(usize, Item) -> Option + 'a { + move |more, x| { + action(x); + more.checked_sub(1) + } + } + + let remaining = self.n; + if remaining > 0 { + self.iter.try_fold(remaining - 1, check(f)); + } + } +} + +impl SpecTake for Take { + #[inline] + fn spec_fold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + let mut acc = init; + let end = self.n.min(self.iter.size()); + for i in 0..end { + // SAFETY: i < end <= self.iter.size() and we discard the iterator at the end + let val = unsafe { self.iter.__iterator_get_unchecked(i) }; + acc = f(acc, val); + } + acc + } + + #[inline] + fn spec_for_each(self, f: F) { + // Based on the the Iterator trait default impl. + #[inline] + fn call(mut f: impl FnMut(T)) -> impl FnMut((), T) { + move |(), item| f(item) + } + + self.spec_fold((), call(f)); + } +} diff --git a/tests/codegen/lib-optimizations/iter-sum.rs b/tests/codegen/lib-optimizations/iter-sum.rs new file mode 100644 index 0000000000000..d6ea4cd74d558 --- /dev/null +++ b/tests/codegen/lib-optimizations/iter-sum.rs @@ -0,0 +1,14 @@ +// ignore-debug: the debug assertions get in the way +// compile-flags: -O +#![crate_type = "lib"] + + +// Ensure that slice + take + sum gets vectorized. +// Currently this relies on the slice::Iter::try_fold implementation +// CHECK-LABEL: @slice_take_sum +#[no_mangle] +pub fn slice_take_sum(s: &[u64], l: usize) -> u64 { + // CHECK: vector.body: + // CHECK: ret + s.iter().take(l).sum() +} From 07a1d5f0273fdcce5ec7ec3cba9a6db232d9124b Mon Sep 17 00:00:00 2001 From: The 8472 Date: Mon, 28 Aug 2023 14:35:51 +0200 Subject: [PATCH 2/3] reduce indirection in for_each specialization --- library/core/src/iter/adapters/take.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs index 70252e075b9f6..c1d8cc4ff57bd 100644 --- a/library/core/src/iter/adapters/take.rs +++ b/library/core/src/iter/adapters/take.rs @@ -303,13 +303,12 @@ impl SpecTake for Take { } #[inline] - fn spec_for_each(self, f: F) { - // Based on the the Iterator trait default impl. - #[inline] - fn call(mut f: impl FnMut(T)) -> impl FnMut((), T) { - move |(), item| f(item) + fn spec_for_each(mut self, mut f: F) { + let end = self.n.min(self.iter.size()); + for i in 0..end { + // SAFETY: i < end <= self.iter.size() and we discard the iterator at the end + let val = unsafe { self.iter.__iterator_get_unchecked(i) }; + f(val); } - - self.spec_fold((), call(f)); } } From f93e1258287e189aaf080d7e5336bac75633eb58 Mon Sep 17 00:00:00 2001 From: The 8472 Date: Sat, 2 Sep 2023 13:42:58 +0200 Subject: [PATCH 3/3] restrict test to x86-64 --- tests/codegen/lib-optimizations/iter-sum.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/codegen/lib-optimizations/iter-sum.rs b/tests/codegen/lib-optimizations/iter-sum.rs index d6ea4cd74d558..ff7ca6ef6c11e 100644 --- a/tests/codegen/lib-optimizations/iter-sum.rs +++ b/tests/codegen/lib-optimizations/iter-sum.rs @@ -1,5 +1,6 @@ // ignore-debug: the debug assertions get in the way // compile-flags: -O +// only-x86_64 (vectorization varies between architectures) #![crate_type = "lib"]