Skip to content

Commit

Permalink
Avoid C types in pure haskell implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hsyl20 committed Jan 8, 2024
1 parent 3996c2c commit 6f67a1a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 40 deletions.
57 changes: 27 additions & 30 deletions Data/ByteString/Internal/Pure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import GHC.Int (Int8(..))
import Data.Bits (Bits(..), shiftR, (.&.))
import Data.Word
import Foreign.Ptr (plusPtr)
import Foreign.C.Types (CSize(..), CPtrdiff, CInt)
import Foreign.Storable (Storable(..))
import Control.Monad (when)

Expand All @@ -46,7 +45,7 @@ import Control.Monad (when)

-- | duplicate a string, interspersing the character through the elements of the
-- duplicated string
intersperse :: Ptr Word8 -> Ptr Word8 -> CSize -> Word8 -> IO ()
intersperse :: Ptr Word8 -> Ptr Word8 -> Int -> Word8 -> IO ()
intersperse !dst !src !len !w = case len of
0 -> pure ()
1 -> do
Expand All @@ -59,41 +58,40 @@ intersperse !dst !src !len !w = case len of
pokeByteOff dst 1 w
intersperse (plusPtr dst 2) (plusPtr src 1) (len-1) w

countOccBA :: ByteArray# -> Int -> Word8 -> IO CSize
countOccBA :: ByteArray# -> Int -> Word8 -> IO Int
countOccBA ba len w = return (go 0 0)
where
go !n !i@(I# i#)
| i == len = n
| W8# (indexWord8Array# ba i#) == w = go (n+1) (i+1)
| otherwise = go n (i+1)

countOcc :: Ptr Word8 -> CSize -> Word8 -> IO CSize
countOcc :: Ptr Word8 -> Int -> Word8 -> IO Int
countOcc p len w
| len == 0 = pure 0
| otherwise = count_occ (plusPtr p (fromIntegral len - 1)) w 0 p
| otherwise = count_occ (plusPtr p (len - 1)) w 0 p

count_occ :: Ptr Word8 -> Word8 -> CSize -> Ptr Word8 -> IO CSize
count_occ :: Ptr Word8 -> Word8 -> Int -> Ptr Word8 -> IO Int
count_occ !plast !w !count !p = do
c <- peekByteOff p 0
let !count' = if c == w then count+1 else count
if p == plast
then pure count'
else count_occ plast w count' (plusPtr p 1)

elemIndex :: ByteArray# -> Word8 -> CSize -> IO CPtrdiff
elemIndex :: ByteArray# -> Word8 -> Int -> IO Int
elemIndex !ba !w !len = return (go 0)
where
!len' = fromIntegral len
go !i@(I# i#)
| i == len' = -1
| W8# (indexWord8Array# ba i#) == w = fromIntegral i
| i == len = -1
| W8# (indexWord8Array# ba i#) == w = i
| otherwise = go (i+1)

-- | Reverse n-bytes from the second pointer into the first
reverseBytes :: Ptr Word8 -> Ptr Word8 -> CSize -> IO ()
reverseBytes :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
reverseBytes !dst !src !n
| n == 0 = pure ()
| otherwise = reverse_bytes dst (plusPtr dst (fromIntegral n - 1)) src
| otherwise = reverse_bytes dst (plusPtr dst (n - 1)) src

-- | Note that reverse_bytes reverses at least one byte.
-- Then it loops if necessary until the destination buffer is full
Expand All @@ -110,8 +108,8 @@ reverse_bytes orig_dst dst src = do
else reverse_bytes orig_dst (plusPtr dst (-1)) (plusPtr src 1)


findMaximum :: Ptr Word8 -> CSize -> IO Word8
findMaximum !p !n = find_maximum minBound p (plusPtr p (fromIntegral n - 1))
findMaximum :: Ptr Word8 -> Int -> IO Word8
findMaximum !p !n = find_maximum minBound p (plusPtr p (n - 1))

find_maximum :: Word8 -> Ptr Word8 -> Ptr Word8 -> IO Word8
find_maximum !m !p !plast = do
Expand All @@ -121,8 +119,8 @@ find_maximum !m !p !plast = do
then pure c'
else find_maximum c' (plusPtr p 1) plast

findMinimum :: Ptr Word8 -> CSize -> IO Word8
findMinimum !p !n = find_minimum maxBound p (plusPtr p (fromIntegral n - 1))
findMinimum :: Ptr Word8 -> Int -> IO Word8
findMinimum !p !n = find_minimum maxBound p (plusPtr p (n - 1))

find_minimum :: Word8 -> Ptr Word8 -> Ptr Word8 -> IO Word8
find_minimum !m !p !plast = do
Expand All @@ -133,10 +131,10 @@ find_minimum !m !p !plast = do
else find_minimum c' (plusPtr p 1) plast


quickSort :: Ptr Word8 -> CSize -> IO ()
quickSort :: Ptr Word8 -> Int -> IO ()
quickSort !p !n
| n <= 0 = pure ()
| otherwise = quick_sort p 0 (fromIntegral n - 1)
| otherwise = quick_sort p 0 (n - 1)

quick_sort :: Ptr Word8 -> Int -> Int -> IO ()
quick_sort !p !low !high
Expand Down Expand Up @@ -170,23 +168,22 @@ partition !p !low !high = do
go i (j+1)
go low low

isValidUtf8BA :: ByteArray# -> CSize -> IO CInt
isValidUtf8BA :: ByteArray# -> Int -> IO Bool
isValidUtf8BA !ba !len' = isValidUtf8' (indexWord8Array# ba) len'

isValidUtf8 :: Ptr Word8 -> CSize -> IO CInt
isValidUtf8 :: Ptr Word8 -> Int -> IO Bool
isValidUtf8 !(Ptr a) !len' = isValidUtf8' (indexWord8OffAddr# a) len'

isValidUtf8' :: (Int# -> Word8#) -> CSize -> IO CInt
isValidUtf8' idx !len' = go 0
isValidUtf8' :: (Int# -> Word8#) -> Int -> IO Bool
isValidUtf8' idx !len = go 0
where
!len = fromIntegral len'
indexWord8 (I# i) = W8# (idx i)
indexInt8 :: Int -> Int8
indexInt8 i = fromIntegral (indexWord8 i)
is_cont :: Int8 -> Bool
is_cont i = i <= (fromIntegral (0xBF :: Word8))
go !i
| i >= len = return 1 -- done
| i >= len = return True -- done
| otherwise = do
let !b0 = indexWord8 i
if | b0 <= 0x7F -> go (i+1) -- ASCII
Expand All @@ -195,17 +192,17 @@ isValidUtf8' idx !len' = go 0
| otherwise -> go4 (i+1) b0

go2 !i
| i >= len = return 0
| i >= len = return False
-- We use a signed comparison to avoid an extra comparison with
-- 0x80, since _signed_ 0x80 is -128.
| i1 <- indexInt8 i
, is_cont i1
= go (i+1)
| otherwise
= return 0
= return False

go3 !i !b0
| i+1 >= len = return 0
| i+1 >= len = return False
-- We use a signed comparison to avoid an extra comparison with
-- 0x80, since _signed_ 0x80 is -128.
| i1 <- indexInt8 i
Expand All @@ -219,10 +216,10 @@ isValidUtf8' idx !len' = go 0
|| (b0 >= 0xEE && b0 <= 0xEF) -- EE..EF, 80..BF, 80..BF
= go (i+2)
| otherwise
= return 0
= return False

go4 !i !b0
| i+2 >= len = return 0
| i+2 >= len = return False
-- We use a signed comparison to avoid an extra comparison with
-- 0x80, since _signed_ 0x80 is -128.
| i1 <- indexInt8 i
Expand All @@ -238,7 +235,7 @@ isValidUtf8' idx !len' = go 0
= go (i+3)

| otherwise
= return 0
= return False


----------------------------------------------------------------
Expand Down
24 changes: 14 additions & 10 deletions Data/ByteString/Internal/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1220,48 +1220,52 @@ foreign import ccall safe "bytestring_is_valid_utf8" cIsValidUtf8Safe

-- | Reverse n-bytes from the second pointer into the first
c_reverse :: Ptr Word8 -> Ptr Word8 -> CSize -> IO ()
c_reverse = Pure.reverseBytes
c_reverse p1 p2 sz = Pure.reverseBytes p1 p2 (fromIntegral sz)

-- | find maximum char in a packed string
c_maximum :: Ptr Word8 -> CSize -> IO Word8
c_maximum = Pure.findMaximum
c_maximum ptr sz = Pure.findMaximum ptr (fromIntegral sz)

-- | find minimum char in a packed string
c_minimum :: Ptr Word8 -> CSize -> IO Word8
c_minimum = Pure.findMinimum
c_minimum ptr sz = Pure.findMinimum ptr (fromIntegral sz)

-- | count the number of occurrences of a char in a string
c_count :: Ptr Word8 -> CSize -> Word8 -> IO CSize
c_count = Pure.countOcc
c_count ptr sz c = fromIntegral <$> Pure.countOcc ptr (fromIntegral sz) c

-- | count the number of occurrences of a char in a string
c_count_ba :: ByteArray# -> Int -> Word8 -> IO CSize
c_count_ba = Pure.countOccBA
c_count_ba ba o c = fromIntegral <$> Pure.countOccBA ba o c

-- | duplicate a string, interspersing the character through the elements of the
-- duplicated string
c_intersperse :: Ptr Word8 -> Ptr Word8 -> CSize -> Word8 -> IO ()
c_intersperse = Pure.intersperse
c_intersperse p1 p2 sz e = Pure.intersperse p1 p2 (fromIntegral sz) e

-- | Quick sort bytes
c_sort :: Ptr Word8 -> CSize -> IO ()
c_sort = Pure.quickSort
c_sort ptr sz = Pure.quickSort ptr (fromIntegral sz)

c_elem_index :: ByteArray# -> Word8 -> CSize -> IO CPtrdiff
c_elem_index = Pure.elemIndex
c_elem_index ba e sz = fromIntegral <$> Pure.elemIndex ba e (fromIntegral sz)

cIsValidUtf8BA :: ByteArray# -> CSize -> IO CInt
cIsValidUtf8BA = Pure.isValidUtf8BA
cIsValidUtf8BA ba sz = bool_to_cint <$> Pure.isValidUtf8BA ba (fromIntegral sz)

cIsValidUtf8BASafe :: ByteArray# -> CSize -> IO CInt
cIsValidUtf8BASafe = cIsValidUtf8BA

cIsValidUtf8 :: Ptr Word8 -> CSize -> IO CInt
cIsValidUtf8 = Pure.isValidUtf8
cIsValidUtf8 ptr sz = bool_to_cint <$> Pure.isValidUtf8 ptr (fromIntegral sz)

cIsValidUtf8Safe :: Ptr Word8 -> CSize -> IO CInt
cIsValidUtf8Safe = cIsValidUtf8

bool_to_cint :: Bool -> CInt
bool_to_cint True = 1
bool_to_cint False = 0

----------------------------------------------------------------
-- Haskell version of functions in itoa.c
----------------------------------------------------------------
Expand Down

0 comments on commit 6f67a1a

Please sign in to comment.