Skip to content

Commit

Permalink
Implement stimes more efficiently (#301)
Browse files Browse the repository at this point in the history
* Adds stimes implementation to strict ByteString

* Adds stimes to Lazy ByteString

* Actually make Lazy stimes work

* Strict stimes now calls memcpy log n instead of n times

* Correct fill_from capitalization

* Throw error on non negative n in strict

* expect positive multiplier in lazy stimes

* Make strict stimes work for n=0

* Add n=0 case to Lazy stimes again

* non-negative -> positive

Co-authored-by: Bodigrim <[email protected]>

* non-negative -> positive 2

Co-authored-by: Bodigrim <[email protected]>

* Add test cases for stimes

* positive n for stimes precondition in tests

* Use QuickCheck NonNegative

* Swap memcpy arguments

* Add semigroups to build-depends in tests

* Only use semigroups in tests when base > 490

* Restrict semigroup import to base>=490

* Use mempty and make more ledgible docs

* Optimize strict times

* Guard Lazy Empty case for greater than 0

* Handle n < 0 better

* Swap guard cases

* Added CPP to Pragmas

Co-authored-by: Bodigrim <[email protected]>
  • Loading branch information
elikoga and Bodigrim authored Oct 28, 2020
1 parent d7c9647 commit 1f97c4c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
30 changes: 28 additions & 2 deletions Data/ByteString/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ import Foreign.C.Types (CInt, CSize)
import Foreign.C.String (CString)

#if MIN_VERSION_base(4,13,0)
import Data.Semigroup (Semigroup (sconcat))
import Data.Semigroup (Semigroup (sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#elif MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup ((<>), sconcat))
import Data.Semigroup (Semigroup ((<>), sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#endif

Expand Down Expand Up @@ -241,6 +241,7 @@ instance Ord ByteString where
instance Semigroup ByteString where
(<>) = append
sconcat (b:|bs) = concat (b:bs)
stimes = times
#endif

instance Monoid ByteString where
Expand Down Expand Up @@ -648,6 +649,31 @@ concat = \bss0 -> goLen0 bss0 bss0
concat [x] = x
#-}

-- | /O(log n)/ Repeats the given ByteString n times.
times :: Integral a => a -> ByteString -> ByteString
times n (BS fp len)
| n < 0 = error "stimes: non-negative multiplier expected"
| n == 0 = mempty
| n == 1 = BS fp len
| len == 0 = mempty
| len == 1 = unsafeCreate size $ \destptr ->
withForeignPtr fp $ \p -> do
byte <- peek p
memset destptr byte (fromIntegral size) >> return ()
| otherwise = unsafeCreate size $ \destptr ->
withForeignPtr fp $ \p -> do
memcpy destptr p len
fillFrom destptr len
where
size = len * (fromIntegral n)

fillFrom :: Ptr Word8 -> Int -> IO ()
fillFrom destptr copied
| 2 * copied < size = do
memcpy (destptr `plusPtr` copied) destptr copied
fillFrom destptr (copied * 2)
| otherwise = memcpy (destptr `plusPtr` copied) destptr (size - copied)

-- | Add two non-negative numbers. Errors out on overflow.
checkedAdd :: String -> Int -> Int -> Int
checkedAdd fun x y
Expand Down
17 changes: 15 additions & 2 deletions Data/ByteString/Lazy/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ import Foreign.Ptr (plusPtr)
import Foreign.Storable (Storable(sizeOf))

#if MIN_VERSION_base(4,13,0)
import Data.Semigroup (Semigroup (sconcat))
import Data.Semigroup (Semigroup (sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#elif MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup ((<>), sconcat))
import Data.Semigroup (Semigroup ((<>), sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#endif
#if !(MIN_VERSION_base(4,8,0))
Expand Down Expand Up @@ -98,6 +98,7 @@ instance Ord ByteString where
instance Semigroup ByteString where
(<>) = append
sconcat (b:|bs) = concat (b:bs)
stimes = times
#endif

instance Monoid ByteString where
Expand Down Expand Up @@ -275,6 +276,18 @@ concat css0 = to css0
to [] = Empty
to (cs:css) = go cs css

-- | Repeats the given ByteString n times.
times :: Integral a => a -> ByteString -> ByteString
times 0 _ = Empty
times n lbs0
| n < 0 = error "stimes: non-negative multiplier expected"
| otherwise = case lbs0 of
Empty -> Empty
Chunk bs lbs -> Chunk bs (go lbs)
where
go Empty = times (n-1) lbs0
go (Chunk c cs) = Chunk c (go cs)

------------------------------------------------------------------------
-- Conversions

Expand Down
16 changes: 15 additions & 1 deletion tests/Properties.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP, BangPatterns #-}
--
-- Must have rules off, otherwise the rewrite rules will replace the rhs
-- with the lhs, and we only end up testing lhs == lhs
Expand Down Expand Up @@ -26,6 +26,9 @@ import Data.Word
import Data.Maybe
import Data.Int (Int64)
import Data.Monoid
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup
#endif

import Text.Printf
import Data.String
Expand Down Expand Up @@ -1376,6 +1379,13 @@ prop_packZipWithLC f xs ys = LC.pack (LC.zipWith f xs ys) == LC.packZipWith f xs

prop_unzipBB x = let (xs,ys) = unzip x in (P.pack xs, P.pack ys) == P.unzip x

#if MIN_VERSION_base(4,9,0)
prop_stimesBB :: NonNegative Int -> P.ByteString -> Bool
prop_stimesBB (NonNegative i) bs = stimes i bs == mtimesDefault i bs

prop_stimesLL :: NonNegative Int -> L.ByteString -> Bool
prop_stimesLL (NonNegative i) bs = stimes i bs == mtimesDefault i bs
#endif

-- prop_zipwith_spec f p q =
-- P.pack (P.zipWith f p q) == P.zipWith' f p q
Expand Down Expand Up @@ -2368,6 +2378,10 @@ bb_tests =
, testProperty "unzip" prop_unzipBB
, testProperty "concatMap" prop_concatMapBB
-- , testProperty "join/joinByte" prop_join_spec
#if MIN_VERSION_base(4,9,0)
, testProperty "stimes strict" prop_stimesBB
, testProperty "stimes lazy" prop_stimesLL
#endif
]


Expand Down

0 comments on commit 1f97c4c

Please sign in to comment.