Skip to content

Commit

Permalink
Use one less argument in folds, mapAccum, scan and filter to allow mo…
Browse files Browse the repository at this point in the history
…re inlining (#345)
  • Loading branch information
Boarders authored May 19, 2021
1 parent 062b5b1 commit cdf6ebb
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 66 deletions.
5 changes: 5 additions & 0 deletions .hlint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
within:
- Data.ByteString.Builder.Internal
- Data.ByteString.Builder.Prim
- ignore:
name: Reduce duplication
within:
- Data.ByteString
- ignore:
name: Redundant lambda
within:
- Data.ByteString.Builder.Internal
- Data.ByteString
159 changes: 93 additions & 66 deletions Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,8 @@ transpose = P.map pack . List.transpose . P.map unpack
-- ByteString using the binary operator, from left to right.
--
foldl :: (a -> Word8 -> a) -> a -> ByteString -> a
foldl f z (BS fp len) = go (end `plusPtr` len)
where
foldl f z = \(BS fp len) ->
let
end = unsafeForeignPtrToPtr fp `plusPtr` (-1)
-- not tail recursive; traverses array right to left
go !p | p == end = z
Expand All @@ -472,29 +472,44 @@ foldl f z (BS fp len) = go (end `plusPtr` len)
touchForeignPtr fp
return x'
in f (go (p `plusPtr` (-1))) x

in
go (end `plusPtr` len)
{-# INLINE foldl #-}

{-
Note [fold inlining]:
GHC will only inline a function marked INLINE
if it is fully saturated (meaning the number of
arguments provided at the call site is at least
equal to the number of lhs arguments).
-}
-- | 'foldl'' is like 'foldl', but strict in the accumulator.
--
foldl' :: (a -> Word8 -> a) -> a -> ByteString -> a
foldl' f v (BS fp len) =
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g
where
foldl' f v = \(BS fp len) ->
-- see fold inlining
let
g ptr = go v ptr
where
end = ptr `plusPtr` len
-- tail recursive; traverses array left to right
go !z !p | p == end = return z
| otherwise = do x <- peek p
go (f z x) (p `plusPtr` 1)
in
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g
{-# INLINE foldl' #-}

-- | 'foldr', applied to a binary operator, a starting value
-- (typically the right-identity of the operator), and a ByteString,
-- reduces the ByteString using the binary operator, from right to left.
foldr :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr k z (BS fp len) = go ptr
where
foldr k z = \(BS fp len) ->
-- see fold inlining
let
ptr = unsafeForeignPtrToPtr fp
end = ptr `plusPtr` len
-- not tail recursive; traverses array left to right
Expand All @@ -504,20 +519,25 @@ foldr k z (BS fp len) = go ptr
touchForeignPtr fp
return x'
in k x (go (p `plusPtr` 1))
in
go ptr
{-# INLINE foldr #-}

-- | 'foldr'' is like 'foldr', but strict in the accumulator.
foldr' :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr' k v (BS fp len) =
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g
where
foldr' k v = \(BS fp len) ->
-- see fold inlining
let
g ptr = go v (end `plusPtr` len)
where
end = ptr `plusPtr` (-1)
-- tail recursive; traverses array right to left
go !z !p | p == end = return z
| otherwise = do x <- peek p
go (k x z) (p `plusPtr` (-1))
in
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g

{-# INLINE foldr' #-}

-- | 'foldl1' is a variant of 'foldl' that has no starting value
Expand Down Expand Up @@ -650,40 +670,42 @@ minimum xs@(BS x l)
-- passing an accumulating parameter from left to right, and returning a
-- final value of this accumulator together with the new list.
mapAccumL :: (acc -> Word8 -> (acc, Word8)) -> acc -> ByteString -> (acc, ByteString)
mapAccumL f acc (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
mapAccumL f acc = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
-- see fold inlining
gp <- mallocByteString len
let
go src dst = mapAccumL_ acc 0
where
mapAccumL_ !s !n
| n >= len = return s
| otherwise = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumL_ s' (n+1)
acc' <- unsafeWithForeignPtr gp (go a)
return (acc', BS gp len)
where
go src dst = mapAccumL_ acc 0
where
mapAccumL_ !s !n
| n >= len = return s
| otherwise = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumL_ s' (n+1)
{-# INLINE mapAccumL #-}

-- | The 'mapAccumR' function behaves like a combination of 'map' and
-- 'foldr'; it applies a function to each element of a ByteString,
-- passing an accumulating parameter from right to left, and returning a
-- final value of this accumulator together with the new ByteString.
mapAccumR :: (acc -> Word8 -> (acc, Word8)) -> acc -> ByteString -> (acc, ByteString)
mapAccumR f acc (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
mapAccumR f acc = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
-- see fold inlining
gp <- mallocByteString len
let
go src dst = mapAccumR_ acc (len-1)
where
mapAccumR_ !s (-1) = return s
mapAccumR_ !s !n = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumR_ s' (n-1)
acc' <- unsafeWithForeignPtr gp (go a)
return (acc', BS gp len)
where
go src dst = mapAccumR_ acc (len-1)
where
mapAccumR_ !s (-1) = return s
mapAccumR_ !s !n = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumR_ s' (n-1)
{-# INLINE mapAccumR #-}

-- ---------------------------------------------------------------------
Expand All @@ -708,20 +730,21 @@ scanl
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanl f v (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
scanl f v = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
-- see fold inlining
create (len+1) $ \q -> do
poke q v
let
go src dst = scanl_ v 0
where
scanl_ !z !n
| n >= len = return ()
| otherwise = do
x <- peekByteOff src n
let z' = f z x
pokeByteOff dst n z'
scanl_ z' (n+1)
go a (q `plusPtr` 1)
where
go src dst = scanl_ v 0
where
scanl_ !z !n
| n >= len = return ()
| otherwise = do
x <- peekByteOff src n
let z' = f z x
pokeByteOff dst n z'
scanl_ z' (n+1)
{-# INLINE scanl #-}

-- n.b. haskell's List scan returns a list one bigger than the
Expand Down Expand Up @@ -757,20 +780,21 @@ scanr
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanr f v (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
scanr f v = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
-- see fold inlining
create (len+1) $ \q -> do
poke (q `plusPtr` len) v
let
go p q = scanr_ v (len-1)
where
scanr_ !z !n
| n < 0 = return ()
| otherwise = do
x <- peekByteOff p n
let z' = f x z
pokeByteOff q n z'
scanr_ z' (n-1)
go a q
where
go p q = scanr_ v (len-1)
where
scanr_ !z !n
| n < 0 = return ()
| otherwise = do
x <- peekByteOff p n
let z' = f x z
pokeByteOff q n z'
scanr_ z' (n-1)
{-# INLINE scanr #-}

-- | 'scanr1' is a variant of 'scanr' that has no starting value argument.
Expand Down Expand Up @@ -1327,21 +1351,24 @@ notElem c ps = not (c `elem` ps)
-- returns a ByteString containing those characters that satisfy the
-- predicate.
filter :: (Word8 -> Bool) -> ByteString -> ByteString
filter k ps@(BS x l)
| null ps = ps
| otherwise = unsafePerformIO $ createAndTrim l $ \p -> withForeignPtr x $ \f -> do
filter k = \ps@(BS x l) ->
-- see fold inlining.
if null ps
then ps
else
unsafePerformIO $ createAndTrim l $ \p -> withForeignPtr x $ \f -> do
let
go' pf pt = go pf pt
where
end = pf `plusPtr` l
go !f !t | f == end = return t
| otherwise = do
w <- peek f
if k w
then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1)
else go (f `plusPtr` 1) t
t <- go' f p
return $! t `minusPtr` p -- actual length
where
go' pf pt = go pf pt
where
end = pf `plusPtr` l
go !f !t | f == end = return t
| otherwise = do
w <- peek f
if k w
then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1)
else go (f `plusPtr` 1) t
{-# INLINE filter #-}

{-
Expand Down

0 comments on commit cdf6ebb

Please sign in to comment.