diff --git a/Data/ByteString/Internal.hs b/Data/ByteString/Internal.hs index 1bb385ee2..c31cec086 100644 --- a/Data/ByteString/Internal.hs +++ b/Data/ByteString/Internal.hs @@ -66,7 +66,10 @@ module Data.ByteString.Internal ( -- * Utilities nullForeignPtr, + SizeOverflowException, + overflowError, checkedAdd, + checkedMultiply, -- * Standard C Functions c_strlen, @@ -107,18 +110,17 @@ import Foreign.Storable (Storable(..)) import Foreign.C.Types (CInt(..), CSize(..)) import Foreign.C.String (CString) -#if MIN_VERSION_base(4,13,0) -import Data.Semigroup (Semigroup (sconcat, stimes)) -#else -import Data.Semigroup (Semigroup ((<>), sconcat, stimes)) +#if !MIN_VERSION_base(4,13,0) +import Data.Semigroup (Semigroup ((<>))) #endif +import Data.Semigroup (Semigroup (sconcat, stimes)) import Data.List.NonEmpty (NonEmpty ((:|))) import Control.DeepSeq (NFData(rnf)) import Data.String (IsString(..)) -import Control.Exception (assert) +import Control.Exception (assert, throw, Exception) import Data.Bits ((.&.)) import Data.Char (ord) @@ -131,6 +133,20 @@ import GHC.Base (nullAddr#,realWorld#,unsafeChr) import GHC.Exts (IsList(..)) import GHC.CString (unpackCString#) import GHC.Prim (Addr#) + +#define TIMES_INT_2_AVAILABLE MIN_VERSION_ghc_prim(0,7,0) +#if TIMES_INT_2_AVAILABLE +import GHC.Prim (timesInt2#) +#else +import GHC.Prim ( timesWord2# + , or# + , uncheckedShiftRL# + , int2Word# + , word2Int# + ) +import Data.Bits (finiteBitSize) +#endif + import GHC.IO (IO(IO),unsafeDupablePerformIO) import GHC.ForeignPtr (ForeignPtr(ForeignPtr) #if __GLASGOW_HASKELL__ < 900 @@ -151,9 +167,7 @@ import GHC.ForeignPtr (ForeignPtrContents(FinalPtr)) import GHC.Ptr (Ptr(..)) #endif -#if (__GLASGOW_HASKELL__ < 802) || (__GLASGOW_HASKELL__ >= 811) import GHC.Types (Int (..)) -#endif #if MIN_VERSION_base(4,15,0) import GHC.ForeignPtr (unsafeWithForeignPtr) @@ -237,7 +251,8 @@ instance Ord ByteString where instance Semigroup ByteString where (<>) = append sconcat (b:|bs) = concat (b:bs) - stimes = times + {-# INLINE stimes #-} + stimes = stimesPolymorphic instance Monoid ByteString where mempty = empty @@ -663,7 +678,7 @@ append :: ByteString -> ByteString -> ByteString append (BS _ 0) b = b append a (BS _ 0) = a append (BS fp1 len1) (BS fp2 len2) = - unsafeCreate (len1+len2) $ \destptr1 -> do + unsafeCreate (checkedAdd "append" len1 len2) $ \destptr1 -> do let destptr2 = destptr1 `plusPtr` len1 unsafeWithForeignPtr fp1 $ \p1 -> memcpy destptr1 p1 len1 unsafeWithForeignPtr fp2 $ \p2 -> memcpy destptr2 p2 len2 @@ -719,38 +734,58 @@ concat = \bss0 -> goLen0 bss0 bss0 concat [x] = x #-} --- | /O(log n)/ Repeats the given ByteString n times. -times :: Integral a => a -> ByteString -> ByteString -times n (BS fp len) - | n < 0 = error "stimes: non-negative multiplier expected" +-- | Repeats the given ByteString n times. +-- Polymorphic wrapper to make sure any generated +-- specializations are reasonably small. +stimesPolymorphic :: Integral a => a -> ByteString -> ByteString +{-# INLINABLE stimesPolymorphic #-} +stimesPolymorphic nRaw = \ !bs -> case checkedIntegerToInt n of + Just nInt + | nInt >= 0 -> stimesNonNegativeInt nInt bs + | otherwise -> stimesNegativeErr + Nothing + | n < 0 -> stimesNegativeErr + | BS _ 0 <- bs -> empty + | otherwise -> stimesOverflowErr + where n = toInteger nRaw + -- By exclusively using n instead of nRaw, the semantics are kept simple + -- and the likelihood of potentially dangerous mistakes minimized. + + +stimesNegativeErr :: ByteString +stimesNegativeErr + = error "stimes @ByteString: non-negative multiplier expected" + +stimesOverflowErr :: ByteString +-- Although this only appears once, it is extracted here to prevent it +-- from being duplicated in specializations of 'stimesPolymorphic' +stimesOverflowErr = overflowError "stimes" + +-- | Repeats the given ByteString n times. +stimesNonNegativeInt :: Int -> ByteString -> ByteString +stimesNonNegativeInt n (BS fp len) | n == 0 = empty | n == 1 = BS fp len | len == 0 = empty - | len == 1 = unsafeCreate size $ \destptr -> + | len == 1 = unsafeCreate n $ \destptr -> unsafeWithForeignPtr fp $ \p -> do byte <- peek p - void $ memset destptr byte (fromIntegral size) + void $ memset destptr byte (fromIntegral n) | otherwise = unsafeCreate size $ \destptr -> unsafeWithForeignPtr fp $ \p -> do memcpy destptr p len fillFrom destptr len where - size = len * fromIntegral n + size = checkedMultiply "stimes" n len + halfSize = (size - 1) `div` 2 -- subtraction and division won't overflow fillFrom :: Ptr Word8 -> Int -> IO () fillFrom destptr copied - | 2 * copied < size = do + | copied <= halfSize = do memcpy (destptr `plusPtr` copied) destptr copied fillFrom destptr (copied * 2) | otherwise = memcpy (destptr `plusPtr` copied) destptr (size - copied) --- | Add two non-negative numbers. Errors out on overflow. -checkedAdd :: String -> Int -> Int -> Int -checkedAdd fun x y - | r >= 0 = r - | otherwise = overflowError fun - where r = x + y -{-# INLINE checkedAdd #-} ------------------------------------------------------------------------ @@ -785,8 +820,64 @@ isSpaceChar8 :: Char -> Bool isSpaceChar8 = isSpaceWord8 . c2w {-# INLINE isSpaceChar8 #-} +------------------------------------------------------------------------ + +-- | The type of exception raised by 'overflowError' +-- and on failure by overflow-checked arithmetic operations. +newtype SizeOverflowException + = SizeOverflowException String + +instance Show SizeOverflowException where + show (SizeOverflowException err) = err + +instance Exception SizeOverflowException + +-- | Raises a 'SizeOverflowException', +-- with a message using the given function name. overflowError :: String -> a -overflowError fun = error $ "Data.ByteString." ++ fun ++ ": size overflow" +overflowError fun = throw $ SizeOverflowException msg + where msg = "Data.ByteString." ++ fun ++ ": size overflow" + +-- | Add two non-negative numbers. +-- Calls 'overflowError' on overflow. +checkedAdd :: String -> Int -> Int -> Int +{-# INLINE checkedAdd #-} +checkedAdd fun x y + | r >= 0 = r + | otherwise = overflowError fun + where r = assert (min x y >= 0) $ x + y + +-- | Multiplies two non-negative numbers. +-- Calls 'overflowError' on overflow. +checkedMultiply :: String -> Int -> Int -> Int +{-# INLINE checkedMultiply #-} +checkedMultiply fun !x@(I# x#) !y@(I# y#) = assert (min x y >= 0) $ +#if TIMES_INT_2_AVAILABLE + case timesInt2# x# y# of + (# 0#, _, result #) -> I# result + _ -> overflowError fun +#else + case timesWord2# (int2Word# x#) (int2Word# y#) of + (# hi, lo #) -> case or# hi (uncheckedShiftRL# lo shiftAmt) of + 0## -> I# (word2Int# lo) + _ -> overflowError fun + where !(I# shiftAmt) = finiteBitSize (0 :: Word) - 1 +#endif + + +-- | Attempts to convert an 'Integer' value to an 'Int', returning +-- 'Nothing' if doing so would result in an overflow. +checkedIntegerToInt :: Integer -> Maybe Int +{-# INLINE checkedIntegerToInt #-} +-- We could use Data.Bits.toIntegralSized, but this hand-rolled +-- version is currently a bit faster as of GHC 9.2. +-- It's even faster to just match on the Integer constructors, but +-- we'd still need a fallback implementation for integer-simple. +checkedIntegerToInt x + | x == toInteger res = Just res + | otherwise = Nothing + where res = fromInteger x :: Int + ------------------------------------------------------------------------ diff --git a/Data/ByteString/Lazy/Internal.hs b/Data/ByteString/Lazy/Internal.hs index 5cbf09c81..63dc3f670 100644 --- a/Data/ByteString/Lazy/Internal.hs +++ b/Data/ByteString/Lazy/Internal.hs @@ -307,11 +307,11 @@ toStrict = \cs -> goLen0 cs cs goLen1 _ bs Empty = bs goLen1 cs0 bs (Chunk (S.BS _ 0) cs) = goLen1 cs0 bs cs goLen1 cs0 (S.BS _ bl) (Chunk (S.BS _ cl) cs) = - goLen cs0 (S.checkedAdd "Lazy.concat" bl cl) cs + goLen cs0 (S.checkedAdd "Lazy.toStrict" bl cl) cs -- General case, just find the total length we'll need goLen cs0 !total (Chunk (S.BS _ cl) cs) = - goLen cs0 (S.checkedAdd "Lazy.concat" total cl) cs + goLen cs0 (S.checkedAdd "Lazy.toStrict" total cl) cs goLen cs0 total Empty = S.unsafeCreate total $ \ptr -> goCopy cs0 ptr diff --git a/Data/ByteString/Short/Internal.hs b/Data/ByteString/Short/Internal.hs index e39c93423..d0e5bf31b 100644 --- a/Data/ByteString/Short/Internal.hs +++ b/Data/ByteString/Short/Internal.hs @@ -650,7 +650,7 @@ append :: ShortByteString -> ShortByteString -> ShortByteString append src1 src2 = let !len1 = length src1 !len2 = length src2 - in create (len1 + len2) $ \dst -> do + in create (checkedAdd "Short.append" len1 len2) $ \dst -> do copyByteArray (asBA src1) 0 dst 0 len1 copyByteArray (asBA src2) 0 dst len1 len2 @@ -658,8 +658,9 @@ concat :: [ShortByteString] -> ShortByteString concat = \sbss -> create (totalLen 0 sbss) (\dst -> copy dst 0 sbss) where - totalLen !acc [] = acc - totalLen !acc (sbs: sbss) = totalLen (acc + length sbs) sbss + totalLen !acc [] = acc + totalLen !acc (curr : rest) + = totalLen (checkedAdd "Short.concat" acc $ length curr) rest copy :: MBA s -> Int -> [ShortByteString] -> ST s () copy !_ !_ [] = return () diff --git a/README.md b/README.md index 2509b81ce..17c9ff2a7 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ of `ByteString` values from smaller pieces during binary serialization. Requirements: * Cabal 1.10 or greater - * GHC 7.0 or greater + * GHC 8.0 or greater ### Authors diff --git a/bench/BenchAll.hs b/bench/BenchAll.hs index 81b36c1e1..451eebc65 100644 --- a/bench/BenchAll.hs +++ b/bench/BenchAll.hs @@ -16,6 +16,7 @@ module Main (main) where import Data.Foldable (foldMap) import Data.Monoid +import Data.Semigroup import Data.String import Test.Tasty.Bench import Prelude hiding (words) @@ -407,6 +408,11 @@ main = do ] ] , bgroup "sort" $ map (\s -> bench (S8.unpack s) $ nf S.sort s) sortInputs + , bgroup "stimes" $ let st = stimes :: Int -> S.ByteString -> S.ByteString + in + [ bench "strict (tiny)" $ whnf (st 4) (S8.pack "test") + , bench "strict (large)" $ whnf (st 50) byteStringData + ] , bgroup "words" [ bench "lorem ipsum" $ nf S8.words loremIpsum , bench "one huge word" $ nf S8.words byteStringData diff --git a/tests/Properties.hs b/tests/Properties.hs index 0b3c29f96..15cb5d724 100644 --- a/tests/Properties.hs +++ b/tests/Properties.hs @@ -30,6 +30,8 @@ import qualified Data.List as List import Data.Char import Data.Word import Data.Maybe +import Data.Either (isLeft) +import Data.Bits (finiteBitSize, bit) import Data.Int (Int8, Int16, Int32, Int64) import Data.Semigroup import GHC.Exts (Int(..), newPinnedByteArray#, unsafeFreezeByteArray#) @@ -55,7 +57,6 @@ import qualified Data.ByteString.Lazy.Char8 as LC import qualified Data.ByteString.Lazy.Char8 as D import qualified Data.ByteString.Lazy.Internal as L -import Prelude hiding (abs) import QuickCheckUtils import Test.Tasty @@ -246,6 +247,74 @@ prop_readIntBoundsLC = rdWordBounds @Word ------------------------------------------------------------------------ +expectSizeOverflow :: a -> Property +expectSizeOverflow val = ioProperty $ do + isLeft <$> try @P.SizeOverflowException (evaluate val) + +prop_checkedAdd = forAll (vectorOf 2 nonNeg) $ \[x, y] -> if oflo x y + then expectSizeOverflow (P.checkedAdd "" x y) + else property $ P.checkedAdd "" x y == x + y + where nonNeg = choose (0, (maxBound @Int)) + oflo x y = toInteger x + toInteger y /= toInteger @Int (x + y) + +multCompl :: Int -> Gen Int +multCompl x = choose (0, fromInteger @Int maxc) + -- This choice creates products with magnitude roughly in the range + -- [0..5*(maxBound @Int)], which results in a roughly even split + -- between positive and negative overflowed Int results, while still + -- producing a fair number of non-overflowing products. + where maxc = toInteger (maxBound @Int) * 5 `quot` max 5 (abs $ toInteger x) + +prop_checkedMultiply = forAll genScale $ \scale -> + forAll (genVal scale) $ \x -> + forAll (multCompl x) $ \y -> if oflo x y + then expectSizeOverflow (P.checkedMultiply "" x y) + else property $ P.checkedMultiply "" x y == x * y + where genScale = choose (0, finiteBitSize @Int 0 - 1) + genVal scale = choose (0, bit scale - 1) + oflo x y = toInteger x * toInteger y /= toInteger @Int (x * y) + +prop_stimesOverflowBasic bs = forAll (multCompl len) $ \n -> + toInteger n * toInteger len > maxInt ==> expectSizeOverflow (stimes n bs) + where + maxInt = toInteger @Int (maxBound @Int) + len = P.length bs + +prop_stimesOverflowScary bs = + -- "Scary" because this test will cause heap corruption + -- (not just memory exhaustion) with the old stimes implementation. + n > 1 ==> expectSizeOverflow (stimes reps bs) + where + n = P.length bs + reps = maxBound @Word `quot` fromIntegral @Int @Word n + 1 + +prop_stimesOverflowEmpty = forAll (choose (0, maxBound @Word)) $ \n -> + stimes n mempty === mempty @P.ByteString + +concat32bitOverflow :: (Int -> a) -> ([a] -> a) -> Property +concat32bitOverflow replicateLike concatLike = let + intBits = finiteBitSize @Int 0 + largeBS = concatLike $ replicate (bit 14) $ replicateLike (bit 17) + in if intBits /= 32 + then label "skipped due to non-32-bit Int" True + else expectSizeOverflow largeBS + +prop_32bitOverflow_Strict_mconcat :: Property +prop_32bitOverflow_Strict_mconcat = + concat32bitOverflow (`P.replicate` 0) mconcat + +prop_32bitOverflow_Lazy_toStrict :: Property +prop_32bitOverflow_Lazy_toStrict = + concat32bitOverflow (`P.replicate` 0) (L.toStrict . L.fromChunks) + +prop_32bitOverflow_Short_mconcat :: Property +prop_32bitOverflow_Short_mconcat = + concat32bitOverflow makeShort mconcat + where makeShort n = Short.toShort $ P.replicate n 0 + + +------------------------------------------------------------------------ + prop_packUptoLenBytes cs = forAll (choose (0, length cs + 1)) $ \n -> let (bs, cs') = P.packUptoLenBytes n cs @@ -557,6 +626,7 @@ testSuite = testGroup "Properties" , testGroup "StrictChar8" PropBS8.tests , testGroup "LazyWord8" PropBL.tests , testGroup "LazyChar8" PropBL8.tests + , testGroup "Overflow" overflow_tests , testGroup "Misc" misc_tests , testGroup "IO" io_tests , testGroup "Short" short_tests @@ -577,6 +647,17 @@ io_tests = , testProperty "packAddress " prop_packAddress ] +overflow_tests = + [ testProperty "checkedAdd" prop_checkedAdd + , testProperty "checkedMultiply" prop_checkedMultiply + , testProperty "StrictByteString stimes (basic)" prop_stimesOverflowBasic + , testProperty "StrictByteString stimes (scary)" prop_stimesOverflowScary + , testProperty "StrictByteString stimes (empty)" prop_stimesOverflowEmpty + , testProperty "StrictByteString mconcat" prop_32bitOverflow_Strict_mconcat + , testProperty "LazyByteString toStrict" prop_32bitOverflow_Lazy_toStrict + , testProperty "ShortByteString mconcat" prop_32bitOverflow_Short_mconcat + ] + misc_tests = [ testProperty "packUptoLenBytes" prop_packUptoLenBytes , testProperty "packUptoLenChars" prop_packUptoLenChars