From 7b5a417d2c93d2fee30e1e509114c50d1e92830c Mon Sep 17 00:00:00 2001 From: aobatact Date: Fri, 26 May 2023 15:56:44 +0900 Subject: [PATCH 1/2] Add impl for `DistString` to `Uniform` and `Slice` --- src/distributions/slice.rs | 21 +++++++++++++++++++++ src/distributions/uniform.rs | 20 ++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/distributions/slice.rs b/src/distributions/slice.rs index 398cad18b2c..78eb8d4769e 100644 --- a/src/distributions/slice.rs +++ b/src/distributions/slice.rs @@ -7,6 +7,8 @@ // except according to those terms. use crate::distributions::{Distribution, Uniform}; +#[cfg(feature = "alloc")] +use alloc::string::String; /// A distribution to sample items uniformly from a slice. /// @@ -115,3 +117,22 @@ impl core::fmt::Display for EmptySlice { #[cfg(feature = "std")] impl std::error::Error for EmptySlice {} + +/// Note: the `String` is potentially left with excess capacity; optionally the +/// user may call `string.shrink_to_fit()` afterwards. +#[cfg(feature = "alloc")] +impl<'a> super::DistString for Slice<'a, char> { + fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + let max_char_len = self + .slice + .iter() + .try_fold(1, |max_len, char| { + // When the current max_len is 4, the result max_char_len will be 4. + Some(max_len.max(char.len_utf8())).filter(|len| *len < 4) + }) + .unwrap_or(4); + + string.reserve(max_char_len * len); + string.extend(self.sample_iter(rng).take(len)) + } +} diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index a2664768c9c..c18f1014892 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -843,6 +843,26 @@ impl UniformSampler for UniformChar { } } +/// Note: the `String` is potentially left with excess capacity if the range +/// includes non ascii chars; optionally the user may call +/// `string.shrink_to_fit()` afterwards. +#[cfg(feature = "alloc")] +impl super::DistString for Uniform{ + fn append_string(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) { + // Getting the hi value to assume the required length to reserve in string. + let mut hi = self.0.sampler.low + self.0.sampler.range; + if hi >= CHAR_SURROGATE_START { + hi += CHAR_SURROGATE_LEN; + } + // Get the utf8 length of hi to minimize extra space. + // SAFETY: hi used to be valid char. + // This relies on range constructors which accept char arguments. + let max_char_len = unsafe { char::from_u32_unchecked(hi).len_utf8() }; + string.reserve(max_char_len * len); + string.extend(self.sample_iter(rng).take(len)) + } +} + /// The back-end implementing [`UniformSampler`] for floating-point types. /// /// Unless you are implementing [`UniformSampler`] for your own type, this type From 4d4f34db960a54ccfbe56425edc9a29d4616caaf Mon Sep 17 00:00:00 2001 From: aobatact Date: Wed, 31 May 2023 20:39:26 +0900 Subject: [PATCH 2/2] Fix `DistString` impl. --- src/distributions/slice.rs | 33 +++++++++++++++++++++++---------- src/distributions/uniform.rs | 22 ++++++++++++++++++---- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/distributions/slice.rs b/src/distributions/slice.rs index 78eb8d4769e..224bf1712c1 100644 --- a/src/distributions/slice.rs +++ b/src/distributions/slice.rs @@ -123,16 +123,29 @@ impl std::error::Error for EmptySlice {} #[cfg(feature = "alloc")] impl<'a> super::DistString for Slice<'a, char> { fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { - let max_char_len = self - .slice - .iter() - .try_fold(1, |max_len, char| { - // When the current max_len is 4, the result max_char_len will be 4. - Some(max_len.max(char.len_utf8())).filter(|len| *len < 4) - }) - .unwrap_or(4); + // Get the max char length to minimize extra space. + // Limit this check to avoid searching for long slice. + let max_char_len = if self.slice.len() < 200 { + self.slice + .iter() + .try_fold(1, |max_len, char| { + // When the current max_len is 4, the result max_char_len will be 4. + Some(max_len.max(char.len_utf8())).filter(|len| *len < 4) + }) + .unwrap_or(4) + } else { + 4 + }; - string.reserve(max_char_len * len); - string.extend(self.sample_iter(rng).take(len)) + // Split the extension of string to reuse the unused capacities. + // Skip the split for small length or only ascii slice. + let mut extend_len = if max_char_len == 1 || len < 100 { len } else { len / 4 }; + let mut remain_len = len; + while extend_len > 0 { + string.reserve(max_char_len * extend_len); + string.extend(self.sample_iter(&mut *rng).take(extend_len)); + remain_len -= extend_len; + extend_len = extend_len.min(remain_len); + } } } diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index c18f1014892..713961e8e0c 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -850,14 +850,12 @@ impl UniformSampler for UniformChar { impl super::DistString for Uniform{ fn append_string(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) { // Getting the hi value to assume the required length to reserve in string. - let mut hi = self.0.sampler.low + self.0.sampler.range; + let mut hi = self.0.sampler.low + self.0.sampler.range - 1; if hi >= CHAR_SURROGATE_START { hi += CHAR_SURROGATE_LEN; } // Get the utf8 length of hi to minimize extra space. - // SAFETY: hi used to be valid char. - // This relies on range constructors which accept char arguments. - let max_char_len = unsafe { char::from_u32_unchecked(hi).len_utf8() }; + let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4); string.reserve(max_char_len * len); string.extend(self.sample_iter(rng).take(len)) } @@ -1396,6 +1394,22 @@ mod tests { let c = d.sample(&mut rng); assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); } + #[cfg(feature = "alloc")] + { + use crate::distributions::DistString; + let string1 = d.sample_string(&mut rng, 100); + assert_eq!(string1.capacity(), 300); + let string2 = Uniform::new( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ).unwrap().sample_string(&mut rng, 100); + assert_eq!(string2.capacity(), 100); + let string3 = Uniform::new_inclusive( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ).unwrap().sample_string(&mut rng, 100); + assert_eq!(string3.capacity(), 200); + } } #[test]