diff --git a/Data/Primitive/Array.hs b/Data/Primitive/Array.hs index fc0301c9..1943a862 100644 --- a/Data/Primitive/Array.hs +++ b/Data/Primitive/Array.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP, MagicHash, UnboxedTuples, DeriveDataTypeable, BangPatterns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ScopedTypeVariables #-} -- | -- Module : Data.Primitive.Array @@ -24,7 +25,8 @@ module Data.Primitive.Array ( sizeofArray, sizeofMutableArray, fromListN, fromList, mapArray', - traverseArrayP + traverseArrayP, + filterArray ) where import Control.Monad.Primitive @@ -68,12 +70,14 @@ import GHC.Exts (runRW#) import GHC.Base (runRW#) #endif -import Text.ParserCombinators.ReadP +import Text.ParserCombinators.ReadP (string, skipSpaces, readS_to_P, readP_to_S) #if MIN_VERSION_base(4,9,0) || MIN_VERSION_transformers(0,4,0) import Data.Functor.Classes (Eq1(..),Ord1(..),Show1(..),Read1(..)) #endif +import Data.Primitive.Internal.Bit + -- | Boxed arrays data Array a = Array { array# :: Array# a } @@ -591,6 +595,49 @@ arrayFromListN n l = arrayFromList :: [a] -> Array a arrayFromList l = arrayFromListN (length l) l +filterArray :: forall a. (a -> Bool) -> Array a -> Array a +filterArray f arr = runArray $ + newBitArray s >>= check 0 0 + where + s = sizeofArray arr + check :: Int -> Int -> MutableBitArray s -> ST s (MutableArray s a) + check i count ba + | i /= s + = do + v <- indexArrayM arr i + if f v + then setBitArray ba i >> check (i + 1) (count + 1) ba + else check (i + 1) count ba + | otherwise + = do + mary <- newArray count (die "filterArray" "invalid") + fill 0 0 ba mary + + -- This performs a few bit operations and a conditional + -- jump for every element of the original array. This is + -- not so great if most element are filtered out. We should + -- consider going word by word through the bit array and + -- using countTrailingZeroes. We could even choose + -- a different strategy for each word depending on its + -- popCount. + fill :: forall s. Int -> Int -> MutableBitArray s -> MutableArray s a -> ST s (MutableArray s a) + fill !i0 !i'0 !ba !mary = go i0 i'0 + where + go :: Int -> Int -> ST s (MutableArray s a) + go i i' + | i == s + = return mary + | otherwise + = do + b <- readBitArray ba i + if b + then do + v <- indexArrayM arr i + writeArray mary i' v + go (i + 1) (i' + 1) + else go (i + 1) i' + + #if MIN_VERSION_base(4,7,0) instance Exts.IsList (Array a) where type Item (Array a) = a diff --git a/Data/Primitive/Internal/Bit.hs b/Data/Primitive/Internal/Bit.hs new file mode 100644 index 00000000..0596ed0e --- /dev/null +++ b/Data/Primitive/Internal/Bit.hs @@ -0,0 +1,68 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE BangPatterns #-} +module Data.Primitive.Internal.Bit + ( + MutableBitArray + , newBitArray + , readBitArray + , setBitArray + ) where + +import Data.Primitive.ByteArray +import Control.Monad.Primitive +import Data.Bits +import Control.Monad.ST + +newtype MutableBitArray s = MBA (MutableByteArray s) + +newBitArray :: Int -> ST s (MutableBitArray s) +--newBitArray :: PrimMonad m => Int -> m (MutableBitArray (PrimState m)) +newBitArray !n = do + let s = ((n + wordSize - 1) `unsafeShiftR` 3) + mary <- newByteArray s + fillByteArray mary 0 s 0 + return (MBA mary) + +readBitArray :: MutableBitArray s -> Int -> ST s Bool +--readBitArray :: PrimMonad m => MutableBitArray (PrimState m) -> Int -> m Bool +readBitArray !(MBA mry) !i = do + wd :: Word <- readByteArray mry (whichWord i) + return $! (((wd `unsafeShiftR` whichBit i) .&. 1) == 1) + +setBitArray :: MutableBitArray s -> Int -> ST s () +--setBitArray :: PrimMonad m => MutableBitArray (PrimState m) -> Int -> m () +setBitArray !(MBA mry) !i = do + let ww = whichWord i + wd :: Word <- readByteArray mry ww + let wd' = wd .|. (1 `unsafeShiftL` (whichBit i)) + writeByteArray mry ww wd' + +wordSize :: Int +wordSize = finiteBitSize (undefined :: Word) + +ctlws :: Int +ctlws + | wordSize == 64 = 6 + | wordSize == 32 = 5 + | otherwise = countTrailingZeros wordSize + +whichWord :: Int -> Int +whichWord i = i `unsafeShiftR` ctlws + +whichBit :: Int -> Int +whichBit i = i .&. (wordSize - 1) + +{- +-- For debugging +freezeByteArray + :: PrimMonad m => MutableByteArray (PrimState m) -> m ByteArray +freezeByteArray mary = do + s <- getSizeofMutableByteArray mary + cop <- newByteArray s + copyMutableByteArray cop 0 mary 0 s + unsafeFreezeByteArray cop + +prant :: MutableBitArray RealWorld -> IO () +prant (MBA x) = freezeByteArray x >>= print +-} diff --git a/primitive.cabal b/primitive.cabal index 9974832e..fdcee5c7 100644 --- a/primitive.cabal +++ b/primitive.cabal @@ -49,6 +49,7 @@ Library Data.Primitive.MutVar Other-Modules: + Data.Primitive.Internal.Bit Data.Primitive.Internal.Compat Data.Primitive.Internal.Operations