diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index 19235237f..63cf1c397 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -246,7 +246,7 @@ where R: Rng + ?Sized, Sh: ShapeBuilder, { - Self::from_shape_fn(shape, |_| dist.sample(rng)) + Self::from_shape_simple_fn(shape, move || dist.sample(rng)) } fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index 52de74595..f7860ac12 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -1,6 +1,8 @@ use ndarray::{Array, Array2, ArrayView1, Axis}; #[cfg(feature = "quickcheck")] use ndarray_rand::rand::{distributions::Distribution, thread_rng}; + +use ndarray::ShapeBuilder; use ndarray_rand::rand_distr::Uniform; use ndarray_rand::{RandomExt, SamplingStrategy}; use quickcheck::quickcheck; @@ -14,6 +16,21 @@ fn test_dim() { assert_eq!(a.shape(), &[m, n]); assert!(a.iter().all(|x| *x < 2.)); assert!(a.iter().all(|x| *x >= 0.)); + assert!(a.is_standard_layout()); + } + } +} + +#[test] +fn test_dim_f() { + let (mm, nn) = (5, 5); + for m in 0..mm { + for n in 0..nn { + let a = Array::random((m, n).f(), Uniform::new(0., 2.)); + assert_eq!(a.shape(), &[m, n]); + assert!(a.iter().all(|x| *x < 2.)); + assert!(a.iter().all(|x| *x >= 0.)); + assert!(a.t().is_standard_layout()); } } } diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 4c35bbfb3..cb39f8e9b 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -305,10 +305,28 @@ where where A: Default, Sh: ShapeBuilder, + { + Self::from_shape_simple_fn(shape, A::default) + } + + /// Create an array with values created by the function `f`. + /// + /// `f` is called with no argument, and it should return the element to + /// create. If the precise index of the element to create is needed, + /// use [`from_shape_fn`](ArrayBase::from_shape_fn) instead. + /// + /// This constructor can be useful if the element order is not important, + /// for example if they are identical or random. + /// + /// **Panics** if the product of non-zero axis lengths overflows `isize`. + pub fn from_shape_simple_fn(shape: Sh, mut f: F) -> Self + where + Sh: ShapeBuilder, + F: FnMut() -> A, { let shape = shape.into_shape(); - let size = size_of_shape_checked_unwrap!(&shape.dim); - let v = to_vec((0..size).map(|_| A::default())); + let len = size_of_shape_checked_unwrap!(&shape.dim); + let v = to_vec_mapped(0..len, move |_| f()); unsafe { Self::from_shape_vec_unchecked(shape, v) } } @@ -318,6 +336,20 @@ where /// visited in arbitrary order. /// /// **Panics** if the product of non-zero axis lengths overflows `isize`. + /// + /// ``` + /// use ndarray::{Array, arr2}; + /// + /// // Create a table of i × j (with i and j from 1 to 3) + /// let ij_table = Array::from_shape_fn((3, 3), |(i, j)| (1 + i) * (1 + j)); + /// + /// assert_eq!( + /// ij_table, + /// arr2(&[[1, 2, 3], + /// [2, 4, 6], + /// [3, 6, 9]]) + /// ); + /// ``` pub fn from_shape_fn(shape: Sh, f: F) -> Self where Sh: ShapeBuilder, diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 5caa11b81..97193de36 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2187,7 +2187,7 @@ where let view_stride = self.strides.axis(axis); if view_len == 0 { let new_dim = self.dim.remove_axis(axis); - Array::from_shape_fn(new_dim, move |_| mapping(ArrayView::from(&[]))) + Array::from_shape_simple_fn(new_dim, move || mapping(ArrayView::from(&[]))) } else { // use the 0th subview as a map to each 1d array view extended from // the 0th element. @@ -2218,7 +2218,7 @@ where let view_stride = self.strides.axis(axis); if view_len == 0 { let new_dim = self.dim.remove_axis(axis); - Array::from_shape_fn(new_dim, move |_| mapping(ArrayViewMut::from(&mut []))) + Array::from_shape_simple_fn(new_dim, move || mapping(ArrayViewMut::from(&mut []))) } else { // use the 0th subview as a map to each 1d array view extended from // the 0th element.