diff --git a/Data/ByteString/Internal.hs b/Data/ByteString/Internal.hs index b70829078..ab0bb31f6 100644 --- a/Data/ByteString/Internal.hs +++ b/Data/ByteString/Internal.hs @@ -108,10 +108,10 @@ import Foreign.C.Types (CInt, CSize) import Foreign.C.String (CString) #if MIN_VERSION_base(4,13,0) -import Data.Semigroup (Semigroup (sconcat)) +import Data.Semigroup (Semigroup (sconcat, stimes)) import Data.List.NonEmpty (NonEmpty ((:|))) #elif MIN_VERSION_base(4,9,0) -import Data.Semigroup (Semigroup ((<>), sconcat)) +import Data.Semigroup (Semigroup ((<>), sconcat, stimes)) import Data.List.NonEmpty (NonEmpty ((:|))) #endif @@ -241,6 +241,7 @@ instance Ord ByteString where instance Semigroup ByteString where (<>) = append sconcat (b:|bs) = concat (b:bs) + stimes = times #endif instance Monoid ByteString where @@ -648,6 +649,31 @@ concat = \bss0 -> goLen0 bss0 bss0 concat [x] = x #-} +-- | /O(log n)/ Repeats the given ByteString n times. +times :: Integral a => a -> ByteString -> ByteString +times n (BS fp len) + | n < 0 = error "stimes: non-negative multiplier expected" + | n == 0 = mempty + | n == 1 = BS fp len + | len == 0 = mempty + | len == 1 = unsafeCreate size $ \destptr -> + withForeignPtr fp $ \p -> do + byte <- peek p + memset destptr byte (fromIntegral size) >> return () + | otherwise = unsafeCreate size $ \destptr -> + withForeignPtr fp $ \p -> do + memcpy destptr p len + fillFrom destptr len + where + size = len * (fromIntegral n) + + fillFrom :: Ptr Word8 -> Int -> IO () + fillFrom destptr copied + | 2 * copied < size = do + memcpy (destptr `plusPtr` copied) destptr copied + fillFrom destptr (copied * 2) + | otherwise = memcpy (destptr `plusPtr` copied) destptr (size - copied) + -- | Add two non-negative numbers. Errors out on overflow. checkedAdd :: String -> Int -> Int -> Int checkedAdd fun x y diff --git a/Data/ByteString/Lazy/Internal.hs b/Data/ByteString/Lazy/Internal.hs index 45e5c9c4f..e1b13649c 100644 --- a/Data/ByteString/Lazy/Internal.hs +++ b/Data/ByteString/Lazy/Internal.hs @@ -57,10 +57,10 @@ import Foreign.Ptr (plusPtr) import Foreign.Storable (Storable(sizeOf)) #if MIN_VERSION_base(4,13,0) -import Data.Semigroup (Semigroup (sconcat)) +import Data.Semigroup (Semigroup (sconcat, stimes)) import Data.List.NonEmpty (NonEmpty ((:|))) #elif MIN_VERSION_base(4,9,0) -import Data.Semigroup (Semigroup ((<>), sconcat)) +import Data.Semigroup (Semigroup ((<>), sconcat, stimes)) import Data.List.NonEmpty (NonEmpty ((:|))) #endif #if !(MIN_VERSION_base(4,8,0)) @@ -98,6 +98,7 @@ instance Ord ByteString where instance Semigroup ByteString where (<>) = append sconcat (b:|bs) = concat (b:bs) + stimes = times #endif instance Monoid ByteString where @@ -275,6 +276,18 @@ concat css0 = to css0 to [] = Empty to (cs:css) = go cs css +-- | Repeats the given ByteString n times. +times :: Integral a => a -> ByteString -> ByteString +times 0 _ = Empty +times n lbs0 + | n < 0 = error "stimes: non-negative multiplier expected" + | otherwise = case lbs0 of + Empty -> Empty + Chunk bs lbs -> Chunk bs (go lbs) + where + go Empty = times (n-1) lbs0 + go (Chunk c cs) = Chunk c (go cs) + ------------------------------------------------------------------------ -- Conversions diff --git a/tests/Properties.hs b/tests/Properties.hs index 751bfb7b6..29993a917 100644 --- a/tests/Properties.hs +++ b/tests/Properties.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP, BangPatterns #-} -- -- Must have rules off, otherwise the rewrite rules will replace the rhs -- with the lhs, and we only end up testing lhs == lhs @@ -26,6 +26,9 @@ import Data.Word import Data.Maybe import Data.Int (Int64) import Data.Monoid +#if MIN_VERSION_base(4,9,0) +import Data.Semigroup +#endif import Text.Printf import Data.String @@ -1376,6 +1379,13 @@ prop_packZipWithLC f xs ys = LC.pack (LC.zipWith f xs ys) == LC.packZipWith f xs prop_unzipBB x = let (xs,ys) = unzip x in (P.pack xs, P.pack ys) == P.unzip x +#if MIN_VERSION_base(4,9,0) +prop_stimesBB :: NonNegative Int -> P.ByteString -> Bool +prop_stimesBB (NonNegative i) bs = stimes i bs == mtimesDefault i bs + +prop_stimesLL :: NonNegative Int -> L.ByteString -> Bool +prop_stimesLL (NonNegative i) bs = stimes i bs == mtimesDefault i bs +#endif -- prop_zipwith_spec f p q = -- P.pack (P.zipWith f p q) == P.zipWith' f p q @@ -2368,6 +2378,10 @@ bb_tests = , testProperty "unzip" prop_unzipBB , testProperty "concatMap" prop_concatMapBB -- , testProperty "join/joinByte" prop_join_spec +#if MIN_VERSION_base(4,9,0) + , testProperty "stimes strict" prop_stimesBB + , testProperty "stimes lazy" prop_stimesLL +#endif ]