diff --git a/Network.hs b/Network.hs index 4edb763e..6c04bf5a 100644 --- a/Network.hs +++ b/Network.hs @@ -317,8 +317,8 @@ accept (MkSocket _ family _ _ _) = error $ "Sorry, address family " ++ (show family) ++ " is not supported!" --- | Close the socket. All future operations on the socket object will fail. --- The remote end will receive no more data (after queued data is flushed). +-- | Close the socket. Sending data to or receiving data from a closed socket +-- causes undefined behaviour. sClose :: Socket -> IO () sClose = close -- Explicit redefinition because Network.sClose is deperecated, -- hence the re-export would also be marked as such. diff --git a/Network/Socket.hsc b/Network/Socket.hsc index c1d56d72..34f9a435 100644 --- a/Network/Socket.hsc +++ b/Network/Socket.hsc @@ -561,6 +561,10 @@ foreign import ccall unsafe "free" -- -- NOTE: blocking on Windows unless you compile with -threaded (see -- GHC ticket #1129) +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the version in "Network.Socket.Safe". + {-# WARNING sendTo "Use sendTo defined in \"Network.Socket.ByteString\"" #-} sendTo :: Socket -- (possibly) bound/connected Socket -> String -- Data to send @@ -640,6 +644,10 @@ recvBufFrom sock@(MkSocket s family _stype _protocol _status) ptr nbytes -- | Send data to the socket. The socket must be connected to a remote -- socket. Returns the number of bytes sent. Applications are -- responsible for ensuring that all data has been sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the version in "Network.Socket.Safe". + {-# WARNING send "Use send defined in \"Network.Socket.ByteString\"" #-} send :: Socket -- Bound/Connected Socket -> String -- Data to send @@ -650,6 +658,10 @@ send sock xs = withCStringLen xs $ \(str, len) -> -- | Send data to the socket. The socket must be connected to a remote -- socket. Returns the number of bytes sent. Applications are -- responsible for ensuring that all data has been sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the version in "Network.Socket.Safe". + sendBuf :: Socket -- Bound/Connected Socket -> Ptr Word8 -- Pointer to the data to send -> Int -- Length of the buffer @@ -684,6 +696,10 @@ sendBuf sock@(MkSocket s _family _stype _protocol _status) str len = do -- -- For TCP sockets, a zero length return value means the peer has -- closed its half side of the connection. +-- +-- Receiving data from a closed socket causes undefined behaviour. To always get +-- an exception, use the version in "Network.Socket.Safe". + {-# WARNING recv "Use recv defined in \"Network.Socket.ByteString\"" #-} recv :: Socket -> Int -> IO String recv sock l = fst <$> recvLen sock l @@ -707,6 +723,9 @@ recvLen sock nbytes = -- -- For TCP sockets, a zero length return value means the peer has -- closed its half side of the connection. +-- +-- Receiving data from a closed socket causes undefined behaviour. To always get +-- an exception, use "Network.Socket.Safe". recvBuf :: Socket -> Ptr Word8 -> Int -> IO Int recvBuf sock@(MkSocket s _family _stype _protocol _status) ptr nbytes | nbytes <= 0 = ioError (mkInvalidRecvArgError "Network.Socket.recvBuf") @@ -1072,9 +1091,8 @@ shutdown (MkSocket s _ _ _ _) stype = do -- ----------------------------------------------------------------------------- --- | Close the socket. All future operations on the socket object --- will fail. The remote end will receive no more data (after queued --- data is flushed). +-- | Close the socket. Sending data to or receiving data from a closed socket +-- causes undefined behaviour. close :: Socket -> IO () close (MkSocket s _ _ _ socketStatus) = do modifyMVar_ socketStatus $ \ status -> diff --git a/Network/Socket/ByteString.hsc b/Network/Socket/ByteString.hsc index a5250339..5a4ae466 100644 --- a/Network/Socket/ByteString.hsc +++ b/Network/Socket/ByteString.hsc @@ -75,6 +75,10 @@ import Network.Socket.ByteString.MsgHdr (MsgHdr(..)) -- | Send data to the socket. The socket must be connected to a -- remote socket. Returns the number of bytes sent. Applications are -- responsible for ensuring that all data has been sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get +-- an exception, use the verion in "Network.Socket.ByteString.Safe". + send :: Socket -- ^ Connected socket -> ByteString -- ^ Data to send -> IO Int -- ^ Number of bytes sent @@ -86,6 +90,10 @@ send sock xs = unsafeUseAsCStringLen xs $ \(str, len) -> -- until either all data has been sent or an error occurs. On error, -- an exception is raised, and there is no way to determine how much -- data, if any, was successfully sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the verion in "Network.Socket.ByteString.Safe". + sendAll :: Socket -- ^ Connected socket -> ByteString -- ^ Data to send -> IO () @@ -97,6 +105,10 @@ sendAll sock bs = do -- explicitly, so the socket need not be in a connected state. -- Returns the number of bytes sent. Applications are responsible for -- ensuring that all data has been sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the verion in "Network.Socket.ByteString.Safe". + sendTo :: Socket -- ^ Socket -> ByteString -- ^ Data to send -> SockAddr -- ^ Recipient address @@ -110,6 +122,10 @@ sendTo sock xs addr = -- data has been sent or an error occurs. On error, an exception is -- raised, and there is no way to determine how much data, if any, was -- successfully sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the verion in "Network.Socket.ByteString.Safe". + sendAllTo :: Socket -- ^ Socket -> ByteString -- ^ Data to send -> SockAddr -- ^ Recipient address @@ -146,6 +162,10 @@ sendAllTo sock xs addr = do -- sent or an error occurs. On error, an exception is raised, and -- there is no way to determine how much data, if any, was -- successfully sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the verion in "Network.Socket.ByteString.Safe". + sendMany :: Socket -- ^ Connected socket -> [ByteString] -- ^ Data to send -> IO () @@ -169,6 +189,10 @@ sendMany sock = sendAll sock . B.concat -- continues to send data until either all data has been sent or an -- error occurs. On error, an exception is raised, and there is no -- way to determine how much data, if any, was successfully sent. +-- +-- Sending data to a closed socket causes undefined behaviour. To always get an +-- exception, use the verion in "Network.Socket.ByteString.Safe". + sendManyTo :: Socket -- ^ Socket -> [ByteString] -- ^ Data to send -> SockAddr -- ^ Recipient address @@ -205,6 +229,10 @@ sendManyTo sock cs = sendAllTo sock (B.concat cs) -- -- For TCP sockets, a zero length return value means the peer has -- closed its half side of the connection. +-- +-- Receiving data from a closed socket causes undefined behaviour. To +-- always get an exception, use the verion in "Network.Socket.ByteString.Safe". + recv :: Socket -- ^ Connected socket -> Int -- ^ Maximum number of bytes to receive -> IO ByteString -- ^ Data received @@ -217,6 +245,10 @@ recv sock nbytes -- connected state. Returns @(bytes, address)@ where @bytes@ is a -- 'ByteString' representing the data received and @address@ is a -- 'SockAddr' representing the address of the sending socket. +-- +-- Receiving data from a closed socket causes undefined behaviour. To always get +-- an exception, use the verion in "Network.Socket.ByteString.Safe". + recvFrom :: Socket -- ^ Socket -> Int -- ^ Maximum number of bytes to receive -> IO (ByteString, SockAddr) -- ^ Data received and sender address diff --git a/Network/Socket/ByteString/Lazy.hs b/Network/Socket/ByteString/Lazy.hs index 8f993175..a53ceae8 100644 --- a/Network/Socket/ByteString/Lazy.hs +++ b/Network/Socket/ByteString/Lazy.hs @@ -61,6 +61,9 @@ import Network.Socket.ByteString.Lazy.Posix (send, sendAll) -- more data to be received, the receiving side of the socket is shut -- down. If there is an error and an exception is thrown, the socket -- is not shut down. +-- +-- Receiving data from a closed socket causes undefined behaviour. To always get +-- an exception, use the verion in "Network.Socket.ByteString.Lazy.Safe". getContents :: Socket -- ^ Connected socket -> IO ByteString -- ^ Data received getContents sock = loop where @@ -77,6 +80,9 @@ getContents sock = loop where -- until a message arrives. -- -- If there is no more data to be received, returns an empty 'ByteString'. +-- +-- Receiving data from a closed socket causes undefined behaviour. To always get +-- an exception, use the verion in "Network.Socket.ByteString.Lazy.Safe". recv :: Socket -- ^ Connected socket -> Int64 -- ^ Maximum number of bytes to receive -> IO ByteString -- ^ Data received diff --git a/Network/Socket/ByteString/Lazy/Safe.hs b/Network/Socket/ByteString/Lazy/Safe.hs new file mode 100644 index 00000000..f0ba8e79 --- /dev/null +++ b/Network/Socket/ByteString/Lazy/Safe.hs @@ -0,0 +1,42 @@ +----------------------------------------------------------------------------- +-- | +-- Module : Network.Socket.ByteString.Lazy.Safe +-- Copyright : Echo Nolan 2016 +-- License : BSD-style (see the file libraries/network/LICENSE) +-- +-- Maintainer : libraries@haskell.org +-- Stability : provisional +-- Portability : portable +-- +-- A drop in replacement for "Network.Socket.ByteString.Lazy" that sacrifices +-- some performance for correctness. See "Network.Socket.Safe" for what exactly +-- that means. See "Network.Socket.ByteString.Lazy" for API documentation. +----------------------------------------------------------------------------- + +module Network.Socket.ByteString.Lazy.Safe + ( + send + , sendAll + , getContents + , recv + ) where + +import qualified Network.Socket.ByteString.Lazy as Unsafe +import Network.Socket.Internal +import Network.Socket.Types + +import Prelude hiding (getContents) +import Data.ByteString.Lazy (ByteString) +import Data.Int (Int64) + +send :: Socket -> ByteString -> IO Int64 +send = wrapCheckStatus2 Unsafe.send "Network.Socket.ByteString.Lazy.Safe.send" + +sendAll :: Socket -> ByteString -> IO () +sendAll = wrapCheckStatus2 Unsafe.sendAll "Network.Socket.ByteString.Lazy.Safe.sendAll" + +getContents :: Socket -> IO ByteString +getContents = wrapCheckStatus Unsafe.getContents "Network.Socket.ByteString.Lazy.Safe.getContents" + +recv :: Socket -> Int64 -> IO ByteString +recv = wrapCheckStatus2 Unsafe.recv "Network.Socket.ByteString.Lazy.Safe.recv" diff --git a/Network/Socket/ByteString/Safe.hs b/Network/Socket/ByteString/Safe.hs new file mode 100644 index 00000000..d82c3216 --- /dev/null +++ b/Network/Socket/ByteString/Safe.hs @@ -0,0 +1,56 @@ +----------------------------------------------------------------------------- +-- | +-- Module : Network.Socket.ByteString.Safe +-- Copyright : Echo Nolan 2016 +-- License : BSD-style (see the file libraries/network/LICENSE) +-- +-- Maintainer : libraries@haskell.org +-- Stability : provisional +-- Portability : portable +-- +-- A drop in replacement for "Network.Socket.ByteString" that sacrifices some +-- performance for correctness. See "Network.Socket.Safe" for what exactly that +-- means. See "Network.Socket.ByteString" for API documentation. +----------------------------------------------------------------------------- + +module Network.Socket.ByteString.Safe + ( + send + , sendAll + , sendTo + , sendAllTo + , sendMany + , sendManyTo + , recv + , recvFrom + ) where + +import qualified Network.Socket.ByteString as Unsafe +import Network.Socket.Internal +import Network.Socket.Types + +import Data.ByteString (ByteString) + +send :: Socket -> ByteString -> IO Int +send = wrapCheckStatus2 Unsafe.send "Network.Socket.ByteString.Safe.send" + +sendAll :: Socket -> ByteString -> IO () +sendAll = wrapCheckStatus2 Unsafe.sendAll "Network.Socket.ByteString.Safe.sendAll" + +sendTo :: Socket -> ByteString -> SockAddr -> IO Int +sendTo = wrapCheckStatus3 Unsafe.sendTo "Network.Socket.ByteString.Safe.sendTo" + +sendAllTo :: Socket -> ByteString -> SockAddr -> IO () +sendAllTo = wrapCheckStatus3 Unsafe.sendAllTo "Network.Socket.ByteString.Safe.sendAllTo" + +sendMany :: Socket -> [ByteString] -> IO () +sendMany = wrapCheckStatus2 Unsafe.sendMany "Network.Socket.ByteString.Safe.sendMany" + +sendManyTo :: Socket -> [ByteString] -> SockAddr -> IO () +sendManyTo = wrapCheckStatus3 Unsafe.sendManyTo "Network.Socket.ByteString.Safe.sendManyTo" + +recv :: Socket -> Int -> IO ByteString +recv = wrapCheckStatus2 Unsafe.recv "Network.Socket.ByteString.Safe.recv" + +recvFrom :: Socket -> Int -> IO (ByteString,SockAddr) +recvFrom = wrapCheckStatus2 Unsafe.recvFrom "Network.Socket.ByteString.Safe.recvFrom" diff --git a/Network/Socket/Internal.hsc b/Network/Socket/Internal.hsc index c8bf4f68..bedab69a 100644 --- a/Network/Socket/Internal.hsc +++ b/Network/Socket/Internal.hsc @@ -67,6 +67,12 @@ module Network.Socket.Internal -- * Low-level helpers , zeroMemory + + -- * Helpers for Network.Socket.Safe + , wrapCheckStatus + , wrapCheckStatus2 + , wrapCheckStatus3 + , wrapCheckStatus4 ) where import Foreign.C.Error (throwErrno, throwErrnoIfMinus1Retry, @@ -92,6 +98,9 @@ import Foreign.C.Types ( CChar ) import System.IO.Error ( ioeSetErrorString, mkIOError ) #endif +import Control.Concurrent (withMVar) +import Control.Exception (throwIO) + import Network.Socket.Types -- --------------------------------------------------------------------- @@ -271,3 +280,26 @@ withSocketsInit = unsafePerformIO $ do foreign import ccall unsafe "initWinSock" initWinSock :: IO Int #endif + +wrapCheckStatus :: (Socket -> IO a) -> String -> Socket -> IO a +wrapCheckStatus act fnName sock@(MkSocket _ _ _ _ statusVar) = + withMVar statusVar $ \status -> + case status of + Closed -> throwIO $ userError $ + fnName ++ ": attempted to use a closed socket" + _ -> act sock + +wrapCheckStatus2 :: (Socket -> a -> IO b) -> + String -> Socket -> a -> IO b +wrapCheckStatus2 act fnName sock a = + wrapCheckStatus (\s -> act s a) fnName sock + +wrapCheckStatus3 :: (Socket -> a -> b -> IO c) -> + String -> Socket -> a -> b -> IO c +wrapCheckStatus3 act fnName sock a b = + wrapCheckStatus (\s -> act s a b) fnName sock + +wrapCheckStatus4 :: (Socket -> a -> b -> c -> IO d) -> + String -> Socket -> a -> b -> c -> IO d +wrapCheckStatus4 act fnName sock a b c = + wrapCheckStatus (\s -> act s a b c) fnName sock diff --git a/Network/Socket/Safe.hs b/Network/Socket/Safe.hs new file mode 100644 index 00000000..7d79eb39 --- /dev/null +++ b/Network/Socket/Safe.hs @@ -0,0 +1,78 @@ +----------------------------------------------------------------------------- +-- | +-- Module : Network.Socket.Safe +-- Copyright : Echo Nolan 2016 +-- License : BSD-style (see the file libraries/network/LICENSE) +-- +-- Maintainer : libraries@haskell.org +-- Stability : provisional +-- Portability : portable +-- +-- A drop in replacement for "Network.Socket" that sacrifices some performance +-- for correctness. Specifically, this module's functions check that a socket +-- hasn't been closed before attempting to read or write from it. With the +-- "Network.Socket" API, reading or writing to a socket after closing it causes +-- undefined behavior. In this module it always throws an exception. N.b. this +-- serializes all use of a given socket via an MVar. +-- +-- See "Network.Socket" for API documentation. +----------------------------------------------------------------------------- + +-- We reexport things with attached warnings, but we don't want those warnings +-- here. +{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-} + +module Network.Socket.Safe + ( + module Network.Socket + , send + , sendTo + , recv + , recvFrom + , recvLen + , sendBuf + , recvBuf + , sendBufTo + , recvBufFrom + ) where + +import Network.Socket hiding (send, sendTo, recv, recvFrom, recvLen, sendBuf, + recvBuf, sendBufTo, recvBufFrom) +import qualified Network.Socket as Unsafe + +import Network.Socket.Internal + +import Data.Word (Word8) +import Foreign.Ptr (Ptr) + +{-# WARNING send "Use send defined in \"Network.Socket.ByteString.Safe\"" #-} +send :: Socket -> String -> IO Int +send = wrapCheckStatus2 Unsafe.send "Network.Socket.Safe.send" + +{-# WARNING sendTo "Use sendTo defined in \"Network.Socket.ByteString.Safe\"" #-} +sendTo :: Socket -> String -> SockAddr -> IO Int +sendTo = wrapCheckStatus3 Unsafe.sendTo "Network.Socket.Safe.sendTo" + +{-# WARNING recv "Use recv defined in \"Network.Socket.ByteString.Safe\"" #-} +recv :: Socket -> Int -> IO String +recv = wrapCheckStatus2 Unsafe.recv "Network.Socket.Safe.recv" + +{-# WARNING recvFrom "Use recvFrom defined in \"Network.Socket.ByteString.Safe\"" #-} +recvFrom :: Socket -> Int -> IO (String, Int, SockAddr) +recvFrom = wrapCheckStatus2 Unsafe.recvFrom "Network.Socket.Safe.recvFrom" + +{-# WARNING recvLen "Use recvLen defined in \"Network.Socket.ByteString.Safe\"" #-} +recvLen :: Socket -> Int -> IO (String, Int) +recvLen = wrapCheckStatus2 Unsafe.recvLen "Network.Socket.Safe.recvLen" + +sendBuf :: Socket -> Ptr Word8 -> Int -> IO Int +sendBuf = wrapCheckStatus3 Unsafe.sendBuf "Network.Socket.Safe.sendBuf" + +recvBuf :: Socket -> Ptr Word8 -> Int -> IO Int +recvBuf = wrapCheckStatus3 Unsafe.recvBuf "Network.Socket.Safe.recvBuf" + +sendBufTo :: Socket -> Ptr a -> Int -> SockAddr -> IO Int +sendBufTo = wrapCheckStatus4 Unsafe.sendBufTo "Network.Socket.Safe.sendBufTo" + +recvBufFrom :: Socket -> Ptr a -> Int -> IO (Int, SockAddr) +recvBufFrom = wrapCheckStatus3 Unsafe.recvBufFrom "Network.Socket.Safe.recvBufFrom" diff --git a/network.cabal b/network.cabal index 076a1dff..7d4d0b0d 100644 --- a/network.cabal +++ b/network.cabal @@ -48,8 +48,11 @@ library Network Network.BSD Network.Socket + Network.Socket.Safe Network.Socket.ByteString + Network.Socket.ByteString.Safe Network.Socket.ByteString.Lazy + Network.Socket.ByteString.Lazy.Safe Network.Socket.Internal other-modules: Network.Socket.ByteString.Internal diff --git a/tests/Simple.hs b/tests/Simple.hs index c2aaabb0..aadbd4af 100644 --- a/tests/Simple.hs +++ b/tests/Simple.hs @@ -12,6 +12,7 @@ import qualified Data.ByteString.Char8 as C import Data.Maybe (fromJust) #endif import Network.Socket hiding (recv, recvFrom, send, sendTo) +import qualified Network.Socket.ByteString.Safe as Safe import Network.Socket.ByteString --- To tests for AF_CAN on Linux, you need to bring up a virtual (or real can @@ -30,7 +31,7 @@ import Network.BSD (ifNameToIndex) #endif import Test.Framework (Test, defaultMain, testGroup) import Test.Framework.Providers.HUnit (testCase) -import Test.HUnit (Assertion, (@=?)) +import Test.HUnit (Assertion, (@=?), assertFailure) ------------------------------------------------------------------------ @@ -141,6 +142,29 @@ testUserTimeout = do getSocketOption sock UserTimeout >>= (@=?) 2000 sClose sock +testSafeSend :: Assertion +testSafeSend = tcpTest client server + where + server sock = do + close sock + res :: Either E.IOException Int <- E.try (Safe.send sock testMsg) + case res of + Left ex -> show ex @=? + "user error (Network.Socket.ByteString.Safe.send: attempted to use a closed socket)" + Right _ -> assertFailure "send didn't throw an exception on a closed socket" + client sock = return () + +testSafeRecv :: Assertion +testSafeRecv = tcpTest client server + where + server sock = do + close sock + res :: Either E.IOException S.ByteString <- E.try (Safe.recv sock 1024) + case res of + Left ex -> show ex @=? + "user error (Network.Socket.ByteString.Safe.recv: attempted to use a closed socket)" + Right _ -> assertFailure "recv didn't throw an exception on a closed socket" + client sock = return () {- testGetPeerCred:: Assertion testGetPeerCred = @@ -252,6 +276,8 @@ basicTests = testGroup "Basic socket operations" , testCase "testUserTimeout" testUserTimeout -- , testCase "testGetPeerCred" testGetPeerCred -- , testCase "testGetPeerEid" testGetPeerEid + , testCase "testSafeSend" testSafeSend + , testCase "testSafeRecv" testSafeRecv #if defined(HAVE_LINUX_CAN_H) , testCase "testCanSend" testCanSend #endif