Skip to content

Commit

Permalink
Merge pull request #555 from kazu-yamamoto/timo
Browse files Browse the repository at this point in the history
Socket timeout
  • Loading branch information
kazu-yamamoto authored May 24, 2023
2 parents 176aa09 + 353d8ef commit e73616c
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 91 deletions.
1 change: 1 addition & 0 deletions Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ module Network.Socket
,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo
,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo)
, StructLinger (..)
, SocketTimeout (..)
, isSupportedSocketOption
, whenSupported
, getSocketOption
Expand Down
5 changes: 5 additions & 0 deletions Network/Socket/ByteString/IO.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ sendManyTo s cs addr = do
--
-- For TCP sockets, a zero length return value means the peer has
-- closed its half side of the connection.
--
-- Currently, the 'recv' family is blocked on Windows because a proper
-- IO manager is not implemented. To use with 'System.Timeout.timeout'
-- on Windows, use 'Network.Socket.setSocketOption' with
-- 'Network.Socket.RecvTimeOut' as well.
recv :: Socket -- ^ Connected socket
-> Int -- ^ Maximum number of bytes to receive
-> IO ByteString -- ^ Data received
Expand Down
237 changes: 146 additions & 91 deletions Network/Socket/Options.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ module Network.Socket.Options (
, getSockOpt
, setSockOpt
, StructLinger (..)
, SocketTimeout (..)
) where

import qualified Text.Read as P
Expand All @@ -38,7 +39,9 @@ import Network.Socket.Internal
import Network.Socket.Types
import Network.Socket.ReadShow

-----------------------------------------------------------------------------
#include <sys/time.h>

----------------------------------------------------------------
-- Socket Properties

-- | Socket options for use with 'setSocketOption' and 'getSocketOption'.
Expand All @@ -55,18 +58,75 @@ data SocketOption = SockOpt
#endif
deriving (Eq)

----------------------------------------------------------------

socketOptionBijection :: Bijection SocketOption String
socketOptionBijection =
[ (UnsupportedSocketOption, "UnsupportedSocketOption")
, (Debug, "Debug")
, (ReuseAddr, "ReuseAddr")
, (SoDomain, "SoDomain")
, (Type, "Type")
, (SoProtocol, "SoProtocol")
, (SoError, "SoError")
, (DontRoute, "DontRoute")
, (Broadcast, "Broadcast")
, (SendBuffer, "SendBuffer")
, (RecvBuffer, "RecvBuffer")
, (KeepAlive, "KeepAlive")
, (OOBInline, "OOBInline")
, (Linger, "Linger")
, (ReusePort, "ReusePort")
, (RecvLowWater, "RecvLowWater")
, (SendLowWater, "SendLowWater")
, (RecvTimeOut, "RecvTimeOut")
, (SendTimeOut, "SendTimeOut")
, (UseLoopBack, "UseLoopBack")
, (MaxSegment, "MaxSegment")
, (NoDelay, "NoDelay")
, (UserTimeout, "UserTimeout")
, (Cork, "Cork")
, (TimeToLive, "TimeToLive")
, (RecvIPv4TTL, "RecvIPv4TTL")
, (RecvIPv4TOS, "RecvIPv4TOS")
, (RecvIPv4PktInfo, "RecvIPv4PktInfo")
, (IPv6Only, "IPv6Only")
, (RecvIPv6HopLimit, "RecvIPv6HopLimit")
, (RecvIPv6TClass, "RecvIPv6TClass")
, (RecvIPv6PktInfo, "RecvIPv6PktInfo")
]

instance Show SocketOption where
showsPrec = bijectiveShow socketOptionBijection def
where
defname = "SockOpt"
unwrap = \(CustomSockOpt nm) -> nm
def = defShow defname unwrap showIntInt


instance Read SocketOption where
readPrec = bijectiveRead socketOptionBijection def
where
defname = "SockOpt"
def = defRead defname CustomSockOpt readIntInt

----------------------------------------------------------------

pattern UnsupportedSocketOption :: SocketOption
pattern UnsupportedSocketOption = SockOpt (-1) (-1)

-- | Does the 'SocketOption' exist on this system?
isSupportedSocketOption :: SocketOption -> Bool
isSupportedSocketOption opt = opt /= SockOpt (-1) (-1)

-- | Get the 'SocketType' of an active socket.
--
-- Since: 3.0.1.0
getSocketType :: Socket -> IO SocketType
getSocketType s = unpackSocketType <$> getSockOpt s Type
-- | Execute the given action only when the specified socket option is
-- supported. Any return value is ignored.
whenSupported :: SocketOption -> IO a -> IO ()
whenSupported s action
| isSupportedSocketOption s = action >> return ()
| otherwise = return ()

