From cdf6ebb3766ce668ab6c7bab436ddf1e7e032321 Mon Sep 17 00:00:00 2001 From: Callan McGill Date: Wed, 19 May 2021 17:50:39 -0400 Subject: [PATCH] Use one less argument in folds, mapAccum, scan and filter to allow more inlining (#345) --- .hlint.yaml | 5 ++ Data/ByteString.hs | 159 ++++++++++++++++++++++++++------------------- 2 files changed, 98 insertions(+), 66 deletions(-) diff --git a/.hlint.yaml b/.hlint.yaml index d9f6bad97..6b9a92306 100644 --- a/.hlint.yaml +++ b/.hlint.yaml @@ -14,7 +14,12 @@ within: - Data.ByteString.Builder.Internal - Data.ByteString.Builder.Prim +- ignore: + name: Reduce duplication + within: + - Data.ByteString - ignore: name: Redundant lambda within: - Data.ByteString.Builder.Internal + - Data.ByteString diff --git a/Data/ByteString.hs b/Data/ByteString.hs index 0457eed28..099922174 100644 --- a/Data/ByteString.hs +++ b/Data/ByteString.hs @@ -462,8 +462,8 @@ transpose = P.map pack . List.transpose . P.map unpack -- ByteString using the binary operator, from left to right. -- foldl :: (a -> Word8 -> a) -> a -> ByteString -> a -foldl f z (BS fp len) = go (end `plusPtr` len) - where +foldl f z = \(BS fp len) -> + let end = unsafeForeignPtrToPtr fp `plusPtr` (-1) -- not tail recursive; traverses array right to left go !p | p == end = z @@ -472,14 +472,26 @@ foldl f z (BS fp len) = go (end `plusPtr` len) touchForeignPtr fp return x' in f (go (p `plusPtr` (-1))) x + + in + go (end `plusPtr` len) {-# INLINE foldl #-} +{- +Note [fold inlining]: + +GHC will only inline a function marked INLINE +if it is fully saturated (meaning the number of +arguments provided at the call site is at least +equal to the number of lhs arguments). + +-} -- | 'foldl'' is like 'foldl', but strict in the accumulator. -- foldl' :: (a -> Word8 -> a) -> a -> ByteString -> a -foldl' f v (BS fp len) = - accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g - where +foldl' f v = \(BS fp len) -> + -- see fold inlining + let g ptr = go v ptr where end = ptr `plusPtr` len @@ -487,14 +499,17 @@ foldl' f v (BS fp len) = go !z !p | p == end = return z | otherwise = do x <- peek p go (f z x) (p `plusPtr` 1) + in + accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g {-# INLINE foldl' #-} -- | 'foldr', applied to a binary operator, a starting value -- (typically the right-identity of the operator), and a ByteString, -- reduces the ByteString using the binary operator, from right to left. foldr :: (Word8 -> a -> a) -> a -> ByteString -> a -foldr k z (BS fp len) = go ptr - where +foldr k z = \(BS fp len) -> + -- see fold inlining + let ptr = unsafeForeignPtrToPtr fp end = ptr `plusPtr` len -- not tail recursive; traverses array left to right @@ -504,13 +519,15 @@ foldr k z (BS fp len) = go ptr touchForeignPtr fp return x' in k x (go (p `plusPtr` 1)) + in + go ptr {-# INLINE foldr #-} -- | 'foldr'' is like 'foldr', but strict in the accumulator. foldr' :: (Word8 -> a -> a) -> a -> ByteString -> a -foldr' k v (BS fp len) = - accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g - where +foldr' k v = \(BS fp len) -> + -- see fold inlining + let g ptr = go v (end `plusPtr` len) where end = ptr `plusPtr` (-1) @@ -518,6 +535,9 @@ foldr' k v (BS fp len) = go !z !p | p == end = return z | otherwise = do x <- peek p go (k x z) (p `plusPtr` (-1)) + in + accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g + {-# INLINE foldr' #-} -- | 'foldl1' is a variant of 'foldl' that has no starting value @@ -650,20 +670,21 @@ minimum xs@(BS x l) -- passing an accumulating parameter from left to right, and returning a -- final value of this accumulator together with the new list. mapAccumL :: (acc -> Word8 -> (acc, Word8)) -> acc -> ByteString -> (acc, ByteString) -mapAccumL f acc (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do +mapAccumL f acc = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do + -- see fold inlining gp <- mallocByteString len + let + go src dst = mapAccumL_ acc 0 + where + mapAccumL_ !s !n + | n >= len = return s + | otherwise = do + x <- peekByteOff src n + let (s', y) = f s x + pokeByteOff dst n y + mapAccumL_ s' (n+1) acc' <- unsafeWithForeignPtr gp (go a) return (acc', BS gp len) - where - go src dst = mapAccumL_ acc 0 - where - mapAccumL_ !s !n - | n >= len = return s - | otherwise = do - x <- peekByteOff src n - let (s', y) = f s x - pokeByteOff dst n y - mapAccumL_ s' (n+1) {-# INLINE mapAccumL #-} -- | The 'mapAccumR' function behaves like a combination of 'map' and @@ -671,19 +692,20 @@ mapAccumL f acc (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ -- passing an accumulating parameter from right to left, and returning a -- final value of this accumulator together with the new ByteString. mapAccumR :: (acc -> Word8 -> (acc, Word8)) -> acc -> ByteString -> (acc, ByteString) -mapAccumR f acc (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do +mapAccumR f acc = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do + -- see fold inlining gp <- mallocByteString len + let + go src dst = mapAccumR_ acc (len-1) + where + mapAccumR_ !s (-1) = return s + mapAccumR_ !s !n = do + x <- peekByteOff src n + let (s', y) = f s x + pokeByteOff dst n y + mapAccumR_ s' (n-1) acc' <- unsafeWithForeignPtr gp (go a) return (acc', BS gp len) - where - go src dst = mapAccumR_ acc (len-1) - where - mapAccumR_ !s (-1) = return s - mapAccumR_ !s !n = do - x <- peekByteOff src n - let (s', y) = f s x - pokeByteOff dst n y - mapAccumR_ s' (n-1) {-# INLINE mapAccumR #-} -- --------------------------------------------------------------------- @@ -708,20 +730,21 @@ scanl -- ^ input of length n -> ByteString -- ^ output of length n+1 -scanl f v (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> +scanl f v = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> + -- see fold inlining create (len+1) $ \q -> do poke q v + let + go src dst = scanl_ v 0 + where + scanl_ !z !n + | n >= len = return () + | otherwise = do + x <- peekByteOff src n + let z' = f z x + pokeByteOff dst n z' + scanl_ z' (n+1) go a (q `plusPtr` 1) - where - go src dst = scanl_ v 0 - where - scanl_ !z !n - | n >= len = return () - | otherwise = do - x <- peekByteOff src n - let z' = f z x - pokeByteOff dst n z' - scanl_ z' (n+1) {-# INLINE scanl #-} -- n.b. haskell's List scan returns a list one bigger than the @@ -757,20 +780,21 @@ scanr -- ^ input of length n -> ByteString -- ^ output of length n+1 -scanr f v (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> +scanr f v = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> + -- see fold inlining create (len+1) $ \q -> do poke (q `plusPtr` len) v + let + go p q = scanr_ v (len-1) + where + scanr_ !z !n + | n < 0 = return () + | otherwise = do + x <- peekByteOff p n + let z' = f x z + pokeByteOff q n z' + scanr_ z' (n-1) go a q - where - go p q = scanr_ v (len-1) - where - scanr_ !z !n - | n < 0 = return () - | otherwise = do - x <- peekByteOff p n - let z' = f x z - pokeByteOff q n z' - scanr_ z' (n-1) {-# INLINE scanr #-} -- | 'scanr1' is a variant of 'scanr' that has no starting value argument. @@ -1327,21 +1351,24 @@ notElem c ps = not (c `elem` ps) -- returns a ByteString containing those characters that satisfy the -- predicate. filter :: (Word8 -> Bool) -> ByteString -> ByteString -filter k ps@(BS x l) - | null ps = ps - | otherwise = unsafePerformIO $ createAndTrim l $ \p -> withForeignPtr x $ \f -> do +filter k = \ps@(BS x l) -> + -- see fold inlining. + if null ps + then ps + else + unsafePerformIO $ createAndTrim l $ \p -> withForeignPtr x $ \f -> do + let + go' pf pt = go pf pt + where + end = pf `plusPtr` l + go !f !t | f == end = return t + | otherwise = do + w <- peek f + if k w + then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1) + else go (f `plusPtr` 1) t t <- go' f p return $! t `minusPtr` p -- actual length - where - go' pf pt = go pf pt - where - end = pf `plusPtr` l - go !f !t | f == end = return t - | otherwise = do - w <- peek f - if k w - then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1) - else go (f `plusPtr` 1) t {-# INLINE filter #-} {-