Skip to content

Commit

Permalink
More stimes-related cleanup and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
clyring committed Dec 19, 2021
1 parent 9906555 commit 554506d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 31 deletions.
62 changes: 36 additions & 26 deletions Data/ByteString/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ import Control.Exception (assert, throw, Exception)

import Data.Bits (Bits, (.&.), toIntegralSized)
import Data.Char (ord)
import Data.Int
import Data.Word
import Numeric.Natural (Natural)

import Data.Typeable (Typeable)
import Data.Data (Data(..), mkNoRepType)
Expand Down Expand Up @@ -735,12 +733,30 @@ stimesPolymorphic :: Integral a => a -> ByteString -> ByteString
{-# INLINABLE stimesPolymorphic #-}
stimesPolymorphic nRaw bs = case checkedToInt nRaw of
Just n -> stimesInt n bs
Nothing -> overflowError "stimes"
_ | nRaw < 0
-- It may seem odd to check for negative input in both the
-- polymorphic wrapper and the specialized Int worker. But
-- checking here before converting to Int would not remove the
-- need to check inside the worker, since the Ord instance can
-- potentially behave arbitrarily. At the moment this isn't a
-- memory-safety issue for StrictByteString since
-- mallocPlainForeignPtrBytes checks for negative sizes, but
-- this seems like a dangerous behavior to rely on. The same is
-- not true for ShortByteString and newByteArray#, for example.
-> stimesNegativeErr
| otherwise -> case bs of
BS _ 0 -> mempty -- "null" is only defined in Data.ByteString
_ -> stimesOverflowErr

stimesNegativeErr :: ByteString
stimesNegativeErr = error "stimes: non-negative multiplier expected"
stimesOverflowErr :: ByteString
stimesOverflowErr = overflowError "stimes"

-- | Repeats the given ByteString n times.
stimesInt :: Int -> ByteString -> ByteString
stimesInt n (BS fp len)
| n < 0 = error "stimes: non-negative multiplier expected"
| n < 0 = stimesNegativeErr
| n == 0 = mempty
| n == 1 = BS fp len
| len == 0 = mempty
Expand All @@ -754,10 +770,11 @@ stimesInt n (BS fp len)
fillFrom destptr len
where
size = checkedMul "stimes" n len
halfSize = (size - 1) `div` 2

fillFrom :: Ptr Word8 -> Int -> IO ()
fillFrom destptr copied
| 2 * copied < size = do
| copied <= halfSize = do
memcpy (destptr `plusPtr` copied) destptr copied
fillFrom destptr (copied * 2)
| otherwise = memcpy (destptr `plusPtr` copied) destptr (size - copied)
Expand Down Expand Up @@ -832,38 +849,31 @@ checkedMul fun !x@(I# x#) !y@(I# y#) = assert (min x y >= 0) $
_ -> overflowError fun
#else
case timesWord2# (int2Word# x#) (int2Word# y#) of
(# hi, lo #) -> case word2Int# (or# hi (uncheckedShiftRL# lo shiftAmt)) of
0# -> I# (word2Int# lo)
_ -> overflowError fun
(# hi, lo #) -> case or# hi (uncheckedShiftRL# lo shiftAmt) of
0## -> I# (word2Int# lo)
_ -> overflowError fun
where !(I# shiftAmt) = finiteBitSize (0 :: Word) - 1
#endif


_toIntegralSized :: (Integral a, Integral b, Bits a, Bits b) => a -> Maybe b
{-# INLINE _toIntegralSized #-}
-- stupid hack to make sure 'toIntegralSized' specializes, without
-- generating spcializations for every version of 'times'
-- stupid hack to make sure specialized versions of 'toIntegralSized'
-- are generated when the checkedToInt RULES fire
_toIntegralSized = inline toIntegralSized

checkedToInt :: Integral t => t -> Maybe Int
{-# RULES
"checkedToInt/Int" checkedToInt = _toIntegralSized :: Int -> Maybe Int
; "checkedToInt/Int8" checkedToInt = _toIntegralSized :: Int8 -> Maybe Int
; "checkedToInt/Int16" checkedToInt = _toIntegralSized :: Int16 -> Maybe Int
; "checkedToInt/Int32" checkedToInt = _toIntegralSized :: Int32 -> Maybe Int
; "checkedToInt/Int64" checkedToInt = _toIntegralSized :: Int64 -> Maybe Int
; "checkedToInt/Word" checkedToInt = _toIntegralSized :: Word -> Maybe Int
; "checkedToInt/Word8" checkedToInt = _toIntegralSized :: Word8 -> Maybe Int
; "checkedToInt/Word16" checkedToInt = _toIntegralSized :: Word16 -> Maybe Int
; "checkedToInt/Word32" checkedToInt = _toIntegralSized :: Word32 -> Maybe Int
; "checkedToInt/Word64" checkedToInt = _toIntegralSized :: Word64 -> Maybe Int
; "checkedToInt/Integer" checkedToInt = _toIntegralSized :: Integer -> Maybe Int
; "checkedToInt/Natural" checkedToInt = _toIntegralSized :: Natural -> Maybe Int
"checkedToInt/Int" checkedToInt = _toIntegralSized :: Int -> Maybe Int
; "checkedToInt/Word" checkedToInt = _toIntegralSized :: Word -> Maybe Int
#-}
{-# NOINLINE [1] checkedToInt #-}
checkedToInt x = if toInteger x == toInteger res
then Just res
else Nothing
where res = fromIntegral x :: Int
checkedToInt x
| xi == toInteger res = Just res
| otherwise = Nothing
where
xi = toInteger x
res = fromInteger xi :: Int


------------------------------------------------------------------------
Expand Down
14 changes: 9 additions & 5 deletions tests/Properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,11 @@ prop_stimesOverflowScary bs =
n = P.length bs
reps = maxBound @Word `quot` fromIntegral @Int @Word n + 1

concat32bitOverflow :: (Int -> a) -> ([a] -> a) -> (a -> Int) -> Property
concat32bitOverflow replicateLike concatLike lengthLike = let
prop_stimesOverflowEmpty = forAll (choose (0, maxBound @Word)) $ \n ->
stimes n mempty === mempty @P.ByteString

concat32bitOverflow :: (Int -> a) -> ([a] -> a) -> Property
concat32bitOverflow replicateLike concatLike = let
intBits = finiteBitSize @Int 0
largeBS = concatLike $ replicate (bit 14) $ replicateLike (bit 17)
in if intBits /= 32
Expand All @@ -297,15 +300,15 @@ concat32bitOverflow replicateLike concatLike lengthLike = let

prop_32bitOverflow_Strict_mconcat :: Property
prop_32bitOverflow_Strict_mconcat =
concat32bitOverflow (`P.replicate` 0) mconcat P.length
concat32bitOverflow (`P.replicate` 0) mconcat

prop_32bitOverflow_Lazy_toStrict :: Property
prop_32bitOverflow_Lazy_toStrict =
concat32bitOverflow (`P.replicate` 0) (L.toStrict . L.fromChunks) P.length
concat32bitOverflow (`P.replicate` 0) (L.toStrict . L.fromChunks)

prop_32bitOverflow_Short_mconcat :: Property
prop_32bitOverflow_Short_mconcat =
concat32bitOverflow makeShort mconcat Short.length
concat32bitOverflow makeShort mconcat
where makeShort n = Short.toShort $ P.replicate n 0


Expand Down Expand Up @@ -647,6 +650,7 @@ overflow_tests =
, testProperty "checkedMul" prop_checkedMul
, testProperty "StrictByteString stimes (basic)" prop_stimesOverflowBasic
, testProperty "StrictByteString stimes (scary)" prop_stimesOverflowScary
, testProperty "StrictByteString stimes (empty)" prop_stimesOverflowEmpty
, testProperty "StrictByteString mconcat" prop_32bitOverflow_Strict_mconcat
, testProperty "LazyByteString toStrict" prop_32bitOverflow_Lazy_toStrict
, testProperty "ShortByteString mconcat" prop_32bitOverflow_Short_mconcat
Expand Down

0 comments on commit 554506d

Please sign in to comment.