From 900f636e130d725f564b6649371eb420bb1b0dfc Mon Sep 17 00:00:00 2001 From: archaephyrryx Date: Wed, 26 Aug 2020 19:11:34 -0400 Subject: [PATCH] FindIndices optimized using findIndex and inlining Reimplements findIndices to iteratively call findIndex, which yields an approximate 2x improvement over original implementation for a simple equality predicate Adds inline pragma for findIndices, which yields an approximate 10x improvement for a simple equality predicate Adds rewrite rules that optimize calls of `findIndex (==x)` with `elemIndex x` and `findIndices (==x)`->`elemIndices x` (both left- and right-sections) for Data.ByteString and Data.ByteString.Char8 Adds phase number [1] for inline rules on `findIndex` and `findIndices` to allow said rules to fire properly --- Data/ByteString.hs | 28 ++++++++++++++++++++++++---- Data/ByteString/Char8.hs | 28 +++++++++++++++++++++++++++- Data/ByteString/Lazy.hs | 1 + Data/ByteString/Lazy/Char8.hs | 1 + 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/Data/ByteString.hs b/Data/ByteString.hs index 96bfff65c..09d0b6f97 100644 --- a/Data/ByteString.hs +++ b/Data/ByteString.hs @@ -1320,7 +1320,7 @@ findIndex k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \f -> g if k w then return (Just n) else go (ptr `plusPtr` 1) (n+1) -{-# INLINE findIndex #-} +{-# INLINE [1] findIndex #-} -- | /O(n)/ The 'findIndexEnd' function takes a predicate and a 'ByteString' and -- returns the index of the last element in the ByteString @@ -1342,9 +1342,29 @@ findIndexEnd k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \ f findIndices :: (Word8 -> Bool) -> ByteString -> [Int] findIndices p ps = loop 0 ps where - loop !n !qs | null qs = [] - | p (unsafeHead qs) = n : loop (n+1) (unsafeTail qs) - | otherwise = loop (n+1) (unsafeTail qs) + loop !n !qs = case findIndex p qs of + Just !i -> + let !j = n+i + in j : loop (j+1) (unsafeDrop (i+1) qs) + Nothing -> [] +{-# INLINE [1] findIndices #-} + + +#if MIN_VERSION_base(4,9,0) +{-# RULES +"ByteString specialise findIndex (x ==)" forall x. findIndex (x`eqWord8`) = elemIndex x +"ByteString specialise findIndex (== x)" forall x. findIndex (`eqWord8`x) = elemIndex x +"ByteString specialise findIndices (x ==)" forall x. findIndices (x`eqWord8`) = elemIndices x +"ByteString specialise findIndices (== x)" forall x. findIndices (`eqWord8`x) = elemIndices x + #-} +#else +{-# RULES +"ByteString specialise findIndex (x ==)" forall x. findIndex (x==) = elemIndex x +"ByteString specialise findIndex (== x)" forall x. findIndex (==x) = elemIndex x +"ByteString specialise findIndices (x ==)" forall x. findIndices (x==) = elemIndices x +"ByteString specialise findIndices (== x)" forall x. findIndices (==x) = elemIndices x + #-} +#endif -- --------------------------------------------------------------------- -- Searching ByteStrings diff --git a/Data/ByteString/Char8.hs b/Data/ByteString/Char8.hs index 25cc13ecf..d75ac52f5 100644 --- a/Data/ByteString/Char8.hs +++ b/Data/ByteString/Char8.hs @@ -693,12 +693,38 @@ elemIndices = B.elemIndices . c2w -- returns the index of the first element in the ByteString satisfying the predicate. findIndex :: (Char -> Bool) -> ByteString -> Maybe Int findIndex f = B.findIndex (f . w2c) -{-# INLINE findIndex #-} +{-# INLINE [1] findIndex #-} -- | The 'findIndices' function extends 'findIndex', by returning the -- indices of all elements satisfying the predicate, in ascending order. findIndices :: (Char -> Bool) -> ByteString -> [Int] findIndices f = B.findIndices (f . w2c) +{-# INLINE [1] findIndices #-} + +#if MIN_VERSION_base(4,9,0) +{-# RULES +"ByteString specialise findIndex (x==)" forall x. + findIndex (x `eqChar`) = elemIndex x +"ByteString specialise findIndex (==x)" forall x. + findIndex (`eqChar` x) = elemIndex x +"ByteString specialise findIndices (x==)" forall x. + findIndices (x `eqChar`) = elemIndices x +"ByteString specialise findIndices (==x)" forall x. + findIndices (`eqChar` x) = elemIndices x + #-} +#else +{-# RULES +"ByteString specialise findIndex (x==)" forall x. + findIndex (x==) = elemIndex x +"ByteString specialise findIndex (==x)" forall x. + findIndex (==x) = elemIndex x +"ByteString specialise findIndices (x==)" forall x. + findIndices (x==) = elemIndices x +"ByteString specialise findIndices (==x)" forall x. + findIndices (==x) = elemIndices x + #-} +#endif + -- | count returns the number of times its argument appears in the ByteString -- diff --git a/Data/ByteString/Lazy.hs b/Data/ByteString/Lazy.hs index 7e70f2bc8..6c4a1c173 100644 --- a/Data/ByteString/Lazy.hs +++ b/Data/ByteString/Lazy.hs @@ -1031,6 +1031,7 @@ findIndices k cs0 = findIndices' 0 cs0 where findIndices' _ Empty = [] findIndices' n (Chunk c cs) = L.map ((+n).fromIntegral) (S.findIndices k c) ++ findIndices' (n + fromIntegral (S.length c)) cs +{-# INLINE findIndices #-} -- --------------------------------------------------------------------- -- Searching ByteStrings diff --git a/Data/ByteString/Lazy/Char8.hs b/Data/ByteString/Lazy/Char8.hs index 317bb2bfb..83ba1b88c 100644 --- a/Data/ByteString/Lazy/Char8.hs +++ b/Data/ByteString/Lazy/Char8.hs @@ -572,6 +572,7 @@ findIndex f = L.findIndex (f . w2c) -- indices of all elements satisfying the predicate, in ascending order. findIndices :: (Char -> Bool) -> ByteString -> [Int64] findIndices f = L.findIndices (f . w2c) +{-# INLINE findIndices #-} -- | count returns the number of times its argument appears in the ByteString --