pattern UnsupportedSocketOption :: SocketOption
pattern UnsupportedSocketOption = SockOpt (-1) (-1)
----------------------------------------------------------------

#ifdef SOL_SOCKET
-- | SO_ACCEPTCONN, read-only
Expand Down Expand Up @@ -192,14 +252,14 @@ pattern SendLowWater = SockOpt (#const SOL_SOCKET) (#const SO_SNDLOWAT)
#else
pattern SendLowWater = SockOpt (-1) (-1)
#endif
-- | SO_RCVTIMEO: this does not work at this moment.
-- | SO_RCVTIMEO: timeout in microseconds
pattern RecvTimeOut :: SocketOption
#ifdef SO_RCVTIMEO
pattern RecvTimeOut = SockOpt (#const SOL_SOCKET) (#const SO_RCVTIMEO)
#else
pattern RecvTimeOut = SockOpt (-1) (-1)
#endif
-- | SO_SNDTIMEO: this does not work at this moment.
-- | SO_SNDTIMEO: timeout in microseconds
pattern SendTimeOut :: SocketOption
#ifdef SO_SNDTIMEO
pattern SendTimeOut = SockOpt (#const SOL_SOCKET) (#const SO_SNDTIMEO)
Expand Down Expand Up @@ -317,41 +377,7 @@ pattern CustomSockOpt xy <- ((\(SockOpt x y) -> (x, y)) -> xy)
where
CustomSockOpt (x, y) = SockOpt x y

#if __GLASGOW_HASKELL__ >= 806
{-# COMPLETE CustomSockOpt #-}
#endif
#ifdef SO_LINGER
-- | Low level 'SO_LINBER' option value, which can be used with 'setSockOpt'.
--
data StructLinger = StructLinger {
-- | Set the linger option on.
sl_onoff :: CInt,

-- | Linger timeout.
sl_linger :: CInt
}
deriving (Eq, Ord, Show)

instance Storable StructLinger where
sizeOf _ = (#const sizeof(struct linger))
alignment _ = alignment (0 :: CInt)

peek p = do
onoff <- (#peek struct linger, l_onoff) p
linger <- (#peek struct linger, l_linger) p
return $ StructLinger onoff linger

poke p (StructLinger onoff linger) = do
(#poke struct linger, l_onoff) p onoff
(#poke struct linger, l_linger) p linger
#endif

-- | Execute the given action only when the specified socket option is
-- supported. Any return value is ignored.
whenSupported :: SocketOption -> IO a -> IO ()
whenSupported s action
| isSupportedSocketOption s = action >> return ()
| otherwise = return ()
----------------------------------------------------------------

-- | Set a socket option that expects an 'Int' value.
setSocketOption :: Socket
Expand All @@ -363,6 +389,8 @@ setSocketOption s so@Linger v = do
let arg = if v == 0 then StructLinger 0 0 else StructLinger 1 (fromIntegral v)
setSockOpt s so arg
#endif
setSocketOption s so@RecvTimeOut v = setSockOpt s so $ SocketTimeout $ fromIntegral v
setSocketOption s so@SendTimeOut v = setSockOpt s so $ SocketTimeout $ fromIntegral v
setSocketOption s sa v = setSockOpt s sa (fromIntegral v :: CInt)

-- | Set a socket option.
Expand All @@ -378,6 +406,8 @@ setSockOpt s (SockOpt level opt) v = do
throwSocketErrorIfMinus1_ "Network.Socket.setSockOpt" $
c_setsockopt fd level opt ptr sz

----------------------------------------------------------------

-- | Get a socket option that gives an 'Int' value.
getSocketOption :: Socket
-> SocketOption -- Option Name
Expand All @@ -387,6 +417,12 @@ getSocketOption s so@Linger = do
StructLinger onoff linger <- getSockOpt s so
return $ fromIntegral $ if onoff == 0 then 0 else linger
#endif
getSocketOption s so@RecvTimeOut = do
SocketTimeout to <- getSockOpt s so
return $ fromIntegral to
getSocketOption s so@SendTimeOut = do
SocketTimeout to <- getSockOpt s so
return $ fromIntegral to
getSocketOption s so = do
n :: CInt <- getSockOpt s so
return $ fromIntegral n
Expand All @@ -404,56 +440,75 @@ getSockOpt s (SockOpt level opt) = do
c_getsockopt fd level opt ptr ptr_sz
peek ptr

----------------------------------------------------------------

socketOptionBijection :: Bijection SocketOption String
socketOptionBijection =
[ (UnsupportedSocketOption, "UnsupportedSocketOption")
, (Debug, "Debug")
, (ReuseAddr, "ReuseAddr")
, (SoDomain, "SoDomain")
, (Type, "Type")
, (SoProtocol, "SoProtocol")
, (SoError, "SoError")
, (DontRoute, "DontRoute")
, (Broadcast, "Broadcast")
, (SendBuffer, "SendBuffer")
, (RecvBuffer, "RecvBuffer")
, (KeepAlive, "KeepAlive")
, (OOBInline, "OOBInline")
, (Linger, "Linger")
, (ReusePort, "ReusePort")
, (RecvLowWater, "RecvLowWater")
, (SendLowWater, "SendLowWater")
, (RecvTimeOut, "RecvTimeOut")
, (SendTimeOut, "SendTimeOut")
, (UseLoopBack, "UseLoopBack")
, (MaxSegment, "MaxSegment")
, (NoDelay, "NoDelay")
, (UserTimeout, "UserTimeout")
, (Cork, "Cork")
, (TimeToLive, "TimeToLive")
, (RecvIPv4TTL, "RecvIPv4TTL")
, (RecvIPv4TOS, "RecvIPv4TOS")
, (RecvIPv4PktInfo, "RecvIPv4PktInfo")
, (IPv6Only, "IPv6Only")
, (RecvIPv6HopLimit, "RecvIPv6HopLimit")
, (RecvIPv6TClass, "RecvIPv6TClass")
, (RecvIPv6PktInfo, "RecvIPv6PktInfo")
]
-- | Get the 'SocketType' of an active socket.
--
-- Since: 3.0.1.0
getSocketType :: Socket -> IO SocketType
getSocketType s = unpackSocketType <$> getSockOpt s Type

instance Show SocketOption where
showsPrec = bijectiveShow socketOptionBijection def
where
defname = "SockOpt"
unwrap = \(CustomSockOpt nm) -> nm
def = defShow defname unwrap showIntInt
----------------------------------------------------------------

#if __GLASGOW_HASKELL__ >= 806
{-# COMPLETE CustomSockOpt #-}
#endif
#ifdef SO_LINGER
-- | Low level 'SO_LINBER' option value, which can be used with 'setSockOpt'.
--
data StructLinger = StructLinger {
-- | Set the linger option on.
sl_onoff :: CInt,

instance Read SocketOption where
readPrec = bijectiveRead socketOptionBijection def
where
defname = "SockOpt"
def = defRead defname CustomSockOpt readIntInt
-- | Linger timeout.
sl_linger :: CInt
}
deriving (Eq, Ord, Show)

instance Storable StructLinger where
sizeOf _ = (#const sizeof(struct linger))
alignment _ = alignment (0 :: CInt)

peek p = do
onoff <- (#peek struct linger, l_onoff) p
linger <- (#peek struct linger, l_linger) p
return $ StructLinger onoff linger

poke p (StructLinger onoff linger) = do
(#poke struct linger, l_onoff) p onoff
(#poke struct linger, l_linger) p linger
#endif

----------------------------------------------------------------

-- | Timeout in microseconds.
-- This will be converted into struct timeval on Unix and
-- DWORD (as milliseconds) on Windows.
newtype SocketTimeout = SocketTimeout Word32 deriving (Eq, Ord, Show)

#if defined(mingw32_HOST_OS)
instance Storable SocketTimeout where
sizeOf (SocketTimeout to) = sizeOf to -- DWORD as milliseconds
alignment _ = 0
peek ptr = do
to <- peek (castPtr ptr)
return $ SocketTimeout (to * 1000)
poke ptr (SocketTimeout to) = poke (castPtr ptr) (to `div` 1000)
#else
instance Storable SocketTimeout where
sizeOf _ = (#size struct timeval)
alignment _ = (#const offsetof(struct {char x__; struct timeval (y__); }, y__))
peek ptr = do
sec <- (#peek struct timeval, tv_sec) ptr
usec <- (#peek struct timeval, tv_usec) ptr
return $ SocketTimeout (sec * 1000000 + usec)
poke ptr (SocketTimeout to) = do
let (sec, usec) = to `divMod` 1000000
(#poke struct timeval, tv_sec) ptr sec
(#poke struct timeval, tv_usec) ptr usec
#endif

----------------------------------------------------------------

foreign import CALLCONV unsafe "getsockopt"
c_getsockopt :: CInt -> CInt -> CInt -> Ptr a -> Ptr CInt -> IO CInt
Expand Down

0 comments on commit e73616c

Please sign in to comment.