Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stimes more efficiently #301

Merged
merged 24 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions Data/ByteString/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ import Foreign.C.Types (CInt, CSize)
import Foreign.C.String (CString)

#if MIN_VERSION_base(4,13,0)
import Data.Semigroup (Semigroup (sconcat))
import Data.Semigroup (Semigroup (sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#elif MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup ((<>), sconcat))
import Data.Semigroup (Semigroup ((<>), sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#endif

Expand Down Expand Up @@ -241,6 +241,7 @@ instance Ord ByteString where
instance Semigroup ByteString where
(<>) = append
sconcat (b:|bs) = concat (b:bs)
stimes = times
#endif

instance Monoid ByteString where
Expand Down Expand Up @@ -648,6 +649,31 @@ concat = \bss0 -> goLen0 bss0 bss0
concat [x] = x
#-}

-- | /O(log n)/ Repeats the given ByteString n times.
times :: Integral a => a -> ByteString -> ByteString
times n (BS fp len)
| n < 0 = error "stimes: non-negative multiplier expected"
| n == 0 = mempty
| n == 1 = BS fp len
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does returning BS fp len instead of a variable bound to the BS pattern result in additional allocations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked with ddump-simpl and couldn't see any differences outside of different identifiers.
ddump-core-stats also returns the same amount of Terms, Types and Coercions, so I think it's optimized away. Wouldn't say no to assigning it to a name though if you believe it would be more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assembly generated with ddump-asm also looks essentially the same, up to renaming.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating! 👍

| len == 0 = mempty
| len == 1 = unsafeCreate size $ \destptr ->
withForeignPtr fp $ \p -> do
byte <- peek p
memset destptr byte (fromIntegral size) >> return ()
| otherwise = unsafeCreate size $ \destptr ->
withForeignPtr fp $ \p -> do
memcpy destptr p len
fillFrom destptr len
sjakobi marked this conversation as resolved.
Show resolved Hide resolved
where
size = len * (fromIntegral n)

fillFrom :: Ptr Word8 -> Int -> IO ()
fillFrom destptr copied
| 2 * copied < size = do
memcpy (destptr `plusPtr` copied) destptr copied
fillFrom destptr (copied * 2)
| otherwise = memcpy (destptr `plusPtr` copied) destptr (size - copied)

-- | Add two non-negative numbers. Errors out on overflow.
checkedAdd :: String -> Int -> Int -> Int
checkedAdd fun x y
Expand Down
17 changes: 15 additions & 2 deletions Data/ByteString/Lazy/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ import Foreign.Ptr (plusPtr)
import Foreign.Storable (Storable(sizeOf))

#if MIN_VERSION_base(4,13,0)
import Data.Semigroup (Semigroup (sconcat))
import Data.Semigroup (Semigroup (sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#elif MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup ((<>), sconcat))
import Data.Semigroup (Semigroup ((<>), sconcat, stimes))
import Data.List.NonEmpty (NonEmpty ((:|)))
#endif
#if !(MIN_VERSION_base(4,8,0))
Expand Down Expand Up @@ -98,6 +98,7 @@ instance Ord ByteString where
instance Semigroup ByteString where
(<>) = append
sconcat (b:|bs) = concat (b:bs)
stimes = times
#endif

instance Monoid ByteString where
Expand Down Expand Up @@ -275,6 +276,18 @@ concat css0 = to css0
to [] = Empty
to (cs:css) = go cs css

-- | Repeats the given ByteString n times.
times :: Integral a => a -> ByteString -> ByteString
times 0 _ = Empty
times n lbs0
| n < 0 = error "stimes: non-negative multiplier expected"
| otherwise = case lbs0 of
Empty -> Empty
Chunk bs lbs -> Chunk bs (go lbs)
where
go Empty = times (n-1) lbs0
go (Chunk c cs) = Chunk c (go cs)

------------------------------------------------------------------------
-- Conversions

Expand Down
16 changes: 15 additions & 1 deletion tests/Properties.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP, BangPatterns #-}
--
-- Must have rules off, otherwise the rewrite rules will replace the rhs
-- with the lhs, and we only end up testing lhs == lhs
Expand Down Expand Up @@ -26,6 +26,9 @@ import Data.Word
import Data.Maybe
import Data.Int (Int64)
import Data.Monoid
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup
#endif

import Text.Printf
import Data.String
Expand Down Expand Up @@ -1376,6 +1379,13 @@ prop_packZipWithLC f xs ys = LC.pack (LC.zipWith f xs ys) == LC.packZipWith f xs

prop_unzipBB x = let (xs,ys) = unzip x in (P.pack xs, P.pack ys) == P.unzip x

#if MIN_VERSION_base(4,9,0)
prop_stimesBB :: NonNegative Int -> P.ByteString -> Bool
prop_stimesBB (NonNegative i) bs = stimes i bs == mtimesDefault i bs

prop_stimesLL :: NonNegative Int -> L.ByteString -> Bool
prop_stimesLL (NonNegative i) bs = stimes i bs == mtimesDefault i bs
#endif

-- prop_zipwith_spec f p q =
-- P.pack (P.zipWith f p q) == P.zipWith' f p q
Expand Down Expand Up @@ -2368,6 +2378,10 @@ bb_tests =
, testProperty "unzip" prop_unzipBB
, testProperty "concatMap" prop_concatMapBB
-- , testProperty "join/joinByte" prop_join_spec
#if MIN_VERSION_base(4,9,0)
, testProperty "stimes strict" prop_stimesBB
, testProperty "stimes lazy" prop_stimesLL
#endif
]


Expand Down