From 1ae887f472a7675b186b24439f86f61366e6ea88 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 29 Nov 2019 20:59:07 +0900 Subject: [PATCH 01/48] defining RECV options. --- Network/Socket/Options.hsc | 40 ++++++++++++++++++++++++++++++++++++++ include/HsNet.h | 1 + 2 files changed, 41 insertions(+) diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 1d41b09f..82217323 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -52,6 +52,12 @@ data SocketOption | UseLoopBack -- ^ SO_USELOOPBACK | UserTimeout -- ^ TCP_USER_TIMEOUT | IPv6Only -- ^ IPV6_V6ONLY: don't use this on OpenBSD. + | RecvIPv4TTL -- ^ Receiving IPv4 TTL. + | RecvIPv4TOS -- ^ Receiving IPv4 TOS. + | RecvIPv4PktInfo -- ^ Receiving IP_PKTINFO (struct in_pktinfo). + | RecvIPv6HopLimit -- ^ Receiving IPv6 hop limit. + | RecvIPv6TClass -- ^ Receiving IPv6 traffic class. + | RecvIPv6PktInfo -- ^ Receiving IPV6_PKTINFO (struct in6_pktinfo). | CustomSockOpt (CInt, CInt) deriving (Show, Typeable) @@ -158,6 +164,40 @@ packSocketOption so = #if HAVE_DECL_IPV6_V6ONLY Just IPv6Only -> Just ((#const IPPROTO_IPV6), (#const IPV6_V6ONLY)) #endif +#endif // HAVE_DECL_IPPROTO_IPV6 +#if HAVE_DECL_IPPROTO_IP +#ifdef IP_RECVTTL + Just RecvIPv4TTL -> Just ((#const IPPROTO_IP), (#const IP_RECVTTL)) +#endif +#endif // HAVE_DECL_IPPROTO_IP +#if HAVE_DECL_IPPROTO_IP +#ifdef IP_RECVTOS + Just RecvIPv4TOS -> Just ((#const IPPROTO_IP), (#const IP_RECVTOS)) +#endif +#endif // HAVE_DECL_IPPROTO_IP +#if HAVE_DECL_IPPROTO_IP +#if defined(IP_RECVPKTINFO) + Just RecvIPv4PktInfo -> Just ((#const IPPROTO_IP), (#const IP_RECVPKTINFO)) +#elif defined(IP_PKTINFO) + Just RecvIPv4PktInfo -> Just ((#const IPPROTO_IP), (#const IP_PKTINFO)) +#endif +#endif // HAVE_DECL_IPPROTO_IP +#if HAVE_DECL_IPPROTO_IPV6 +#ifdef IPV6_RECVHOPLIMIT + Just RecvIPv6HopLimit -> Just ((#const IPPROTO_IPV6), (#const IPV6_RECVHOPLIMIT)) +#endif +#endif // HAVE_DECL_IPPROTO_IPV6 +#if HAVE_DECL_IPPROTO_IPV6 +#ifdef IPV6_RECVTCLASS + Just RecvIPv6TClass -> Just ((#const IPPROTO_IPV6), (#const IPV6_RECVTCLASS)) +#endif +#endif // HAVE_DECL_IPPROTO_IPV6 +#if HAVE_DECL_IPPROTO_IPV6 +#ifdef IPV6_RECVPKTINFO + Just RecvIPv6PktInfo -> Just ((#const IPPROTO_IPV6), (#const IPV6_RECVPKTINFO)) +#elif defined(IPV6_PKTINFO) + Just RecvIPv6PktInfo -> Just ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) +#endif #endif // HAVE_DECL_IPPROTO_IPV6 Just (CustomSockOpt opt) -> Just opt _ -> Nothing diff --git a/include/HsNet.h b/include/HsNet.h index 73fb3ba3..33c3954c 100644 --- a/include/HsNet.h +++ b/include/HsNet.h @@ -20,6 +20,7 @@ #endif #define _GNU_SOURCE 1 /* for struct ucred on Linux */ +#define __APPLE_USE_RFC_3542 1 /* for IPV6_RECVPKTINFO */ #ifdef _WIN32 # include From 1fb594e7b476c7c4cb57b947120caf0f975940ff Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 29 Nov 2019 23:18:59 +0900 Subject: [PATCH 02/48] extending MsgHdr. --- Network/Socket/ByteString/IO.hsc | 12 +++++++++--- Network/Socket/ByteString/MsgHdr.hsc | 11 ++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index a5a3bd7e..05cf9f87 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -158,9 +158,15 @@ sendManyTo s cs addr = do sendManyToInner = withSockAddr addr $ \addrPtr addrSize -> withIOVec cs $ \(iovsPtr, iovsLen) -> do - let msgHdr = MsgHdr - addrPtr (fromIntegral addrSize) - iovsPtr (fromIntegral iovsLen) + let msgHdr = MsgHdr { + msgName = addrPtr + , msgNameLen = fromIntegral addrSize + , msgIov = iovsPtr + , msgIovLen = fromIntegral iovsLen + , msgCtrl = nullPtr + , msgCtrlLen = 0 + , msgFlags = 0 + } withFdSocket s $ \fd -> with msgHdr $ \msgHdrPtr -> throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendManyTo" $ diff --git a/Network/Socket/ByteString/MsgHdr.hsc b/Network/Socket/ByteString/MsgHdr.hsc index 6ccd9d24..ad42c226 100644 --- a/Network/Socket/ByteString/MsgHdr.hsc +++ b/Network/Socket/ByteString/MsgHdr.hsc @@ -21,6 +21,9 @@ data MsgHdr = MsgHdr , msgNameLen :: !CUInt , msgIov :: !(Ptr IOVec) , msgIovLen :: !CSize + , msgCtrl :: !(Ptr Word8) + , msgCtrlLen :: !CInt + , msgFlags :: !CInt } instance Storable MsgHdr where @@ -32,7 +35,10 @@ instance Storable MsgHdr where nameLen <- (#peek struct msghdr, msg_namelen) p iov <- (#peek struct msghdr, msg_iov) p iovLen <- (#peek struct msghdr, msg_iovlen) p - return $ MsgHdr name nameLen iov iovLen + ctrl <- (#peek struct msghdr, msg_control) p + ctrlLen <- (#peek struct msghdr, msg_controllen) p + flags <- (#peek struct msghdr, msg_flags) p + return $ MsgHdr name nameLen iov iovLen ctrl ctrlLen flags poke p mh = do -- We need to zero the msg_control, msg_controllen, and msg_flags @@ -43,3 +49,6 @@ instance Storable MsgHdr where (#poke struct msghdr, msg_namelen) p (msgNameLen mh) (#poke struct msghdr, msg_iov) p (msgIov mh) (#poke struct msghdr, msg_iovlen) p (msgIovLen mh) + (#poke struct msghdr, msg_control) p (msgCtrl mh) + (#poke struct msghdr, msg_controllen) p (msgCtrlLen mh) + (#poke struct msghdr, msg_flags) p (msgFlags mh) From 3bca80bfd0fb277a4f9711f55fe497b6a89494cf Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 29 Nov 2019 23:38:57 +0900 Subject: [PATCH 03/48] implementing MsgFlag. --- Network/Socket/ByteString/Flag.hsc | 59 ++++++++++++++++++++++++++++++ network.cabal | 1 + 2 files changed, 60 insertions(+) create mode 100644 Network/Socket/ByteString/Flag.hsc diff --git a/Network/Socket/ByteString/Flag.hsc b/Network/Socket/ByteString/Flag.hsc new file mode 100644 index 00000000..175f1a60 --- /dev/null +++ b/Network/Socket/ByteString/Flag.hsc @@ -0,0 +1,59 @@ +#include "HsNet.h" + +module Network.Socket.ByteString.Flag where + +import Network.Socket.Imports +import Network.Socket.Info + +-- | Message flags. +data MsgFlag = + MSG_OOB -- ^ Send or receive OOB(out-of-bound) data. + | MSG_DONTROUTE -- ^ Bypass routing table lookup. + | MSG_PEEK -- ^ Peek at incoming message without removing it from the queue. + | MSG_EOR -- ^ End of record. + | MSG_TRUNC -- ^ Received data is truncated. More data exist. + | MSG_CTRUNC -- ^ Received control message is truncated. More control message exist. + | MSG_WAITALL -- ^ Wait until the requested number of bytes have been read. + deriving (Eq, Show) + +msgFlagMapping :: [(MsgFlag, CInt)] +msgFlagMapping = [ +#ifdef MSG_OOB + (MSG_OOB, #const MSG_OOB) +#else + (MSG_OOB, 0) +#endif +#ifdef MSG_DONTROUTE + , (MSG_DONTROUTE, #const MSG_DONTROUTE) +#else + , (MSG_DONTROUTE, 0) +#endif +#ifdef MSG_PEEK + , (MSG_PEEK, #const MSG_PEEK) +#else + , (MSG_PEEK, 0) +#endif +#ifdef MSG_EOR + , (MSG_EOR, #const MSG_EOR) +#else + , (MSG_EOR, 0) +#endif +#ifdef MSG_TRUNC + , (MSG_TRUNC, #const MSG_TRUNC) +#else + , (MSG_TRUNC, 0) +#endif +#ifdef MSG_CTRUNC + , (MSG_CTRUNC, #const MSG_CTRUNC) +#else + , (MSG_CTRUNC, 0) +#endif +#ifdef MSG_WAITALL + , (MSG_WAITALL, #const MSG_WAITALL) +#else + , (MSG_WAITALL, 0) +#endif + ] + +msgFlagImplemented :: MsgFlag -> Bool +msgFlagImplemented f = packBits msgFlagMapping [f] /= 0 diff --git a/network.cabal b/network.cabal index 4d7afe36..9481fedc 100644 --- a/network.cabal +++ b/network.cabal @@ -96,6 +96,7 @@ library -- Add some platform specific stuff if !os(windows) other-modules: + Network.Socket.ByteString.Flag Network.Socket.ByteString.IOVec Network.Socket.ByteString.Lazy.Posix Network.Socket.ByteString.MsgHdr From 5e79cbf706d63ab059a02510ed67e7fc6e4cbb3f Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 4 Dec 2019 10:07:58 +0900 Subject: [PATCH 04/48] implementing sendMsg and recvMsg. --- Network/Socket/ByteString.hs | 5 ++ Network/Socket/ByteString/IO.hsc | 85 +++++++++++++++++++++++++++ Network/Socket/ByteString/Internal.hs | 4 ++ 3 files changed, 94 insertions(+) diff --git a/Network/Socket/ByteString.hs b/Network/Socket/ByteString.hs index 1adefe61..8f0a7872 100644 --- a/Network/Socket/ByteString.hs +++ b/Network/Socket/ByteString.hs @@ -33,6 +33,11 @@ module Network.Socket.ByteString -- * Receive data from a socket , recv , recvFrom + + -- * Advanced send and recv + , sendMsg + , recvMsg + , MsgFlag(..) ) where import Data.ByteString (ByteString) diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 05cf9f87..1382b077 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -29,6 +29,11 @@ module Network.Socket.ByteString.IO , recv , recvFrom , waitWhen0 + + -- * Advanced send and recv + , sendMsg + , recvMsg + , MsgFlag(..) ) where import Control.Concurrent (threadWaitWrite, rtsSupportsBoundThreads) @@ -44,12 +49,15 @@ import Network.Socket.Imports import Network.Socket.Types #if !defined(mingw32_HOST_OS) +import Data.ByteString.Internal (create, ByteString(..)) import Foreign.Marshal.Array (allocaArray) import Foreign.Marshal.Utils (with) import Network.Socket.Internal +import Network.Socket.Info (packBits, unpackBits) import Network.Socket.ByteString.IOVec (IOVec(..)) import Network.Socket.ByteString.MsgHdr (MsgHdr(..)) +import Network.Socket.ByteString.Flag #endif -- ---------------------------------------------------------------------------- @@ -243,3 +251,80 @@ withIOVec cs f = unsafeUseAsCStringLen s $ \(sPtr, sLen) -> poke ptr $ IOVec sPtr (fromIntegral sLen) #endif + +-- | Send data from the socket using sendmsg(2). +sendMsg :: Socket -- ^ Socket + -> SockAddr -- ^ Destination address + -> [ByteString] -- ^ Data to be sent + -> [MsgFlag] -- ^ Message flags + -> IO Int -- ^ The length actually sent +sendMsg _ _ [] _ = return 0 +sendMsg s addr bss flags = do + sz <- withSockAddr addr $ \addrPtr addrSize -> + withIOVec bss $ \(iovsPtr, iovsLen) -> do + let msgHdr = MsgHdr { + msgName = addrPtr + , msgNameLen = fromIntegral addrSize + , msgIov = iovsPtr + , msgIovLen = fromIntegral iovsLen + , msgCtrl = nullPtr + , msgCtrlLen = 0 + , msgFlags = 0 + } + cflags = packBits msgFlagMapping flags + withFdSocket s $ \fd -> + with msgHdr $ \msgHdrPtr -> + throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMsg" $ + c_sendmsg fd msgHdrPtr cflags + return $ fromIntegral sz + +-- | Receive data from the socket using recvmsg(2). +-- The receive buffers are created according to the second argument. +-- If the length of received data is less than the total of +-- the second argument, the buffers are truncated properly. +-- So, only the received data can be seen. +recvMsg :: Socket -- ^ Socket + -> [Int] -- ^ a list of length of data to be received + -- If the total length is not large enough, + -- 'MSG_TRUNC' is returned + -> [MsgFlag] -- ^ Message flags + -> IO (SockAddr, [ByteString], [MsgFlag]) -- ^ Source address, received data and received message flags +recvMsg _ [] _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") +recvMsg s sizs flags = do + bss <- mapM newBS sizs + withNewSocketAddress $ \addrPtr addrSize -> + withIOVec bss $ \(iovsPtr, iovsLen) -> do + let msgHdr = MsgHdr { + msgName = addrPtr + , msgNameLen = fromIntegral addrSize + , msgIov = iovsPtr + , msgIovLen = fromIntegral iovsLen + , msgCtrl = nullPtr + , msgCtrlLen = 0 + , msgFlags = 0 + } + cflags = packBits msgFlagMapping flags + withFdSocket s $ \fd -> do + with msgHdr $ \msgHdrPtr -> do + len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.ByteString.recvmg" (c_recvmsg fd msgHdrPtr cflags) + let total = sum sizs + let bss' = case len `compare` total of + EQ -> bss + LT -> trunc bss len + GT -> error "recvMsg" -- never reach + sockaddr <- peekSocketAddress addrPtr + hdr <- peek msgHdrPtr + let flags' = unpackBits msgFlagMapping $ msgFlags hdr + return (sockaddr, bss', flags') + +newBS :: Int -> IO ByteString +newBS n = create n $ \ptr -> zeroMemory ptr (fromIntegral n) + +trunc :: [ByteString] -> Int -> [ByteString] +trunc bss0 siz0 = loop bss0 siz0 id + where + -- off is always 0 + loop (bs@(PS buf off len):bss) siz build + | siz >= len = loop bss (siz - len) (build . (bs :)) + | otherwise = build [PS buf off siz] + loop _ _ build = build [] diff --git a/Network/Socket/ByteString/Internal.hs b/Network/Socket/ByteString/Internal.hs index 7799f60d..3d6aea48 100644 --- a/Network/Socket/ByteString/Internal.hs +++ b/Network/Socket/ByteString/Internal.hs @@ -15,6 +15,7 @@ module Network.Socket.ByteString.Internal #if !defined(mingw32_HOST_OS) , c_writev , c_sendmsg + , c_recvmsg #endif ) where @@ -40,4 +41,7 @@ foreign import ccall unsafe "writev" foreign import ccall unsafe "sendmsg" c_sendmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CSsize + +foreign import ccall unsafe "recvmsg" + c_recvmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CSsize #endif From 69109a529b2d34f4218a0fba9811ae3b4343638c Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Mon, 2 Dec 2019 16:24:32 +0900 Subject: [PATCH 05/48] implementing Cmsg. --- Network/Socket/ByteString.hs | 1 + Network/Socket/ByteString/Cmsg.hsc | 103 +++++++++++++++++++++++++++++ Network/Socket/ByteString/IO.hsc | 53 +++++++++------ cbits/cmsg.c | 22 ++++++ include/HsNet.h | 15 +++++ network.cabal | 5 +- 6 files changed, 175 insertions(+), 24 deletions(-) create mode 100644 Network/Socket/ByteString/Cmsg.hsc create mode 100644 cbits/cmsg.c diff --git a/Network/Socket/ByteString.hs b/Network/Socket/ByteString.hs index 8f0a7872..d723b226 100644 --- a/Network/Socket/ByteString.hs +++ b/Network/Socket/ByteString.hs @@ -38,6 +38,7 @@ module Network.Socket.ByteString , sendMsg , recvMsg , MsgFlag(..) + , Cmsg(..) ) where import Data.ByteString (ByteString) diff --git a/Network/Socket/ByteString/Cmsg.hsc b/Network/Socket/ByteString/Cmsg.hsc new file mode 100644 index 00000000..0cffc94e --- /dev/null +++ b/Network/Socket/ByteString/Cmsg.hsc @@ -0,0 +1,103 @@ +{-# OPTIONS_GHC -funbox-strict-fields #-} + +#include "HsNet.h" + +module Network.Socket.ByteString.Cmsg ( + Cmsg(..) + , withCmsgs + , parseCmsgs + ) where + +#include +#include + +import Foreign.Marshal.Alloc (allocaBytes) +import Foreign.ForeignPtr +import qualified Data.ByteString as B +import Data.ByteString.Internal + +import Network.Socket.ByteString.MsgHdr +import Network.Socket.Imports +import Network.Socket.Types + +-- | Control message including a pair of level and type. +data Cmsg = Cmsg { + cmsgLevelType :: (CInt,CInt) + , cmsgBody :: ByteString + } deriving (Eq, Show) + +data CmsgHdr = CmsgHdr CInt CInt CInt deriving (Eq, Show) + +instance Storable CmsgHdr where + sizeOf _ = (#size struct cmsghdr) + alignment _ = alignment (undefined :: CInt) + + peek p = do + len <- (#peek struct cmsghdr, cmsg_len) p + lvl <- (#peek struct cmsghdr, cmsg_level) p + typ <- (#peek struct cmsghdr, cmsg_type) p + return $ CmsgHdr len lvl typ + + poke p (CmsgHdr len lvl typ) = do + zeroMemory p (#size struct cmsghdr) + (#poke struct cmsghdr, cmsg_len) p len + (#poke struct cmsghdr, cmsg_level) p lvl + (#poke struct cmsghdr, cmsg_type) p typ + +withCmsgs :: [Cmsg] -> (Ptr CmsgHdr -> Int -> IO a) -> IO a +withCmsgs cmsgs0 action + | total == 0 = action nullPtr 0 + | otherwise = allocaBytes total $ \ctrlPtr -> do + loop ctrlPtr cmsgs0 spaces + action ctrlPtr total + where + loop ctrlPtr (cmsg:cmsgs) (s:ss) = do + encodeCmsg ctrlPtr cmsg + let nextPtr = ctrlPtr `plusPtr` s + loop nextPtr cmsgs ss + loop _ _ _ = return () + cmsg_space = fromIntegral . c_cmsg_space . fromIntegral + spaces = map (cmsg_space . B.length . cmsgBody) cmsgs0 + total = sum spaces + +encodeCmsg :: Ptr CmsgHdr -> Cmsg -> IO () +encodeCmsg ctrlPtr (Cmsg (lvl,typ) (PS fptr off len)) = do + poke ctrlPtr $ CmsgHdr (c_cmsg_len (fromIntegral len)) lvl typ + withForeignPtr fptr $ \src0 -> do + let src = src0 `plusPtr` off + dst <- c_cmsg_data ctrlPtr + memcpy dst src len + +parseCmsgs :: Ptr MsgHdr -> IO [Cmsg] +parseCmsgs msgptr = do + ptr <- c_cmsg_firsthdr msgptr + loop ptr id + where + loop ptr build + | ptr == nullPtr = return $ build [] + | otherwise = do + cmsg <- decodeCmsg ptr + nextPtr <- c_cmsg_nxthdr msgptr ptr + loop nextPtr (build . (cmsg :)) + +decodeCmsg :: Ptr CmsgHdr -> IO Cmsg +decodeCmsg ptr = do + CmsgHdr len lvl typ <- peek ptr + src <- c_cmsg_data ptr + let siz = fromIntegral len - (src `minusPtr` ptr) + Cmsg (lvl,typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz) + +foreign import ccall unsafe "cmsg_firsthdr" + c_cmsg_firsthdr :: Ptr MsgHdr -> IO (Ptr CmsgHdr) + +foreign import ccall unsafe "cmsg_nxthdr" + c_cmsg_nxthdr :: Ptr MsgHdr -> Ptr CmsgHdr -> IO (Ptr CmsgHdr) + +foreign import ccall unsafe "cmsg_data" + c_cmsg_data :: Ptr CmsgHdr -> IO (Ptr Word8) + +foreign import ccall unsafe "cmsg_space" + c_cmsg_space :: CInt -> CInt + +foreign import ccall unsafe "cmsg_len" + c_cmsg_len :: CInt -> CInt diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 1382b077..fa45e637 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -34,6 +34,7 @@ module Network.Socket.ByteString.IO , sendMsg , recvMsg , MsgFlag(..) + , Cmsg(..) ) where import Control.Concurrent (threadWaitWrite, rtsSupportsBoundThreads) @@ -58,6 +59,7 @@ import Network.Socket.Info (packBits, unpackBits) import Network.Socket.ByteString.IOVec (IOVec(..)) import Network.Socket.ByteString.MsgHdr (MsgHdr(..)) import Network.Socket.ByteString.Flag +import Network.Socket.ByteString.Cmsg #endif -- ---------------------------------------------------------------------------- @@ -256,26 +258,28 @@ withIOVec cs f = sendMsg :: Socket -- ^ Socket -> SockAddr -- ^ Destination address -> [ByteString] -- ^ Data to be sent + -> [Cmsg] -- ^ Control messages -> [MsgFlag] -- ^ Message flags -> IO Int -- ^ The length actually sent -sendMsg _ _ [] _ = return 0 -sendMsg s addr bss flags = do +sendMsg _ _ [] _ _ = return 0 +sendMsg s addr bss cmsgs flags = do sz <- withSockAddr addr $ \addrPtr addrSize -> withIOVec bss $ \(iovsPtr, iovsLen) -> do - let msgHdr = MsgHdr { - msgName = addrPtr - , msgNameLen = fromIntegral addrSize - , msgIov = iovsPtr - , msgIovLen = fromIntegral iovsLen - , msgCtrl = nullPtr - , msgCtrlLen = 0 - , msgFlags = 0 - } + withCmsgs cmsgs $ \ctrlPtr ctrlLen -> do + let msgHdr = MsgHdr { + msgName = addrPtr + , msgNameLen = fromIntegral addrSize + , msgIov = iovsPtr + , msgIovLen = fromIntegral iovsLen + , msgCtrl = castPtr ctrlPtr + , msgCtrlLen = fromIntegral ctrlLen + , msgFlags = 0 + } cflags = packBits msgFlagMapping flags - withFdSocket s $ \fd -> - with msgHdr $ \msgHdrPtr -> - throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMsg" $ - c_sendmsg fd msgHdrPtr cflags + withFdSocket s $ \fd -> + with msgHdr $ \msgHdrPtr -> + throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMsg" $ + c_sendmsg fd msgHdrPtr cflags return $ fromIntegral sz -- | Receive data from the socket using recvmsg(2). @@ -284,23 +288,27 @@ sendMsg s addr bss flags = do -- the second argument, the buffers are truncated properly. -- So, only the received data can be seen. recvMsg :: Socket -- ^ Socket - -> [Int] -- ^ a list of length of data to be received + -> [Int] -- ^ A list of length of data to be received -- If the total length is not large enough, -- 'MSG_TRUNC' is returned + -> Int -- ^ The buffer size for control messages. + -- If the length is not large enough, + -- 'MSG_CTRUNC' is returned -> [MsgFlag] -- ^ Message flags - -> IO (SockAddr, [ByteString], [MsgFlag]) -- ^ Source address, received data and received message flags -recvMsg _ [] _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") -recvMsg s sizs flags = do + -> IO (SockAddr, [ByteString], [Cmsg], [MsgFlag]) -- ^ Source address, received data, control messages and message flags +recvMsg _ [] _ _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") +recvMsg s sizs clen flags = do bss <- mapM newBS sizs withNewSocketAddress $ \addrPtr addrSize -> + allocaBytes clen $ \ctrlPtr -> withIOVec bss $ \(iovsPtr, iovsLen) -> do let msgHdr = MsgHdr { msgName = addrPtr , msgNameLen = fromIntegral addrSize , msgIov = iovsPtr , msgIovLen = fromIntegral iovsLen - , msgCtrl = nullPtr - , msgCtrlLen = 0 + , msgCtrl = castPtr ctrlPtr + , msgCtrlLen = fromIntegral clen , msgFlags = 0 } cflags = packBits msgFlagMapping flags @@ -314,8 +322,9 @@ recvMsg s sizs flags = do GT -> error "recvMsg" -- never reach sockaddr <- peekSocketAddress addrPtr hdr <- peek msgHdrPtr + cmsgs <- parseCmsgs msgHdrPtr let flags' = unpackBits msgFlagMapping $ msgFlags hdr - return (sockaddr, bss', flags') + return (sockaddr, bss', cmsgs, flags') newBS :: Int -> IO ByteString newBS n = create n $ \ptr -> zeroMemory ptr (fromIntegral n) diff --git a/cbits/cmsg.c b/cbits/cmsg.c new file mode 100644 index 00000000..71f4d4ef --- /dev/null +++ b/cbits/cmsg.c @@ -0,0 +1,22 @@ +#include "HsNet.h" +#include + +struct cmsghdr *cmsg_firsthdr(struct msghdr *mhdr) { + return (CMSG_FIRSTHDR(mhdr)); +} + +struct cmsghdr *cmsg_nxthdr(struct msghdr *mhdr, struct cmsghdr *cmsg) { + return (CMSG_NXTHDR(mhdr, cmsg)); +} + +unsigned char *cmsg_data(struct cmsghdr *cmsg) { + return (CMSG_DATA(cmsg)); +} + +int cmsg_space(int l) { + return (CMSG_SPACE(l)); +} + +int cmsg_len(int l) { + return (CMSG_LEN(l)); +} diff --git a/include/HsNet.h b/include/HsNet.h index 33c3954c..f848c0a1 100644 --- a/include/HsNet.h +++ b/include/HsNet.h @@ -85,6 +85,21 @@ sendFd(int sock, int outfd); extern int recvFd(int sock); + +extern struct cmsghdr * +cmsg_firsthdr(struct msghdr *mhdr); + +extern struct cmsghdr * +cmsg_nxthdr(struct msghdr *mhdr, struct cmsghdr *cmsg); + +extern unsigned char * +cmsg_data(struct cmsghdr *cmsg); + +extern int +cmsg_space(int l); + +extern int +cmsg_len(int l); #endif /* _WIN32 */ INLINE char * diff --git a/network.cabal b/network.cabal index 9481fedc..3a90fbce 100644 --- a/network.cabal +++ b/network.cabal @@ -44,7 +44,7 @@ extra-source-files: include/HsNetworkConfig.h.in include/HsNet.h include/HsNetDef.h -- C sources only used on some systems cbits/ancilData.c cbits/asyncAccept.c cbits/initWinSock.c - cbits/winSockErr.c + cbits/winSockErr.c cbits/cmsg.c homepage: https://github.com/haskell/network bug-reports: https://github.com/haskell/network/issues tested-with: GHC == 7.8.4 @@ -100,7 +100,8 @@ library Network.Socket.ByteString.IOVec Network.Socket.ByteString.Lazy.Posix Network.Socket.ByteString.MsgHdr - c-sources: cbits/ancilData.c + Network.Socket.ByteString.Cmsg + c-sources: cbits/ancilData.c cbits/cmsg.c if os(solaris) extra-libraries: nsl, socket From 579b94597e9d4f6de201927701787b1977d5b217 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 3 Dec 2019 10:40:08 +0900 Subject: [PATCH 06/48] implementing Auxiliary. --- Network/Socket/ByteString.hs | 18 ++ Network/Socket/ByteString/Auxiliary.hsc | 217 ++++++++++++++++++++++++ Network/Socket/Types.hsc | 1 + network.cabal | 1 + 4 files changed, 237 insertions(+) create mode 100644 Network/Socket/ByteString/Auxiliary.hsc diff --git a/Network/Socket/ByteString.hs b/Network/Socket/ByteString.hs index d723b226..5f7c0040 100644 --- a/Network/Socket/ByteString.hs +++ b/Network/Socket/ByteString.hs @@ -39,6 +39,23 @@ module Network.Socket.ByteString , recvMsg , MsgFlag(..) , Cmsg(..) + -- ** Auxiliary data + , Auxiliary(..) + , AuxiliaryID + , auxiliaryIPv4TTL + , auxiliaryIPv6HopLimit + , auxiliaryIPv4TOS + , auxiliaryIPv6TClass + , auxiliaryIPv4PktInfo + , auxiliaryIPv6PktInfo + , lookupAuxiliary + -- ** Types + , IPv4TTL(..) + , IPv6HopLimit(..) + , IPv4TOS(..) + , IPv6TClass(..) + , IPv4PktInfo(..) + , IPv6PktInfo(..) ) where import Data.ByteString (ByteString) @@ -46,6 +63,7 @@ import Data.ByteString (ByteString) import Network.Socket.ByteString.IO hiding (sendTo, sendAllTo, recvFrom) import qualified Network.Socket.ByteString.IO as G import Network.Socket.Types +import Network.Socket.ByteString.Auxiliary -- ---------------------------------------------------------------------------- -- ** Vectored I/O diff --git a/Network/Socket/ByteString/Auxiliary.hsc b/Network/Socket/ByteString/Auxiliary.hsc new file mode 100644 index 00000000..283a6208 --- /dev/null +++ b/Network/Socket/ByteString/Auxiliary.hsc @@ -0,0 +1,217 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Network.Socket.ByteString.Auxiliary where + +#include "HsNet.h" + +#include +#include + +import Data.ByteString.Internal +import Foreign.ForeignPtr +import System.IO.Unsafe (unsafeDupablePerformIO) + +import Network.Socket.ByteString.Cmsg +import Network.Socket.Imports +import Network.Socket.Types + +---------------------------------------------------------------- + +-- | Identifier of auxiliary data. A pair of level and type. +type AuxiliaryID = (CInt, CInt) + +-- | The identifier for 'IPv4TTL'. +auxiliaryIPv4TTL :: AuxiliaryID +#if defined(darwin_HOST_OS) +auxiliaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_RECVTTL)) +#else +auxiliaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_TTL)) +#endif + +-- | The identifier for 'IPv6HopLimit'. +auxiliaryIPv6HopLimit :: AuxiliaryID +auxiliaryIPv6HopLimit = ((#const IPPROTO_IPV6), (#const IPV6_HOPLIMIT)) + +-- | The identifier for 'IPv4TOS'. +auxiliaryIPv4TOS :: AuxiliaryID +#if defined(darwin_HOST_OS) +auxiliaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_RECVTOS)) +#else +auxiliaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_TOS)) +#endif + +-- | The identifier for 'IPv6TClass'. +auxiliaryIPv6TClass :: AuxiliaryID +auxiliaryIPv6TClass = ((#const IPPROTO_IPV6), (#const IPV6_TCLASS)) + +-- | The identifier for 'IPv4PktInfo'. +auxiliaryIPv4PktInfo :: AuxiliaryID +auxiliaryIPv4PktInfo = ((#const IPPROTO_IP), (#const IP_PKTINFO)) + +-- | The identifier for 'IPv6PktInfo'. +auxiliaryIPv6PktInfo :: AuxiliaryID +auxiliaryIPv6PktInfo = ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) + +---------------------------------------------------------------- + +-- | Looking up auxiliary data. The following shows an example usage: +-- +-- > (lookupAuxiliary auxiliaryIPv4TOS cmsgs >>= auxiliaryDecode) :: Maybe IPv4TOS +lookupAuxiliary :: AuxiliaryID -> [Cmsg] -> Maybe Cmsg +lookupAuxiliary _ [] = Nothing +lookupAuxiliary aid (cmsg@(Cmsg cid _):cmsgs) + | aid == cid = Just cmsg + | otherwise = lookupAuxiliary aid cmsgs + +---------------------------------------------------------------- + +-- | A class to encode and decode auxiliary data. +class Auxiliary a where + auxiliaryEncode :: a -> Cmsg + auxiliaryDecode :: Cmsg -> Maybe a + +---------------------------------------------------------------- + +packCInt :: CInt -> ByteString +packCInt n = unsafeDupablePerformIO $ create siz $ \p0 -> do + let p = castPtr p0 :: Ptr CInt + poke p n + where + siz = (#size int) + +unpackCInt :: ByteString -> Maybe CInt +unpackCInt (PS fptr off len) + | len < siz = Nothing + | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do + let p = castPtr (p0 `plusPtr` off) :: Ptr CInt + Just <$> peek p + where + siz = (#size int) + +packCChar :: CChar -> ByteString +packCChar n = unsafeDupablePerformIO $ create siz $ \p0 -> do + let p = castPtr p0 :: Ptr CChar + poke p n + where + siz = (#size char) + +unpackCChar :: ByteString -> Maybe CChar +unpackCChar (PS fptr off len) + | len < siz = Nothing + | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do + let p = castPtr (p0 `plusPtr` off) :: Ptr CChar + Just <$> peek p + where + siz = (#size char) + +---------------------------------------------------------------- + +-- | Time to live of IPv4. +newtype IPv4TTL = IPv4TTL Int deriving (Eq, Show) + +instance Auxiliary IPv4TTL where +#if defined(darwin_HOST_OS) + auxiliaryEncode (IPv4TTL ttl) = Cmsg auxiliaryIPv4TTL $ packCChar $ fromIntegral ttl +#else + auxiliaryEncode (IPv4TTL ttl) = Cmsg auxiliaryIPv4TTL $ packCInt $ fromIntegral ttl +#endif +#if defined(darwin_HOST_OS) + auxiliaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCChar bs +#else + auxiliaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCInt bs +#endif + +---------------------------------------------------------------- + +-- | Hop limit of IPv6. +newtype IPv6HopLimit = IPv6HopLimit Int deriving (Eq, Show) + +instance Auxiliary IPv6HopLimit where + auxiliaryEncode (IPv6HopLimit ttl) = Cmsg auxiliaryIPv6HopLimit $ packCInt $ fromIntegral ttl + auxiliaryDecode (Cmsg _ bs) = IPv6HopLimit . fromIntegral <$> unpackCInt bs + +---------------------------------------------------------------- + +-- | TOS of IPv4. +newtype IPv4TOS = IPv4TOS Int deriving (Eq, Show) + +instance Auxiliary IPv4TOS where + auxiliaryEncode (IPv4TOS ttl) = Cmsg auxiliaryIPv4TOS $ packCChar $ fromIntegral ttl + auxiliaryDecode (Cmsg _ bs) = IPv4TOS . fromIntegral <$> unpackCChar bs + +---------------------------------------------------------------- + +-- | Traffic class of IPv6. +newtype IPv6TClass = IPv6TClass Int deriving (Eq, Show) + +instance Auxiliary IPv6TClass where + auxiliaryEncode (IPv6TClass ttl) = Cmsg auxiliaryIPv6TClass $ packCInt $ fromIntegral ttl + auxiliaryDecode (Cmsg _ bs) = IPv6TClass . fromIntegral <$> unpackCInt bs + +---------------------------------------------------------------- + +-- | Network interface ID and local IPv4 address. +data IPv4PktInfo = IPv4PktInfo Int HostAddress deriving (Eq) + +instance Show IPv4PktInfo where + show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) + +instance Auxiliary IPv4PktInfo where + auxiliaryEncode pktinfo = Cmsg auxiliaryIPv4PktInfo $ packIPv4PktInfo pktinfo + auxiliaryDecode (Cmsg _ bs) = unpackIPv4PktInfo bs + +{-# NOINLINE packIPv4PktInfo #-} +packIPv4PktInfo :: IPv4PktInfo -> ByteString +packIPv4PktInfo (IPv4PktInfo n ha) = unsafeDupablePerformIO $ + create siz $ \p -> do + (#poke struct in_pktinfo, ipi_ifindex) p (fromIntegral n :: CInt) + (#poke struct in_pktinfo, ipi_spec_dst) p (0 :: CInt) + (#poke struct in_pktinfo, ipi_addr) p ha + where + siz = (#size struct in_pktinfo) + +{-# NOINLINE unpackIPv4PktInfo #-} +unpackIPv4PktInfo :: ByteString -> Maybe IPv4PktInfo +unpackIPv4PktInfo (PS fptr off len) + | len < siz = Nothing + | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do + let p = p0 `plusPtr` off + n <- (#peek struct in_pktinfo, ipi_ifindex) p + ha <- (#peek struct in_pktinfo, ipi_addr) p + return $ Just $ IPv4PktInfo n ha + where + siz = (#size struct in_pktinfo) + +---------------------------------------------------------------- + +-- | Network interface ID and local IPv4 address. +data IPv6PktInfo = IPv6PktInfo Int HostAddress6 deriving (Eq) + +instance Show IPv6PktInfo where + show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6) + +instance Auxiliary IPv6PktInfo where + auxiliaryEncode pktinfo = Cmsg auxiliaryIPv6PktInfo $ packIPv6PktInfo pktinfo + auxiliaryDecode (Cmsg _ bs) = unpackIPv6PktInfo bs + +{-# NOINLINE packIPv6PktInfo #-} +packIPv6PktInfo :: IPv6PktInfo -> ByteString +packIPv6PktInfo (IPv6PktInfo n ha6) = unsafeDupablePerformIO $ + create siz $ \p -> do + (#poke struct in6_pktinfo, ipi6_ifindex) p (fromIntegral n :: CInt) + (#poke struct in6_pktinfo, ipi6_addr) p (In6Addr ha6) + where + siz = (#size struct in6_pktinfo) + +{-# NOINLINE unpackIPv6PktInfo #-} +unpackIPv6PktInfo :: ByteString -> Maybe IPv6PktInfo +unpackIPv6PktInfo (PS fptr off len) + | len < siz = Nothing + | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do + let p = p0 `plusPtr` off + In6Addr ha6 <- (#peek struct in6_pktinfo, ipi6_addr) p + n :: CInt <- (#peek struct in6_pktinfo, ipi6_ifindex) p + return $ Just $ IPv6PktInfo (fromIntegral n) ha6 + where + siz = (#size struct in6_pktinfo) diff --git a/Network/Socket/Types.hsc b/Network/Socket/Types.hsc index 1ac7607d..c854f823 100644 --- a/Network/Socket/Types.hsc +++ b/Network/Socket/Types.hsc @@ -66,6 +66,7 @@ module Network.Socket.Types ( , zeroMemory , htonl , ntohl + , In6Addr(..) ) where import Data.IORef (IORef, newIORef, readIORef, atomicModifyIORef', mkWeakIORef) diff --git a/network.cabal b/network.cabal index 3a90fbce..533bc283 100644 --- a/network.cabal +++ b/network.cabal @@ -101,6 +101,7 @@ library Network.Socket.ByteString.Lazy.Posix Network.Socket.ByteString.MsgHdr Network.Socket.ByteString.Cmsg + Network.Socket.ByteString.Auxiliary c-sources: cbits/ancilData.c cbits/cmsg.c if os(solaris) From 54e0892003195259163a742ced191cc2ee5cb560 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 4 Dec 2019 10:57:55 +0900 Subject: [PATCH 07/48] making recvMsg usable for TCP. --- Network/Socket/ByteString/IO.hsc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index fa45e637..7fde334c 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -54,12 +54,14 @@ import Data.ByteString.Internal (create, ByteString(..)) import Foreign.Marshal.Array (allocaArray) import Foreign.Marshal.Utils (with) import Network.Socket.Internal +import System.IO.Error (catchIOError) -import Network.Socket.Info (packBits, unpackBits) +import Network.Socket.ByteString.Cmsg +import Network.Socket.ByteString.Flag import Network.Socket.ByteString.IOVec (IOVec(..)) import Network.Socket.ByteString.MsgHdr (MsgHdr(..)) -import Network.Socket.ByteString.Flag -import Network.Socket.ByteString.Cmsg +import Network.Socket.Info (packBits, unpackBits) +import Network.Socket.Name (getPeerName) #endif -- ---------------------------------------------------------------------------- @@ -320,7 +322,7 @@ recvMsg s sizs clen flags = do EQ -> bss LT -> trunc bss len GT -> error "recvMsg" -- never reach - sockaddr <- peekSocketAddress addrPtr + sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s hdr <- peek msgHdrPtr cmsgs <- parseCmsgs msgHdrPtr let flags' = unpackBits msgFlagMapping $ msgFlags hdr From 9d5a69c54f99cc7398c6449d7e813c5cfd684ca1 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 4 Dec 2019 12:05:32 +0900 Subject: [PATCH 08/48] adding tests. --- tests/Network/Socket/ByteStringSpec.hs | 112 +++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/Network/Socket/ByteStringSpec.hs b/tests/Network/Socket/ByteStringSpec.hs index 5a26eff5..a600ac87 100644 --- a/tests/Network/Socket/ByteStringSpec.hs +++ b/tests/Network/Socket/ByteStringSpec.hs @@ -174,3 +174,115 @@ spec = do seg1 `shouldBe` S.empty client sock = shutdown sock ShutdownSend tcpTest client server + + describe "sendMsg" $ do + it "works well" $ do + let server sock = recv sock 1024 `shouldReturn` S.append seg1 seg2 + client sock addr = sendMsg sock addr [seg1, seg2] [] [] + + seg1 = C.pack "This is a " + seg2 = C.pack "test message." + udpTest client server + + it "throws when closed" $ do + let server _ = return () + client sock addr = do + close sock + sendMsg sock addr [seg1, seg2] [] [] `shouldThrow` anyException + + seg1 = C.pack "This is a " + seg2 = C.pack "test message." + udpTest client server + + describe "recvMsg" $ do + it "works well" $ do + let server sock = do + (_, msgs, cmsgs, flags) <- recvMsg sock [1024] 0 [] + S.concat msgs `shouldBe` seg + cmsgs `shouldBe` [] + flags `shouldBe` [] + client sock addr = sendTo sock seg addr + + seg = C.pack "This is a test message" + udpTest client server + + it "receives message fragments" $ do + let server sock = do + (_, msgs, _, _) <- recvMsg sock [1,2,3,4] 0 [] + S.concat msgs `shouldBe` S.take 10 seg + client sock addr = sendTo sock seg addr + + seg = C.pack "This is a test message" + udpTest client server + + it "receives message fragments with truncation" $ do + let server sock = do + (_, msgs, _, _) <- recvMsg sock [10,10,10,10] 0 [] + msgs `shouldBe` ["0123456789", "0123456789", "012345"] + client sock addr = sendTo sock seg addr + + seg = C.pack "01234567890123456789012345" + udpTest client server + + it "receives truncated flag" $ do + let server sock = do + (_, _, _, flags) <- recvMsg sock [S.length seg - 2] 0 [] + flags `shouldContain` [MSG_TRUNC] + client sock addr = sendTo sock seg addr + + seg = C.pack "This is a test message" + udpTest client server + + it "peek" $ do + let server sock = do + (_, msgs, _, _flags) <- recvMsg sock [1024] 0 [MSG_PEEK] + -- flags `shouldContain` [MSG_PEEK] -- Mac only + (_, msgs', _, _) <- recvMsg sock [1024] 0 [] + msgs `shouldBe` msgs' + client sock addr = sendTo sock seg addr + + seg = C.pack "This is a test message" + udpTest client server + + it "receives control messages for IPv4" $ do + let server sock = do + setSocketOption sock RecvIPv4TTL 1 + setSocketOption sock RecvIPv4TOS 1 + setSocketOption sock RecvIPv4PktInfo 1 + (_, _, cmsgs, _) <- recvMsg sock [1024] 128 [] + + ((lookupAuxiliary auxiliaryIPv4TTL cmsgs >>= auxiliaryDecode) :: Maybe IPv4TTL) `shouldNotBe` Nothing + ((lookupAuxiliary auxiliaryIPv4TOS cmsgs >>= auxiliaryDecode) :: Maybe IPv4TOS) `shouldNotBe` Nothing + ((lookupAuxiliary auxiliaryIPv4PktInfo cmsgs >>= auxiliaryDecode) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing + client sock addr = sendTo sock seg addr + + seg = C.pack "This is a test message" + udpTest client server + + it "receives control messages for IPv6" $ do + let server sock = do + setSocketOption sock RecvIPv6HopLimit 1 + setSocketOption sock RecvIPv6TClass 1 + setSocketOption sock RecvIPv6PktInfo 1 + (_, _, cmsgs, _) <- recvMsg sock [1024] 128 [] + + ((lookupAuxiliary auxiliaryIPv6HopLimit cmsgs >>= auxiliaryDecode) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing + ((lookupAuxiliary auxiliaryIPv6TClass cmsgs >>= auxiliaryDecode) :: Maybe IPv6TClass) `shouldNotBe` Nothing + ((lookupAuxiliary auxiliaryIPv6PktInfo cmsgs >>= auxiliaryDecode) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing + client sock addr = sendTo sock seg addr + + seg = C.pack "This is a test message" + udpTest6 client server + + it "receives truncated control messages" $ do + let server sock = do + setSocketOption sock RecvIPv4TTL 1 + setSocketOption sock RecvIPv4TOS 1 + setSocketOption sock RecvIPv4PktInfo 1 + (_, _, _, flags) <- recvMsg sock [1024] 10 [] + flags `shouldContain` [MSG_CTRUNC] + + client sock addr = sendTo sock seg addr + + seg = C.pack "This is a test message" + udpTest client server From 10845029707eb9e26cc0e4baaf98fdce1503eee2 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 9 Jan 2020 13:09:28 +0900 Subject: [PATCH 09/48] dropping support for GHC 7.x. Now we support GHC 8.x only. --- .travis.yml | 6 ------ appveyor.yml | 3 +-- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index 81aad3d3..21a0f34b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,12 +26,6 @@ before_cache: matrix: include: - - compiler: "ghc-7.8.4" - # env: TEST=--disable-tests BENCH=--disable-benchmarks - addons: {apt: {packages: [ghc-ppa-tools,cabal-install-2.4,ghc-7.8.4], sources: [hvr-ghc]}} - - compiler: "ghc-7.10.3" - # env: TEST=--disable-tests BENCH=--disable-benchmarks - addons: {apt: {packages: [ghc-ppa-tools,cabal-install-2.4,ghc-7.10.3], sources: [hvr-ghc]}} - compiler: "ghc-8.0.2" # env: TEST=--disable-tests BENCH=--disable-benchmarks addons: {apt: {packages: [ghc-ppa-tools,cabal-install-2.4,ghc-8.0.2], sources: [hvr-ghc]}} diff --git a/appveyor.yml b/appveyor.yml index 4a248de1..951603fe 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -11,12 +11,11 @@ environment: CABOPTS: "--store-dir=C:\\SR --http-transport=plain-http" DOCTEST: YES matrix: - - GHCVER: 7.8.4.1 - - GHCVER: 7.10.3.2 - GHCVER: 8.0.2 - GHCVER: 8.2.2 - GHCVER: 8.4.4 - GHCVER: 8.6.3 + - GHCVER: 8.8.1 platform: # - x86 # We may want to test x86 as well, but it would double the 23min build time. From 68f2990f1202b328cb5c5bb4cc2cf02d0931e8ba Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 9 Jan 2020 13:15:45 +0900 Subject: [PATCH 10/48] fixing typos: s/auxiliary/ancillary/ --- Network/Socket/ByteString.hs | 22 ++--- .../{Auxiliary.hsc => Ancillary.hsc} | 94 +++++++++---------- network.cabal | 2 +- tests/Network/Socket/ByteStringSpec.hs | 12 +-- 4 files changed, 65 insertions(+), 65 deletions(-) rename Network/Socket/ByteString/{Auxiliary.hsc => Ancillary.hsc} (68%) diff --git a/Network/Socket/ByteString.hs b/Network/Socket/ByteString.hs index 5f7c0040..ce2daf67 100644 --- a/Network/Socket/ByteString.hs +++ b/Network/Socket/ByteString.hs @@ -39,16 +39,16 @@ module Network.Socket.ByteString , recvMsg , MsgFlag(..) , Cmsg(..) - -- ** Auxiliary data - , Auxiliary(..) - , AuxiliaryID - , auxiliaryIPv4TTL - , auxiliaryIPv6HopLimit - , auxiliaryIPv4TOS - , auxiliaryIPv6TClass - , auxiliaryIPv4PktInfo - , auxiliaryIPv6PktInfo - , lookupAuxiliary + -- ** Ancillary data + , Ancillary(..) + , AncillaryID + , ancillaryIPv4TTL + , ancillaryIPv6HopLimit + , ancillaryIPv4TOS + , ancillaryIPv6TClass + , ancillaryIPv4PktInfo + , ancillaryIPv6PktInfo + , lookupAncillary -- ** Types , IPv4TTL(..) , IPv6HopLimit(..) @@ -63,7 +63,7 @@ import Data.ByteString (ByteString) import Network.Socket.ByteString.IO hiding (sendTo, sendAllTo, recvFrom) import qualified Network.Socket.ByteString.IO as G import Network.Socket.Types -import Network.Socket.ByteString.Auxiliary +import Network.Socket.ByteString.Ancillary -- ---------------------------------------------------------------------------- -- ** Vectored I/O diff --git a/Network/Socket/ByteString/Auxiliary.hsc b/Network/Socket/ByteString/Ancillary.hsc similarity index 68% rename from Network/Socket/ByteString/Auxiliary.hsc rename to Network/Socket/ByteString/Ancillary.hsc index 283a6208..8b135478 100644 --- a/Network/Socket/ByteString/Auxiliary.hsc +++ b/Network/Socket/ByteString/Ancillary.hsc @@ -1,7 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE ScopedTypeVariables #-} -module Network.Socket.ByteString.Auxiliary where +module Network.Socket.ByteString.Ancillary where #include "HsNet.h" @@ -18,58 +18,58 @@ import Network.Socket.Types ---------------------------------------------------------------- --- | Identifier of auxiliary data. A pair of level and type. -type AuxiliaryID = (CInt, CInt) +-- | Identifier of ancillary data. A pair of level and type. +type AncillaryID = (CInt, CInt) -- | The identifier for 'IPv4TTL'. -auxiliaryIPv4TTL :: AuxiliaryID +ancillaryIPv4TTL :: AncillaryID #if defined(darwin_HOST_OS) -auxiliaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_RECVTTL)) +ancillaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_RECVTTL)) #else -auxiliaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_TTL)) +ancillaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_TTL)) #endif -- | The identifier for 'IPv6HopLimit'. -auxiliaryIPv6HopLimit :: AuxiliaryID -auxiliaryIPv6HopLimit = ((#const IPPROTO_IPV6), (#const IPV6_HOPLIMIT)) +ancillaryIPv6HopLimit :: AncillaryID +ancillaryIPv6HopLimit = ((#const IPPROTO_IPV6), (#const IPV6_HOPLIMIT)) -- | The identifier for 'IPv4TOS'. -auxiliaryIPv4TOS :: AuxiliaryID +ancillaryIPv4TOS :: AncillaryID #if defined(darwin_HOST_OS) -auxiliaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_RECVTOS)) +ancillaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_RECVTOS)) #else -auxiliaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_TOS)) +ancillaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_TOS)) #endif -- | The identifier for 'IPv6TClass'. -auxiliaryIPv6TClass :: AuxiliaryID -auxiliaryIPv6TClass = ((#const IPPROTO_IPV6), (#const IPV6_TCLASS)) +ancillaryIPv6TClass :: AncillaryID +ancillaryIPv6TClass = ((#const IPPROTO_IPV6), (#const IPV6_TCLASS)) -- | The identifier for 'IPv4PktInfo'. -auxiliaryIPv4PktInfo :: AuxiliaryID -auxiliaryIPv4PktInfo = ((#const IPPROTO_IP), (#const IP_PKTINFO)) +ancillaryIPv4PktInfo :: AncillaryID +ancillaryIPv4PktInfo = ((#const IPPROTO_IP), (#const IP_PKTINFO)) -- | The identifier for 'IPv6PktInfo'. -auxiliaryIPv6PktInfo :: AuxiliaryID -auxiliaryIPv6PktInfo = ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) +ancillaryIPv6PktInfo :: AncillaryID +ancillaryIPv6PktInfo = ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) ---------------------------------------------------------------- --- | Looking up auxiliary data. The following shows an example usage: +-- | Looking up ancillary data. The following shows an example usage: -- --- > (lookupAuxiliary auxiliaryIPv4TOS cmsgs >>= auxiliaryDecode) :: Maybe IPv4TOS -lookupAuxiliary :: AuxiliaryID -> [Cmsg] -> Maybe Cmsg -lookupAuxiliary _ [] = Nothing -lookupAuxiliary aid (cmsg@(Cmsg cid _):cmsgs) +-- > (lookupAncillary ancillaryIPv4TOS cmsgs >>= ancillaryDecode) :: Maybe IPv4TOS +lookupAncillary :: AncillaryID -> [Cmsg] -> Maybe Cmsg +lookupAncillary _ [] = Nothing +lookupAncillary aid (cmsg@(Cmsg cid _):cmsgs) | aid == cid = Just cmsg - | otherwise = lookupAuxiliary aid cmsgs + | otherwise = lookupAncillary aid cmsgs ---------------------------------------------------------------- --- | A class to encode and decode auxiliary data. -class Auxiliary a where - auxiliaryEncode :: a -> Cmsg - auxiliaryDecode :: Cmsg -> Maybe a +-- | A class to encode and decode ancillary data. +class Ancillary a where + ancillaryEncode :: a -> Cmsg + ancillaryDecode :: Cmsg -> Maybe a ---------------------------------------------------------------- @@ -110,16 +110,16 @@ unpackCChar (PS fptr off len) -- | Time to live of IPv4. newtype IPv4TTL = IPv4TTL Int deriving (Eq, Show) -instance Auxiliary IPv4TTL where +instance Ancillary IPv4TTL where #if defined(darwin_HOST_OS) - auxiliaryEncode (IPv4TTL ttl) = Cmsg auxiliaryIPv4TTL $ packCChar $ fromIntegral ttl + ancillaryEncode (IPv4TTL ttl) = Cmsg ancillaryIPv4TTL $ packCChar $ fromIntegral ttl #else - auxiliaryEncode (IPv4TTL ttl) = Cmsg auxiliaryIPv4TTL $ packCInt $ fromIntegral ttl + ancillaryEncode (IPv4TTL ttl) = Cmsg ancillaryIPv4TTL $ packCInt $ fromIntegral ttl #endif #if defined(darwin_HOST_OS) - auxiliaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCChar bs + ancillaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCChar bs #else - auxiliaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCInt bs + ancillaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCInt bs #endif ---------------------------------------------------------------- @@ -127,27 +127,27 @@ instance Auxiliary IPv4TTL where -- | Hop limit of IPv6. newtype IPv6HopLimit = IPv6HopLimit Int deriving (Eq, Show) -instance Auxiliary IPv6HopLimit where - auxiliaryEncode (IPv6HopLimit ttl) = Cmsg auxiliaryIPv6HopLimit $ packCInt $ fromIntegral ttl - auxiliaryDecode (Cmsg _ bs) = IPv6HopLimit . fromIntegral <$> unpackCInt bs +instance Ancillary IPv6HopLimit where + ancillaryEncode (IPv6HopLimit ttl) = Cmsg ancillaryIPv6HopLimit $ packCInt $ fromIntegral ttl + ancillaryDecode (Cmsg _ bs) = IPv6HopLimit . fromIntegral <$> unpackCInt bs ---------------------------------------------------------------- -- | TOS of IPv4. newtype IPv4TOS = IPv4TOS Int deriving (Eq, Show) -instance Auxiliary IPv4TOS where - auxiliaryEncode (IPv4TOS ttl) = Cmsg auxiliaryIPv4TOS $ packCChar $ fromIntegral ttl - auxiliaryDecode (Cmsg _ bs) = IPv4TOS . fromIntegral <$> unpackCChar bs +instance Ancillary IPv4TOS where + ancillaryEncode (IPv4TOS ttl) = Cmsg ancillaryIPv4TOS $ packCChar $ fromIntegral ttl + ancillaryDecode (Cmsg _ bs) = IPv4TOS . fromIntegral <$> unpackCChar bs ---------------------------------------------------------------- -- | Traffic class of IPv6. newtype IPv6TClass = IPv6TClass Int deriving (Eq, Show) -instance Auxiliary IPv6TClass where - auxiliaryEncode (IPv6TClass ttl) = Cmsg auxiliaryIPv6TClass $ packCInt $ fromIntegral ttl - auxiliaryDecode (Cmsg _ bs) = IPv6TClass . fromIntegral <$> unpackCInt bs +instance Ancillary IPv6TClass where + ancillaryEncode (IPv6TClass ttl) = Cmsg ancillaryIPv6TClass $ packCInt $ fromIntegral ttl + ancillaryDecode (Cmsg _ bs) = IPv6TClass . fromIntegral <$> unpackCInt bs ---------------------------------------------------------------- @@ -157,9 +157,9 @@ data IPv4PktInfo = IPv4PktInfo Int HostAddress deriving (Eq) instance Show IPv4PktInfo where show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) -instance Auxiliary IPv4PktInfo where - auxiliaryEncode pktinfo = Cmsg auxiliaryIPv4PktInfo $ packIPv4PktInfo pktinfo - auxiliaryDecode (Cmsg _ bs) = unpackIPv4PktInfo bs +instance Ancillary IPv4PktInfo where + ancillaryEncode pktinfo = Cmsg ancillaryIPv4PktInfo $ packIPv4PktInfo pktinfo + ancillaryDecode (Cmsg _ bs) = unpackIPv4PktInfo bs {-# NOINLINE packIPv4PktInfo #-} packIPv4PktInfo :: IPv4PktInfo -> ByteString @@ -191,9 +191,9 @@ data IPv6PktInfo = IPv6PktInfo Int HostAddress6 deriving (Eq) instance Show IPv6PktInfo where show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6) -instance Auxiliary IPv6PktInfo where - auxiliaryEncode pktinfo = Cmsg auxiliaryIPv6PktInfo $ packIPv6PktInfo pktinfo - auxiliaryDecode (Cmsg _ bs) = unpackIPv6PktInfo bs +instance Ancillary IPv6PktInfo where + ancillaryEncode pktinfo = Cmsg ancillaryIPv6PktInfo $ packIPv6PktInfo pktinfo + ancillaryDecode (Cmsg _ bs) = unpackIPv6PktInfo bs {-# NOINLINE packIPv6PktInfo #-} packIPv6PktInfo :: IPv6PktInfo -> ByteString diff --git a/network.cabal b/network.cabal index 533bc283..aca610ab 100644 --- a/network.cabal +++ b/network.cabal @@ -101,7 +101,7 @@ library Network.Socket.ByteString.Lazy.Posix Network.Socket.ByteString.MsgHdr Network.Socket.ByteString.Cmsg - Network.Socket.ByteString.Auxiliary + Network.Socket.ByteString.Ancillary c-sources: cbits/ancilData.c cbits/cmsg.c if os(solaris) diff --git a/tests/Network/Socket/ByteStringSpec.hs b/tests/Network/Socket/ByteStringSpec.hs index a600ac87..7d6d36c3 100644 --- a/tests/Network/Socket/ByteStringSpec.hs +++ b/tests/Network/Socket/ByteStringSpec.hs @@ -251,9 +251,9 @@ spec = do setSocketOption sock RecvIPv4PktInfo 1 (_, _, cmsgs, _) <- recvMsg sock [1024] 128 [] - ((lookupAuxiliary auxiliaryIPv4TTL cmsgs >>= auxiliaryDecode) :: Maybe IPv4TTL) `shouldNotBe` Nothing - ((lookupAuxiliary auxiliaryIPv4TOS cmsgs >>= auxiliaryDecode) :: Maybe IPv4TOS) `shouldNotBe` Nothing - ((lookupAuxiliary auxiliaryIPv4PktInfo cmsgs >>= auxiliaryDecode) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing + ((lookupAncillary ancillaryIPv4TTL cmsgs >>= ancillaryDecode) :: Maybe IPv4TTL) `shouldNotBe` Nothing + ((lookupAncillary ancillaryIPv4TOS cmsgs >>= ancillaryDecode) :: Maybe IPv4TOS) `shouldNotBe` Nothing + ((lookupAncillary ancillaryIPv4PktInfo cmsgs >>= ancillaryDecode) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" @@ -266,9 +266,9 @@ spec = do setSocketOption sock RecvIPv6PktInfo 1 (_, _, cmsgs, _) <- recvMsg sock [1024] 128 [] - ((lookupAuxiliary auxiliaryIPv6HopLimit cmsgs >>= auxiliaryDecode) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing - ((lookupAuxiliary auxiliaryIPv6TClass cmsgs >>= auxiliaryDecode) :: Maybe IPv6TClass) `shouldNotBe` Nothing - ((lookupAuxiliary auxiliaryIPv6PktInfo cmsgs >>= auxiliaryDecode) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing + ((lookupAncillary ancillaryIPv6HopLimit cmsgs >>= ancillaryDecode) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing + ((lookupAncillary ancillaryIPv6TClass cmsgs >>= ancillaryDecode) :: Maybe IPv6TClass) `shouldNotBe` Nothing + ((lookupAncillary ancillaryIPv6PktInfo cmsgs >>= ancillaryDecode) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" From ca7b6edbb33f9094225e272d70747917bd989ccf Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 9 Jan 2020 13:32:58 +0900 Subject: [PATCH 11/48] adding fields for CmsgHdr. --- Network/Socket/ByteString/Cmsg.hsc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Network/Socket/ByteString/Cmsg.hsc b/Network/Socket/ByteString/Cmsg.hsc index 0cffc94e..8337ef4a 100644 --- a/Network/Socket/ByteString/Cmsg.hsc +++ b/Network/Socket/ByteString/Cmsg.hsc @@ -26,7 +26,11 @@ data Cmsg = Cmsg { , cmsgBody :: ByteString } deriving (Eq, Show) -data CmsgHdr = CmsgHdr CInt CInt CInt deriving (Eq, Show) +data CmsgHdr = CmsgHdr { + cmsgHdrLen :: !CInt + , cmsgHdrLevel :: !CInt + , cmsgHdrType :: !CInt + } deriving (Eq, Show) instance Storable CmsgHdr where sizeOf _ = (#size struct cmsghdr) From 367a5aadf383aba49f1504bad5f5620548107476 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 9 Jan 2020 13:34:59 +0900 Subject: [PATCH 12/48] deleting outdated comment. --- Network/Socket/ByteString/MsgHdr.hsc | 2 -- 1 file changed, 2 deletions(-) diff --git a/Network/Socket/ByteString/MsgHdr.hsc b/Network/Socket/ByteString/MsgHdr.hsc index ad42c226..13c36f6b 100644 --- a/Network/Socket/ByteString/MsgHdr.hsc +++ b/Network/Socket/ByteString/MsgHdr.hsc @@ -14,8 +14,6 @@ import Network.Socket.Types (SockAddr) import Network.Socket.ByteString.IOVec (IOVec) --- We don't use msg_control, msg_controllen, and msg_flags as these --- don't exist on OpenSolaris. data MsgHdr = MsgHdr { msgName :: !(Ptr SockAddr) , msgNameLen :: !CUInt From 2e3cfdd507a3e21ade30facce5ea82f237c3c132 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 9 Jan 2020 14:45:53 +0900 Subject: [PATCH 13/48] MsgFlag is now a newtype of CInt. --- Network/Socket/ByteString.hs | 7 ++- Network/Socket/ByteString/Flag.hsc | 81 +++++++++++++++----------- Network/Socket/ByteString/IO.hsc | 13 ++--- tests/Network/Socket/ByteStringSpec.hs | 31 +++++----- 4 files changed, 74 insertions(+), 58 deletions(-) diff --git a/Network/Socket/ByteString.hs b/Network/Socket/ByteString.hs index ce2daf67..3dc3f994 100644 --- a/Network/Socket/ByteString.hs +++ b/Network/Socket/ByteString.hs @@ -37,7 +37,7 @@ module Network.Socket.ByteString -- * Advanced send and recv , sendMsg , recvMsg - , MsgFlag(..) + , MsgFlag(MSG_OOB,MSG_DONTROUTE,MSG_PEEK,MSG_EOR,MSG_TRUNC,MSG_CTRUNC,MSG_WAITALL) , Cmsg(..) -- ** Ancillary data , Ancillary(..) @@ -60,10 +60,11 @@ module Network.Socket.ByteString import Data.ByteString (ByteString) -import Network.Socket.ByteString.IO hiding (sendTo, sendAllTo, recvFrom) +import Network.Socket.ByteString.Ancillary +import Network.Socket.ByteString.Flag import qualified Network.Socket.ByteString.IO as G +import Network.Socket.ByteString.IO hiding (sendTo, sendAllTo, recvFrom) import Network.Socket.Types -import Network.Socket.ByteString.Ancillary -- ---------------------------------------------------------------------------- -- ** Vectored I/O diff --git a/Network/Socket/ByteString/Flag.hsc b/Network/Socket/ByteString/Flag.hsc index 175f1a60..db445373 100644 --- a/Network/Socket/ByteString/Flag.hsc +++ b/Network/Socket/ByteString/Flag.hsc @@ -1,59 +1,74 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternSynonyms #-} + #include "HsNet.h" module Network.Socket.ByteString.Flag where import Network.Socket.Imports -import Network.Socket.Info - --- | Message flags. -data MsgFlag = - MSG_OOB -- ^ Send or receive OOB(out-of-bound) data. - | MSG_DONTROUTE -- ^ Bypass routing table lookup. - | MSG_PEEK -- ^ Peek at incoming message without removing it from the queue. - | MSG_EOR -- ^ End of record. - | MSG_TRUNC -- ^ Received data is truncated. More data exist. - | MSG_CTRUNC -- ^ Received control message is truncated. More control message exist. - | MSG_WAITALL -- ^ Wait until the requested number of bytes have been read. - deriving (Eq, Show) - -msgFlagMapping :: [(MsgFlag, CInt)] -msgFlagMapping = [ + +-- | Message flags. To combine flags, use '(<>)'. +newtype MsgFlag = MsgFlag { fromMsgFlag :: CInt } + deriving (Show, Eq, Ord, Num, Bits) + +instance Semigroup MsgFlag where + (<>) = (.|.) + +instance Monoid MsgFlag where + mempty = 0 + +-- | Send or receive OOB(out-of-bound) data. +pattern MSG_OOB :: MsgFlag #ifdef MSG_OOB - (MSG_OOB, #const MSG_OOB) +pattern MSG_OOB = MsgFlag (#const MSG_OOB) #else - (MSG_OOB, 0) +pattern MSG_OOB = mempty #endif + +-- | Bypass routing table lookup. +pattern MSG_DONTROUTE :: MsgFlag #ifdef MSG_DONTROUTE - , (MSG_DONTROUTE, #const MSG_DONTROUTE) +pattern MSG_DONTROUTE = MsgFlag (#const MSG_DONTROUTE) #else - , (MSG_DONTROUTE, 0) +pattern MSG_DONTROUTE = mempty #endif + +-- | Peek at incoming message without removing it from the queue. +pattern MSG_PEEK :: MsgFlag #ifdef MSG_PEEK - , (MSG_PEEK, #const MSG_PEEK) +pattern MSG_PEEK = MsgFlag (#const MSG_PEEK) #else - , (MSG_PEEK, 0) +pattern MSG_PEEK = mempty #endif + +-- | End of record. +pattern MSG_EOR :: MsgFlag #ifdef MSG_EOR - , (MSG_EOR, #const MSG_EOR) +pattern MSG_EOR = MsgFlag (#const MSG_EOR) #else - , (MSG_EOR, 0) +pattern MSG_EOR = mempty #endif + +-- | Received data is truncated. More data exist. +pattern MSG_TRUNC :: MsgFlag #ifdef MSG_TRUNC - , (MSG_TRUNC, #const MSG_TRUNC) +pattern MSG_TRUNC = MsgFlag (#const MSG_TRUNC) #else - , (MSG_TRUNC, 0) +pattern MSG_TRUNC = mempty #endif + +-- | Received control message is truncated. More control message exist. +pattern MSG_CTRUNC :: MsgFlag #ifdef MSG_CTRUNC - , (MSG_CTRUNC, #const MSG_CTRUNC) +pattern MSG_CTRUNC = MsgFlag (#const MSG_CTRUNC) #else - , (MSG_CTRUNC, 0) +pattern MSG_CTRUNC = mempty #endif + +-- | Wait until the requested number of bytes have been read. +pattern MSG_WAITALL :: MsgFlag #ifdef MSG_WAITALL - , (MSG_WAITALL, #const MSG_WAITALL) +pattern MSG_WAITALL = MsgFlag (#const MSG_WAITALL) #else - , (MSG_WAITALL, 0) +pattern MSG_WAITALL = mempty #endif - ] - -msgFlagImplemented :: MsgFlag -> Bool -msgFlagImplemented f = packBits msgFlagMapping [f] /= 0 diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 7fde334c..ff3f058d 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -60,7 +60,6 @@ import Network.Socket.ByteString.Cmsg import Network.Socket.ByteString.Flag import Network.Socket.ByteString.IOVec (IOVec(..)) import Network.Socket.ByteString.MsgHdr (MsgHdr(..)) -import Network.Socket.Info (packBits, unpackBits) import Network.Socket.Name (getPeerName) #endif @@ -261,7 +260,7 @@ sendMsg :: Socket -- ^ Socket -> SockAddr -- ^ Destination address -> [ByteString] -- ^ Data to be sent -> [Cmsg] -- ^ Control messages - -> [MsgFlag] -- ^ Message flags + -> MsgFlag -- ^ Message flags -> IO Int -- ^ The length actually sent sendMsg _ _ [] _ _ = return 0 sendMsg s addr bss cmsgs flags = do @@ -277,7 +276,7 @@ sendMsg s addr bss cmsgs flags = do , msgCtrlLen = fromIntegral ctrlLen , msgFlags = 0 } - cflags = packBits msgFlagMapping flags + cflags = fromMsgFlag flags withFdSocket s $ \fd -> with msgHdr $ \msgHdrPtr -> throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMsg" $ @@ -296,8 +295,8 @@ recvMsg :: Socket -- ^ Socket -> Int -- ^ The buffer size for control messages. -- If the length is not large enough, -- 'MSG_CTRUNC' is returned - -> [MsgFlag] -- ^ Message flags - -> IO (SockAddr, [ByteString], [Cmsg], [MsgFlag]) -- ^ Source address, received data, control messages and message flags + -> MsgFlag -- ^ Message flags + -> IO (SockAddr, [ByteString], [Cmsg], MsgFlag) -- ^ Source address, received data, control messages and message flags recvMsg _ [] _ _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") recvMsg s sizs clen flags = do bss <- mapM newBS sizs @@ -313,7 +312,7 @@ recvMsg s sizs clen flags = do , msgCtrlLen = fromIntegral clen , msgFlags = 0 } - cflags = packBits msgFlagMapping flags + cflags = fromMsgFlag flags withFdSocket s $ \fd -> do with msgHdr $ \msgHdrPtr -> do len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.ByteString.recvmg" (c_recvmsg fd msgHdrPtr cflags) @@ -325,7 +324,7 @@ recvMsg s sizs clen flags = do sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s hdr <- peek msgHdrPtr cmsgs <- parseCmsgs msgHdrPtr - let flags' = unpackBits msgFlagMapping $ msgFlags hdr + let flags' = MsgFlag $ msgFlags hdr return (sockaddr, bss', cmsgs, flags') newBS :: Int -> IO ByteString diff --git a/tests/Network/Socket/ByteStringSpec.hs b/tests/Network/Socket/ByteStringSpec.hs index 7d6d36c3..ff557582 100644 --- a/tests/Network/Socket/ByteStringSpec.hs +++ b/tests/Network/Socket/ByteStringSpec.hs @@ -2,6 +2,7 @@ module Network.Socket.ByteStringSpec (main, spec) where +import Data.Bits import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as C import Network.Socket @@ -178,7 +179,7 @@ spec = do describe "sendMsg" $ do it "works well" $ do let server sock = recv sock 1024 `shouldReturn` S.append seg1 seg2 - client sock addr = sendMsg sock addr [seg1, seg2] [] [] + client sock addr = sendMsg sock addr [seg1, seg2] [] mempty seg1 = C.pack "This is a " seg2 = C.pack "test message." @@ -188,7 +189,7 @@ spec = do let server _ = return () client sock addr = do close sock - sendMsg sock addr [seg1, seg2] [] [] `shouldThrow` anyException + sendMsg sock addr [seg1, seg2] [] mempty `shouldThrow` anyException seg1 = C.pack "This is a " seg2 = C.pack "test message." @@ -197,10 +198,10 @@ spec = do describe "recvMsg" $ do it "works well" $ do let server sock = do - (_, msgs, cmsgs, flags) <- recvMsg sock [1024] 0 [] + (_, msgs, cmsgs, flags) <- recvMsg sock [1024] 0 mempty S.concat msgs `shouldBe` seg cmsgs `shouldBe` [] - flags `shouldBe` [] + flags `shouldBe` mempty client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" @@ -208,7 +209,7 @@ spec = do it "receives message fragments" $ do let server sock = do - (_, msgs, _, _) <- recvMsg sock [1,2,3,4] 0 [] + (_, msgs, _, _) <- recvMsg sock [1,2,3,4] 0 mempty S.concat msgs `shouldBe` S.take 10 seg client sock addr = sendTo sock seg addr @@ -217,7 +218,7 @@ spec = do it "receives message fragments with truncation" $ do let server sock = do - (_, msgs, _, _) <- recvMsg sock [10,10,10,10] 0 [] + (_, msgs, _, _) <- recvMsg sock [10,10,10,10] 0 mempty msgs `shouldBe` ["0123456789", "0123456789", "012345"] client sock addr = sendTo sock seg addr @@ -226,8 +227,8 @@ spec = do it "receives truncated flag" $ do let server sock = do - (_, _, _, flags) <- recvMsg sock [S.length seg - 2] 0 [] - flags `shouldContain` [MSG_TRUNC] + (_, _, _, flags) <- recvMsg sock [S.length seg - 2] 0 mempty + flags .&. MSG_TRUNC `shouldBe` MSG_TRUNC client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" @@ -235,9 +236,9 @@ spec = do it "peek" $ do let server sock = do - (_, msgs, _, _flags) <- recvMsg sock [1024] 0 [MSG_PEEK] - -- flags `shouldContain` [MSG_PEEK] -- Mac only - (_, msgs', _, _) <- recvMsg sock [1024] 0 [] + (_, msgs, _, _flags) <- recvMsg sock [1024] 0 MSG_PEEK + -- flags .&. MSG_PEEK `shouldBe` MSG_PEEK -- Mac only + (_, msgs', _, _) <- recvMsg sock [1024] 0 mempty msgs `shouldBe` msgs' client sock addr = sendTo sock seg addr @@ -249,7 +250,7 @@ spec = do setSocketOption sock RecvIPv4TTL 1 setSocketOption sock RecvIPv4TOS 1 setSocketOption sock RecvIPv4PktInfo 1 - (_, _, cmsgs, _) <- recvMsg sock [1024] 128 [] + (_, _, cmsgs, _) <- recvMsg sock [1024] 128 mempty ((lookupAncillary ancillaryIPv4TTL cmsgs >>= ancillaryDecode) :: Maybe IPv4TTL) `shouldNotBe` Nothing ((lookupAncillary ancillaryIPv4TOS cmsgs >>= ancillaryDecode) :: Maybe IPv4TOS) `shouldNotBe` Nothing @@ -264,7 +265,7 @@ spec = do setSocketOption sock RecvIPv6HopLimit 1 setSocketOption sock RecvIPv6TClass 1 setSocketOption sock RecvIPv6PktInfo 1 - (_, _, cmsgs, _) <- recvMsg sock [1024] 128 [] + (_, _, cmsgs, _) <- recvMsg sock [1024] 128 mempty ((lookupAncillary ancillaryIPv6HopLimit cmsgs >>= ancillaryDecode) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing ((lookupAncillary ancillaryIPv6TClass cmsgs >>= ancillaryDecode) :: Maybe IPv6TClass) `shouldNotBe` Nothing @@ -279,8 +280,8 @@ spec = do setSocketOption sock RecvIPv4TTL 1 setSocketOption sock RecvIPv4TOS 1 setSocketOption sock RecvIPv4PktInfo 1 - (_, _, _, flags) <- recvMsg sock [1024] 10 [] - flags `shouldContain` [MSG_CTRUNC] + (_, _, _, flags) <- recvMsg sock [1024] 10 mempty + flags .&. MSG_CTRUNC `shouldBe` MSG_CTRUNC client sock addr = sendTo sock seg addr From a52c4df12fe292d7c2435375b567f0c5aa301fcb Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 9 Jan 2020 16:10:48 +0900 Subject: [PATCH 14/48] making Posix directory. --- Network/Socket/ByteString.hs | 4 ++-- Network/Socket/ByteString/IO.hsc | 8 ++++---- Network/Socket/ByteString/Internal.hs | 4 ++-- Network/Socket/ByteString/Lazy/Posix.hs | 4 ++-- Network/Socket/{ByteString => }/Flag.hsc | 2 +- Network/Socket/{ByteString => Posix}/Ancillary.hsc | 4 ++-- Network/Socket/{ByteString => Posix}/Cmsg.hsc | 4 ++-- Network/Socket/{ByteString => Posix}/IOVec.hsc | 2 +- Network/Socket/{ByteString => Posix}/MsgHdr.hsc | 4 ++-- network.cabal | 14 +++++++------- 10 files changed, 25 insertions(+), 25 deletions(-) rename Network/Socket/{ByteString => }/Flag.hsc (97%) rename Network/Socket/{ByteString => Posix}/Ancillary.hsc (98%) rename Network/Socket/{ByteString => Posix}/Cmsg.hsc (97%) rename Network/Socket/{ByteString => Posix}/IOVec.hsc (94%) rename Network/Socket/{ByteString => Posix}/MsgHdr.hsc (95%) diff --git a/Network/Socket/ByteString.hs b/Network/Socket/ByteString.hs index 3dc3f994..52a8142c 100644 --- a/Network/Socket/ByteString.hs +++ b/Network/Socket/ByteString.hs @@ -60,10 +60,10 @@ module Network.Socket.ByteString import Data.ByteString (ByteString) -import Network.Socket.ByteString.Ancillary -import Network.Socket.ByteString.Flag import qualified Network.Socket.ByteString.IO as G import Network.Socket.ByteString.IO hiding (sendTo, sendAllTo, recvFrom) +import Network.Socket.Flag +import Network.Socket.Posix.Ancillary import Network.Socket.Types -- ---------------------------------------------------------------------------- diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index ff3f058d..bf21e46d 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -56,11 +56,11 @@ import Foreign.Marshal.Utils (with) import Network.Socket.Internal import System.IO.Error (catchIOError) -import Network.Socket.ByteString.Cmsg -import Network.Socket.ByteString.Flag -import Network.Socket.ByteString.IOVec (IOVec(..)) -import Network.Socket.ByteString.MsgHdr (MsgHdr(..)) +import Network.Socket.Flag import Network.Socket.Name (getPeerName) +import Network.Socket.Posix.Cmsg +import Network.Socket.Posix.IOVec (IOVec(..)) +import Network.Socket.Posix.MsgHdr (MsgHdr(..)) #endif -- ---------------------------------------------------------------------------- diff --git a/Network/Socket/ByteString/Internal.hs b/Network/Socket/ByteString/Internal.hs index 3d6aea48..4e9ebbad 100644 --- a/Network/Socket/ByteString/Internal.hs +++ b/Network/Socket/ByteString/Internal.hs @@ -26,8 +26,8 @@ import System.IO.Error (ioeSetErrorString, mkIOError) import System.Posix.Types (CSsize(..)) import Network.Socket.Imports -import Network.Socket.ByteString.IOVec (IOVec) -import Network.Socket.ByteString.MsgHdr (MsgHdr) +import Network.Socket.Posix.IOVec (IOVec) +import Network.Socket.Posix.MsgHdr (MsgHdr) #endif mkInvalidRecvArgError :: String -> IOError diff --git a/Network/Socket/ByteString/Lazy/Posix.hs b/Network/Socket/ByteString/Lazy/Posix.hs index 6eae5517..6f629837 100644 --- a/Network/Socket/ByteString/Lazy/Posix.hs +++ b/Network/Socket/ByteString/Lazy/Posix.hs @@ -11,11 +11,11 @@ import qualified Data.ByteString.Lazy as L import Data.ByteString.Unsafe (unsafeUseAsCStringLen) import Foreign.Marshal.Array (allocaArray) -import Network.Socket.ByteString.Internal (c_writev) import Network.Socket.ByteString.IO (waitWhen0) -import Network.Socket.ByteString.IOVec (IOVec (IOVec)) +import Network.Socket.ByteString.Internal (c_writev) import Network.Socket.Imports import Network.Socket.Internal +import Network.Socket.Posix.IOVec (IOVec (IOVec)) import Network.Socket.Types -- ----------------------------------------------------------------------------- diff --git a/Network/Socket/ByteString/Flag.hsc b/Network/Socket/Flag.hsc similarity index 97% rename from Network/Socket/ByteString/Flag.hsc rename to Network/Socket/Flag.hsc index db445373..40ef2d96 100644 --- a/Network/Socket/ByteString/Flag.hsc +++ b/Network/Socket/Flag.hsc @@ -3,7 +3,7 @@ #include "HsNet.h" -module Network.Socket.ByteString.Flag where +module Network.Socket.Flag where import Network.Socket.Imports diff --git a/Network/Socket/ByteString/Ancillary.hsc b/Network/Socket/Posix/Ancillary.hsc similarity index 98% rename from Network/Socket/ByteString/Ancillary.hsc rename to Network/Socket/Posix/Ancillary.hsc index 8b135478..cc8cb54b 100644 --- a/Network/Socket/ByteString/Ancillary.hsc +++ b/Network/Socket/Posix/Ancillary.hsc @@ -1,7 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE ScopedTypeVariables #-} -module Network.Socket.ByteString.Ancillary where +module Network.Socket.Posix.Ancillary where #include "HsNet.h" @@ -12,8 +12,8 @@ import Data.ByteString.Internal import Foreign.ForeignPtr import System.IO.Unsafe (unsafeDupablePerformIO) -import Network.Socket.ByteString.Cmsg import Network.Socket.Imports +import Network.Socket.Posix.Cmsg import Network.Socket.Types ---------------------------------------------------------------- diff --git a/Network/Socket/ByteString/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc similarity index 97% rename from Network/Socket/ByteString/Cmsg.hsc rename to Network/Socket/Posix/Cmsg.hsc index 8337ef4a..b0f4df7c 100644 --- a/Network/Socket/ByteString/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -2,7 +2,7 @@ #include "HsNet.h" -module Network.Socket.ByteString.Cmsg ( +module Network.Socket.Posix.Cmsg ( Cmsg(..) , withCmsgs , parseCmsgs @@ -16,8 +16,8 @@ import Foreign.ForeignPtr import qualified Data.ByteString as B import Data.ByteString.Internal -import Network.Socket.ByteString.MsgHdr import Network.Socket.Imports +import Network.Socket.Posix.MsgHdr import Network.Socket.Types -- | Control message including a pair of level and type. diff --git a/Network/Socket/ByteString/IOVec.hsc b/Network/Socket/Posix/IOVec.hsc similarity index 94% rename from Network/Socket/ByteString/IOVec.hsc rename to Network/Socket/Posix/IOVec.hsc index e2b85fc4..fe6a4c09 100644 --- a/Network/Socket/ByteString/IOVec.hsc +++ b/Network/Socket/Posix/IOVec.hsc @@ -1,7 +1,7 @@ {-# OPTIONS_GHC -funbox-strict-fields #-} -- | Support module for the POSIX writev system call. -module Network.Socket.ByteString.IOVec +module Network.Socket.Posix.IOVec ( IOVec(..) ) where diff --git a/Network/Socket/ByteString/MsgHdr.hsc b/Network/Socket/Posix/MsgHdr.hsc similarity index 95% rename from Network/Socket/ByteString/MsgHdr.hsc rename to Network/Socket/Posix/MsgHdr.hsc index 13c36f6b..df3fd198 100644 --- a/Network/Socket/ByteString/MsgHdr.hsc +++ b/Network/Socket/Posix/MsgHdr.hsc @@ -1,7 +1,7 @@ {-# OPTIONS_GHC -funbox-strict-fields #-} -- | Support module for the POSIX 'sendmsg' system call. -module Network.Socket.ByteString.MsgHdr +module Network.Socket.Posix.MsgHdr ( MsgHdr(..) ) where @@ -12,7 +12,7 @@ import Network.Socket.Imports import Network.Socket.Internal (zeroMemory) import Network.Socket.Types (SockAddr) -import Network.Socket.ByteString.IOVec (IOVec) +import Network.Socket.Posix.IOVec (IOVec) data MsgHdr = MsgHdr { msgName :: !(Ptr SockAddr) diff --git a/network.cabal b/network.cabal index aca610ab..32f0b150 100644 --- a/network.cabal +++ b/network.cabal @@ -64,13 +64,14 @@ library Network.Socket.Internal other-modules: Network.Socket.Buffer - Network.Socket.ByteString.Internal Network.Socket.ByteString.IO + Network.Socket.ByteString.Internal Network.Socket.Cbits Network.Socket.Fcntl + Network.Socket.Flag Network.Socket.Handle - Network.Socket.Imports Network.Socket.If + Network.Socket.Imports Network.Socket.Info Network.Socket.Name Network.Socket.Options @@ -96,12 +97,11 @@ library -- Add some platform specific stuff if !os(windows) other-modules: - Network.Socket.ByteString.Flag - Network.Socket.ByteString.IOVec Network.Socket.ByteString.Lazy.Posix - Network.Socket.ByteString.MsgHdr - Network.Socket.ByteString.Cmsg - Network.Socket.ByteString.Ancillary + Network.Socket.Posix.Ancillary + Network.Socket.Posix.Cmsg + Network.Socket.Posix.IOVec + Network.Socket.Posix.MsgHdr c-sources: cbits/ancilData.c cbits/cmsg.c if os(solaris) From 3d666ce5b471e0a71cbf0e1b7a71657bf700d399 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 09:58:53 +0900 Subject: [PATCH 15/48] implementing sendBufMsg and recvBufMsg. --- Network/Socket.hs | 26 +++++- Network/Socket/Buffer.hsc | 80 ++++++++++++++++++- Network/Socket/ByteString.hs | 21 ----- Network/Socket/ByteString/IO.hsc | 102 +++++++----------------- Network/Socket/ByteString/Lazy/Posix.hs | 2 +- Network/Socket/Posix/IOVec.hsc | 19 ++++- 6 files changed, 154 insertions(+), 96 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index 375e7e91..d5239ea5 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -205,7 +205,28 @@ module Network.Socket , recvBuf , sendBufTo , recvBufFrom - + -- ** IO with ancillary data + , sendBufMsg + , recvBufMsg + , MsgFlag(MSG_OOB,MSG_DONTROUTE,MSG_PEEK,MSG_EOR,MSG_TRUNC,MSG_CTRUNC,MSG_WAITALL) + , Cmsg(..) + -- ** Ancillary data + , Ancillary(..) + , AncillaryID + , ancillaryIPv4TTL + , ancillaryIPv6HopLimit + , ancillaryIPv4TOS + , ancillaryIPv6TClass + , ancillaryIPv4PktInfo + , ancillaryIPv6PktInfo + , lookupAncillary + -- ** Types + , IPv4TTL(..) + , IPv6HopLimit(..) + , IPv4TOS(..) + , IPv6TClass(..) + , IPv4PktInfo(..) + , IPv6PktInfo(..) -- * Special constants , maxListenQueue ) where @@ -213,12 +234,15 @@ module Network.Socket import Network.Socket.Buffer hiding (sendBufTo, recvBufFrom) import Network.Socket.Cbits import Network.Socket.Fcntl +import Network.Socket.Flag import Network.Socket.Handle import Network.Socket.If import Network.Socket.Info import Network.Socket.Internal import Network.Socket.Name hiding (getPeerName, getSocketName) import Network.Socket.Options +import Network.Socket.Posix.Ancillary +import Network.Socket.Posix.Cmsg import Network.Socket.Shutdown import Network.Socket.SockAddr import Network.Socket.Syscall hiding (connect, bind, accept) diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 8d355520..91d5a2ee 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -11,12 +11,15 @@ module Network.Socket.Buffer ( , recvBufFrom , recvBuf , recvBufNoWait + , sendBufMsg + , recvBufMsg ) where #if !defined(mingw32_HOST_OS) import Foreign.C.Error (getErrno, eAGAIN, eWOULDBLOCK) #endif -import Foreign.Marshal.Alloc (alloca) +import Foreign.Marshal.Alloc (alloca, allocaBytes) +import Foreign.Marshal.Utils (with) import GHC.IO.Exception (IOErrorType(InvalidArgument)) import System.IO.Error (mkIOError, ioeSetErrorString, catchIOError) @@ -28,6 +31,10 @@ import Network.Socket.Imports import Network.Socket.Internal import Network.Socket.Name import Network.Socket.Types +import Network.Socket.Posix.Cmsg +import Network.Socket.Posix.MsgHdr +import Network.Socket.Posix.IOVec +import Network.Socket.Flag -- | Send data to the socket. The recipient can be specified -- explicitly, so the socket need not be in a connected state. @@ -178,6 +185,72 @@ mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError InvalidArgument loc Nothing Nothing) "non-positive length" +-- | Send data from the socket using sendmsg(2). +sendBufMsg :: Socket -- ^ Socket + -> SockAddr -- ^ Destination address -- fixme + -> [(Ptr Word8,Int)] -- ^ Data to be sent + -> [Cmsg] -- ^ Control messages + -> MsgFlag -- ^ Message flags + -> IO Int -- ^ The length actually sent +sendBufMsg _ _ [] _ _ = return 0 +sendBufMsg s addr bufsizs cmsgs flags = do + sz <- withSockAddr addr $ \addrPtr addrSize -> + withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do + withCmsgs cmsgs $ \ctrlPtr ctrlLen -> do + let msgHdr = MsgHdr { + msgName = addrPtr + , msgNameLen = fromIntegral addrSize + , msgIov = iovsPtr + , msgIovLen = fromIntegral iovsLen + , msgCtrl = castPtr ctrlPtr + , msgCtrlLen = fromIntegral ctrlLen + , msgFlags = 0 + } + cflags = fromMsgFlag flags + withFdSocket s $ \fd -> + with msgHdr $ \msgHdrPtr -> + throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMsg" $ + c_sendmsg fd msgHdrPtr cflags + return $ fromIntegral sz + +-- | Receive data from the socket using recvmsg(2). +-- The receive buffers are created according to the second argument. +-- If the length of received data is less than the total of +-- the second argument, the buffers are truncated properly. +-- So, only the received data can be seen. +recvBufMsg :: Socket -- ^ Socket + -> [(Ptr Word8,Int)] -- ^ A list of a pair of buffer and its size + -- If the total length is not large enough, + -- 'MSG_TRUNC' is returned + -> Int -- ^ The buffer size for control messages. + -- If the length is not large enough, + -- 'MSG_CTRUNC' is returned + -> MsgFlag -- ^ Message flags + -> IO (SockAddr,Int,[Cmsg],MsgFlag) -- ^ Source address, received data, control messages and message flags +recvBufMsg _ [] _ _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") +recvBufMsg s bufsizs clen flags = do + withNewSocketAddress $ \addrPtr addrSize -> + allocaBytes clen $ \ctrlPtr -> + withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do + let msgHdr = MsgHdr { + msgName = addrPtr + , msgNameLen = fromIntegral addrSize + , msgIov = iovsPtr + , msgIovLen = fromIntegral iovsLen + , msgCtrl = castPtr ctrlPtr + , msgCtrlLen = fromIntegral clen + , msgFlags = 0 + } + cflags = fromMsgFlag flags + withFdSocket s $ \fd -> do + with msgHdr $ \msgHdrPtr -> do + len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.ByteString.recvmg" (c_recvmsg fd msgHdrPtr cflags) + sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s + hdr <- peek msgHdrPtr + cmsgs <- parseCmsgs msgHdrPtr + let flags' = MsgFlag $ msgFlags hdr + return (sockaddr, len, cmsgs, flags') + #if !defined(mingw32_HOST_OS) foreign import ccall unsafe "send" c_send :: CInt -> Ptr a -> CSize -> CInt -> IO CInt @@ -193,3 +266,8 @@ foreign import CALLCONV SAFE_ON_WIN "sendto" c_sendto :: CInt -> Ptr a -> CSize -> CInt -> Ptr sa -> CInt -> IO CInt foreign import CALLCONV SAFE_ON_WIN "recvfrom" c_recvfrom :: CInt -> Ptr a -> CSize -> CInt -> Ptr sa -> Ptr CInt -> IO CInt + +foreign import ccall unsafe "sendmsg" + c_sendmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt -- fixme CSsize +foreign import ccall unsafe "recvmsg" + c_recvmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt diff --git a/Network/Socket/ByteString.hs b/Network/Socket/ByteString.hs index 52a8142c..f5810212 100644 --- a/Network/Socket/ByteString.hs +++ b/Network/Socket/ByteString.hs @@ -37,33 +37,12 @@ module Network.Socket.ByteString -- * Advanced send and recv , sendMsg , recvMsg - , MsgFlag(MSG_OOB,MSG_DONTROUTE,MSG_PEEK,MSG_EOR,MSG_TRUNC,MSG_CTRUNC,MSG_WAITALL) - , Cmsg(..) - -- ** Ancillary data - , Ancillary(..) - , AncillaryID - , ancillaryIPv4TTL - , ancillaryIPv6HopLimit - , ancillaryIPv4TOS - , ancillaryIPv6TClass - , ancillaryIPv4PktInfo - , ancillaryIPv6PktInfo - , lookupAncillary - -- ** Types - , IPv4TTL(..) - , IPv6HopLimit(..) - , IPv4TOS(..) - , IPv6TClass(..) - , IPv4PktInfo(..) - , IPv6PktInfo(..) ) where import Data.ByteString (ByteString) import qualified Network.Socket.ByteString.IO as G import Network.Socket.ByteString.IO hiding (sendTo, sendAllTo, recvFrom) -import Network.Socket.Flag -import Network.Socket.Posix.Ancillary import Network.Socket.Types -- ---------------------------------------------------------------------------- diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index bf21e46d..15a32b1d 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -51,15 +51,13 @@ import Network.Socket.Types #if !defined(mingw32_HOST_OS) import Data.ByteString.Internal (create, ByteString(..)) -import Foreign.Marshal.Array (allocaArray) +import Foreign.ForeignPtr (withForeignPtr) import Foreign.Marshal.Utils (with) import Network.Socket.Internal -import System.IO.Error (catchIOError) import Network.Socket.Flag -import Network.Socket.Name (getPeerName) import Network.Socket.Posix.Cmsg -import Network.Socket.Posix.IOVec (IOVec(..)) +import Network.Socket.Posix.IOVec import Network.Socket.Posix.MsgHdr (MsgHdr(..)) #endif @@ -140,7 +138,7 @@ sendMany s cs = do when (sent >= 0) $ sendMany s $ remainingChunks sent cs where sendManyInner = - fmap fromIntegral . withIOVec cs $ \(iovsPtr, iovsLen) -> + fmap fromIntegral . withIOVecfromBS cs $ \(iovsPtr, iovsLen) -> withFdSocket s $ \fd -> do let len = fromIntegral $ min iovsLen (#const IOV_MAX) throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMany" $ @@ -168,7 +166,7 @@ sendManyTo s cs addr = do where sendManyToInner = withSockAddr addr $ \addrPtr addrSize -> - withIOVec cs $ \(iovsPtr, iovsLen) -> do + withIOVecfromBS cs $ \(iovsPtr, iovsLen) -> do let msgHdr = MsgHdr { msgName = addrPtr , msgNameLen = fromIntegral addrSize @@ -238,23 +236,19 @@ remainingChunks i (x:xs) where len = B.length x --- | @withIOVec cs f@ executes the computation @f@, passing as argument a pair +-- | @withIOVecfromBS cs f@ executes the computation @f@, passing as argument a pair -- consisting of a pointer to a temporarily allocated array of pointers to -- IOVec made from @cs@ and the number of pointers (@length cs@). -- /Unix only/. -withIOVec :: [ByteString] -> ((Ptr IOVec, Int) -> IO a) -> IO a -withIOVec cs f = - allocaArray csLen $ \aPtr -> do - zipWithM_ pokeIov (ptrs aPtr) cs - f (aPtr, csLen) - where - csLen = length cs - ptrs = iterate (`plusPtr` sizeOf (undefined :: IOVec)) - pokeIov ptr s = - unsafeUseAsCStringLen s $ \(sPtr, sLen) -> - poke ptr $ IOVec sPtr (fromIntegral sLen) +withIOVecfromBS :: [ByteString] -> ((Ptr IOVec, Int) -> IO a) -> IO a +withIOVecfromBS cs f = do + bufsizs <- mapM getBufsiz cs + withIOVec bufsizs f #endif +getBufsiz :: ByteString -> IO (Ptr Word8, Int) +getBufsiz (PS fptr off len) = withForeignPtr fptr $ \ptr -> return (ptr `plusPtr` off, len) + -- | Send data from the socket using sendmsg(2). sendMsg :: Socket -- ^ Socket -> SockAddr -- ^ Destination address @@ -264,68 +258,34 @@ sendMsg :: Socket -- ^ Socket -> IO Int -- ^ The length actually sent sendMsg _ _ [] _ _ = return 0 sendMsg s addr bss cmsgs flags = do - sz <- withSockAddr addr $ \addrPtr addrSize -> - withIOVec bss $ \(iovsPtr, iovsLen) -> do - withCmsgs cmsgs $ \ctrlPtr ctrlLen -> do - let msgHdr = MsgHdr { - msgName = addrPtr - , msgNameLen = fromIntegral addrSize - , msgIov = iovsPtr - , msgIovLen = fromIntegral iovsLen - , msgCtrl = castPtr ctrlPtr - , msgCtrlLen = fromIntegral ctrlLen - , msgFlags = 0 - } - cflags = fromMsgFlag flags - withFdSocket s $ \fd -> - with msgHdr $ \msgHdrPtr -> - throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMsg" $ - c_sendmsg fd msgHdrPtr cflags - return $ fromIntegral sz + bufsizs <- mapM getBufsiz bss + sendBufMsg s addr bufsizs cmsgs flags -- | Receive data from the socket using recvmsg(2). -- The receive buffers are created according to the second argument. -- If the length of received data is less than the total of -- the second argument, the buffers are truncated properly. -- So, only the received data can be seen. -recvMsg :: Socket -- ^ Socket - -> [Int] -- ^ A list of length of data to be received - -- If the total length is not large enough, - -- 'MSG_TRUNC' is returned - -> Int -- ^ The buffer size for control messages. - -- If the length is not large enough, - -- 'MSG_CTRUNC' is returned - -> MsgFlag -- ^ Message flags +recvMsg :: Socket -- ^ Socket + -> [Int] -- ^ A list of length of data to be received + -- If the total length is not large enough, + -- 'MSG_TRUNC' is returned + -> Int -- ^ The buffer size for control messages. + -- If the length is not large enough, + -- 'MSG_CTRUNC' is returned + -> MsgFlag -- ^ Message flags -> IO (SockAddr, [ByteString], [Cmsg], MsgFlag) -- ^ Source address, received data, control messages and message flags recvMsg _ [] _ _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") recvMsg s sizs clen flags = do - bss <- mapM newBS sizs - withNewSocketAddress $ \addrPtr addrSize -> - allocaBytes clen $ \ctrlPtr -> - withIOVec bss $ \(iovsPtr, iovsLen) -> do - let msgHdr = MsgHdr { - msgName = addrPtr - , msgNameLen = fromIntegral addrSize - , msgIov = iovsPtr - , msgIovLen = fromIntegral iovsLen - , msgCtrl = castPtr ctrlPtr - , msgCtrlLen = fromIntegral clen - , msgFlags = 0 - } - cflags = fromMsgFlag flags - withFdSocket s $ \fd -> do - with msgHdr $ \msgHdrPtr -> do - len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.ByteString.recvmg" (c_recvmsg fd msgHdrPtr cflags) - let total = sum sizs - let bss' = case len `compare` total of - EQ -> bss - LT -> trunc bss len - GT -> error "recvMsg" -- never reach - sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s - hdr <- peek msgHdrPtr - cmsgs <- parseCmsgs msgHdrPtr - let flags' = MsgFlag $ msgFlags hdr - return (sockaddr, bss', cmsgs, flags') + bss <- mapM newBS sizs + bufsizs <- mapM getBufsiz bss + (addr,len,cmsgs,flags') <- recvBufMsg s bufsizs clen flags + let total = sum sizs + let bss' = case len `compare` total of + EQ -> bss + LT -> trunc bss len + GT -> error "recvMsg" -- never reach + return (addr, bss', cmsgs, flags') newBS :: Int -> IO ByteString newBS n = create n $ \ptr -> zeroMemory ptr (fromIntegral n) diff --git a/Network/Socket/ByteString/Lazy/Posix.hs b/Network/Socket/ByteString/Lazy/Posix.hs index 6f629837..fa1c4fde 100644 --- a/Network/Socket/ByteString/Lazy/Posix.hs +++ b/Network/Socket/ByteString/Lazy/Posix.hs @@ -36,7 +36,7 @@ send s lbs = do where loop (c:cs) q k !niovs | k < maxNumBytes = unsafeUseAsCStringLen c $ \(ptr, len) -> do - poke q $ IOVec ptr (fromIntegral len) + poke q $ IOVec (castPtr ptr) (fromIntegral len) loop cs (q `plusPtr` sizeOf (undefined :: IOVec)) (k + fromIntegral len) diff --git a/Network/Socket/Posix/IOVec.hsc b/Network/Socket/Posix/IOVec.hsc index fe6a4c09..0f5f8a40 100644 --- a/Network/Socket/Posix/IOVec.hsc +++ b/Network/Socket/Posix/IOVec.hsc @@ -3,15 +3,18 @@ -- | Support module for the POSIX writev system call. module Network.Socket.Posix.IOVec ( IOVec(..) + , withIOVec ) where +import Foreign.Marshal.Array (allocaArray) + import Network.Socket.Imports #include #include data IOVec = IOVec - { iovBase :: !(Ptr CChar) + { iovBase :: !(Ptr Word8) , iovLen :: !CSize } @@ -27,3 +30,17 @@ instance Storable IOVec where poke p iov = do (#poke struct iovec, iov_base) p (iovBase iov) (#poke struct iovec, iov_len) p (iovLen iov) + +-- | @withIOVec cs f@ executes the computation @f@, passing as argument a pair +-- consisting of a pointer to a temporarily allocated array of pointers to +-- IOVec made from @cs@ and the number of pointers (@length cs@). +-- /Unix only/. +withIOVec :: [(Ptr Word8, Int)] -> ((Ptr IOVec, Int) -> IO a) -> IO a +withIOVec cs f = + allocaArray csLen $ \aPtr -> do + zipWithM_ pokeIov (ptrs aPtr) cs + f (aPtr, csLen) + where + csLen = length cs + ptrs = iterate (`plusPtr` sizeOf (IOVec nullPtr 0)) + pokeIov ptr (sPtr, sLen) = poke ptr $ IOVec sPtr (fromIntegral sLen) From d5b5d0ed4a1c598d40c98a9813a4c3f7409d6a9b Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 10:35:32 +0900 Subject: [PATCH 16/48] recvMsg now takes Int instead of [Int]. --- Network/Socket/ByteString/IO.hsc | 38 ++++++-------------------- tests/Network/Socket/ByteStringSpec.hs | 34 ++++++----------------- 2 files changed, 17 insertions(+), 55 deletions(-) diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 15a32b1d..7c4946ec 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -262,39 +262,19 @@ sendMsg s addr bss cmsgs flags = do sendBufMsg s addr bufsizs cmsgs flags -- | Receive data from the socket using recvmsg(2). --- The receive buffers are created according to the second argument. --- If the length of received data is less than the total of --- the second argument, the buffers are truncated properly. --- So, only the received data can be seen. recvMsg :: Socket -- ^ Socket - -> [Int] -- ^ A list of length of data to be received + -> Int -- ^ The maximum length of data to be received -- If the total length is not large enough, -- 'MSG_TRUNC' is returned -> Int -- ^ The buffer size for control messages. -- If the length is not large enough, -- 'MSG_CTRUNC' is returned -> MsgFlag -- ^ Message flags - -> IO (SockAddr, [ByteString], [Cmsg], MsgFlag) -- ^ Source address, received data, control messages and message flags -recvMsg _ [] _ _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") -recvMsg s sizs clen flags = do - bss <- mapM newBS sizs - bufsizs <- mapM getBufsiz bss - (addr,len,cmsgs,flags') <- recvBufMsg s bufsizs clen flags - let total = sum sizs - let bss' = case len `compare` total of - EQ -> bss - LT -> trunc bss len - GT -> error "recvMsg" -- never reach - return (addr, bss', cmsgs, flags') - -newBS :: Int -> IO ByteString -newBS n = create n $ \ptr -> zeroMemory ptr (fromIntegral n) - -trunc :: [ByteString] -> Int -> [ByteString] -trunc bss0 siz0 = loop bss0 siz0 id - where - -- off is always 0 - loop (bs@(PS buf off len):bss) siz build - | siz >= len = loop bss (siz - len) (build . (bs :)) - | otherwise = build [PS buf off siz] - loop _ _ build = build [] + -> IO (SockAddr, ByteString, [Cmsg], MsgFlag) -- ^ Source address, received data, control messages and message flags +recvMsg s siz clen flags = do + bs <- create siz $ \ptr -> zeroMemory ptr (fromIntegral siz) + bufsiz <- getBufsiz bs + (addr,len,cmsgs,flags') <- recvBufMsg s [bufsiz] clen flags + let bs' | len < siz = let PS buf 0 _ = bs in PS buf 0 len + | otherwise = bs + return (addr, bs', cmsgs, flags') diff --git a/tests/Network/Socket/ByteStringSpec.hs b/tests/Network/Socket/ByteStringSpec.hs index ff557582..ffe7aeb6 100644 --- a/tests/Network/Socket/ByteStringSpec.hs +++ b/tests/Network/Socket/ByteStringSpec.hs @@ -198,8 +198,8 @@ spec = do describe "recvMsg" $ do it "works well" $ do let server sock = do - (_, msgs, cmsgs, flags) <- recvMsg sock [1024] 0 mempty - S.concat msgs `shouldBe` seg + (_, msg, cmsgs, flags) <- recvMsg sock 1024 0 mempty + msg `shouldBe` seg cmsgs `shouldBe` [] flags `shouldBe` mempty client sock addr = sendTo sock seg addr @@ -207,27 +207,9 @@ spec = do seg = C.pack "This is a test message" udpTest client server - it "receives message fragments" $ do - let server sock = do - (_, msgs, _, _) <- recvMsg sock [1,2,3,4] 0 mempty - S.concat msgs `shouldBe` S.take 10 seg - client sock addr = sendTo sock seg addr - - seg = C.pack "This is a test message" - udpTest client server - - it "receives message fragments with truncation" $ do - let server sock = do - (_, msgs, _, _) <- recvMsg sock [10,10,10,10] 0 mempty - msgs `shouldBe` ["0123456789", "0123456789", "012345"] - client sock addr = sendTo sock seg addr - - seg = C.pack "01234567890123456789012345" - udpTest client server - it "receives truncated flag" $ do let server sock = do - (_, _, _, flags) <- recvMsg sock [S.length seg - 2] 0 mempty + (_, _, _, flags) <- recvMsg sock (S.length seg - 2) 0 mempty flags .&. MSG_TRUNC `shouldBe` MSG_TRUNC client sock addr = sendTo sock seg addr @@ -236,9 +218,9 @@ spec = do it "peek" $ do let server sock = do - (_, msgs, _, _flags) <- recvMsg sock [1024] 0 MSG_PEEK + (_, msgs, _, _flags) <- recvMsg sock 1024 0 MSG_PEEK -- flags .&. MSG_PEEK `shouldBe` MSG_PEEK -- Mac only - (_, msgs', _, _) <- recvMsg sock [1024] 0 mempty + (_, msgs', _, _) <- recvMsg sock 1024 0 mempty msgs `shouldBe` msgs' client sock addr = sendTo sock seg addr @@ -250,7 +232,7 @@ spec = do setSocketOption sock RecvIPv4TTL 1 setSocketOption sock RecvIPv4TOS 1 setSocketOption sock RecvIPv4PktInfo 1 - (_, _, cmsgs, _) <- recvMsg sock [1024] 128 mempty + (_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty ((lookupAncillary ancillaryIPv4TTL cmsgs >>= ancillaryDecode) :: Maybe IPv4TTL) `shouldNotBe` Nothing ((lookupAncillary ancillaryIPv4TOS cmsgs >>= ancillaryDecode) :: Maybe IPv4TOS) `shouldNotBe` Nothing @@ -265,7 +247,7 @@ spec = do setSocketOption sock RecvIPv6HopLimit 1 setSocketOption sock RecvIPv6TClass 1 setSocketOption sock RecvIPv6PktInfo 1 - (_, _, cmsgs, _) <- recvMsg sock [1024] 128 mempty + (_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty ((lookupAncillary ancillaryIPv6HopLimit cmsgs >>= ancillaryDecode) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing ((lookupAncillary ancillaryIPv6TClass cmsgs >>= ancillaryDecode) :: Maybe IPv6TClass) `shouldNotBe` Nothing @@ -280,7 +262,7 @@ spec = do setSocketOption sock RecvIPv4TTL 1 setSocketOption sock RecvIPv4TOS 1 setSocketOption sock RecvIPv4PktInfo 1 - (_, _, _, flags) <- recvMsg sock [1024] 10 mempty + (_, _, _, flags) <- recvMsg sock 1024 10 mempty flags .&. MSG_CTRUNC `shouldBe` MSG_CTRUNC client sock addr = sendTo sock seg addr From 5aac55f0f71392acaddf36b9b38ec58622f71fb3 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 11:39:17 +0900 Subject: [PATCH 17/48] generalizing sendBufMsg and recvBufMsg. --- Network/Socket.hs | 2 +- Network/Socket/Address.hs | 3 +++ Network/Socket/Buffer.hsc | 36 +++++++++++++-------------- Network/Socket/ByteString/Internal.hs | 5 ++-- Network/Socket/Posix/Cmsg.hsc | 6 ++--- Network/Socket/Posix/MsgHdr.hsc | 7 +++--- Network/Socket/SockAddr.hs | 25 +++++++++++++++++++ 7 files changed, 55 insertions(+), 29 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index d5239ea5..20ef7c3d 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -231,7 +231,7 @@ module Network.Socket , maxListenQueue ) where -import Network.Socket.Buffer hiding (sendBufTo, recvBufFrom) +import Network.Socket.Buffer hiding (sendBufTo, recvBufFrom, sendBufMsg, recvBufMsg) import Network.Socket.Cbits import Network.Socket.Fcntl import Network.Socket.Flag diff --git a/Network/Socket/Address.hs b/Network/Socket/Address.hs index e9be7cb1..415b910c 100644 --- a/Network/Socket/Address.hs +++ b/Network/Socket/Address.hs @@ -16,6 +16,9 @@ module Network.Socket.Address ( -- * Sending and receiving data from a buffer , sendBufTo , recvBufFrom + -- * IO with ancillary data + , sendBufMsg + , recvBufMsg ) where import Network.Socket.ByteString.IO diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 91d5a2ee..3a783e96 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -186,15 +186,16 @@ mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError loc Nothing Nothing) "non-positive length" -- | Send data from the socket using sendmsg(2). -sendBufMsg :: Socket -- ^ Socket - -> SockAddr -- ^ Destination address -- fixme +sendBufMsg :: SocketAddress sa + => Socket -- ^ Socket + -> sa -- ^ Destination address -> [(Ptr Word8,Int)] -- ^ Data to be sent -> [Cmsg] -- ^ Control messages -> MsgFlag -- ^ Message flags -> IO Int -- ^ The length actually sent sendBufMsg _ _ [] _ _ = return 0 -sendBufMsg s addr bufsizs cmsgs flags = do - sz <- withSockAddr addr $ \addrPtr addrSize -> +sendBufMsg s sa bufsizs cmsgs flags = do + sz <- withSocketAddress sa $ \addrPtr addrSize -> withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do withCmsgs cmsgs $ \ctrlPtr ctrlLen -> do let msgHdr = MsgHdr { @@ -214,19 +215,16 @@ sendBufMsg s addr bufsizs cmsgs flags = do return $ fromIntegral sz -- | Receive data from the socket using recvmsg(2). --- The receive buffers are created according to the second argument. --- If the length of received data is less than the total of --- the second argument, the buffers are truncated properly. --- So, only the received data can be seen. -recvBufMsg :: Socket -- ^ Socket - -> [(Ptr Word8,Int)] -- ^ A list of a pair of buffer and its size - -- If the total length is not large enough, - -- 'MSG_TRUNC' is returned - -> Int -- ^ The buffer size for control messages. - -- If the length is not large enough, - -- 'MSG_CTRUNC' is returned - -> MsgFlag -- ^ Message flags - -> IO (SockAddr,Int,[Cmsg],MsgFlag) -- ^ Source address, received data, control messages and message flags +recvBufMsg :: SocketAddress sa + => Socket -- ^ Socket + -> [(Ptr Word8,Int)] -- ^ A list of a pair of buffer and its size. + -- If the total length is not large enough, + -- 'MSG_TRUNC' is returned + -> Int -- ^ The buffer size for control messages. + -- If the length is not large enough, + -- 'MSG_CTRUNC' is returned + -> MsgFlag -- ^ Message flags + -> IO (sa,Int,[Cmsg],MsgFlag) -- ^ Source address, received data, control messages and message flags recvBufMsg _ [] _ _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") recvBufMsg s bufsizs clen flags = do withNewSocketAddress $ \addrPtr addrSize -> @@ -268,6 +266,6 @@ foreign import CALLCONV SAFE_ON_WIN "recvfrom" c_recvfrom :: CInt -> Ptr a -> CSize -> CInt -> Ptr sa -> Ptr CInt -> IO CInt foreign import ccall unsafe "sendmsg" - c_sendmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt -- fixme CSsize + c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt -- fixme CSsize foreign import ccall unsafe "recvmsg" - c_recvmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt + c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt diff --git a/Network/Socket/ByteString/Internal.hs b/Network/Socket/ByteString/Internal.hs index 4e9ebbad..b8f8853e 100644 --- a/Network/Socket/ByteString/Internal.hs +++ b/Network/Socket/ByteString/Internal.hs @@ -28,6 +28,7 @@ import System.Posix.Types (CSsize(..)) import Network.Socket.Imports import Network.Socket.Posix.IOVec (IOVec) import Network.Socket.Posix.MsgHdr (MsgHdr) +import Network.Socket.Types #endif mkInvalidRecvArgError :: String -> IOError @@ -40,8 +41,8 @@ foreign import ccall unsafe "writev" c_writev :: CInt -> Ptr IOVec -> CInt -> IO CSsize foreign import ccall unsafe "sendmsg" - c_sendmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CSsize + c_sendmsg :: CInt -> Ptr (MsgHdr SockAddr) -> CInt -> IO CSsize foreign import ccall unsafe "recvmsg" - c_recvmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CSsize + c_recvmsg :: CInt -> Ptr (MsgHdr SockAddr) -> CInt -> IO CSsize #endif diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index b0f4df7c..bd00946c 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -72,7 +72,7 @@ encodeCmsg ctrlPtr (Cmsg (lvl,typ) (PS fptr off len)) = do dst <- c_cmsg_data ctrlPtr memcpy dst src len -parseCmsgs :: Ptr MsgHdr -> IO [Cmsg] +parseCmsgs :: SocketAddress sa => Ptr (MsgHdr sa) -> IO [Cmsg] parseCmsgs msgptr = do ptr <- c_cmsg_firsthdr msgptr loop ptr id @@ -92,10 +92,10 @@ decodeCmsg ptr = do Cmsg (lvl,typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz) foreign import ccall unsafe "cmsg_firsthdr" - c_cmsg_firsthdr :: Ptr MsgHdr -> IO (Ptr CmsgHdr) + c_cmsg_firsthdr :: Ptr (MsgHdr sa) -> IO (Ptr CmsgHdr) foreign import ccall unsafe "cmsg_nxthdr" - c_cmsg_nxthdr :: Ptr MsgHdr -> Ptr CmsgHdr -> IO (Ptr CmsgHdr) + c_cmsg_nxthdr :: Ptr (MsgHdr sa) -> Ptr CmsgHdr -> IO (Ptr CmsgHdr) foreign import ccall unsafe "cmsg_data" c_cmsg_data :: Ptr CmsgHdr -> IO (Ptr Word8) diff --git a/Network/Socket/Posix/MsgHdr.hsc b/Network/Socket/Posix/MsgHdr.hsc index df3fd198..3c812a06 100644 --- a/Network/Socket/Posix/MsgHdr.hsc +++ b/Network/Socket/Posix/MsgHdr.hsc @@ -10,12 +10,11 @@ module Network.Socket.Posix.MsgHdr import Network.Socket.Imports import Network.Socket.Internal (zeroMemory) -import Network.Socket.Types (SockAddr) import Network.Socket.Posix.IOVec (IOVec) -data MsgHdr = MsgHdr - { msgName :: !(Ptr SockAddr) +data MsgHdr sa = MsgHdr + { msgName :: !(Ptr sa) , msgNameLen :: !CUInt , msgIov :: !(Ptr IOVec) , msgIovLen :: !CSize @@ -24,7 +23,7 @@ data MsgHdr = MsgHdr , msgFlags :: !CInt } -instance Storable MsgHdr where +instance Storable (MsgHdr sa) where sizeOf _ = (#const sizeof(struct msghdr)) alignment _ = alignment (undefined :: CInt) diff --git a/Network/Socket/SockAddr.hs b/Network/Socket/SockAddr.hs index a16b2e2b..25049853 100644 --- a/Network/Socket/SockAddr.hs +++ b/Network/Socket/SockAddr.hs @@ -6,12 +6,16 @@ module Network.Socket.SockAddr ( , accept , sendBufTo , recvBufFrom + , sendBufMsg + , recvBufMsg ) where import qualified Network.Socket.Buffer as G import qualified Network.Socket.Name as G import qualified Network.Socket.Syscall as G +import Network.Socket.Flag import Network.Socket.Imports +import Network.Socket.Posix.Cmsg import Network.Socket.Types -- | Getting peer's 'SockAddr'. @@ -64,3 +68,24 @@ sendBufTo = G.sendBufTo -- GHC ticket #1129) recvBufFrom :: Socket -> Ptr a -> Int -> IO (Int, SockAddr) recvBufFrom = G.recvBufFrom + +-- | Send data from the socket using sendmsg(2). +sendBufMsg :: Socket -- ^ Socket + -> SockAddr -- ^ Destination address + -> [(Ptr Word8,Int)] -- ^ Data to be sent + -> [Cmsg] -- ^ Control messages + -> MsgFlag -- ^ Message flags + -> IO Int -- ^ The length actually sent +sendBufMsg = G.sendBufMsg + +-- | Receive data from the socket using recvmsg(2). +recvBufMsg :: Socket -- ^ Socket + -> [(Ptr Word8,Int)] -- ^ A list of a pair of buffer and its size. + -- If the total length is not large enough, + -- 'MSG_TRUNC' is returned + -> Int -- ^ The buffer size for control messages. + -- If the length is not large enough, + -- 'MSG_CTRUNC' is returned + -> MsgFlag -- ^ Message flags + -> IO (SockAddr,Int,[Cmsg],MsgFlag) -- ^ Source address, received data, control messages and message flags +recvBufMsg = G.recvBufMsg From 58ca30c5eb03bc784560a4d38b6218d3dd5af097 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 13:28:47 +0900 Subject: [PATCH 18/48] implementing sendFd and recvFd on sendBufMsg and recvBufMsg. --- Network/Socket/Buffer.hsc | 6 +- Network/Socket/Posix/Ancillary.hsc | 11 +++ Network/Socket/Posix/IOVec.hsc | 1 + Network/Socket/Types.hsc | 5 +- Network/Socket/Unix.hsc | 31 +++++-- cbits/ancilData.c | 132 ----------------------------- network.cabal | 4 +- 7 files changed, 43 insertions(+), 147 deletions(-) delete mode 100644 cbits/ancilData.c diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 3a783e96..513725c8 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -193,7 +193,6 @@ sendBufMsg :: SocketAddress sa -> [Cmsg] -- ^ Control messages -> MsgFlag -- ^ Message flags -> IO Int -- ^ The length actually sent -sendBufMsg _ _ [] _ _ = return 0 sendBufMsg s sa bufsizs cmsgs flags = do sz <- withSocketAddress sa $ \addrPtr addrSize -> withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do @@ -210,7 +209,7 @@ sendBufMsg s sa bufsizs cmsgs flags = do cflags = fromMsgFlag flags withFdSocket s $ \fd -> with msgHdr $ \msgHdrPtr -> - throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMsg" $ + throwSocketErrorWaitWrite s "Network.Socket.Buffer.sendMsg" $ c_sendmsg fd msgHdrPtr cflags return $ fromIntegral sz @@ -225,7 +224,6 @@ recvBufMsg :: SocketAddress sa -- 'MSG_CTRUNC' is returned -> MsgFlag -- ^ Message flags -> IO (sa,Int,[Cmsg],MsgFlag) -- ^ Source address, received data, control messages and message flags -recvBufMsg _ [] _ _ = ioError (mkInvalidRecvArgError "Network.Socket.ByteString.recvMsg") recvBufMsg s bufsizs clen flags = do withNewSocketAddress $ \addrPtr addrSize -> allocaBytes clen $ \ctrlPtr -> @@ -242,7 +240,7 @@ recvBufMsg s bufsizs clen flags = do cflags = fromMsgFlag flags withFdSocket s $ \fd -> do with msgHdr $ \msgHdrPtr -> do - len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.ByteString.recvmg" (c_recvmsg fd msgHdrPtr cflags) + len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" (c_recvmsg fd msgHdrPtr cflags) sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s hdr <- peek msgHdrPtr cmsgs <- parseCmsgs msgHdrPtr diff --git a/Network/Socket/Posix/Ancillary.hsc b/Network/Socket/Posix/Ancillary.hsc index cc8cb54b..448c844d 100644 --- a/Network/Socket/Posix/Ancillary.hsc +++ b/Network/Socket/Posix/Ancillary.hsc @@ -11,6 +11,7 @@ module Network.Socket.Posix.Ancillary where import Data.ByteString.Internal import Foreign.ForeignPtr import System.IO.Unsafe (unsafeDupablePerformIO) +import System.Posix.Types (Fd(..)) import Network.Socket.Imports import Network.Socket.Posix.Cmsg @@ -53,6 +54,10 @@ ancillaryIPv4PktInfo = ((#const IPPROTO_IP), (#const IP_PKTINFO)) ancillaryIPv6PktInfo :: AncillaryID ancillaryIPv6PktInfo = ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) +-- | The identifier for 'Fd'. +ancillaryFd :: AncillaryID +ancillaryFd = ((#const SOL_SOCKET), (#const SCM_RIGHTS)) + ---------------------------------------------------------------- -- | Looking up ancillary data. The following shows an example usage: @@ -215,3 +220,9 @@ unpackIPv6PktInfo (PS fptr off len) return $ Just $ IPv6PktInfo (fromIntegral n) ha6 where siz = (#size struct in6_pktinfo) + +---------------------------------------------------------------- + +instance Ancillary Fd where + ancillaryEncode (Fd fd) = Cmsg ancillaryFd $ packCInt $ fromIntegral fd + ancillaryDecode (Cmsg _ bs) = Fd . fromIntegral <$> unpackCInt bs diff --git a/Network/Socket/Posix/IOVec.hsc b/Network/Socket/Posix/IOVec.hsc index 0f5f8a40..90d0b5ae 100644 --- a/Network/Socket/Posix/IOVec.hsc +++ b/Network/Socket/Posix/IOVec.hsc @@ -36,6 +36,7 @@ instance Storable IOVec where -- IOVec made from @cs@ and the number of pointers (@length cs@). -- /Unix only/. withIOVec :: [(Ptr Word8, Int)] -> ((Ptr IOVec, Int) -> IO a) -> IO a +withIOVec [] f = f (nullPtr, 0) withIOVec cs f = allocaArray csLen $ \aPtr -> do zipWithM_ pokeIov (ptrs aPtr) cs diff --git a/Network/Socket/Types.hsc b/Network/Socket/Types.hsc index c854f823..3eb2d465 100644 --- a/Network/Socket/Types.hsc +++ b/Network/Socket/Types.hsc @@ -938,7 +938,10 @@ sockaddrStorageLen = 128 withSocketAddress :: SocketAddress sa => sa -> (Ptr sa -> Int -> IO a) -> IO a withSocketAddress addr f = do let sz = sizeOfSocketAddress addr - allocaBytes sz $ \p -> pokeSocketAddress p addr >> f (castPtr p) sz + if sz == 0 then + f nullPtr 0 + else + allocaBytes sz $ \p -> pokeSocketAddress p addr >> f (castPtr p) sz withNewSocketAddress :: SocketAddress sa => (Ptr sa -> Int -> IO a) -> IO a withNewSocketAddress f = allocaBytes sockaddrStorageLen $ \ptr -> do diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index ebd03a2f..06ec5856 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -13,7 +13,11 @@ module Network.Socket.Unix ( , getPeerEid ) where +import System.Posix.Types (Fd(..)) + +import Network.Socket.Buffer import Network.Socket.Imports +import Network.Socket.Posix.Ancillary import Network.Socket.Types #if defined(HAVE_GETPEEREID) @@ -122,15 +126,23 @@ isUnixDomainSocketAvailable = True isUnixDomainSocketAvailable = False #endif +data NullSockAddr = NullSockAddr + +instance SocketAddress NullSockAddr where + sizeOfSocketAddress _ = 0 + peekSocketAddress _ = return NullSockAddr + pokeSocketAddress _ _ = return () + -- | Send a file descriptor over a UNIX-domain socket. -- Use this function in the case where 'isUnixDomainSocketAvailable' is -- 'True'. sendFd :: Socket -> CInt -> IO () #if defined(DOMAIN_SOCKET_SUPPORT) -sendFd s outfd = void $ do - withFdSocket s $ \fd -> - throwSocketErrorWaitWrite s "Network.Socket.sendFd" $ c_sendFd fd outfd -foreign import ccall SAFE_ON_WIN "sendFd" c_sendFd :: CInt -> CInt -> IO CInt +sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do + let cmsg = ancillaryEncode $ Fd outfd + sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty + where + dummyBufSize = 1 #else sendFd _ _ = error "Network.Socket.sendFd" #endif @@ -142,10 +154,13 @@ sendFd _ _ = error "Network.Socket.sendFd" -- 'True'. recvFd :: Socket -> IO CInt #if defined(DOMAIN_SOCKET_SUPPORT) -recvFd s = do - withFdSocket s $ \fd -> - throwSocketErrorWaitRead s "Network.Socket.recvFd" $ c_recvFd fd -foreign import ccall SAFE_ON_WIN "recvFd" c_recvFd :: CInt -> IO CInt +recvFd s = allocaBytes dummyBufSize $ \buf -> do + (NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty + case (lookupAncillary ancillaryFd cmsgs >>= ancillaryDecode) :: Maybe Fd of + Nothing -> return (-1) + Just (Fd fd) -> return fd + where + dummyBufSize = 16 #else recvFd _ = error "Network.Socket.recvFd" #endif diff --git a/cbits/ancilData.c b/cbits/ancilData.c deleted file mode 100644 index 7d6bef2f..00000000 --- a/cbits/ancilData.c +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright(c), 2002 The GHC Team. - */ - -#ifdef aix_HOST_OS -#define _LINUX_SOURCE_COMPAT -// Required to get CMSG_SPACE/CMSG_LEN macros. See #265. -// Alternative is to #define COMPAT_43 and use the -// HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS code instead, but that means -// fiddling with the configure script too. -#endif - -#include "HsNet.h" -#include - -#if HAVE_STRUCT_MSGHDR_MSG_CONTROL || HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS /* until end */ - -/* - * Support for transmitting file descriptors. - * - * - */ - - -/* - * sendmsg() and recvmsg() wrappers for transmitting - * ancillary socket data. - * - * Doesn't provide the full generality of either, specifically: - * - * - no support for scattered read/writes. - * - only possible to send one ancillary chunk of data at a time. - */ - -int -sendFd(int sock, - int outfd) -{ - struct msghdr msg = {0}; - struct iovec iov[1]; - char buf[2]; -#if HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS - msg.msg_accrights = (void*)&outfd; - msg.msg_accrightslen = sizeof(int); -#else - struct cmsghdr *cmsg; - char ancBuffer[CMSG_SPACE(sizeof(int))]; - char* dPtr; - - msg.msg_control = ancBuffer; - msg.msg_controllen = sizeof(ancBuffer); - - cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - dPtr = (char*)CMSG_DATA(cmsg); - - *(int*)dPtr = outfd; - msg.msg_controllen = cmsg->cmsg_len; -#endif - - buf[0] = 0; buf[1] = '\0'; - iov[0].iov_base = buf; - iov[0].iov_len = 2; - - msg.msg_iov = iov; - msg.msg_iovlen = 1; - - return sendmsg(sock,&msg,0); -} - -int -recvFd(int sock) -{ - struct msghdr msg = {0}; - char duffBuf[10]; - int rc; - int len = sizeof(int); - struct iovec iov[1]; -#if HAVE_STRUCT_MSGHDR_MSG_CONTROL - struct cmsghdr *cmsg = NULL; - struct cmsghdr *cptr; -#else - int* fdBuffer; -#endif - int fd; - - iov[0].iov_base = duffBuf; - iov[0].iov_len = sizeof(duffBuf); - msg.msg_iov = iov; - msg.msg_iovlen = 1; - -#if HAVE_STRUCT_MSGHDR_MSG_CONTROL - cmsg = (struct cmsghdr*)malloc(CMSG_SPACE(len)); - if (cmsg==NULL) { - return -1; - } - - msg.msg_control = (void *)cmsg; - msg.msg_controllen = CMSG_LEN(len); -#else - fdBuffer = (int*)malloc(len); - if (fdBuffer) { - msg.msg_accrights = (void *)fdBuffer; - } else { - return -1; - } - msg.msg_accrightslen = len; -#endif - - if ((rc = recvmsg(sock,&msg,0)) < 0) { -#if HAVE_STRUCT_MSGHDR_MSG_CONTROL - free(cmsg); -#else - free(fdBuffer); -#endif - return rc; - } - -#if HAVE_STRUCT_MSGHDR_MSG_CONTROL - cptr = (struct cmsghdr*)CMSG_FIRSTHDR(&msg); - fd = *(int*)CMSG_DATA(cptr); - free(cmsg); -#else - fd = *(int*)fdBuffer; - free(fdBuffer); -#endif - return fd; -} - -#endif diff --git a/network.cabal b/network.cabal index 32f0b150..d359bb73 100644 --- a/network.cabal +++ b/network.cabal @@ -43,7 +43,7 @@ extra-source-files: configure.ac configure include/HsNetworkConfig.h.in include/HsNet.h include/HsNetDef.h -- C sources only used on some systems - cbits/ancilData.c cbits/asyncAccept.c cbits/initWinSock.c + cbits/asyncAccept.c cbits/initWinSock.c cbits/winSockErr.c cbits/cmsg.c homepage: https://github.com/haskell/network bug-reports: https://github.com/haskell/network/issues @@ -102,7 +102,7 @@ library Network.Socket.Posix.Cmsg Network.Socket.Posix.IOVec Network.Socket.Posix.MsgHdr - c-sources: cbits/ancilData.c cbits/cmsg.c + c-sources: cbits/cmsg.c if os(solaris) extra-libraries: nsl, socket From e3c38f8a6bc3146fd9833ee53354319b4036cbea Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 13:58:58 +0900 Subject: [PATCH 19/48] Semigroup hack for GHC 8.0 and 8.2. --- Network/Socket/Flag.hsc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/Network/Socket/Flag.hsc b/Network/Socket/Flag.hsc index 40ef2d96..a1cd1da8 100644 --- a/Network/Socket/Flag.hsc +++ b/Network/Socket/Flag.hsc @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PatternSynonyms #-} @@ -5,17 +6,22 @@ module Network.Socket.Flag where +import qualified Data.Semigroup as Sem + import Network.Socket.Imports -- | Message flags. To combine flags, use '(<>)'. newtype MsgFlag = MsgFlag { fromMsgFlag :: CInt } deriving (Show, Eq, Ord, Num, Bits) -instance Semigroup MsgFlag where - (<>) = (.|.) +instance Sem.Semigroup MsgFlag where + (<>) = (.|.) instance Monoid MsgFlag where - mempty = 0 + mempty = 0 +#if !(MIN_VERSION_base(4,11,0)) + mappend = (Sem.<>) +#endif -- | Send or receive OOB(out-of-bound) data. pattern MSG_OOB :: MsgFlag From ef226a888d5ce0fb68c157f929050c540bc460ce Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 15:30:27 +0900 Subject: [PATCH 20/48] let Ancillary be a subclass of Storable. --- Network/Socket.hs | 4 +- Network/Socket/Posix/Ancillary.hsc | 147 ++++++++++------------------- 2 files changed, 51 insertions(+), 100 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index 20ef7c3d..5b96c65f 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -212,7 +212,9 @@ module Network.Socket , Cmsg(..) -- ** Ancillary data , Ancillary(..) - , AncillaryID + , AncillaryId + , ancillaryEncode + , ancillaryDecode , ancillaryIPv4TTL , ancillaryIPv6HopLimit , ancillaryIPv4TOS diff --git a/Network/Socket/Posix/Ancillary.hsc b/Network/Socket/Posix/Ancillary.hsc index 448c844d..16b47413 100644 --- a/Network/Socket/Posix/Ancillary.hsc +++ b/Network/Socket/Posix/Ancillary.hsc @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} module Network.Socket.Posix.Ancillary where @@ -20,10 +21,10 @@ import Network.Socket.Types ---------------------------------------------------------------- -- | Identifier of ancillary data. A pair of level and type. -type AncillaryID = (CInt, CInt) +type AncillaryId = (CInt, CInt) -- | The identifier for 'IPv4TTL'. -ancillaryIPv4TTL :: AncillaryID +ancillaryIPv4TTL :: AncillaryId #if defined(darwin_HOST_OS) ancillaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_RECVTTL)) #else @@ -31,11 +32,11 @@ ancillaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_TTL)) #endif -- | The identifier for 'IPv6HopLimit'. -ancillaryIPv6HopLimit :: AncillaryID +ancillaryIPv6HopLimit :: AncillaryId ancillaryIPv6HopLimit = ((#const IPPROTO_IPV6), (#const IPV6_HOPLIMIT)) -- | The identifier for 'IPv4TOS'. -ancillaryIPv4TOS :: AncillaryID +ancillaryIPv4TOS :: AncillaryId #if defined(darwin_HOST_OS) ancillaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_RECVTOS)) #else @@ -43,19 +44,19 @@ ancillaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_TOS)) #endif -- | The identifier for 'IPv6TClass'. -ancillaryIPv6TClass :: AncillaryID +ancillaryIPv6TClass :: AncillaryId ancillaryIPv6TClass = ((#const IPPROTO_IPV6), (#const IPV6_TCLASS)) -- | The identifier for 'IPv4PktInfo'. -ancillaryIPv4PktInfo :: AncillaryID +ancillaryIPv4PktInfo :: AncillaryId ancillaryIPv4PktInfo = ((#const IPPROTO_IP), (#const IP_PKTINFO)) -- | The identifier for 'IPv6PktInfo'. -ancillaryIPv6PktInfo :: AncillaryID +ancillaryIPv6PktInfo :: AncillaryId ancillaryIPv6PktInfo = ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) -- | The identifier for 'Fd'. -ancillaryFd :: AncillaryID +ancillaryFd :: AncillaryId ancillaryFd = ((#const SOL_SOCKET), (#const SCM_RIGHTS)) ---------------------------------------------------------------- @@ -63,7 +64,7 @@ ancillaryFd = ((#const SOL_SOCKET), (#const SCM_RIGHTS)) -- | Looking up ancillary data. The following shows an example usage: -- -- > (lookupAncillary ancillaryIPv4TOS cmsgs >>= ancillaryDecode) :: Maybe IPv4TOS -lookupAncillary :: AncillaryID -> [Cmsg] -> Maybe Cmsg +lookupAncillary :: AncillaryId -> [Cmsg] -> Maybe Cmsg lookupAncillary _ [] = Nothing lookupAncillary aid (cmsg@(Cmsg cid _):cmsgs) | aid == cid = Just cmsg @@ -72,121 +73,81 @@ lookupAncillary aid (cmsg@(Cmsg cid _):cmsgs) ---------------------------------------------------------------- -- | A class to encode and decode ancillary data. -class Ancillary a where - ancillaryEncode :: a -> Cmsg - ancillaryDecode :: Cmsg -> Maybe a - ----------------------------------------------------------------- - -packCInt :: CInt -> ByteString -packCInt n = unsafeDupablePerformIO $ create siz $ \p0 -> do - let p = castPtr p0 :: Ptr CInt - poke p n +class Storable a => Ancillary a where + ancillaryId :: a -> AncillaryId + +ancillaryEncode :: Ancillary a => a -> Cmsg +ancillaryEncode x = unsafeDupablePerformIO $ do + bs <- create siz $ \p0 -> do + let p = castPtr p0 + poke p x + return $ Cmsg (ancillaryId x) bs where - siz = (#size int) + siz = sizeOf x -unpackCInt :: ByteString -> Maybe CInt -unpackCInt (PS fptr off len) +ancillaryDecode :: forall a . Storable a => Cmsg -> Maybe a +ancillaryDecode (Cmsg _ (PS fptr off len)) | len < siz = Nothing | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do - let p = castPtr (p0 `plusPtr` off) :: Ptr CInt + let p = castPtr (p0 `plusPtr` off) Just <$> peek p where - siz = (#size int) - -packCChar :: CChar -> ByteString -packCChar n = unsafeDupablePerformIO $ create siz $ \p0 -> do - let p = castPtr p0 :: Ptr CChar - poke p n - where - siz = (#size char) - -unpackCChar :: ByteString -> Maybe CChar -unpackCChar (PS fptr off len) - | len < siz = Nothing - | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do - let p = castPtr (p0 `plusPtr` off) :: Ptr CChar - Just <$> peek p - where - siz = (#size char) + siz = sizeOf (undefined :: a) ---------------------------------------------------------------- -- | Time to live of IPv4. -newtype IPv4TTL = IPv4TTL Int deriving (Eq, Show) +newtype IPv4TTL = IPv4TTL CChar deriving (Eq, Show, Storable) instance Ancillary IPv4TTL where -#if defined(darwin_HOST_OS) - ancillaryEncode (IPv4TTL ttl) = Cmsg ancillaryIPv4TTL $ packCChar $ fromIntegral ttl -#else - ancillaryEncode (IPv4TTL ttl) = Cmsg ancillaryIPv4TTL $ packCInt $ fromIntegral ttl -#endif -#if defined(darwin_HOST_OS) - ancillaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCChar bs -#else - ancillaryDecode (Cmsg _ bs) = IPv4TTL . fromIntegral <$> unpackCInt bs -#endif + ancillaryId _ = ancillaryIPv4TTL ---------------------------------------------------------------- -- | Hop limit of IPv6. -newtype IPv6HopLimit = IPv6HopLimit Int deriving (Eq, Show) +newtype IPv6HopLimit = IPv6HopLimit CInt deriving (Eq, Show, Storable) instance Ancillary IPv6HopLimit where - ancillaryEncode (IPv6HopLimit ttl) = Cmsg ancillaryIPv6HopLimit $ packCInt $ fromIntegral ttl - ancillaryDecode (Cmsg _ bs) = IPv6HopLimit . fromIntegral <$> unpackCInt bs + ancillaryId _ = ancillaryIPv6HopLimit ---------------------------------------------------------------- -- | TOS of IPv4. -newtype IPv4TOS = IPv4TOS Int deriving (Eq, Show) +newtype IPv4TOS = IPv4TOS CChar deriving (Eq, Show, Storable) instance Ancillary IPv4TOS where - ancillaryEncode (IPv4TOS ttl) = Cmsg ancillaryIPv4TOS $ packCChar $ fromIntegral ttl - ancillaryDecode (Cmsg _ bs) = IPv4TOS . fromIntegral <$> unpackCChar bs + ancillaryId _ = ancillaryIPv4TOS ---------------------------------------------------------------- -- | Traffic class of IPv6. -newtype IPv6TClass = IPv6TClass Int deriving (Eq, Show) +newtype IPv6TClass = IPv6TClass CInt deriving (Eq, Show, Storable) instance Ancillary IPv6TClass where - ancillaryEncode (IPv6TClass ttl) = Cmsg ancillaryIPv6TClass $ packCInt $ fromIntegral ttl - ancillaryDecode (Cmsg _ bs) = IPv6TClass . fromIntegral <$> unpackCInt bs + ancillaryId _ = ancillaryIPv6TClass ---------------------------------------------------------------- -- | Network interface ID and local IPv4 address. -data IPv4PktInfo = IPv4PktInfo Int HostAddress deriving (Eq) +data IPv4PktInfo = IPv4PktInfo CInt HostAddress deriving (Eq) instance Show IPv4PktInfo where show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) instance Ancillary IPv4PktInfo where - ancillaryEncode pktinfo = Cmsg ancillaryIPv4PktInfo $ packIPv4PktInfo pktinfo - ancillaryDecode (Cmsg _ bs) = unpackIPv4PktInfo bs + ancillaryId _ = ancillaryIPv4PktInfo -{-# NOINLINE packIPv4PktInfo #-} -packIPv4PktInfo :: IPv4PktInfo -> ByteString -packIPv4PktInfo (IPv4PktInfo n ha) = unsafeDupablePerformIO $ - create siz $ \p -> do +instance Storable IPv4PktInfo where + sizeOf _ = (#size struct in_pktinfo) + alignment = undefined + poke p (IPv4PktInfo n ha) = do (#poke struct in_pktinfo, ipi_ifindex) p (fromIntegral n :: CInt) (#poke struct in_pktinfo, ipi_spec_dst) p (0 :: CInt) (#poke struct in_pktinfo, ipi_addr) p ha - where - siz = (#size struct in_pktinfo) - -{-# NOINLINE unpackIPv4PktInfo #-} -unpackIPv4PktInfo :: ByteString -> Maybe IPv4PktInfo -unpackIPv4PktInfo (PS fptr off len) - | len < siz = Nothing - | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do - let p = p0 `plusPtr` off + peek p = do n <- (#peek struct in_pktinfo, ipi_ifindex) p ha <- (#peek struct in_pktinfo, ipi_addr) p - return $ Just $ IPv4PktInfo n ha - where - siz = (#size struct in_pktinfo) + return $ IPv4PktInfo n ha ---------------------------------------------------------------- @@ -197,32 +158,20 @@ instance Show IPv6PktInfo where show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6) instance Ancillary IPv6PktInfo where - ancillaryEncode pktinfo = Cmsg ancillaryIPv6PktInfo $ packIPv6PktInfo pktinfo - ancillaryDecode (Cmsg _ bs) = unpackIPv6PktInfo bs + ancillaryId _ = ancillaryIPv6PktInfo -{-# NOINLINE packIPv6PktInfo #-} -packIPv6PktInfo :: IPv6PktInfo -> ByteString -packIPv6PktInfo (IPv6PktInfo n ha6) = unsafeDupablePerformIO $ - create siz $ \p -> do +instance Storable IPv6PktInfo where + sizeOf _ = (#size struct in6_pktinfo) + alignment = undefined + poke p (IPv6PktInfo n ha6) = do (#poke struct in6_pktinfo, ipi6_ifindex) p (fromIntegral n :: CInt) (#poke struct in6_pktinfo, ipi6_addr) p (In6Addr ha6) - where - siz = (#size struct in6_pktinfo) - -{-# NOINLINE unpackIPv6PktInfo #-} -unpackIPv6PktInfo :: ByteString -> Maybe IPv6PktInfo -unpackIPv6PktInfo (PS fptr off len) - | len < siz = Nothing - | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do - let p = p0 `plusPtr` off + peek p = do In6Addr ha6 <- (#peek struct in6_pktinfo, ipi6_addr) p n :: CInt <- (#peek struct in6_pktinfo, ipi6_ifindex) p - return $ Just $ IPv6PktInfo (fromIntegral n) ha6 - where - siz = (#size struct in6_pktinfo) + return $ IPv6PktInfo (fromIntegral n) ha6 ---------------------------------------------------------------- instance Ancillary Fd where - ancillaryEncode (Fd fd) = Cmsg ancillaryFd $ packCInt $ fromIntegral fd - ancillaryDecode (Cmsg _ bs) = Fd . fromIntegral <$> unpackCInt bs + ancillaryId _ = ancillaryFd From b1bdfba114f506fc18932e61e932e4cc04b04c6a Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 16:31:38 +0900 Subject: [PATCH 21/48] pattern synopsis for SocketOption. --- Network/Socket.hs | 9 +- Network/Socket/Options.hsc | 293 ++++++++++++++++++++++--------------- 2 files changed, 184 insertions(+), 118 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index 5b96c65f..beddd0a1 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -131,7 +131,14 @@ module Network.Socket , ShutdownCmd(..) -- * Socket options - , SocketOption(..) + , SocketOption(Debug,ReuseAddr,Type,SoError,DontRoute,Broadcast + ,SendBuffer,RecvBuffer,KeepAlive,OOBInline,TimeToLive + ,MaxSegment,NoDelay,Cork,Linger,ReusePort + ,RecvLowWater,SendLowWater,RecvTimeOut,SendTimeOut + ,UseLoopBack,UserTimeout,IPv6Only + ,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo + ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo + ,CustomSockOpt) , isSupportedSocketOption , getSocketOption , setSocketOption diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 82217323..27f52e7a 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -1,11 +1,19 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE PatternSynonyms #-} #include "HsNet.h" ##include "HsNetDef.h" module Network.Socket.Options ( - SocketOption(..) + SocketOption(Debug,ReuseAddr,Type,SoError,DontRoute,Broadcast + ,SendBuffer,RecvBuffer,KeepAlive,OOBInline,TimeToLive + ,MaxSegment,NoDelay,Cork,Linger,ReusePort + ,RecvLowWater,SendLowWater,RecvTimeOut,SendTimeOut + ,UseLoopBack,UserTimeout,IPv6Only + ,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo + ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo + ,CustomSockOpt) , isSupportedSocketOption , getSocketType , getSocketOption @@ -28,42 +36,11 @@ import Network.Socket.Types -- -- The existence of a constructor does not imply that the relevant option -- is supported on your system: see 'isSupportedSocketOption' -data SocketOption - = Debug -- ^ SO_DEBUG - | ReuseAddr -- ^ SO_REUSEADDR - | Type -- ^ SO_TYPE - | SoError -- ^ SO_ERROR - | DontRoute -- ^ SO_DONTROUTE - | Broadcast -- ^ SO_BROADCAST - | SendBuffer -- ^ SO_SNDBUF - | RecvBuffer -- ^ SO_RCVBUF - | KeepAlive -- ^ SO_KEEPALIVE - | OOBInline -- ^ SO_OOBINLINE - | TimeToLive -- ^ IP_TTL - | MaxSegment -- ^ TCP_MAXSEG - | NoDelay -- ^ TCP_NODELAY - | Cork -- ^ TCP_CORK - | Linger -- ^ SO_LINGER: timeout in seconds, 0 means disabling/disabled. - | ReusePort -- ^ SO_REUSEPORT - | RecvLowWater -- ^ SO_RCVLOWAT - | SendLowWater -- ^ SO_SNDLOWAT - | RecvTimeOut -- ^ SO_RCVTIMEO: this does not work at this moment. - | SendTimeOut -- ^ SO_SNDTIMEO: this does not work at this moment. - | UseLoopBack -- ^ SO_USELOOPBACK - | UserTimeout -- ^ TCP_USER_TIMEOUT - | IPv6Only -- ^ IPV6_V6ONLY: don't use this on OpenBSD. - | RecvIPv4TTL -- ^ Receiving IPv4 TTL. - | RecvIPv4TOS -- ^ Receiving IPv4 TOS. - | RecvIPv4PktInfo -- ^ Receiving IP_PKTINFO (struct in_pktinfo). - | RecvIPv6HopLimit -- ^ Receiving IPv6 hop limit. - | RecvIPv6TClass -- ^ Receiving IPv6 traffic class. - | RecvIPv6PktInfo -- ^ Receiving IPV6_PKTINFO (struct in6_pktinfo). - | CustomSockOpt (CInt, CInt) - deriving (Show, Typeable) +newtype SocketOption = SockOpt (CInt,CInt) deriving (Eq, Show) -- | Does the 'SocketOption' exist on this system? isSupportedSocketOption :: SocketOption -> Bool -isSupportedSocketOption = isJust . packSocketOption +isSupportedSocketOption opt = opt /= SockOpt (-1,-1) -- | Get the 'SocketType' of an active socket. -- @@ -72,144 +49,228 @@ getSocketType :: Socket -> IO SocketType getSocketType s = (fromMaybe NoSocketType . unpackSocketType . fromIntegral) <$> getSocketOption s Type --- | For a socket option, return Just (level, value) where level is the --- corresponding C option level constant (e.g. SOL_SOCKET) and value is --- the option constant itself (e.g. SO_DEBUG) --- If either constant does not exist, return Nothing. -packSocketOption :: SocketOption -> Maybe (CInt, CInt) -packSocketOption so = - -- The Just here is a hack to disable GHC's overlapping pattern detection: - -- the problem is if all constants are present, the fallback pattern is - -- redundant, but if they aren't then it isn't. Hence we introduce an - -- extra pattern (Nothing) that can't possibly happen, so that the - -- fallback is always (in principle) necessary. - -- I feel a little bad for including this, but such are the sacrifices we - -- make while working with CPP - excluding the fallback pattern correctly - -- would be a serious nuisance. - -- (NB: comments elsewhere in this file refer to this one) - case Just so of #ifdef SOL_SOCKET +-- | SO_DEBUG +pattern Debug :: SocketOption #ifdef SO_DEBUG - Just Debug -> Just ((#const SOL_SOCKET), (#const SO_DEBUG)) +pattern Debug = SockOpt ((#const SOL_SOCKET), (#const SO_DEBUG)) +#else +pattern Debug = SockOpt (-1,-1) #endif +-- | SO_REUSEADDR +pattern ReuseAddr :: SocketOption #ifdef SO_REUSEADDR - Just ReuseAddr -> Just ((#const SOL_SOCKET), (#const SO_REUSEADDR)) +pattern ReuseAddr = SockOpt ((#const SOL_SOCKET), (#const SO_REUSEADDR)) +#else +pattern ReuseAddr = SockOpt (-1,-1) #endif +-- | SO_TYPE +pattern Type :: SocketOption #ifdef SO_TYPE - Just Type -> Just ((#const SOL_SOCKET), (#const SO_TYPE)) +pattern Type = SockOpt ((#const SOL_SOCKET), (#const SO_TYPE)) +#else +pattern Type = SockOpt (-1,-1) #endif +-- | SO_ERROR +pattern SoError :: SocketOption #ifdef SO_ERROR - Just SoError -> Just ((#const SOL_SOCKET), (#const SO_ERROR)) +pattern SoError = SockOpt ((#const SOL_SOCKET), (#const SO_ERROR)) +#else +pattern SoError = SockOpt (-1,-1) #endif +-- | SO_DONTROUTE +pattern DontRoute :: SocketOption #ifdef SO_DONTROUTE - Just DontRoute -> Just ((#const SOL_SOCKET), (#const SO_DONTROUTE)) +pattern DontRoute = SockOpt ((#const SOL_SOCKET), (#const SO_DONTROUTE)) +#else +pattern DontRoute = SockOpt (-1,-1) #endif +-- | SO_BROADCAST +pattern Broadcast :: SocketOption #ifdef SO_BROADCAST - Just Broadcast -> Just ((#const SOL_SOCKET), (#const SO_BROADCAST)) +pattern Broadcast = SockOpt ((#const SOL_SOCKET), (#const SO_BROADCAST)) +#else +pattern Broadcast = SockOpt (-1,-1) #endif +-- | SO_SNDBUF +pattern SendBuffer :: SocketOption #ifdef SO_SNDBUF - Just SendBuffer -> Just ((#const SOL_SOCKET), (#const SO_SNDBUF)) +pattern SendBuffer = SockOpt ((#const SOL_SOCKET), (#const SO_SNDBUF)) +#else +pattern SendBuffer = SockOpt (-1,-1) #endif +-- | SO_RCVBUF +pattern RecvBuffer :: SocketOption #ifdef SO_RCVBUF - Just RecvBuffer -> Just ((#const SOL_SOCKET), (#const SO_RCVBUF)) +pattern RecvBuffer = SockOpt ((#const SOL_SOCKET), (#const SO_RCVBUF)) +#else +pattern RecvBuffer = SockOpt (-1,-1) #endif +-- | SO_KEEPALIVE +pattern KeepAlive :: SocketOption #ifdef SO_KEEPALIVE - Just KeepAlive -> Just ((#const SOL_SOCKET), (#const SO_KEEPALIVE)) +pattern KeepAlive = SockOpt ((#const SOL_SOCKET), (#const SO_KEEPALIVE)) +#else +pattern KeepAlive = SockOpt (-1,-1) #endif +-- | SO_OOBINLINE +pattern OOBInline :: SocketOption #ifdef SO_OOBINLINE - Just OOBInline -> Just ((#const SOL_SOCKET), (#const SO_OOBINLINE)) +pattern OOBInline = SockOpt ((#const SOL_SOCKET), (#const SO_OOBINLINE)) +#else +pattern OOBINLINE = SockOpt (-1,-1) #endif +-- | SO_LINGER: timeout in seconds, 0 means disabling/disabled. +pattern Linger :: SocketOption #ifdef SO_LINGER - Just Linger -> Just ((#const SOL_SOCKET), (#const SO_LINGER)) +pattern Linger = SockOpt ((#const SOL_SOCKET), (#const SO_LINGER)) +#else +pattern Linger = SockOpt (-1,-1) #endif +-- | SO_REUSEPORT +pattern ReusePort :: SocketOption #ifdef SO_REUSEPORT - Just ReusePort -> Just ((#const SOL_SOCKET), (#const SO_REUSEPORT)) +pattern ReusePort = SockOpt ((#const SOL_SOCKET), (#const SO_REUSEPORT)) +#else +pattern ReusePort = SockOpt (-1,-1) #endif +-- | SO_RCVLOWAT +pattern RecvLowWater :: SocketOption #ifdef SO_RCVLOWAT - Just RecvLowWater -> Just ((#const SOL_SOCKET), (#const SO_RCVLOWAT)) +pattern RecvLowWater = SockOpt ((#const SOL_SOCKET), (#const SO_RCVLOWAT)) +#else +pattern RecvLowWater = SockOpt (-1,-1) #endif +-- | SO_SNDLOWAT +pattern SendLowWater :: SocketOption #ifdef SO_SNDLOWAT - Just SendLowWater -> Just ((#const SOL_SOCKET), (#const SO_SNDLOWAT)) +pattern SendLowWater = SockOpt ((#const SOL_SOCKET), (#const SO_SNDLOWAT)) +#else +pattern SendLowWater = SockOpt (-1,-1) #endif +-- | SO_RCVTIMEO: this does not work at this moment. +pattern RecvTimeOut :: SocketOption #ifdef SO_RCVTIMEO - Just RecvTimeOut -> Just ((#const SOL_SOCKET), (#const SO_RCVTIMEO)) +pattern RecvTimeOut = SockOpt ((#const SOL_SOCKET), (#const SO_RCVTIMEO)) +#else +pattern RecvTimeOut = SockOpt (-1,-1) #endif +-- | SO_SNDTIMEO: this does not work at this moment. +pattern SendTimeOut :: SocketOption #ifdef SO_SNDTIMEO - Just SendTimeOut -> Just ((#const SOL_SOCKET), (#const SO_SNDTIMEO)) +pattern SendTimeOut = SockOpt ((#const SOL_SOCKET), (#const SO_SNDTIMEO)) +#else +pattern SendTimeOut = SockOpt (-1,-1) #endif +-- | SO_USELOOPBACK +pattern UseLoopBack :: SocketOption #ifdef SO_USELOOPBACK - Just UseLoopBack -> Just ((#const SOL_SOCKET), (#const SO_USELOOPBACK)) +pattern UseLoopBack = SockOpt ((#const SOL_SOCKET), (#const SO_USELOOPBACK)) +#else +pattern UseLoopBack = SockOpt (-1,-1) #endif #endif // SOL_SOCKET -#if HAVE_DECL_IPPROTO_IP -#ifdef IP_TTL - Just TimeToLive -> Just ((#const IPPROTO_IP), (#const IP_TTL)) -#endif -#endif // HAVE_DECL_IPPROTO_IP + #if HAVE_DECL_IPPROTO_TCP +-- | TCP_MAXSEG +pattern MaxSegment :: SocketOption #ifdef TCP_MAXSEG - Just MaxSegment -> Just ((#const IPPROTO_TCP), (#const TCP_MAXSEG)) +pattern MaxSegment = SockOpt ((#const IPPROTO_TCP), (#const TCP_MAXSEG)) +#else +pattern MaxSegment = SockOpt (-1,-1) #endif +-- | TCP_NODELAY +pattern NoDelay :: SocketOption #ifdef TCP_NODELAY - Just NoDelay -> Just ((#const IPPROTO_TCP), (#const TCP_NODELAY)) +pattern NoDelay = SockOpt ((#const IPPROTO_TCP), (#const TCP_NODELAY)) +#else +pattern NoDelay = SockOpt (-1,-1) #endif +-- | TCP_USER_TIMEOUT +pattern UserTimeout :: SocketOption #ifdef TCP_USER_TIMEOUT - Just UserTimeout -> Just ((#const IPPROTO_TCP), (#const TCP_USER_TIMEOUT)) +pattern UserTimeout = SockOpt ((#const IPPROTO_TCP), (#const TCP_USER_TIMEOUT)) +#else +pattern UserTimeout = SockOpt (-1, -1) #endif +-- | TCP_CORK +pattern Cork :: SocketOption #ifdef TCP_CORK - Just Cork -> Just ((#const IPPROTO_TCP), (#const TCP_CORK)) +pattern Cork = SockOpt ((#const IPPROTO_TCP), (#const TCP_CORK)) +#else +pattern Cork = SockOpt (-1,-1) #endif #endif // HAVE_DECL_IPPROTO_TCP -#if HAVE_DECL_IPPROTO_IPV6 -#if HAVE_DECL_IPV6_V6ONLY - Just IPv6Only -> Just ((#const IPPROTO_IPV6), (#const IPV6_V6ONLY)) -#endif -#endif // HAVE_DECL_IPPROTO_IPV6 + #if HAVE_DECL_IPPROTO_IP +-- | IP_TTL +pattern TimeToLive :: SocketOption +#ifdef IP_TTL +pattern TimeToLive = SockOpt ((#const IPPROTO_IP), (#const IP_TTL)) +#else +pattern TimeToLive = SockOpt (-1,-1) +#endif +-- | Receiving IPv4 TTL. +pattern RecvIPv4TTL :: SocketOption #ifdef IP_RECVTTL - Just RecvIPv4TTL -> Just ((#const IPPROTO_IP), (#const IP_RECVTTL)) +pattern RecvIPv4TTL = SockOpt ((#const IPPROTO_IP), (#const IP_RECVTTL)) +#else +pattern RecvIPv4TTL = SockOpt (-1,-1) #endif -#endif // HAVE_DECL_IPPROTO_IP -#if HAVE_DECL_IPPROTO_IP +-- | Receiving IPv4 TOS. +pattern RecvIPv4TOS :: SocketOption #ifdef IP_RECVTOS - Just RecvIPv4TOS -> Just ((#const IPPROTO_IP), (#const IP_RECVTOS)) -#endif -#endif // HAVE_DECL_IPPROTO_IP -#if HAVE_DECL_IPPROTO_IP -#if defined(IP_RECVPKTINFO) - Just RecvIPv4PktInfo -> Just ((#const IPPROTO_IP), (#const IP_RECVPKTINFO)) +pattern RecvIPv4TOS = SockOpt ((#const IPPROTO_IP), (#const IP_RECVTOS)) +#else +pattern RecvIPv4TOS = SockOpt (-1,-1) +#endif +-- | Receiving IP_PKTINFO (struct in_pktinfo). +pattern RecvIPv4PktInfo :: SocketOption +#ifdef IP_RECVPKTINFO +pattern RecvIPv4PktInfo = SockOpt ((#const IPPROTO_IP), (#const IP_RECVPKTINFO)) #elif defined(IP_PKTINFO) - Just RecvIPv4PktInfo -> Just ((#const IPPROTO_IP), (#const IP_PKTINFO)) +pattern RecvIPv4PktInfo = SockOpt ((#const IPPROTO_IP), (#const IP_PKTINFO)) +#else +pattern RecvIPv4PktInfo = SockOpt (-1,-1) #endif #endif // HAVE_DECL_IPPROTO_IP + #if HAVE_DECL_IPPROTO_IPV6 +-- | IPV6_V6ONLY: don't use this on OpenBSD. +pattern IPv6Only :: SocketOption +#if HAVE_DECL_IPV6_V6ONLY +pattern IPv6Only = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_V6ONLY)) +#else +pattern IPv6Only = SockOpt (-1,-1) +#endif +-- | Receiving IPv6 hop limit. +pattern RecvIPv6HopLimit :: SocketOption #ifdef IPV6_RECVHOPLIMIT - Just RecvIPv6HopLimit -> Just ((#const IPPROTO_IPV6), (#const IPV6_RECVHOPLIMIT)) +pattern RecvIPv6HopLimit = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_RECVHOPLIMIT)) +#else +pattern RecvIPv6HopLimit = SockOpt (-1,-1) #endif -#endif // HAVE_DECL_IPPROTO_IPV6 -#if HAVE_DECL_IPPROTO_IPV6 +-- | Receiving IPv6 traffic class. +pattern RecvIPv6TClass :: SocketOption #ifdef IPV6_RECVTCLASS - Just RecvIPv6TClass -> Just ((#const IPPROTO_IPV6), (#const IPV6_RECVTCLASS)) +pattern RecvIPv6TClass = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_RECVTCLASS)) +#else +pattern RecvIPv6TClass = SockOpt (-1,-1) #endif -#endif // HAVE_DECL_IPPROTO_IPV6 -#if HAVE_DECL_IPPROTO_IPV6 +-- | Receiving IPV6_PKTINFO (struct in6_pktinfo). +pattern RecvIPv6PktInfo :: SocketOption #ifdef IPV6_RECVPKTINFO - Just RecvIPv6PktInfo -> Just ((#const IPPROTO_IPV6), (#const IPV6_RECVPKTINFO)) +pattern RecvIPv6PktInfo = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_RECVPKTINFO)) #elif defined(IPV6_PKTINFO) - Just RecvIPv6PktInfo -> Just ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) +pattern RecvIPv6PktInfo = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) +#else +pattern RecvIPv6PktInfo = SockOpt (-1,-1) #endif #endif // HAVE_DECL_IPPROTO_IPV6 - Just (CustomSockOpt opt) -> Just opt - _ -> Nothing --- | Return the option level and option value if they exist, --- otherwise throw an error that begins "Network.Socket." ++ the String --- parameter -packSocketOption' :: String -> SocketOption -> IO (CInt, CInt) -packSocketOption' caller so = maybe err return (packSocketOption so) - where - err = ioError . userError . concat $ ["Network.Socket.", caller, - ": socket option ", show so, " unsupported on this system"] +-- | Customizable socket option. +pattern CustomSockOpt :: (CInt,CInt) -> SocketOption +pattern CustomSockOpt opt = SockOpt opt #ifdef SO_LINGER data StructLinger = StructLinger CInt CInt @@ -235,8 +296,8 @@ setSocketOption :: Socket -> Int -- Option Value -> IO () #ifdef SO_LINGER -setSocketOption s Linger v = do - (level, opt) <- packSocketOption' "setSocketOption" Linger +setSocketOption s so@Linger v = do + let SockOpt (level,opt) = so let arg = if v == 0 then StructLinger 0 0 else StructLinger 1 (fromIntegral v) with arg $ \ptr_arg -> void $ do let ptr = ptr_arg :: Ptr StructLinger @@ -245,8 +306,7 @@ setSocketOption s Linger v = do throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $ c_setsockopt fd level opt ptr sz #endif -setSocketOption s so v = do - (level, opt) <- packSocketOption' "setSocketOption" so +setSocketOption s (SockOpt (level,opt)) v = do with (fromIntegral v) $ \ptr_v -> void $ do let ptr = ptr_v :: Ptr CInt sz = fromIntegral $ sizeOf (undefined :: CInt) @@ -260,8 +320,8 @@ getSocketOption :: Socket -> SocketOption -- Option Name -> IO Int -- Option Value #ifdef SO_LINGER -getSocketOption s Linger = do - (level, opt) <- packSocketOption' "getSocketOption" Linger +getSocketOption s so@Linger = do + let SockOpt (level,opt) = so alloca $ \ptr_v -> do let ptr = ptr_v :: Ptr StructLinger sz = fromIntegral $ sizeOf (undefined :: StructLinger) @@ -271,8 +331,7 @@ getSocketOption s Linger = do StructLinger onoff linger <- peek ptr return $ fromIntegral $ if onoff == 0 then 0 else linger #endif -getSocketOption s so = do - (level, opt) <- packSocketOption' "getSocketOption" so +getSocketOption s (SockOpt (level,opt)) = do alloca $ \ptr_v -> do let ptr = ptr_v :: Ptr CInt sz = fromIntegral $ sizeOf (undefined :: CInt) From 8efb7af86cf9f90ae5ae27eaccd2876b3eb12c78 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 16:47:01 +0900 Subject: [PATCH 22/48] implementing getSockOpt and setSockOpt. --- Network/Socket.hs | 2 ++ Network/Socket/Options.hsc | 73 ++++++++++++++++++++------------------ 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index beddd0a1..623fbb00 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -142,6 +142,8 @@ module Network.Socket , isSupportedSocketOption , getSocketOption , setSocketOption + , getSockOpt + , setSockOpt -- * Socket , Socket diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 27f52e7a..8585492a 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} #include "HsNet.h" ##include "HsNetDef.h" @@ -18,6 +19,8 @@ module Network.Socket.Options ( , getSocketType , getSocketOption , setSocketOption + , getSockOpt + , setSockOpt , c_getsockopt , c_setsockopt ) where @@ -297,22 +300,23 @@ setSocketOption :: Socket -> IO () #ifdef SO_LINGER setSocketOption s so@Linger v = do - let SockOpt (level,opt) = so - let arg = if v == 0 then StructLinger 0 0 else StructLinger 1 (fromIntegral v) - with arg $ \ptr_arg -> void $ do - let ptr = ptr_arg :: Ptr StructLinger - sz = fromIntegral $ sizeOf (undefined :: StructLinger) - withFdSocket s $ \fd -> - throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $ - c_setsockopt fd level opt ptr sz -#endif -setSocketOption s (SockOpt (level,opt)) v = do - with (fromIntegral v) $ \ptr_v -> void $ do - let ptr = ptr_v :: Ptr CInt - sz = fromIntegral $ sizeOf (undefined :: CInt) - withFdSocket s $ \fd -> - throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $ - c_setsockopt fd level opt ptr sz + let arg = if v == 0 then StructLinger 0 0 else StructLinger 1 (fromIntegral v) + setSockOpt s so arg +#endif +setSocketOption s sa v = setSockOpt s sa (fromIntegral v :: CInt) + +-- | Set a socket option. +setSockOpt :: Storable a + => Socket + -> SocketOption + -> a + -> IO () +setSockOpt s (SockOpt (level,opt)) v = do + with v $ \ptr -> void $ do + let sz = fromIntegral $ sizeOf v + withFdSocket s $ \fd -> + throwSocketErrorIfMinus1_ "Network.Socket.setSockOpt" $ + c_setsockopt fd level opt ptr sz -- | Get a socket option that gives an Int value. -- There is currently no API to get e.g. the timeval socket options @@ -321,24 +325,25 @@ getSocketOption :: Socket -> IO Int -- Option Value #ifdef SO_LINGER getSocketOption s so@Linger = do - let SockOpt (level,opt) = so - alloca $ \ptr_v -> do - let ptr = ptr_v :: Ptr StructLinger - sz = fromIntegral $ sizeOf (undefined :: StructLinger) - withFdSocket s $ \fd -> with sz $ \ptr_sz -> do - throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketOption" $ - c_getsockopt fd level opt ptr ptr_sz - StructLinger onoff linger <- peek ptr - return $ fromIntegral $ if onoff == 0 then 0 else linger -#endif -getSocketOption s (SockOpt (level,opt)) = do - alloca $ \ptr_v -> do - let ptr = ptr_v :: Ptr CInt - sz = fromIntegral $ sizeOf (undefined :: CInt) - withFdSocket s $ \fd -> with sz $ \ptr_sz -> do - throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketOption" $ - c_getsockopt fd level opt ptr ptr_sz - fromIntegral <$> peek ptr + StructLinger onoff linger <- getSockOpt s so + return $ fromIntegral $ if onoff == 0 then 0 else linger +#endif +getSocketOption s so = do + n :: CInt <- getSockOpt s so + return $ fromIntegral n + +-- | Get a socket option. +getSockOpt :: forall a . Storable a + => Socket + -> SocketOption -- Option Name + -> IO a -- Option Value +getSockOpt s (SockOpt (level,opt)) = do + alloca $ \ptr -> do + let sz = fromIntegral $ sizeOf (undefined :: a) + withFdSocket s $ \fd -> with sz $ \ptr_sz -> do + throwSocketErrorIfMinus1Retry_ "Network.Socket.getSockOpt" $ + c_getsockopt fd level opt ptr ptr_sz + peek ptr foreign import CALLCONV unsafe "getsockopt" c_getsockopt :: CInt -> CInt -> CInt -> Ptr a -> Ptr CInt -> IO CInt From a19f5afafdbe235589a93fbde81b92bf729ad2a5 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 16:56:55 +0900 Subject: [PATCH 23/48] using getSockOpt for cred. --- Network/Socket/Options.hsc | 2 -- Network/Socket/Unix.hsc | 23 ++++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 8585492a..5ce9fa60 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -21,8 +21,6 @@ module Network.Socket.Options ( , setSocketOption , getSockOpt , setSockOpt - , c_getsockopt - , c_setsockopt ) where import Foreign.Marshal.Alloc (alloca) diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index 06ec5856..67a62953 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -78,15 +78,20 @@ getPeerCredential _ = return (Nothing, Nothing, Nothing) getPeerCred :: Socket -> IO (CUInt, CUInt, CUInt) #ifdef HAVE_STRUCT_UCRED_SO_PEERCRED getPeerCred s = do - let sz = (#const sizeof(struct ucred)) - withFdSocket s $ \fd -> allocaBytes sz $ \ ptr_cr -> - with (fromIntegral sz) $ \ ptr_sz -> do - _ <- ($) throwSocketErrorIfMinus1Retry "Network.Socket.getPeerCred" $ - c_getsockopt fd (#const SOL_SOCKET) (#const SO_PEERCRED) ptr_cr ptr_sz - pid <- (#peek struct ucred, pid) ptr_cr - uid <- (#peek struct ucred, uid) ptr_cr - gid <- (#peek struct ucred, gid) ptr_cr - return (pid, uid, gid) + let opt = SockOpt (#const SOL_SOCKET) (#const SO_PEERCRED) + PeerCred cred <- getSockOpt s opt + return cred + +newtype PeerCred = PeerCred (CUInt, CUInt, CUInt) +instance Storable PeerCred where + sizeOf _ = (#const sizeof(struct ucred)) + alignment = undefined + poke = undefined + peek p (PeerCred (pid, uid, gid)) = do + pid <- (#peek struct ucred, pid) ptr_cr + uid <- (#peek struct ucred, uid) ptr_cr + gid <- (#peek struct ucred, gid) ptr_cr + return $ PeerCred (pid, uid, gid) #else getPeerCred _ = return (0, 0, 0) #endif From 5f1cc581c727d87a544a6d5a9996fb0a0b8771ca Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Fri, 10 Jan 2020 16:58:14 +0900 Subject: [PATCH 24/48] exporting SockOpt. --- Network/Socket.hs | 3 ++- Network/Socket/Options.hsc | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index 623fbb00..df6b57ea 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -131,7 +131,8 @@ module Network.Socket , ShutdownCmd(..) -- * Socket options - , SocketOption(Debug,ReuseAddr,Type,SoError,DontRoute,Broadcast + , SocketOption(SockOpt + ,Debug,ReuseAddr,Type,SoError,DontRoute,Broadcast ,SendBuffer,RecvBuffer,KeepAlive,OOBInline,TimeToLive ,MaxSegment,NoDelay,Cork,Linger,ReusePort ,RecvLowWater,SendLowWater,RecvTimeOut,SendTimeOut diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 5ce9fa60..ef6d5c93 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -7,7 +7,8 @@ ##include "HsNetDef.h" module Network.Socket.Options ( - SocketOption(Debug,ReuseAddr,Type,SoError,DontRoute,Broadcast + SocketOption(SockOpt + ,Debug,ReuseAddr,Type,SoError,DontRoute,Broadcast ,SendBuffer,RecvBuffer,KeepAlive,OOBInline,TimeToLive ,MaxSegment,NoDelay,Cork,Linger,ReusePort ,RecvLowWater,SendLowWater,RecvTimeOut,SendTimeOut From 77aa0f4d9b33e28aeca9a389f8567a15243f3d65 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Sat, 11 Jan 2020 08:38:34 +0900 Subject: [PATCH 25/48] importing getSockOpt. --- Network/Socket/Unix.hsc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index 67a62953..7da94294 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -37,7 +37,7 @@ import Network.Socket.Fcntl import Network.Socket.Internal #endif #ifdef HAVE_STRUCT_UCRED_SO_PEERCRED -import Network.Socket.Options (c_getsockopt) +import Network.Socket.Options (getSockOpt) #endif -- | Getting process ID, user ID and group ID for UNIX-domain sockets. From 1eec5dcdb3145ff149f229436028f6584e633c04 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Sat, 11 Jan 2020 08:52:15 +0900 Subject: [PATCH 26/48] fixing cred again. --- Network/Socket/Unix.hsc | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index 7da94294..f1105b02 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -23,9 +23,6 @@ import Network.Socket.Types #if defined(HAVE_GETPEEREID) import System.IO.Error (catchIOError) #endif -#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED -import Foreign.Marshal.Utils (with) -#endif #ifdef HAVE_GETPEEREID import Foreign.Marshal.Alloc (alloca) #endif @@ -37,7 +34,7 @@ import Network.Socket.Fcntl import Network.Socket.Internal #endif #ifdef HAVE_STRUCT_UCRED_SO_PEERCRED -import Network.Socket.Options (getSockOpt) +import Network.Socket.Options #endif -- | Getting process ID, user ID and group ID for UNIX-domain sockets. @@ -78,19 +75,19 @@ getPeerCredential _ = return (Nothing, Nothing, Nothing) getPeerCred :: Socket -> IO (CUInt, CUInt, CUInt) #ifdef HAVE_STRUCT_UCRED_SO_PEERCRED getPeerCred s = do - let opt = SockOpt (#const SOL_SOCKET) (#const SO_PEERCRED) + let opt = SockOpt ((#const SOL_SOCKET),(#const SO_PEERCRED)) PeerCred cred <- getSockOpt s opt return cred newtype PeerCred = PeerCred (CUInt, CUInt, CUInt) instance Storable PeerCred where sizeOf _ = (#const sizeof(struct ucred)) - alignment = undefined + alignment _ = (#const sizeof(int)) poke = undefined - peek p (PeerCred (pid, uid, gid)) = do - pid <- (#peek struct ucred, pid) ptr_cr - uid <- (#peek struct ucred, uid) ptr_cr - gid <- (#peek struct ucred, gid) ptr_cr + peek p = do + pid <- (#peek struct ucred, pid) p + uid <- (#peek struct ucred, uid) p + gid <- (#peek struct ucred, gid) p return $ PeerCred (pid, uid, gid) #else getPeerCred _ = return (0, 0, 0) From e25509748ac02453596ab36d556bb610dd53abd0 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 14 Jan 2020 12:39:20 +0900 Subject: [PATCH 27/48] SocketOption now contains CInt directly. Breacking change: CustomSockOpt was deleted. --- Network/Socket.hs | 3 +- Network/Socket/Options.hsc | 138 ++++++++++++++++++------------------- 2 files changed, 69 insertions(+), 72 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index df6b57ea..05feb2d2 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -138,8 +138,7 @@ module Network.Socket ,RecvLowWater,SendLowWater,RecvTimeOut,SendTimeOut ,UseLoopBack,UserTimeout,IPv6Only ,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo - ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo - ,CustomSockOpt) + ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo) , isSupportedSocketOption , getSocketOption , setSocketOption diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index ef6d5c93..1dfc067a 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -14,8 +14,7 @@ module Network.Socket.Options ( ,RecvLowWater,SendLowWater,RecvTimeOut,SendTimeOut ,UseLoopBack,UserTimeout,IPv6Only ,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo - ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo - ,CustomSockOpt) + ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo) , isSupportedSocketOption , getSocketType , getSocketOption @@ -38,11 +37,14 @@ import Network.Socket.Types -- -- The existence of a constructor does not imply that the relevant option -- is supported on your system: see 'isSupportedSocketOption' -newtype SocketOption = SockOpt (CInt,CInt) deriving (Eq, Show) +data SocketOption = SockOpt { + sockOptLevel :: CInt + , sockOptName :: CInt + } deriving (Eq, Show) -- | Does the 'SocketOption' exist on this system? isSupportedSocketOption :: SocketOption -> Bool -isSupportedSocketOption opt = opt /= SockOpt (-1,-1) +isSupportedSocketOption opt = opt /= SockOpt (-1) (-1) -- | Get the 'SocketType' of an active socket. -- @@ -55,121 +57,121 @@ getSocketType s = (fromMaybe NoSocketType . unpackSocketType . fromIntegral) -- | SO_DEBUG pattern Debug :: SocketOption #ifdef SO_DEBUG -pattern Debug = SockOpt ((#const SOL_SOCKET), (#const SO_DEBUG)) +pattern Debug = SockOpt (#const SOL_SOCKET) (#const SO_DEBUG) #else -pattern Debug = SockOpt (-1,-1) +pattern Debug = SockOpt (-1) (-1) #endif -- | SO_REUSEADDR pattern ReuseAddr :: SocketOption #ifdef SO_REUSEADDR -pattern ReuseAddr = SockOpt ((#const SOL_SOCKET), (#const SO_REUSEADDR)) +pattern ReuseAddr = SockOpt (#const SOL_SOCKET) (#const SO_REUSEADDR) #else -pattern ReuseAddr = SockOpt (-1,-1) +pattern ReuseAddr = SockOpt (-1) (-1) #endif -- | SO_TYPE pattern Type :: SocketOption #ifdef SO_TYPE -pattern Type = SockOpt ((#const SOL_SOCKET), (#const SO_TYPE)) +pattern Type = SockOpt (#const SOL_SOCKET) (#const SO_TYPE) #else -pattern Type = SockOpt (-1,-1) +pattern Type = SockOpt (-1) (-1) #endif -- | SO_ERROR pattern SoError :: SocketOption #ifdef SO_ERROR -pattern SoError = SockOpt ((#const SOL_SOCKET), (#const SO_ERROR)) +pattern SoError = SockOpt (#const SOL_SOCKET) (#const SO_ERROR) #else -pattern SoError = SockOpt (-1,-1) +pattern SoError = SockOpt (-1) (-1) #endif -- | SO_DONTROUTE pattern DontRoute :: SocketOption #ifdef SO_DONTROUTE -pattern DontRoute = SockOpt ((#const SOL_SOCKET), (#const SO_DONTROUTE)) +pattern DontRoute = SockOpt (#const SOL_SOCKET) (#const SO_DONTROUTE) #else -pattern DontRoute = SockOpt (-1,-1) +pattern DontRoute = SockOpt (-1) (-1) #endif -- | SO_BROADCAST pattern Broadcast :: SocketOption #ifdef SO_BROADCAST -pattern Broadcast = SockOpt ((#const SOL_SOCKET), (#const SO_BROADCAST)) +pattern Broadcast = SockOpt (#const SOL_SOCKET) (#const SO_BROADCAST) #else -pattern Broadcast = SockOpt (-1,-1) +pattern Broadcast = SockOpt (-1) (-1) #endif -- | SO_SNDBUF pattern SendBuffer :: SocketOption #ifdef SO_SNDBUF -pattern SendBuffer = SockOpt ((#const SOL_SOCKET), (#const SO_SNDBUF)) +pattern SendBuffer = SockOpt (#const SOL_SOCKET) (#const SO_SNDBUF) #else -pattern SendBuffer = SockOpt (-1,-1) +pattern SendBuffer = SockOpt (-1) (-1) #endif -- | SO_RCVBUF pattern RecvBuffer :: SocketOption #ifdef SO_RCVBUF -pattern RecvBuffer = SockOpt ((#const SOL_SOCKET), (#const SO_RCVBUF)) +pattern RecvBuffer = SockOpt (#const SOL_SOCKET) (#const SO_RCVBUF) #else -pattern RecvBuffer = SockOpt (-1,-1) +pattern RecvBuffer = SockOpt (-1) (-1) #endif -- | SO_KEEPALIVE pattern KeepAlive :: SocketOption #ifdef SO_KEEPALIVE -pattern KeepAlive = SockOpt ((#const SOL_SOCKET), (#const SO_KEEPALIVE)) +pattern KeepAlive = SockOpt (#const SOL_SOCKET) (#const SO_KEEPALIVE) #else -pattern KeepAlive = SockOpt (-1,-1) +pattern KeepAlive = SockOpt (-1) (-1) #endif -- | SO_OOBINLINE pattern OOBInline :: SocketOption #ifdef SO_OOBINLINE -pattern OOBInline = SockOpt ((#const SOL_SOCKET), (#const SO_OOBINLINE)) +pattern OOBInline = SockOpt (#const SOL_SOCKET) (#const SO_OOBINLINE) #else -pattern OOBINLINE = SockOpt (-1,-1) +pattern OOBINLINE = SockOpt (-1) (-1) #endif -- | SO_LINGER: timeout in seconds, 0 means disabling/disabled. pattern Linger :: SocketOption #ifdef SO_LINGER -pattern Linger = SockOpt ((#const SOL_SOCKET), (#const SO_LINGER)) +pattern Linger = SockOpt (#const SOL_SOCKET) (#const SO_LINGER) #else -pattern Linger = SockOpt (-1,-1) +pattern Linger = SockOpt (-1) (-1) #endif -- | SO_REUSEPORT pattern ReusePort :: SocketOption #ifdef SO_REUSEPORT -pattern ReusePort = SockOpt ((#const SOL_SOCKET), (#const SO_REUSEPORT)) +pattern ReusePort = SockOpt (#const SOL_SOCKET) (#const SO_REUSEPORT) #else -pattern ReusePort = SockOpt (-1,-1) +pattern ReusePort = SockOpt (-1) (-1) #endif -- | SO_RCVLOWAT pattern RecvLowWater :: SocketOption #ifdef SO_RCVLOWAT -pattern RecvLowWater = SockOpt ((#const SOL_SOCKET), (#const SO_RCVLOWAT)) +pattern RecvLowWater = SockOpt (#const SOL_SOCKET) (#const SO_RCVLOWAT) #else -pattern RecvLowWater = SockOpt (-1,-1) +pattern RecvLowWater = SockOpt (-1) (-1) #endif -- | SO_SNDLOWAT pattern SendLowWater :: SocketOption #ifdef SO_SNDLOWAT -pattern SendLowWater = SockOpt ((#const SOL_SOCKET), (#const SO_SNDLOWAT)) +pattern SendLowWater = SockOpt (#const SOL_SOCKET) (#const SO_SNDLOWAT) #else -pattern SendLowWater = SockOpt (-1,-1) +pattern SendLowWater = SockOpt (-1) (-1) #endif -- | SO_RCVTIMEO: this does not work at this moment. pattern RecvTimeOut :: SocketOption #ifdef SO_RCVTIMEO -pattern RecvTimeOut = SockOpt ((#const SOL_SOCKET), (#const SO_RCVTIMEO)) +pattern RecvTimeOut = SockOpt (#const SOL_SOCKET) (#const SO_RCVTIMEO) #else -pattern RecvTimeOut = SockOpt (-1,-1) +pattern RecvTimeOut = SockOpt (-1) (-1) #endif -- | SO_SNDTIMEO: this does not work at this moment. pattern SendTimeOut :: SocketOption #ifdef SO_SNDTIMEO -pattern SendTimeOut = SockOpt ((#const SOL_SOCKET), (#const SO_SNDTIMEO)) +pattern SendTimeOut = SockOpt (#const SOL_SOCKET) (#const SO_SNDTIMEO) #else -pattern SendTimeOut = SockOpt (-1,-1) +pattern SendTimeOut = SockOpt (-1) (-1) #endif -- | SO_USELOOPBACK pattern UseLoopBack :: SocketOption #ifdef SO_USELOOPBACK -pattern UseLoopBack = SockOpt ((#const SOL_SOCKET), (#const SO_USELOOPBACK)) +pattern UseLoopBack = SockOpt (#const SOL_SOCKET) (#const SO_USELOOPBACK) #else -pattern UseLoopBack = SockOpt (-1,-1) +pattern UseLoopBack = SockOpt (-1) (-1) #endif #endif // SOL_SOCKET @@ -177,30 +179,30 @@ pattern UseLoopBack = SockOpt (-1,-1) -- | TCP_MAXSEG pattern MaxSegment :: SocketOption #ifdef TCP_MAXSEG -pattern MaxSegment = SockOpt ((#const IPPROTO_TCP), (#const TCP_MAXSEG)) +pattern MaxSegment = SockOpt (#const IPPROTO_TCP) (#const TCP_MAXSEG) #else -pattern MaxSegment = SockOpt (-1,-1) +pattern MaxSegment = SockOpt (-1) (-1) #endif -- | TCP_NODELAY pattern NoDelay :: SocketOption #ifdef TCP_NODELAY -pattern NoDelay = SockOpt ((#const IPPROTO_TCP), (#const TCP_NODELAY)) +pattern NoDelay = SockOpt (#const IPPROTO_TCP) (#const TCP_NODELAY) #else -pattern NoDelay = SockOpt (-1,-1) +pattern NoDelay = SockOpt (-1) (-1) #endif -- | TCP_USER_TIMEOUT pattern UserTimeout :: SocketOption #ifdef TCP_USER_TIMEOUT -pattern UserTimeout = SockOpt ((#const IPPROTO_TCP), (#const TCP_USER_TIMEOUT)) +pattern UserTimeout = SockOpt (#const IPPROTO_TCP) (#const TCP_USER_TIMEOUT) #else -pattern UserTimeout = SockOpt (-1, -1) +pattern UserTimeout = SockOpt (-1) (-1) #endif -- | TCP_CORK pattern Cork :: SocketOption #ifdef TCP_CORK -pattern Cork = SockOpt ((#const IPPROTO_TCP), (#const TCP_CORK)) +pattern Cork = SockOpt (#const IPPROTO_TCP) (#const TCP_CORK) #else -pattern Cork = SockOpt (-1,-1) +pattern Cork = SockOpt (-1) (-1) #endif #endif // HAVE_DECL_IPPROTO_TCP @@ -208,32 +210,32 @@ pattern Cork = SockOpt (-1,-1) -- | IP_TTL pattern TimeToLive :: SocketOption #ifdef IP_TTL -pattern TimeToLive = SockOpt ((#const IPPROTO_IP), (#const IP_TTL)) +pattern TimeToLive = SockOpt (#const IPPROTO_IP) (#const IP_TTL) #else -pattern TimeToLive = SockOpt (-1,-1) +pattern TimeToLive = SockOpt (-1) (-1) #endif -- | Receiving IPv4 TTL. pattern RecvIPv4TTL :: SocketOption #ifdef IP_RECVTTL -pattern RecvIPv4TTL = SockOpt ((#const IPPROTO_IP), (#const IP_RECVTTL)) +pattern RecvIPv4TTL = SockOpt (#const IPPROTO_IP) (#const IP_RECVTTL) #else -pattern RecvIPv4TTL = SockOpt (-1,-1) +pattern RecvIPv4TTL = SockOpt (-1) (-1) #endif -- | Receiving IPv4 TOS. pattern RecvIPv4TOS :: SocketOption #ifdef IP_RECVTOS -pattern RecvIPv4TOS = SockOpt ((#const IPPROTO_IP), (#const IP_RECVTOS)) +pattern RecvIPv4TOS = SockOpt (#const IPPROTO_IP) (#const IP_RECVTOS) #else -pattern RecvIPv4TOS = SockOpt (-1,-1) +pattern RecvIPv4TOS = SockOpt (-1) (-1) #endif -- | Receiving IP_PKTINFO (struct in_pktinfo). pattern RecvIPv4PktInfo :: SocketOption #ifdef IP_RECVPKTINFO -pattern RecvIPv4PktInfo = SockOpt ((#const IPPROTO_IP), (#const IP_RECVPKTINFO)) +pattern RecvIPv4PktInfo = SockOpt (#const IPPROTO_IP) (#const IP_RECVPKTINFO) #elif defined(IP_PKTINFO) -pattern RecvIPv4PktInfo = SockOpt ((#const IPPROTO_IP), (#const IP_PKTINFO)) +pattern RecvIPv4PktInfo = SockOpt (#const IPPROTO_IP) (#const IP_PKTINFO) #else -pattern RecvIPv4PktInfo = SockOpt (-1,-1) +pattern RecvIPv4PktInfo = SockOpt (-1) (-1) #endif #endif // HAVE_DECL_IPPROTO_IP @@ -241,39 +243,35 @@ pattern RecvIPv4PktInfo = SockOpt (-1,-1) -- | IPV6_V6ONLY: don't use this on OpenBSD. pattern IPv6Only :: SocketOption #if HAVE_DECL_IPV6_V6ONLY -pattern IPv6Only = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_V6ONLY)) +pattern IPv6Only = SockOpt (#const IPPROTO_IPV6) (#const IPV6_V6ONLY) #else -pattern IPv6Only = SockOpt (-1,-1) +pattern IPv6Only = SockOpt (-1) (-1) #endif -- | Receiving IPv6 hop limit. pattern RecvIPv6HopLimit :: SocketOption #ifdef IPV6_RECVHOPLIMIT -pattern RecvIPv6HopLimit = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_RECVHOPLIMIT)) +pattern RecvIPv6HopLimit = SockOpt (#const IPPROTO_IPV6) (#const IPV6_RECVHOPLIMIT) #else -pattern RecvIPv6HopLimit = SockOpt (-1,-1) +pattern RecvIPv6HopLimit = SockOpt (-1) (-1) #endif -- | Receiving IPv6 traffic class. pattern RecvIPv6TClass :: SocketOption #ifdef IPV6_RECVTCLASS -pattern RecvIPv6TClass = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_RECVTCLASS)) +pattern RecvIPv6TClass = SockOpt (#const IPPROTO_IPV6) (#const IPV6_RECVTCLASS) #else -pattern RecvIPv6TClass = SockOpt (-1,-1) +pattern RecvIPv6TClass = SockOpt (-1) (-1) #endif -- | Receiving IPV6_PKTINFO (struct in6_pktinfo). pattern RecvIPv6PktInfo :: SocketOption #ifdef IPV6_RECVPKTINFO -pattern RecvIPv6PktInfo = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_RECVPKTINFO)) +pattern RecvIPv6PktInfo = SockOpt (#const IPPROTO_IPV6) (#const IPV6_RECVPKTINFO) #elif defined(IPV6_PKTINFO) -pattern RecvIPv6PktInfo = SockOpt ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) +pattern RecvIPv6PktInfo = SockOpt (#const IPPROTO_IPV6) (#const IPV6_PKTINFO) #else -pattern RecvIPv6PktInfo = SockOpt (-1,-1) +pattern RecvIPv6PktInfo = SockOpt (-1) (-1) #endif #endif // HAVE_DECL_IPPROTO_IPV6 --- | Customizable socket option. -pattern CustomSockOpt :: (CInt,CInt) -> SocketOption -pattern CustomSockOpt opt = SockOpt opt - #ifdef SO_LINGER data StructLinger = StructLinger CInt CInt @@ -310,7 +308,7 @@ setSockOpt :: Storable a -> SocketOption -> a -> IO () -setSockOpt s (SockOpt (level,opt)) v = do +setSockOpt s (SockOpt level opt) v = do with v $ \ptr -> void $ do let sz = fromIntegral $ sizeOf v withFdSocket s $ \fd -> @@ -336,7 +334,7 @@ getSockOpt :: forall a . Storable a => Socket -> SocketOption -- Option Name -> IO a -- Option Value -getSockOpt s (SockOpt (level,opt)) = do +getSockOpt s (SockOpt level opt) = do alloca $ \ptr -> do let sz = fromIntegral $ sizeOf (undefined :: a) withFdSocket s $ \fd -> with sz $ \ptr_sz -> do From f8305c4732abd5a7fc71470a897832117982c01f Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 14 Jan 2020 12:46:44 +0900 Subject: [PATCH 28/48] renaming: cmsg instead of ancillary --- Network/Socket.hs | 30 +-- Network/Socket/Address.hs | 2 +- Network/Socket/Buffer.hsc | 2 +- Network/Socket/Posix/Ancillary.hsc | 177 ----------------- Network/Socket/Posix/Cmsg.hsc | 257 ++++++++++++++++--------- Network/Socket/Posix/CmsgHdr.hsc | 102 ++++++++++ Network/Socket/Unix.hsc | 6 +- network.cabal | 2 +- tests/Network/Socket/ByteStringSpec.hs | 12 +- 9 files changed, 297 insertions(+), 293 deletions(-) delete mode 100644 Network/Socket/Posix/Ancillary.hsc create mode 100644 Network/Socket/Posix/CmsgHdr.hsc diff --git a/Network/Socket.hs b/Network/Socket.hs index 05feb2d2..650e3c3c 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -214,24 +214,25 @@ module Network.Socket , recvBuf , sendBufTo , recvBufFrom - -- ** IO with ancillary data + -- ** Advanced IO , sendBufMsg , recvBufMsg , MsgFlag(MSG_OOB,MSG_DONTROUTE,MSG_PEEK,MSG_EOR,MSG_TRUNC,MSG_CTRUNC,MSG_WAITALL) + -- ** Control message (ancillary data) , Cmsg(..) - -- ** Ancillary data - , Ancillary(..) - , AncillaryId - , ancillaryEncode - , ancillaryDecode - , ancillaryIPv4TTL - , ancillaryIPv6HopLimit - , ancillaryIPv4TOS - , ancillaryIPv6TClass - , ancillaryIPv4PktInfo - , ancillaryIPv6PktInfo - , lookupAncillary - -- ** Types + , CmsgId(CmsgId + ,CmsgIdIPv4TTL + ,CmsgIdIPv6HopLimit + ,CmsgIdIPv4TOS + ,CmsgIdIPv6TClass + ,CmsgIdIPv4PktInfo + ,CmsgIdIPv6PktInfo) + -- ** APIs for control message + , lookupCmsg + , decodeCmsg + , encodeCmsg + -- ** Class and yypes for control message + , ControlMessage(..) , IPv4TTL(..) , IPv6HopLimit(..) , IPv4TOS(..) @@ -252,7 +253,6 @@ import Network.Socket.Info import Network.Socket.Internal import Network.Socket.Name hiding (getPeerName, getSocketName) import Network.Socket.Options -import Network.Socket.Posix.Ancillary import Network.Socket.Posix.Cmsg import Network.Socket.Shutdown import Network.Socket.SockAddr diff --git a/Network/Socket/Address.hs b/Network/Socket/Address.hs index 415b910c..ccf3fcf1 100644 --- a/Network/Socket/Address.hs +++ b/Network/Socket/Address.hs @@ -16,7 +16,7 @@ module Network.Socket.Address ( -- * Sending and receiving data from a buffer , sendBufTo , recvBufFrom - -- * IO with ancillary data + -- * Advanced IO , sendBufMsg , recvBufMsg ) where diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 513725c8..26f70019 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -31,7 +31,7 @@ import Network.Socket.Imports import Network.Socket.Internal import Network.Socket.Name import Network.Socket.Types -import Network.Socket.Posix.Cmsg +import Network.Socket.Posix.CmsgHdr import Network.Socket.Posix.MsgHdr import Network.Socket.Posix.IOVec import Network.Socket.Flag diff --git a/Network/Socket/Posix/Ancillary.hsc b/Network/Socket/Posix/Ancillary.hsc deleted file mode 100644 index 16b47413..00000000 --- a/Network/Socket/Posix/Ancillary.hsc +++ /dev/null @@ -1,177 +0,0 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ScopedTypeVariables #-} - -module Network.Socket.Posix.Ancillary where - -#include "HsNet.h" - -#include -#include - -import Data.ByteString.Internal -import Foreign.ForeignPtr -import System.IO.Unsafe (unsafeDupablePerformIO) -import System.Posix.Types (Fd(..)) - -import Network.Socket.Imports -import Network.Socket.Posix.Cmsg -import Network.Socket.Types - ----------------------------------------------------------------- - --- | Identifier of ancillary data. A pair of level and type. -type AncillaryId = (CInt, CInt) - --- | The identifier for 'IPv4TTL'. -ancillaryIPv4TTL :: AncillaryId -#if defined(darwin_HOST_OS) -ancillaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_RECVTTL)) -#else -ancillaryIPv4TTL = ((#const IPPROTO_IP), (#const IP_TTL)) -#endif - --- | The identifier for 'IPv6HopLimit'. -ancillaryIPv6HopLimit :: AncillaryId -ancillaryIPv6HopLimit = ((#const IPPROTO_IPV6), (#const IPV6_HOPLIMIT)) - --- | The identifier for 'IPv4TOS'. -ancillaryIPv4TOS :: AncillaryId -#if defined(darwin_HOST_OS) -ancillaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_RECVTOS)) -#else -ancillaryIPv4TOS = ((#const IPPROTO_IP), (#const IP_TOS)) -#endif - --- | The identifier for 'IPv6TClass'. -ancillaryIPv6TClass :: AncillaryId -ancillaryIPv6TClass = ((#const IPPROTO_IPV6), (#const IPV6_TCLASS)) - --- | The identifier for 'IPv4PktInfo'. -ancillaryIPv4PktInfo :: AncillaryId -ancillaryIPv4PktInfo = ((#const IPPROTO_IP), (#const IP_PKTINFO)) - --- | The identifier for 'IPv6PktInfo'. -ancillaryIPv6PktInfo :: AncillaryId -ancillaryIPv6PktInfo = ((#const IPPROTO_IPV6), (#const IPV6_PKTINFO)) - --- | The identifier for 'Fd'. -ancillaryFd :: AncillaryId -ancillaryFd = ((#const SOL_SOCKET), (#const SCM_RIGHTS)) - ----------------------------------------------------------------- - --- | Looking up ancillary data. The following shows an example usage: --- --- > (lookupAncillary ancillaryIPv4TOS cmsgs >>= ancillaryDecode) :: Maybe IPv4TOS -lookupAncillary :: AncillaryId -> [Cmsg] -> Maybe Cmsg -lookupAncillary _ [] = Nothing -lookupAncillary aid (cmsg@(Cmsg cid _):cmsgs) - | aid == cid = Just cmsg - | otherwise = lookupAncillary aid cmsgs - ----------------------------------------------------------------- - --- | A class to encode and decode ancillary data. -class Storable a => Ancillary a where - ancillaryId :: a -> AncillaryId - -ancillaryEncode :: Ancillary a => a -> Cmsg -ancillaryEncode x = unsafeDupablePerformIO $ do - bs <- create siz $ \p0 -> do - let p = castPtr p0 - poke p x - return $ Cmsg (ancillaryId x) bs - where - siz = sizeOf x - -ancillaryDecode :: forall a . Storable a => Cmsg -> Maybe a -ancillaryDecode (Cmsg _ (PS fptr off len)) - | len < siz = Nothing - | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do - let p = castPtr (p0 `plusPtr` off) - Just <$> peek p - where - siz = sizeOf (undefined :: a) - ----------------------------------------------------------------- - --- | Time to live of IPv4. -newtype IPv4TTL = IPv4TTL CChar deriving (Eq, Show, Storable) - -instance Ancillary IPv4TTL where - ancillaryId _ = ancillaryIPv4TTL - ----------------------------------------------------------------- - --- | Hop limit of IPv6. -newtype IPv6HopLimit = IPv6HopLimit CInt deriving (Eq, Show, Storable) - -instance Ancillary IPv6HopLimit where - ancillaryId _ = ancillaryIPv6HopLimit - ----------------------------------------------------------------- - --- | TOS of IPv4. -newtype IPv4TOS = IPv4TOS CChar deriving (Eq, Show, Storable) - -instance Ancillary IPv4TOS where - ancillaryId _ = ancillaryIPv4TOS - ----------------------------------------------------------------- - --- | Traffic class of IPv6. -newtype IPv6TClass = IPv6TClass CInt deriving (Eq, Show, Storable) - -instance Ancillary IPv6TClass where - ancillaryId _ = ancillaryIPv6TClass - ----------------------------------------------------------------- - --- | Network interface ID and local IPv4 address. -data IPv4PktInfo = IPv4PktInfo CInt HostAddress deriving (Eq) - -instance Show IPv4PktInfo where - show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) - -instance Ancillary IPv4PktInfo where - ancillaryId _ = ancillaryIPv4PktInfo - -instance Storable IPv4PktInfo where - sizeOf _ = (#size struct in_pktinfo) - alignment = undefined - poke p (IPv4PktInfo n ha) = do - (#poke struct in_pktinfo, ipi_ifindex) p (fromIntegral n :: CInt) - (#poke struct in_pktinfo, ipi_spec_dst) p (0 :: CInt) - (#poke struct in_pktinfo, ipi_addr) p ha - peek p = do - n <- (#peek struct in_pktinfo, ipi_ifindex) p - ha <- (#peek struct in_pktinfo, ipi_addr) p - return $ IPv4PktInfo n ha - ----------------------------------------------------------------- - --- | Network interface ID and local IPv4 address. -data IPv6PktInfo = IPv6PktInfo Int HostAddress6 deriving (Eq) - -instance Show IPv6PktInfo where - show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6) - -instance Ancillary IPv6PktInfo where - ancillaryId _ = ancillaryIPv6PktInfo - -instance Storable IPv6PktInfo where - sizeOf _ = (#size struct in6_pktinfo) - alignment = undefined - poke p (IPv6PktInfo n ha6) = do - (#poke struct in6_pktinfo, ipi6_ifindex) p (fromIntegral n :: CInt) - (#poke struct in6_pktinfo, ipi6_addr) p (In6Addr ha6) - peek p = do - In6Addr ha6 <- (#peek struct in6_pktinfo, ipi6_addr) p - n :: CInt <- (#peek struct in6_pktinfo, ipi6_ifindex) p - return $ IPv6PktInfo (fromIntegral n) ha6 - ----------------------------------------------------------------- - -instance Ancillary Fd where - ancillaryId _ = ancillaryFd diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index bd00946c..71ccc9e1 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -1,107 +1,186 @@ -{-# OPTIONS_GHC -funbox-strict-fields #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} -#include "HsNet.h" +module Network.Socket.Posix.Cmsg where -module Network.Socket.Posix.Cmsg ( - Cmsg(..) - , withCmsgs - , parseCmsgs - ) where +#include "HsNet.h" #include #include -import Foreign.Marshal.Alloc (allocaBytes) -import Foreign.ForeignPtr -import qualified Data.ByteString as B import Data.ByteString.Internal +import Foreign.ForeignPtr +import System.IO.Unsafe (unsafeDupablePerformIO) +import System.Posix.Types (Fd(..)) import Network.Socket.Imports -import Network.Socket.Posix.MsgHdr import Network.Socket.Types --- | Control message including a pair of level and type. +-- | Control message (ancillary data) including a pair of level and type. data Cmsg = Cmsg { - cmsgLevelType :: (CInt,CInt) - , cmsgBody :: ByteString + cmsgId :: CmsgId + , cmsgData :: ByteString } deriving (Eq, Show) -data CmsgHdr = CmsgHdr { - cmsgHdrLen :: !CInt - , cmsgHdrLevel :: !CInt - , cmsgHdrType :: !CInt +---------------------------------------------------------------- + +-- | Identifier of control message (ancillary data). +data CmsgId = CmsgId { + cmsgLevel :: CInt + , cmsglType :: CInt } deriving (Eq, Show) -instance Storable CmsgHdr where - sizeOf _ = (#size struct cmsghdr) - alignment _ = alignment (undefined :: CInt) - - peek p = do - len <- (#peek struct cmsghdr, cmsg_len) p - lvl <- (#peek struct cmsghdr, cmsg_level) p - typ <- (#peek struct cmsghdr, cmsg_type) p - return $ CmsgHdr len lvl typ - - poke p (CmsgHdr len lvl typ) = do - zeroMemory p (#size struct cmsghdr) - (#poke struct cmsghdr, cmsg_len) p len - (#poke struct cmsghdr, cmsg_level) p lvl - (#poke struct cmsghdr, cmsg_type) p typ - -withCmsgs :: [Cmsg] -> (Ptr CmsgHdr -> Int -> IO a) -> IO a -withCmsgs cmsgs0 action - | total == 0 = action nullPtr 0 - | otherwise = allocaBytes total $ \ctrlPtr -> do - loop ctrlPtr cmsgs0 spaces - action ctrlPtr total +-- | The identifier for 'IPv4TTL'. +pattern CmsgIdIPv4TTL :: CmsgId +#if defined(darwin_HOST_OS) +pattern CmsgIdIPv4TTL = CmsgId (#const IPPROTO_IP) (#const IP_RECVTTL) +#else +pattern CmsgIdIPv4TTL = CmsgId (#const IPPROTO_IP) (#const IP_TTL) +#endif + +-- | The identifier for 'IPv6HopLimit'. +pattern CmsgIdIPv6HopLimit :: CmsgId +pattern CmsgIdIPv6HopLimit = CmsgId (#const IPPROTO_IPV6) (#const IPV6_HOPLIMIT) + +-- | The identifier for 'IPv4TOS'. +pattern CmsgIdIPv4TOS :: CmsgId +#if defined(darwin_HOST_OS) +pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_RECVTOS) +#else +pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_TOS) +#endif + +-- | The identifier for 'IPv6TClass'. +pattern CmsgIdIPv6TClass :: CmsgId +pattern CmsgIdIPv6TClass = CmsgId (#const IPPROTO_IPV6) (#const IPV6_TCLASS) + +-- | The identifier for 'IPv4PktInfo'. +pattern CmsgIdIPv4PktInfo :: CmsgId +pattern CmsgIdIPv4PktInfo = CmsgId (#const IPPROTO_IP) (#const IP_PKTINFO) + +-- | The identifier for 'IPv6PktInfo'. +pattern CmsgIdIPv6PktInfo :: CmsgId +pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO) + +-- | The identifier for 'Fd'. +pattern CmsgIdFd :: CmsgId +pattern CmsgIdFd = CmsgId (#const SOL_SOCKET) (#const SCM_RIGHTS) + +---------------------------------------------------------------- + +-- | Looking up control message. The following shows an example usage: +-- +-- > (lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS +lookupCmsg :: CmsgId -> [Cmsg] -> Maybe Cmsg +lookupCmsg _ [] = Nothing +lookupCmsg aid (cmsg@(Cmsg cid _):cmsgs) + | aid == cid = Just cmsg + | otherwise = lookupCmsg aid cmsgs + +---------------------------------------------------------------- + +-- | A class to encode and decode control message. +class Storable a => ControlMessage a where + controlMessageId :: a -> CmsgId + +encodeCmsg :: ControlMessage a => a -> Cmsg +encodeCmsg x = unsafeDupablePerformIO $ do + bs <- create siz $ \p0 -> do + let p = castPtr p0 + poke p x + return $ Cmsg (controlMessageId x) bs where - loop ctrlPtr (cmsg:cmsgs) (s:ss) = do - encodeCmsg ctrlPtr cmsg - let nextPtr = ctrlPtr `plusPtr` s - loop nextPtr cmsgs ss - loop _ _ _ = return () - cmsg_space = fromIntegral . c_cmsg_space . fromIntegral - spaces = map (cmsg_space . B.length . cmsgBody) cmsgs0 - total = sum spaces - -encodeCmsg :: Ptr CmsgHdr -> Cmsg -> IO () -encodeCmsg ctrlPtr (Cmsg (lvl,typ) (PS fptr off len)) = do - poke ctrlPtr $ CmsgHdr (c_cmsg_len (fromIntegral len)) lvl typ - withForeignPtr fptr $ \src0 -> do - let src = src0 `plusPtr` off - dst <- c_cmsg_data ctrlPtr - memcpy dst src len - -parseCmsgs :: SocketAddress sa => Ptr (MsgHdr sa) -> IO [Cmsg] -parseCmsgs msgptr = do - ptr <- c_cmsg_firsthdr msgptr - loop ptr id + siz = sizeOf x + +decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a +decodeCmsg (Cmsg _ (PS fptr off len)) + | len < siz = Nothing + | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do + let p = castPtr (p0 `plusPtr` off) + Just <$> peek p where - loop ptr build - | ptr == nullPtr = return $ build [] - | otherwise = do - cmsg <- decodeCmsg ptr - nextPtr <- c_cmsg_nxthdr msgptr ptr - loop nextPtr (build . (cmsg :)) - -decodeCmsg :: Ptr CmsgHdr -> IO Cmsg -decodeCmsg ptr = do - CmsgHdr len lvl typ <- peek ptr - src <- c_cmsg_data ptr - let siz = fromIntegral len - (src `minusPtr` ptr) - Cmsg (lvl,typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz) - -foreign import ccall unsafe "cmsg_firsthdr" - c_cmsg_firsthdr :: Ptr (MsgHdr sa) -> IO (Ptr CmsgHdr) - -foreign import ccall unsafe "cmsg_nxthdr" - c_cmsg_nxthdr :: Ptr (MsgHdr sa) -> Ptr CmsgHdr -> IO (Ptr CmsgHdr) - -foreign import ccall unsafe "cmsg_data" - c_cmsg_data :: Ptr CmsgHdr -> IO (Ptr Word8) - -foreign import ccall unsafe "cmsg_space" - c_cmsg_space :: CInt -> CInt - -foreign import ccall unsafe "cmsg_len" - c_cmsg_len :: CInt -> CInt + siz = sizeOf (undefined :: a) + +---------------------------------------------------------------- + +-- | Time to live of IPv4. +newtype IPv4TTL = IPv4TTL CChar deriving (Eq, Show, Storable) + +instance ControlMessage IPv4TTL where + controlMessageId _ = CmsgIdIPv4TTL + +---------------------------------------------------------------- + +-- | Hop limit of IPv6. +newtype IPv6HopLimit = IPv6HopLimit CInt deriving (Eq, Show, Storable) + +instance ControlMessage IPv6HopLimit where + controlMessageId _ = CmsgIdIPv6HopLimit + +---------------------------------------------------------------- + +-- | TOS of IPv4. +newtype IPv4TOS = IPv4TOS CChar deriving (Eq, Show, Storable) + +instance ControlMessage IPv4TOS where + controlMessageId _ = CmsgIdIPv4TOS + +---------------------------------------------------------------- + +-- | Traffic class of IPv6. +newtype IPv6TClass = IPv6TClass CInt deriving (Eq, Show, Storable) + +instance ControlMessage IPv6TClass where + controlMessageId _ = CmsgIdIPv6TClass + +---------------------------------------------------------------- + +-- | Network interface ID and local IPv4 address. +data IPv4PktInfo = IPv4PktInfo CInt HostAddress deriving (Eq) + +instance Show IPv4PktInfo where + show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) + +instance ControlMessage IPv4PktInfo where + controlMessageId _ = CmsgIdIPv4PktInfo + +instance Storable IPv4PktInfo where + sizeOf _ = (#size struct in_pktinfo) + alignment = undefined + poke p (IPv4PktInfo n ha) = do + (#poke struct in_pktinfo, ipi_ifindex) p (fromIntegral n :: CInt) + (#poke struct in_pktinfo, ipi_spec_dst) p (0 :: CInt) + (#poke struct in_pktinfo, ipi_addr) p ha + peek p = do + n <- (#peek struct in_pktinfo, ipi_ifindex) p + ha <- (#peek struct in_pktinfo, ipi_addr) p + return $ IPv4PktInfo n ha + +---------------------------------------------------------------- + +-- | Network interface ID and local IPv4 address. +data IPv6PktInfo = IPv6PktInfo Int HostAddress6 deriving (Eq) + +instance Show IPv6PktInfo where + show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6) + +instance ControlMessage IPv6PktInfo where + controlMessageId _ = CmsgIdIPv6PktInfo + +instance Storable IPv6PktInfo where + sizeOf _ = (#size struct in6_pktinfo) + alignment = undefined + poke p (IPv6PktInfo n ha6) = do + (#poke struct in6_pktinfo, ipi6_ifindex) p (fromIntegral n :: CInt) + (#poke struct in6_pktinfo, ipi6_addr) p (In6Addr ha6) + peek p = do + In6Addr ha6 <- (#peek struct in6_pktinfo, ipi6_addr) p + n :: CInt <- (#peek struct in6_pktinfo, ipi6_ifindex) p + return $ IPv6PktInfo (fromIntegral n) ha6 + +---------------------------------------------------------------- + +instance ControlMessage Fd where + controlMessageId _ = CmsgIdFd diff --git a/Network/Socket/Posix/CmsgHdr.hsc b/Network/Socket/Posix/CmsgHdr.hsc new file mode 100644 index 00000000..3a26ced2 --- /dev/null +++ b/Network/Socket/Posix/CmsgHdr.hsc @@ -0,0 +1,102 @@ +{-# OPTIONS_GHC -funbox-strict-fields #-} + +#include "HsNet.h" + +module Network.Socket.Posix.CmsgHdr ( + Cmsg(..) + , withCmsgs + , parseCmsgs + ) where + +#include +#include + +import Foreign.Marshal.Alloc (allocaBytes) +import Foreign.ForeignPtr +import qualified Data.ByteString as B +import Data.ByteString.Internal + +import Network.Socket.Imports +import Network.Socket.Posix.Cmsg +import Network.Socket.Posix.MsgHdr +import Network.Socket.Types + +data CmsgHdr = CmsgHdr { + cmsgHdrLen :: !CInt + , cmsgHdrLevel :: !CInt + , cmsgHdrType :: !CInt + } deriving (Eq, Show) + +instance Storable CmsgHdr where + sizeOf _ = (#size struct cmsghdr) + alignment _ = alignment (undefined :: CInt) + + peek p = do + len <- (#peek struct cmsghdr, cmsg_len) p + lvl <- (#peek struct cmsghdr, cmsg_level) p + typ <- (#peek struct cmsghdr, cmsg_type) p + return $ CmsgHdr len lvl typ + + poke p (CmsgHdr len lvl typ) = do + zeroMemory p (#size struct cmsghdr) + (#poke struct cmsghdr, cmsg_len) p len + (#poke struct cmsghdr, cmsg_level) p lvl + (#poke struct cmsghdr, cmsg_type) p typ + +withCmsgs :: [Cmsg] -> (Ptr CmsgHdr -> Int -> IO a) -> IO a +withCmsgs cmsgs0 action + | total == 0 = action nullPtr 0 + | otherwise = allocaBytes total $ \ctrlPtr -> do + loop ctrlPtr cmsgs0 spaces + action ctrlPtr total + where + loop ctrlPtr (cmsg:cmsgs) (s:ss) = do + toCmsgHdr cmsg ctrlPtr + let nextPtr = ctrlPtr `plusPtr` s + loop nextPtr cmsgs ss + loop _ _ _ = return () + cmsg_space = fromIntegral . c_cmsg_space . fromIntegral + spaces = map (cmsg_space . B.length . cmsgData) cmsgs0 + total = sum spaces + +toCmsgHdr :: Cmsg -> Ptr CmsgHdr -> IO () +toCmsgHdr (Cmsg (CmsgId lvl typ) (PS fptr off len)) ctrlPtr = do + poke ctrlPtr $ CmsgHdr (c_cmsg_len (fromIntegral len)) lvl typ + withForeignPtr fptr $ \src0 -> do + let src = src0 `plusPtr` off + dst <- c_cmsg_data ctrlPtr + memcpy dst src len + +parseCmsgs :: SocketAddress sa => Ptr (MsgHdr sa) -> IO [Cmsg] +parseCmsgs msgptr = do + ptr <- c_cmsg_firsthdr msgptr + loop ptr id + where + loop ptr build + | ptr == nullPtr = return $ build [] + | otherwise = do + cmsg <- fromCmsgHdr ptr + nextPtr <- c_cmsg_nxthdr msgptr ptr + loop nextPtr (build . (cmsg :)) + +fromCmsgHdr :: Ptr CmsgHdr -> IO Cmsg +fromCmsgHdr ptr = do + CmsgHdr len lvl typ <- peek ptr + src <- c_cmsg_data ptr + let siz = fromIntegral len - (src `minusPtr` ptr) + Cmsg (CmsgId lvl typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz) + +foreign import ccall unsafe "cmsg_firsthdr" + c_cmsg_firsthdr :: Ptr (MsgHdr sa) -> IO (Ptr CmsgHdr) + +foreign import ccall unsafe "cmsg_nxthdr" + c_cmsg_nxthdr :: Ptr (MsgHdr sa) -> Ptr CmsgHdr -> IO (Ptr CmsgHdr) + +foreign import ccall unsafe "cmsg_data" + c_cmsg_data :: Ptr CmsgHdr -> IO (Ptr Word8) + +foreign import ccall unsafe "cmsg_space" + c_cmsg_space :: CInt -> CInt + +foreign import ccall unsafe "cmsg_len" + c_cmsg_len :: CInt -> CInt diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index f1105b02..b7d82de6 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -17,7 +17,7 @@ import System.Posix.Types (Fd(..)) import Network.Socket.Buffer import Network.Socket.Imports -import Network.Socket.Posix.Ancillary +import Network.Socket.Posix.Cmsg import Network.Socket.Types #if defined(HAVE_GETPEEREID) @@ -141,7 +141,7 @@ instance SocketAddress NullSockAddr where sendFd :: Socket -> CInt -> IO () #if defined(DOMAIN_SOCKET_SUPPORT) sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do - let cmsg = ancillaryEncode $ Fd outfd + let cmsg = encodeCmsg $ Fd outfd sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty where dummyBufSize = 1 @@ -158,7 +158,7 @@ recvFd :: Socket -> IO CInt #if defined(DOMAIN_SOCKET_SUPPORT) recvFd s = allocaBytes dummyBufSize $ \buf -> do (NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty - case (lookupAncillary ancillaryFd cmsgs >>= ancillaryDecode) :: Maybe Fd of + case (lookupCmsg CmsgIdFd cmsgs >>= decodeCmsg) :: Maybe Fd of Nothing -> return (-1) Just (Fd fd) -> return fd where diff --git a/network.cabal b/network.cabal index d359bb73..765afd0a 100644 --- a/network.cabal +++ b/network.cabal @@ -98,8 +98,8 @@ library if !os(windows) other-modules: Network.Socket.ByteString.Lazy.Posix - Network.Socket.Posix.Ancillary Network.Socket.Posix.Cmsg + Network.Socket.Posix.CmsgHdr Network.Socket.Posix.IOVec Network.Socket.Posix.MsgHdr c-sources: cbits/cmsg.c diff --git a/tests/Network/Socket/ByteStringSpec.hs b/tests/Network/Socket/ByteStringSpec.hs index ffe7aeb6..75af7fd5 100644 --- a/tests/Network/Socket/ByteStringSpec.hs +++ b/tests/Network/Socket/ByteStringSpec.hs @@ -234,9 +234,9 @@ spec = do setSocketOption sock RecvIPv4PktInfo 1 (_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty - ((lookupAncillary ancillaryIPv4TTL cmsgs >>= ancillaryDecode) :: Maybe IPv4TTL) `shouldNotBe` Nothing - ((lookupAncillary ancillaryIPv4TOS cmsgs >>= ancillaryDecode) :: Maybe IPv4TOS) `shouldNotBe` Nothing - ((lookupAncillary ancillaryIPv4PktInfo cmsgs >>= ancillaryDecode) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing + ((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing + ((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing + ((lookupCmsg CmsgIdIPv4PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" @@ -249,9 +249,9 @@ spec = do setSocketOption sock RecvIPv6PktInfo 1 (_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty - ((lookupAncillary ancillaryIPv6HopLimit cmsgs >>= ancillaryDecode) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing - ((lookupAncillary ancillaryIPv6TClass cmsgs >>= ancillaryDecode) :: Maybe IPv6TClass) `shouldNotBe` Nothing - ((lookupAncillary ancillaryIPv6PktInfo cmsgs >>= ancillaryDecode) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing + ((lookupCmsg CmsgIdIPv6HopLimit cmsgs >>= decodeCmsg) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing + ((lookupCmsg CmsgIdIPv6TClass cmsgs >>= decodeCmsg) :: Maybe IPv6TClass) `shouldNotBe` Nothing + ((lookupCmsg CmsgIdIPv6PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" From 7ff9f9ff96c96d1a4e839ce65dc3caf8f5bf7b0d Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 14 Jan 2020 13:17:43 +0900 Subject: [PATCH 29/48] making fields strict. --- Network/Socket/Options.hsc | 4 ++-- Network/Socket/Posix/Cmsg.hsc | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 1dfc067a..817f71c5 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -38,8 +38,8 @@ import Network.Socket.Types -- The existence of a constructor does not imply that the relevant option -- is supported on your system: see 'isSupportedSocketOption' data SocketOption = SockOpt { - sockOptLevel :: CInt - , sockOptName :: CInt + sockOptLevel :: !CInt + , sockOptName :: !CInt } deriving (Eq, Show) -- | Does the 'SocketOption' exist on this system? diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index 71ccc9e1..f8905bdc 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -20,16 +20,16 @@ import Network.Socket.Types -- | Control message (ancillary data) including a pair of level and type. data Cmsg = Cmsg { - cmsgId :: CmsgId - , cmsgData :: ByteString + cmsgId :: !CmsgId + , cmsgData :: !ByteString } deriving (Eq, Show) ---------------------------------------------------------------- -- | Identifier of control message (ancillary data). data CmsgId = CmsgId { - cmsgLevel :: CInt - , cmsglType :: CInt + cmsgLevel :: !CInt + , cmsglType :: !CInt } deriving (Eq, Show) -- | The identifier for 'IPv4TTL'. From b1df8f25c61d97248ec8e312ead8964b673683e4 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 14 Jan 2020 13:29:28 +0900 Subject: [PATCH 30/48] fix for Linux. --- Network/Socket/Unix.hsc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index b7d82de6..a63c1f37 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -75,7 +75,7 @@ getPeerCredential _ = return (Nothing, Nothing, Nothing) getPeerCred :: Socket -> IO (CUInt, CUInt, CUInt) #ifdef HAVE_STRUCT_UCRED_SO_PEERCRED getPeerCred s = do - let opt = SockOpt ((#const SOL_SOCKET),(#const SO_PEERCRED)) + let opt = SockOpt (#const SOL_SOCKET) (#const SO_PEERCRED) PeerCred cred <- getSockOpt s opt return cred From 5676bc9896bbf0cfc52689746e1acacfd23db23c Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 14 Jan 2020 14:31:50 +0900 Subject: [PATCH 31/48] TTL is CInt on Linux. --- Network/Socket/Posix/Cmsg.hsc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index f8905bdc..2834b201 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -106,7 +106,11 @@ decodeCmsg (Cmsg _ (PS fptr off len)) ---------------------------------------------------------------- -- | Time to live of IPv4. +#if defined(darwin_HOST_OS) newtype IPv4TTL = IPv4TTL CChar deriving (Eq, Show, Storable) +#else +newtype IPv4TTL = IPv4TTL CInt deriving (Eq, Show, Storable) +#endif instance ControlMessage IPv4TTL where controlMessageId _ = CmsgIdIPv4TTL From 5f3706fc72a0b67878ef2021cde68157a97fb3ca Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 14 Jan 2020 15:53:39 +0900 Subject: [PATCH 32/48] Using Int in IPv4PktInfo. --- Network/Socket/Posix/Cmsg.hsc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index 2834b201..7d39788d 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -142,7 +142,7 @@ instance ControlMessage IPv6TClass where ---------------------------------------------------------------- -- | Network interface ID and local IPv4 address. -data IPv4PktInfo = IPv4PktInfo CInt HostAddress deriving (Eq) +data IPv4PktInfo = IPv4PktInfo Int HostAddress deriving (Eq) instance Show IPv4PktInfo where show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) From d040a33a0462ac6b5e8c0dacc1d2e9f780f4b52a Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 14 Jan 2020 16:06:44 +0900 Subject: [PATCH 33/48] fixing the Storable instance. --- Network/Socket/Unix.hsc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index a63c1f37..e70409d2 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -82,8 +82,8 @@ getPeerCred s = do newtype PeerCred = PeerCred (CUInt, CUInt, CUInt) instance Storable PeerCred where sizeOf _ = (#const sizeof(struct ucred)) - alignment _ = (#const sizeof(int)) - poke = undefined + alignment _ = alignment (undefined :: CInt) + poke _ _ = return () peek p = do pid <- (#peek struct ucred, pid) p uid <- (#peek struct ucred, uid) p From 031df10e9e25d20886081dc2fbcea6c40e126050 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 15 Jan 2020 13:52:38 +0900 Subject: [PATCH 34/48] fixing alignment. --- Network/Socket/Posix/Cmsg.hsc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index 7d39788d..1674915a 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -152,7 +152,7 @@ instance ControlMessage IPv4PktInfo where instance Storable IPv4PktInfo where sizeOf _ = (#size struct in_pktinfo) - alignment = undefined + alignment _ = alignment (undefined :: CInt) poke p (IPv4PktInfo n ha) = do (#poke struct in_pktinfo, ipi_ifindex) p (fromIntegral n :: CInt) (#poke struct in_pktinfo, ipi_spec_dst) p (0 :: CInt) @@ -175,7 +175,7 @@ instance ControlMessage IPv6PktInfo where instance Storable IPv6PktInfo where sizeOf _ = (#size struct in6_pktinfo) - alignment = undefined + alignment _ = alignment (undefined :: CInt) poke p (IPv6PktInfo n ha6) = do (#poke struct in6_pktinfo, ipi6_ifindex) p (fromIntegral n :: CInt) (#poke struct in6_pktinfo, ipi6_addr) p (In6Addr ha6) From aef4e4200eb6f7f4116a52261457aa0dc4a032d2 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 16 Jan 2020 10:18:47 +0900 Subject: [PATCH 35/48] IPv4PktInfo now contains ipi_spec_dst. --- Network/Socket/Posix/Cmsg.hsc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index 1674915a..4f8d95d6 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -142,10 +142,10 @@ instance ControlMessage IPv6TClass where ---------------------------------------------------------------- -- | Network interface ID and local IPv4 address. -data IPv4PktInfo = IPv4PktInfo Int HostAddress deriving (Eq) +data IPv4PktInfo = IPv4PktInfo Int HostAddress HostAddress deriving (Eq) instance Show IPv4PktInfo where - show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) + show (IPv4PktInfo n sa ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple sa) ++ " " ++ show (hostAddressToTuple ha) instance ControlMessage IPv4PktInfo where controlMessageId _ = CmsgIdIPv4PktInfo @@ -153,14 +153,15 @@ instance ControlMessage IPv4PktInfo where instance Storable IPv4PktInfo where sizeOf _ = (#size struct in_pktinfo) alignment _ = alignment (undefined :: CInt) - poke p (IPv4PktInfo n ha) = do + poke p (IPv4PktInfo n sa ha) = do (#poke struct in_pktinfo, ipi_ifindex) p (fromIntegral n :: CInt) - (#poke struct in_pktinfo, ipi_spec_dst) p (0 :: CInt) + (#poke struct in_pktinfo, ipi_spec_dst) p sa (#poke struct in_pktinfo, ipi_addr) p ha peek p = do - n <- (#peek struct in_pktinfo, ipi_ifindex) p - ha <- (#peek struct in_pktinfo, ipi_addr) p - return $ IPv4PktInfo n ha + n <- (#peek struct in_pktinfo, ipi_ifindex) p + sa <- (#peek struct in_pktinfo, ipi_spec_dst) p + ha <- (#peek struct in_pktinfo, ipi_addr) p + return $ IPv4PktInfo n sa ha ---------------------------------------------------------------- From 54f3afc66a0c0024b7abd8f01ce75372ec10ed01 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 16 Jan 2020 11:19:35 +0900 Subject: [PATCH 36/48] adding filterCmsg. --- Network/Socket.hs | 1 + Network/Socket/Posix/Cmsg.hsc | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index 650e3c3c..1c5a416a 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -229,6 +229,7 @@ module Network.Socket ,CmsgIdIPv6PktInfo) -- ** APIs for control message , lookupCmsg + , filterCmsg , decodeCmsg , encodeCmsg -- ** Class and yypes for control message diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index 4f8d95d6..71620951 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -75,9 +75,13 @@ pattern CmsgIdFd = CmsgId (#const SOL_SOCKET) (#const SCM_RIGHTS) -- > (lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS lookupCmsg :: CmsgId -> [Cmsg] -> Maybe Cmsg lookupCmsg _ [] = Nothing -lookupCmsg aid (cmsg@(Cmsg cid _):cmsgs) - | aid == cid = Just cmsg - | otherwise = lookupCmsg aid cmsgs +lookupCmsg cid (cmsg:cmsgs) + | cmsgId cmsg == cid = Just cmsg + | otherwise = lookupCmsg cid cmsgs + +-- | Filtering control message. +filterCmsg :: CmsgId -> [Cmsg] -> [Cmsg] +filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs ---------------------------------------------------------------- From 3da86ef08c81941fa4107df5d57c92567abcc3a1 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 5 Feb 2020 10:21:02 +0900 Subject: [PATCH 37/48] fixing pattern synonym. --- Network/Socket/Flag.hsc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Network/Socket/Flag.hsc b/Network/Socket/Flag.hsc index a1cd1da8..ad55f12d 100644 --- a/Network/Socket/Flag.hsc +++ b/Network/Socket/Flag.hsc @@ -18,7 +18,7 @@ instance Sem.Semigroup MsgFlag where (<>) = (.|.) instance Monoid MsgFlag where - mempty = 0 + mempty = MsgFlag 0 #if !(MIN_VERSION_base(4,11,0)) mappend = (Sem.<>) #endif @@ -28,7 +28,7 @@ pattern MSG_OOB :: MsgFlag #ifdef MSG_OOB pattern MSG_OOB = MsgFlag (#const MSG_OOB) #else -pattern MSG_OOB = mempty +pattern MSG_OOB = MsgFlag 0 #endif -- | Bypass routing table lookup. @@ -36,7 +36,7 @@ pattern MSG_DONTROUTE :: MsgFlag #ifdef MSG_DONTROUTE pattern MSG_DONTROUTE = MsgFlag (#const MSG_DONTROUTE) #else -pattern MSG_DONTROUTE = mempty +pattern MSG_DONTROUTE = MsgFlag 0 #endif -- | Peek at incoming message without removing it from the queue. @@ -44,7 +44,7 @@ pattern MSG_PEEK :: MsgFlag #ifdef MSG_PEEK pattern MSG_PEEK = MsgFlag (#const MSG_PEEK) #else -pattern MSG_PEEK = mempty +pattern MSG_PEEK = MsgFlag 0 #endif -- | End of record. @@ -52,7 +52,7 @@ pattern MSG_EOR :: MsgFlag #ifdef MSG_EOR pattern MSG_EOR = MsgFlag (#const MSG_EOR) #else -pattern MSG_EOR = mempty +pattern MSG_EOR = MsgFlag 0 #endif -- | Received data is truncated. More data exist. @@ -60,7 +60,7 @@ pattern MSG_TRUNC :: MsgFlag #ifdef MSG_TRUNC pattern MSG_TRUNC = MsgFlag (#const MSG_TRUNC) #else -pattern MSG_TRUNC = mempty +pattern MSG_TRUNC = MsgFlag 0 #endif -- | Received control message is truncated. More control message exist. @@ -68,7 +68,7 @@ pattern MSG_CTRUNC :: MsgFlag #ifdef MSG_CTRUNC pattern MSG_CTRUNC = MsgFlag (#const MSG_CTRUNC) #else -pattern MSG_CTRUNC = mempty +pattern MSG_CTRUNC = MsgFlag 0 #endif -- | Wait until the requested number of bytes have been read. @@ -76,5 +76,5 @@ pattern MSG_WAITALL :: MsgFlag #ifdef MSG_WAITALL pattern MSG_WAITALL = MsgFlag (#const MSG_WAITALL) #else -pattern MSG_WAITALL = mempty +pattern MSG_WAITALL = MsgFlag 0 #endif From e1a405729ac557d12b291acc71762eba780495d7 Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Sun, 23 Feb 2020 18:27:52 +0000 Subject: [PATCH 38/48] network: Initial implementation --- .gitignore | 1 + Network/Socket/Buffer.hsc | 65 ++++++++++-- Network/Socket/Win32/Cmsg.hsc | 172 +++++++++++++++++++++++++++++++ Network/Socket/Win32/CmsgHdr.hsc | 101 ++++++++++++++++++ Network/Socket/Win32/MsgHdr.hsc | 51 +++++++++ Network/Socket/Win32/WSABuf.hsc | 48 +++++++++ cbits/cmsg.c | 23 +++++ include/HsNet.h | 15 +++ include/alignment.h | 3 + network.cabal | 6 +- 10 files changed, 475 insertions(+), 10 deletions(-) create mode 100644 Network/Socket/Win32/Cmsg.hsc create mode 100644 Network/Socket/Win32/CmsgHdr.hsc create mode 100644 Network/Socket/Win32/MsgHdr.hsc create mode 100644 Network/Socket/Win32/WSABuf.hsc create mode 100644 include/alignment.h diff --git a/.gitignore b/.gitignore index 71db6466..53e11f3f 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ cabal.sandbox.config .cabal-sandbox .stack-work/ .ghc.* +.vscode \ No newline at end of file diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 26f70019..cb8040f2 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -17,6 +17,9 @@ module Network.Socket.Buffer ( #if !defined(mingw32_HOST_OS) import Foreign.C.Error (getErrno, eAGAIN, eWOULDBLOCK) +#else +import System.Win32.Types +import Foreign.Ptr (nullPtr) #endif import Foreign.Marshal.Alloc (alloca, allocaBytes) import Foreign.Marshal.Utils (with) @@ -25,15 +28,19 @@ import System.IO.Error (mkIOError, ioeSetErrorString, catchIOError) #if defined(mingw32_HOST_OS) import GHC.IO.FD (FD(..), readRawBufferPtr, writeRawBufferPtr) +import Network.Socket.Win32.CmsgHdr +import Network.Socket.Win32.MsgHdr +import Network.Socket.Win32.WSABuf +#else +import Network.Socket.Posix.CmsgHdr +import Network.Socket.Posix.MsgHdr +import Network.Socket.Posix.IOVec #endif import Network.Socket.Imports import Network.Socket.Internal import Network.Socket.Name import Network.Socket.Types -import Network.Socket.Posix.CmsgHdr -import Network.Socket.Posix.MsgHdr -import Network.Socket.Posix.IOVec import Network.Socket.Flag -- | Send data to the socket. The recipient can be specified @@ -195,13 +202,22 @@ sendBufMsg :: SocketAddress sa -> IO Int -- ^ The length actually sent sendBufMsg s sa bufsizs cmsgs flags = do sz <- withSocketAddress sa $ \addrPtr addrSize -> +#if !defined(mingw32_HOST_OS) withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do +#else + withWSABuf bufsizs $ \(wsaBPtr, wsaBLen) -> do +#endif withCmsgs cmsgs $ \ctrlPtr ctrlLen -> do let msgHdr = MsgHdr { msgName = addrPtr , msgNameLen = fromIntegral addrSize +#if !defined(mingw32_HOST_OS) , msgIov = iovsPtr , msgIovLen = fromIntegral iovsLen +#else + , msgBuffer = wsaBPtr + , msgBufferLen = fromIntegral wsaBLen +#endif , msgCtrl = castPtr ctrlPtr , msgCtrlLen = fromIntegral ctrlLen , msgFlags = 0 @@ -210,7 +226,12 @@ sendBufMsg s sa bufsizs cmsgs flags = do withFdSocket s $ \fd -> with msgHdr $ \msgHdrPtr -> throwSocketErrorWaitWrite s "Network.Socket.Buffer.sendMsg" $ +#if !defined(mingw32_HOST_OS) c_sendmsg fd msgHdrPtr cflags +#else + alloca $ \send_ptr -> + c_sendmsg fd msgHdrPtr cflags send_ptr nullPtr nullPtr +#endif return $ fromIntegral sz -- | Receive data from the socket using recvmsg(2). @@ -227,20 +248,38 @@ recvBufMsg :: SocketAddress sa recvBufMsg s bufsizs clen flags = do withNewSocketAddress $ \addrPtr addrSize -> allocaBytes clen $ \ctrlPtr -> +#if !defined(mingw32_HOST_OS) withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do +#else + withWSABuf bufsizs $ \(wsaBPtr, wsaBLen) -> do +#endif let msgHdr = MsgHdr { msgName = addrPtr , msgNameLen = fromIntegral addrSize +#if !defined(mingw32_HOST_OS) , msgIov = iovsPtr , msgIovLen = fromIntegral iovsLen +#else + , msgBuffer = wsaBPtr + , msgBufferLen = fromIntegral wsaBLen +#endif , msgCtrl = castPtr ctrlPtr , msgCtrlLen = fromIntegral clen , msgFlags = 0 } - cflags = fromMsgFlag flags + _cflags = fromMsgFlag flags withFdSocket s $ \fd -> do with msgHdr $ \msgHdrPtr -> do - len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" (c_recvmsg fd msgHdrPtr cflags) + len <- fromIntegral <$> +#if !defined(mingw32_HOST_OS) + throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $ + c_recvmsg fd msgHdrPtr _cflags +#else + alloca $ \len_ptr -> + throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $ + c_recvmsg fd msgHdrPtr len_ptr nullPtr nullPtr + peek len_ptr +#endif sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s hdr <- peek msgHdrPtr cmsgs <- parseCmsgs msgHdrPtr @@ -250,12 +289,24 @@ recvBufMsg s bufsizs clen flags = do #if !defined(mingw32_HOST_OS) foreign import ccall unsafe "send" c_send :: CInt -> Ptr a -> CSize -> CInt -> IO CInt +foreign import ccall unsafe "sendmsg" + c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt -- fixme CSsize +foreign import ccall unsafe "recvmsg" + c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt #else foreign import CALLCONV SAFE_ON_WIN "ioctlsocket" c_ioctlsocket :: CInt -> CLong -> Ptr CULong -> IO CInt foreign import CALLCONV SAFE_ON_WIN "WSAGetLastError" c_WSAGetLastError :: IO CInt +foreign import CALLCONV SAFE_ON_WIN "sendmsg" + -- fixme Handle for SOCKET, see #426 + c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> DWORD -> LPDWORD -> Ptr () -> Ptr () -> IO CInt +foreign import CALLCONV SAFE_ON_WIN "recvmsg" + c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> LPDWORD -> Ptr () -> Ptr () -> IO CInt + +failIfSockError = failIf_ (==#{const SOCKET_ERROR}) #endif + foreign import ccall unsafe "recv" c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt foreign import CALLCONV SAFE_ON_WIN "sendto" @@ -263,7 +314,3 @@ foreign import CALLCONV SAFE_ON_WIN "sendto" foreign import CALLCONV SAFE_ON_WIN "recvfrom" c_recvfrom :: CInt -> Ptr a -> CSize -> CInt -> Ptr sa -> Ptr CInt -> IO CInt -foreign import ccall unsafe "sendmsg" - c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt -- fixme CSsize -foreign import ccall unsafe "recvmsg" - c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt diff --git a/Network/Socket/Win32/Cmsg.hsc b/Network/Socket/Win32/Cmsg.hsc new file mode 100644 index 00000000..222698c5 --- /dev/null +++ b/Network/Socket/Win32/Cmsg.hsc @@ -0,0 +1,172 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Network.Socket.Win32.Cmsg where + +#include "HsNet.h" + +import Data.ByteString.Internal +import Foreign.ForeignPtr +import System.IO.Unsafe (unsafeDupablePerformIO) +import System.Win32.Types (HANDLE) + +import Network.Socket.Imports +import Network.Socket.Types + +-- | Control message (ancillary data) including a pair of level and type. +data Cmsg = Cmsg { + cmsgId :: !CmsgId + , cmsgData :: !ByteString + } deriving (Eq, Show) + +---------------------------------------------------------------- + +-- | Identifier of control message (ancillary data). +data CmsgId = CmsgId { + cmsgLevel :: !CInt + , cmsglType :: !CInt + } deriving (Eq, Show) + +-- | The identifier for 'IPv4TTL'. +pattern CmsgIdIPv4TTL :: CmsgId +pattern CmsgIdIPv4TTL = CmsgId (#const IPPROTO_IP) (#const IP_TTL) + +-- | The identifier for 'IPv6HopLimit'. +pattern CmsgIdIPv6HopLimit :: CmsgId +pattern CmsgIdIPv6HopLimit = CmsgId (#const IPPROTO_IPV6) (#const IPV6_HOPLIMIT) + +-- | The identifier for 'IPv4TOS'. +pattern CmsgIdIPv4TOS :: CmsgId +pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_RECVTOS) + +-- | The identifier for 'IPv6TClass'. +pattern CmsgIdIPv6TClass :: CmsgId +pattern CmsgIdIPv6TClass = CmsgId (#const IPPROTO_IPV6) (#const IPV6_RECVTCLASS) + +-- | The identifier for 'IPv4PktInfo'. +pattern CmsgIdIPv4PktInfo :: CmsgId +pattern CmsgIdIPv4PktInfo = CmsgId (#const IPPROTO_IP) (#const IP_PKTINFO) + +-- | The identifier for 'IPv6PktInfo'. +pattern CmsgIdIPv6PktInfo :: CmsgId +pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO) + +-- Use WSADuplicateSocket for CmsgIdFd +-- pattern CmsgIdFd :: CmsgId + +---------------------------------------------------------------- + +-- | Looking up control message. The following shows an example usage: +-- +-- > (lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS +lookupCmsg :: CmsgId -> [Cmsg] -> Maybe Cmsg +lookupCmsg _ [] = Nothing +lookupCmsg cid (cmsg:cmsgs) + | cmsgId cmsg == cid = Just cmsg + | otherwise = lookupCmsg cid cmsgs + +-- | Filtering control message. +filterCmsg :: CmsgId -> [Cmsg] -> [Cmsg] +filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs + +---------------------------------------------------------------- + +-- | A class to encode and decode control message. +class Storable a => ControlMessage a where + controlMessageId :: a -> CmsgId + +encodeCmsg :: ControlMessage a => a -> Cmsg +encodeCmsg x = unsafeDupablePerformIO $ do + bs <- create siz $ \p0 -> do + let p = castPtr p0 + poke p x + return $ Cmsg (controlMessageId x) bs + where + siz = sizeOf x + +decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a +decodeCmsg (Cmsg _ (PS fptr off len)) + | len < siz = Nothing + | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do + let p = castPtr (p0 `plusPtr` off) + Just <$> peek p + where + siz = sizeOf (undefined :: a) + +---------------------------------------------------------------- + +-- | Time to live of IPv4. +newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable) + +instance ControlMessage IPv4TTL where + controlMessageId _ = CmsgIdIPv4TTL + +---------------------------------------------------------------- + +-- | Hop limit of IPv6. +newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable) + +instance ControlMessage IPv6HopLimit where + controlMessageId _ = CmsgIdIPv6HopLimit + +---------------------------------------------------------------- + +-- | TOS of IPv4. +newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable) + +instance ControlMessage IPv4TOS where + controlMessageId _ = CmsgIdIPv4TOS + +---------------------------------------------------------------- + +-- | Traffic class of IPv6. +newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable) + +instance ControlMessage IPv6TClass where + controlMessageId _ = CmsgIdIPv6TClass + +---------------------------------------------------------------- + +-- | Network interface ID and local IPv4 address. +data IPv4PktInfo = IPv4PktInfo ULONG HostAddress deriving (Eq) + +instance Show IPv4PktInfo where + show (IPv4PktInfo n sa ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) + +instance ControlMessage IPv4PktInfo where + controlMessageId _ = CmsgIdIPv4PktInfo + +instance Storable IPv4PktInfo where + sizeOf _ = const #{size IN_PKTINFO} + alignment _ = #alignment IN_PKTINFO + poke p (IPv4PktInfo n sa ha) = do + (#poke IN_PKTINFO, ipi_ifindex) p (fromIntegral n :: CInt) + (#poke IN_PKTINFO, ipi_addr) p ha + peek p = do + n <- (#peek IN_PKTINFO, ipi_ifindex) p + ha <- (#peek IN_PKTINFO, ipi_addr) p + return $ IPv4PktInfo n ha + +---------------------------------------------------------------- + +-- | Network interface ID and local IPv4 address. +data IPv6PktInfo = IPv6PktInfo Int HostAddress6 deriving (Eq) + +instance Show IPv6PktInfo where + show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6) + +instance ControlMessage IPv6PktInfo where + controlMessageId _ = CmsgIdIPv6PktInfo + +instance Storable IPv6PktInfo where + sizeOf _ = const #{size IN6_PKTINFO} + alignment _ = #alignment IN6_PKTINFO + poke p (IPv6PktInfo n ha6) = do + (#poke IN6_PKTINFO, ipi6_ifindex) p (fromIntegral n :: CInt) + (#poke IN6_PKTINFO, ipi6_addr) p (In6Addr ha6) + peek p = do + In6Addr ha6 <- (#peek IN6_PKTINFO, ipi6_addr) p + n :: ULONG <- (#peek IN6_PKTINFO, ipi6_ifindex) p + return $ IPv6PktInfo (fromIntegral n) ha6 diff --git a/Network/Socket/Win32/CmsgHdr.hsc b/Network/Socket/Win32/CmsgHdr.hsc new file mode 100644 index 00000000..da7dbc3d --- /dev/null +++ b/Network/Socket/Win32/CmsgHdr.hsc @@ -0,0 +1,101 @@ +{-# OPTIONS_GHC -funbox-strict-fields #-} + +#include "HsNet.h" + +module Network.Socket.Win32.CmsgHdr ( + Cmsg(..) + , withCmsgs + , parseCmsgs + ) where + +import Foreign.Marshal.Alloc (allocaBytes) +import Foreign.ForeignPtr +import qualified Data.ByteString as B +import Data.ByteString.Internal + +import Network.Socket.Imports +import Network.Socket.Win32.Cmsg +import Network.Socket.Win32.MsgHdr +import Network.Socket.Types + +import System.Win32.Types + +data CmsgHdr = CmsgHdr { + cmsgHdrLen :: !CUInt + , cmsgHdrLevel :: !CInt + , cmsgHdrType :: !CInt + } deriving (Eq, Show) + +instance Storable CmsgHdr where + sizeOf _ = const #{size WSACMSGHDR} + alignment _ = #alignment WSACMSGHDR + + peek p = do + len <- (#peek WSACMSGHDR, cmsg_len) p + lvl <- (#peek WSACMSGHDR, cmsg_level) p + typ <- (#peek WSACMSGHDR, cmsg_type) p + return $ CmsgHdr len lvl typ + + poke p (CmsgHdr len lvl typ) = do + zeroMemory p (#size WSACMSGHDR) + (#poke WSACMSGHDR, cmsg_len) p len + (#poke WSACMSGHDR, cmsg_level) p lvl + (#poke WSACMSGHDR, cmsg_type) p typ + +withCmsgs :: [Cmsg] -> (Ptr CmsgHdr -> Int -> IO a) -> IO a +withCmsgs cmsgs0 action + | total == 0 = action nullPtr 0 + | otherwise = allocaBytes total $ \ctrlPtr -> do + loop ctrlPtr cmsgs0 spaces + action ctrlPtr total + where + loop ctrlPtr (cmsg:cmsgs) (s:ss) = do + toCmsgHdr cmsg ctrlPtr + let nextPtr = ctrlPtr `plusPtr` s + loop nextPtr cmsgs ss + loop _ _ _ = return () + cmsg_space = fromIntegral . c_cmsg_space . fromIntegral + spaces = map (cmsg_space . B.length . cmsgData) cmsgs0 + total = sum spaces + +toCmsgHdr :: Cmsg -> Ptr CmsgHdr -> IO () +toCmsgHdr (Cmsg (CmsgId lvl typ) (PS fptr off len)) ctrlPtr = do + poke ctrlPtr $ CmsgHdr (c_cmsg_len (fromIntegral len)) lvl typ + withForeignPtr fptr $ \src0 -> do + let src = src0 `plusPtr` off + dst <- c_cmsg_data ctrlPtr + memcpy dst src len + +parseCmsgs :: SocketAddress sa => Ptr (MsgHdr sa) -> IO [Cmsg] +parseCmsgs msgptr = do + ptr <- c_cmsg_firsthdr msgptr + loop ptr id + where + loop ptr build + | ptr == nullPtr = return $ build [] + | otherwise = do + cmsg <- fromCmsgHdr ptr + nextPtr <- c_cmsg_nxthdr msgptr ptr + loop nextPtr (build . (cmsg :)) + +fromCmsgHdr :: Ptr CmsgHdr -> IO Cmsg +fromCmsgHdr ptr = do + CmsgHdr len lvl typ <- peek ptr + src <- c_cmsg_data ptr + let siz = fromIntegral len - (src `minusPtr` ptr) + Cmsg (CmsgId lvl typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz) + +foreign import ccall unsafe "cmsg_firsthdr" + c_cmsg_firsthdr :: Ptr (MsgHdr sa) -> IO (Ptr CmsgHdr) + +foreign import ccall unsafe "cmsg_nxthdr" + c_cmsg_nxthdr :: Ptr (MsgHdr sa) -> Ptr CmsgHdr -> IO (Ptr CmsgHdr) + +foreign import ccall unsafe "cmsg_data" + c_cmsg_data :: Ptr CmsgHdr -> IO (Ptr Word8) + +foreign import ccall unsafe "cmsg_space" + c_cmsg_space :: CUInt -> CUInt + +foreign import ccall unsafe "cmsg_len" + c_cmsg_len :: CUInt -> CUInt diff --git a/Network/Socket/Win32/MsgHdr.hsc b/Network/Socket/Win32/MsgHdr.hsc new file mode 100644 index 00000000..cf206098 --- /dev/null +++ b/Network/Socket/Win32/MsgHdr.hsc @@ -0,0 +1,51 @@ +{-# OPTIONS_GHC -funbox-strict-fields #-} + +-- | Support module for the Windows 'WSASendMsg' system call. +module Network.Socket.Win32.MsgHdr + ( MsgHdr(..) + ) where + +import Network.Socket.Imports +import Network.Socket.Internal (zeroMemory) +import Network.Socket.Win32.WSABuf + +import System.Win32.Types + +-- The size of BufferLen is different on pre-vista compilers. +-- But since those platforms are out of support anyway we ignore that. +data MsgHdr sa = MsgHdr + { msgName :: !(Ptr sa) + , msgNameLen :: !CInt + , msgBuffer :: !(Ptr WSABuf) + , msgBufferLen :: !DWORD + , msgCtr :: !(Ptr Word8) + , msgCtrLen :: !ULONG + , msgFlags :: !DWORD + } + +instance Storable (MsgHdr sa) where + sizeOf _ = const #{size WSAMSG} + alignment _ = #alignment WSAMSG + + peek p = do + name <- (#peek WSAMSG, name) p + nameLen <- (#peek WSAMSG, namelen) p + buffer <- (#peek WSAMSG, lpBuffers) p + bufferLen <- (#peek WSAMSG, dwBufferCount) p + ctrl <- (#peek WSAMSG, Control.buf) p + ctrlLen <- (#peek WSAMSG, Control.len) p + flags <- (#peek WSAMSG, dwFlags) p + return $ MsgHdr name nameLen buffer bufferLen ctrl ctrlLen flags + + poke p mh = do + -- We need to zero the msg_control, msg_controllen, and msg_flags + -- fields, but they only exist on some platforms (e.g. not on + -- Solaris). Instead of using CPP, we zero the entire struct. + zeroMemory p (#const sizeof(WSAMSG)) + (#poke WSAMSG, name) p (msgName mh) + (#poke WSAMSG, namelen) p (msgNameLen mh) + (#poke WSAMSG, lpBuffers) p (msgBuffer mh) + (#poke WSAMSG, dwBufferCount) p (msgBufferLen mh) + (#poke WSAMSG, Control.buf) p (msgCtrl mh) + (#poke WSAMSG, Control.len) p (msgCtrlLen mh) + (#poke WSAMSG, dwFlags) p (msgFlags mh) diff --git a/Network/Socket/Win32/WSABuf.hsc b/Network/Socket/Win32/WSABuf.hsc new file mode 100644 index 00000000..cafe72b9 --- /dev/null +++ b/Network/Socket/Win32/WSABuf.hsc @@ -0,0 +1,48 @@ +{-# OPTIONS_GHC -funbox-strict-fields #-} + +-- | Support module for the Windows winsock system calls. +module Network.Socket.Win32.WSABuf + ( WSABuf(..) + , withWSABuf + ) where + +#include "HsNet.h" + +import Foreign.Marshal.Array (allocaArray) + +import Network.Socket.Imports + +import System.Win32.Types + +data WSABuf = WSABuf + { wsaBufPtr :: !(Ptr Word8) + , wsaBufLen :: !ULONG + } + +instance Storable WSABuf where + sizeOf _ = const #{size WSABUF} + alignment _ = #alignment WSABUF + + peek p = do + base <- (#peek WSABUF, buf) p + len <- (#peek WSABUF, len) p + return $ IOVec base len + + poke p iov = do + (#poke WSABUF, buf) p (wsaBufPtr iov) + (#poke WSABUF, len) p (wsaBufLen iov) + +-- | @withWSABuf cs f@ executes the computation @f@, passing as argument a pair +-- consisting of a pointer to a temporarily allocated array of pointers to +-- WSABBUF made from @cs@ and the number of pointers (@length cs@). +-- /Windows only/. +withWSABuf :: [(Ptr Word8, Int)] -> ((Ptr WSABuf, Int) -> IO a) -> IO a +withWSABuf [] f = f (nullPtr, 0) +withWSABuf cs f = + allocaArray csLen $ \aPtr -> do + zipWithM_ pokeWsaBuf (ptrs aPtr) cs + f (aPtr, csLen) + where + csLen = length cs + ptrs = iterate (`plusPtr` sizeOf (WSABuf nullPtr 0)) + pokeWsaBuf ptr (sPtr, sLen) = poke ptr $ WSABuf sPtr (fromIntegral sLen) diff --git a/cbits/cmsg.c b/cbits/cmsg.c index 71f4d4ef..748f1381 100644 --- a/cbits/cmsg.c +++ b/cbits/cmsg.c @@ -1,6 +1,28 @@ #include "HsNet.h" #include +#ifdef _WIN32 + +struct LPCMSGHDR cmsg_firsthdr(LPWSAMSG mhdr) { + return (WSA_CMSG_FIRSTHDR(mhdr)); +} + +struct LPCMSGHDR cmsg_nxthdr(LPWSAMSG mhdr, LPWSACMSGHDR cmsg) { + return (WSA_CMSG_NXTHDR(mhdr, cmsg)); +} + +unsigned char *cmsg_data(LPWSACMSGHDR cmsg) { + return (WSA_CMSG_DATA(cmsg)); +} + +unsigned int cmsg_space(unsigned int l) { + return (WSA_CMSG_SPACE(l)); +} + +unsigned int cmsg_len(unsigned int l) { + return (WSA_CMSG_LEN(l)); +} +#else struct cmsghdr *cmsg_firsthdr(struct msghdr *mhdr) { return (CMSG_FIRSTHDR(mhdr)); } @@ -20,3 +42,4 @@ int cmsg_space(int l) { int cmsg_len(int l) { return (CMSG_LEN(l)); } +#endif /* _WIN32 */ diff --git a/include/HsNet.h b/include/HsNet.h index f848c0a1..32a05ed6 100644 --- a/include/HsNet.h +++ b/include/HsNet.h @@ -79,6 +79,21 @@ extern void* newAcceptParams(int sock, void* sockaddr); extern int acceptNewSock(void* d); extern int acceptDoProc(void* param); + +extern struct LPCMSGHDR +cmsg_firsthdr(LPWSAMSG mhdr); + +extern LPCMSGHDR +cmsg_nxthdr(LPWSAMSG mhdr, LPWSACMSGHDR cmsg); + +extern unsigned char * +cmsg_data(LPWSACMSGHDR cmsg); + +extern unsigned int +cmsg_space(unsigned int l); + +extern unsigned int +cmsg_len(unsigned int l); #else /* _WIN32 */ extern int sendFd(int sock, int outfd); diff --git a/include/alignment.h b/include/alignment.h new file mode 100644 index 00000000..787fe4a5 --- /dev/null +++ b/include/alignment.h @@ -0,0 +1,3 @@ +#if __GLASGOW_HASKELL__ < 711 +#define hsc_alignment(t ) hsc_printf ( "%lu", (unsigned long)offsetof(struct {char x__; t(y__); }, y__)); +#endif diff --git a/network.cabal b/network.cabal index 765afd0a..ece83010 100644 --- a/network.cabal +++ b/network.cabal @@ -87,7 +87,7 @@ library deepseq include-dirs: include - includes: HsNet.h HsNetDef.h + includes: HsNet.h HsNetDef.h alignment.h install-includes: HsNet.h HsNetDef.h c-sources: cbits/HsNet.c ghc-options: -Wall -fwarn-tabs @@ -110,6 +110,10 @@ library if os(windows) other-modules: Network.Socket.ByteString.Lazy.Windows + Network.Socket.Win32.Cmsg + Network.Socket.Win32.CmsgHdr + Network.Socket.Win32.WSABuf + Network.Socket.Win32.MsgHdr c-sources: cbits/initWinSock.c, cbits/winSockErr.c, cbits/asyncAccept.c extra-libraries: ws2_32, iphlpapi -- See https://github.com/haskell/network/pull/362 From 4ae61a3afe90602dc1833476abb880a6b2091bfa Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Sun, 23 Feb 2020 18:30:18 +0000 Subject: [PATCH 39/48] network: Add explicit cast --- Network/Socket/Buffer.hsc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index cb8040f2..739d44e5 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -278,7 +278,7 @@ recvBufMsg s bufsizs clen flags = do alloca $ \len_ptr -> throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $ c_recvmsg fd msgHdrPtr len_ptr nullPtr nullPtr - peek len_ptr + peek len_ptr :: IO DWORD #endif sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s hdr <- peek msgHdrPtr From 834e6ffe271d462c6cfbd6c43eeeeeca5b13c4c2 Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Mon, 16 Mar 2020 01:34:41 +0000 Subject: [PATCH 40/48] First implementation Windows msg --- Network/Socket.hs | 8 +- Network/Socket/Buffer.hsc | 22 ++--- Network/Socket/ByteString/IO.hsc | 49 +++++++++-- Network/Socket/ByteString/Internal.hs | 25 +++++- Network/Socket/SockAddr.hs | 5 ++ Network/Socket/Win32/Cmsg.hsc | 12 +-- Network/Socket/Win32/CmsgHdr.hsc | 4 +- Network/Socket/Win32/MsgHdr.hsc | 12 ++- Network/Socket/Win32/WSABuf.hsc | 6 +- include/HsNet.h | 6 +- include/win32defs.h | 117 ++++++++++++++++++++++++++ network.cabal | 7 +- 12 files changed, 235 insertions(+), 38 deletions(-) create mode 100644 include/win32defs.h diff --git a/Network/Socket.hs b/Network/Socket.hs index 1c5a416a..e0f8f3fd 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -192,12 +192,14 @@ module Network.Socket , socketPortSafe , socketPort +#if !defined(mingw32_HOST_OS) -- * UNIX-domain socket , isUnixDomainSocketAvailable , socketPair , sendFd , recvFd , getPeerCredential +#endif -- * Name information , getNameInfo @@ -254,9 +256,13 @@ import Network.Socket.Info import Network.Socket.Internal import Network.Socket.Name hiding (getPeerName, getSocketName) import Network.Socket.Options -import Network.Socket.Posix.Cmsg import Network.Socket.Shutdown import Network.Socket.SockAddr import Network.Socket.Syscall hiding (connect, bind, accept) import Network.Socket.Types +#if !defined(mingw32_HOST_OS) +import Network.Socket.Posix.Cmsg import Network.Socket.Unix +#else +import Network.Socket.Win32.Cmsg +#endif \ No newline at end of file diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 739d44e5..9efaced4 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -18,7 +18,6 @@ module Network.Socket.Buffer ( #if !defined(mingw32_HOST_OS) import Foreign.C.Error (getErrno, eAGAIN, eWOULDBLOCK) #else -import System.Win32.Types import Foreign.Ptr (nullPtr) #endif import Foreign.Marshal.Alloc (alloca, allocaBytes) @@ -43,6 +42,11 @@ import Network.Socket.Name import Network.Socket.Types import Network.Socket.Flag +#if defined(mingw32_HOST_OS) +type DWORD = Word32 +type LPDWORD = Ptr DWORD +#endif + -- | Send data to the socket. The recipient can be specified -- explicitly, so the socket need not be in a connected state. -- Returns the number of bytes sent. Applications are responsible for @@ -230,7 +234,7 @@ sendBufMsg s sa bufsizs cmsgs flags = do c_sendmsg fd msgHdrPtr cflags #else alloca $ \send_ptr -> - c_sendmsg fd msgHdrPtr cflags send_ptr nullPtr nullPtr + c_sendmsg fd msgHdrPtr (fromIntegral cflags) send_ptr nullPtr nullPtr #endif return $ fromIntegral sz @@ -270,20 +274,20 @@ recvBufMsg s bufsizs clen flags = do _cflags = fromMsgFlag flags withFdSocket s $ \fd -> do with msgHdr $ \msgHdrPtr -> do - len <- fromIntegral <$> + len <- (fmap fromIntegral) <$> #if !defined(mingw32_HOST_OS) throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $ c_recvmsg fd msgHdrPtr _cflags #else - alloca $ \len_ptr -> - throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $ - c_recvmsg fd msgHdrPtr len_ptr nullPtr nullPtr - peek len_ptr :: IO DWORD + alloca $ \len_ptr -> do + _ <- throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $ + c_recvmsg fd msgHdrPtr len_ptr nullPtr nullPtr + peek len_ptr #endif sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s hdr <- peek msgHdrPtr cmsgs <- parseCmsgs msgHdrPtr - let flags' = MsgFlag $ msgFlags hdr + let flags' = MsgFlag $ fromIntegral $ msgFlags hdr return (sockaddr, len, cmsgs, flags') #if !defined(mingw32_HOST_OS) @@ -303,8 +307,6 @@ foreign import CALLCONV SAFE_ON_WIN "sendmsg" c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> DWORD -> LPDWORD -> Ptr () -> Ptr () -> IO CInt foreign import CALLCONV SAFE_ON_WIN "recvmsg" c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> LPDWORD -> Ptr () -> Ptr () -> IO CInt - -failIfSockError = failIf_ (==#{const SOCKET_ERROR}) #endif foreign import ccall unsafe "recv" diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 7c4946ec..51898726 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -49,16 +49,22 @@ import Network.Socket.ByteString.Internal import Network.Socket.Imports import Network.Socket.Types -#if !defined(mingw32_HOST_OS) import Data.ByteString.Internal (create, ByteString(..)) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Marshal.Utils (with) import Network.Socket.Internal import Network.Socket.Flag + +#if !defined(mingw32_HOST_OS) import Network.Socket.Posix.Cmsg import Network.Socket.Posix.IOVec import Network.Socket.Posix.MsgHdr (MsgHdr(..)) +#else +import Foreign.Marshal.Alloc (alloca) +import Network.Socket.Win32.Cmsg +import Network.Socket.Win32.WSABuf +import Network.Socket.Win32.MsgHdr (MsgHdr(..)) #endif -- ---------------------------------------------------------------------------- @@ -130,7 +136,6 @@ sendAllTo s xs sa = do sendMany :: Socket -- ^ Connected socket -> [ByteString] -- ^ Data to send -> IO () -#if !defined(mingw32_HOST_OS) sendMany _ [] = return () sendMany s cs = do sent <- sendManyInner @@ -138,13 +143,20 @@ sendMany s cs = do when (sent >= 0) $ sendMany s $ remainingChunks sent cs where sendManyInner = +#if !defined(mingw32_HOST_OS) fmap fromIntegral . withIOVecfromBS cs $ \(iovsPtr, iovsLen) -> withFdSocket s $ \fd -> do let len = fromIntegral $ min iovsLen (#const IOV_MAX) throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMany" $ c_writev fd iovsPtr len #else -sendMany s = sendAll s . B.concat + fmap fromIntegral . withWSABuffromBS cs $ \(wsabsPtr, wsabsLen) -> + withFdSocket s $ \fd -> do + let len = fromIntegral wsabsLen + alloca $ \send_ptr -> do + _ <- throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMany" $ + c_wsasend fd wsabsPtr len send_ptr 0 nullPtr nullPtr + peek send_ptr #endif -- | Send data to the socket. The recipient can be specified @@ -157,7 +169,6 @@ sendManyTo :: Socket -- ^ Socket -> [ByteString] -- ^ Data to send -> SockAddr -- ^ Recipient address -> IO () -#if !defined(mingw32_HOST_OS) sendManyTo _ [] _ = return () sendManyTo s cs addr = do sent <- fromIntegral <$> sendManyToInner @@ -166,6 +177,7 @@ sendManyTo s cs addr = do where sendManyToInner = withSockAddr addr $ \addrPtr addrSize -> +#if !defined(mingw32_HOST_OS) withIOVecfromBS cs $ \(iovsPtr, iovsLen) -> do let msgHdr = MsgHdr { msgName = addrPtr @@ -181,7 +193,22 @@ sendManyTo s cs addr = do throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendManyTo" $ c_sendmsg fd msgHdrPtr 0 #else -sendManyTo s cs = sendAllTo s (B.concat cs) + withWSABuffromBS cs $ \(wsabsPtr, wsabsLen) -> do + let msgHdr = MsgHdr { + msgName = addrPtr + , msgNameLen = fromIntegral addrSize + , msgBuffer = wsabsPtr + , msgBufferLen = fromIntegral wsabsLen + , msgCtrl = nullPtr + , msgCtrlLen = 0 + , msgFlags = 0 + } + withFdSocket s $ \fd -> + with msgHdr $ \msgHdrPtr -> + alloca $ \send_ptr -> do + _ <- throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendManyTo" $ + c_sendmsg fd msgHdrPtr 0 send_ptr nullPtr nullPtr + peek send_ptr #endif -- ---------------------------------------------------------------------------- @@ -224,7 +251,7 @@ recvFrom sock nbytes = -- ---------------------------------------------------------------------------- -- Not exported -#if !defined(mingw32_HOST_OS) + -- | Suppose we try to transmit a list of chunks @cs@ via a gathering write -- operation and find that @n@ bytes were sent. Then @remainingChunks n cs@ is -- list of chunks remaining to be sent. @@ -236,6 +263,7 @@ remainingChunks i (x:xs) where len = B.length x +#if !defined(mingw32_HOST_OS) -- | @withIOVecfromBS cs f@ executes the computation @f@, passing as argument a pair -- consisting of a pointer to a temporarily allocated array of pointers to -- IOVec made from @cs@ and the number of pointers (@length cs@). @@ -244,6 +272,15 @@ withIOVecfromBS :: [ByteString] -> ((Ptr IOVec, Int) -> IO a) -> IO a withIOVecfromBS cs f = do bufsizs <- mapM getBufsiz cs withIOVec bufsizs f +#else +-- | @withWSABuffromBS cs f@ executes the computation @f@, passing as argument a pair +-- consisting of a pointer to a temporarily allocated array of pointers to +-- WSABuf made from @cs@ and the number of pointers (@length cs@). +-- /Windows only/. +withWSABuffromBS :: [ByteString] -> ((Ptr WSABuf, Int) -> IO a) -> IO a +withWSABuffromBS cs f = do + bufsizs <- mapM getBufsiz cs + withWSABuf bufsizs f #endif getBufsiz :: ByteString -> IO (Ptr Word8, Int) diff --git a/Network/Socket/ByteString/Internal.hs b/Network/Socket/ByteString/Internal.hs index b8f8853e..84c47ebc 100644 --- a/Network/Socket/ByteString/Internal.hs +++ b/Network/Socket/ByteString/Internal.hs @@ -14,11 +14,15 @@ module Network.Socket.ByteString.Internal mkInvalidRecvArgError #if !defined(mingw32_HOST_OS) , c_writev +#else + , c_wsasend +#endif , c_sendmsg , c_recvmsg -#endif ) where +#include "HsNetDef.h" + import GHC.IO.Exception (IOErrorType(..)) import System.IO.Error (ioeSetErrorString, mkIOError) @@ -29,6 +33,17 @@ import Network.Socket.Imports import Network.Socket.Posix.IOVec (IOVec) import Network.Socket.Posix.MsgHdr (MsgHdr) import Network.Socket.Types +#else +import Data.Word +import Foreign.C.Types +import Foreign.Ptr + +import Network.Socket.Win32.WSABuf (WSABuf) +import Network.Socket.Win32.MsgHdr (MsgHdr) +import Network.Socket.Types + +type DWORD = Word32 +type LPDWORD = Ptr DWORD #endif mkInvalidRecvArgError :: String -> IOError @@ -45,4 +60,12 @@ foreign import ccall unsafe "sendmsg" foreign import ccall unsafe "recvmsg" c_recvmsg :: CInt -> Ptr (MsgHdr SockAddr) -> CInt -> IO CSsize +#else + -- fixme Handle for SOCKET, see #426 +foreign import CALLCONV SAFE_ON_WIN "wsasend" + c_wsasend :: CInt -> Ptr WSABuf -> DWORD -> LPDWORD -> DWORD -> Ptr () -> Ptr () -> IO CInt +foreign import CALLCONV SAFE_ON_WIN "sendmsg" + c_sendmsg :: CInt -> Ptr (MsgHdr SockAddr) -> DWORD -> LPDWORD -> Ptr () -> Ptr () -> IO CInt +foreign import CALLCONV SAFE_ON_WIN "recvmsg" + c_recvmsg :: CInt -> Ptr (MsgHdr SockAddr) -> LPDWORD -> Ptr () -> Ptr () -> IO CInt #endif diff --git a/Network/Socket/SockAddr.hs b/Network/Socket/SockAddr.hs index 25049853..f668849b 100644 --- a/Network/Socket/SockAddr.hs +++ b/Network/Socket/SockAddr.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} module Network.Socket.SockAddr ( getPeerName , getSocketName @@ -15,7 +16,11 @@ import qualified Network.Socket.Name as G import qualified Network.Socket.Syscall as G import Network.Socket.Flag import Network.Socket.Imports +#if !defined(mingw32_HOST_OS) import Network.Socket.Posix.Cmsg +#else +import Network.Socket.Win32.Cmsg +#endif import Network.Socket.Types -- | Getting peer's 'SockAddr'. diff --git a/Network/Socket/Win32/Cmsg.hsc b/Network/Socket/Win32/Cmsg.hsc index 222698c5..6fd02d4e 100644 --- a/Network/Socket/Win32/Cmsg.hsc +++ b/Network/Socket/Win32/Cmsg.hsc @@ -10,11 +10,13 @@ module Network.Socket.Win32.Cmsg where import Data.ByteString.Internal import Foreign.ForeignPtr import System.IO.Unsafe (unsafeDupablePerformIO) -import System.Win32.Types (HANDLE) import Network.Socket.Imports import Network.Socket.Types +type DWORD = Word32 +type ULONG = Word32 + -- | Control message (ancillary data) including a pair of level and type. data Cmsg = Cmsg { cmsgId :: !CmsgId @@ -133,15 +135,15 @@ instance ControlMessage IPv6TClass where data IPv4PktInfo = IPv4PktInfo ULONG HostAddress deriving (Eq) instance Show IPv4PktInfo where - show (IPv4PktInfo n sa ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) + show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha) instance ControlMessage IPv4PktInfo where controlMessageId _ = CmsgIdIPv4PktInfo instance Storable IPv4PktInfo where - sizeOf _ = const #{size IN_PKTINFO} + sizeOf = const #{size IN_PKTINFO} alignment _ = #alignment IN_PKTINFO - poke p (IPv4PktInfo n sa ha) = do + poke p (IPv4PktInfo n ha) = do (#poke IN_PKTINFO, ipi_ifindex) p (fromIntegral n :: CInt) (#poke IN_PKTINFO, ipi_addr) p ha peek p = do @@ -161,7 +163,7 @@ instance ControlMessage IPv6PktInfo where controlMessageId _ = CmsgIdIPv6PktInfo instance Storable IPv6PktInfo where - sizeOf _ = const #{size IN6_PKTINFO} + sizeOf = const #{size IN6_PKTINFO} alignment _ = #alignment IN6_PKTINFO poke p (IPv6PktInfo n ha6) = do (#poke IN6_PKTINFO, ipi6_ifindex) p (fromIntegral n :: CInt) diff --git a/Network/Socket/Win32/CmsgHdr.hsc b/Network/Socket/Win32/CmsgHdr.hsc index da7dbc3d..467e452a 100644 --- a/Network/Socket/Win32/CmsgHdr.hsc +++ b/Network/Socket/Win32/CmsgHdr.hsc @@ -18,8 +18,6 @@ import Network.Socket.Win32.Cmsg import Network.Socket.Win32.MsgHdr import Network.Socket.Types -import System.Win32.Types - data CmsgHdr = CmsgHdr { cmsgHdrLen :: !CUInt , cmsgHdrLevel :: !CInt @@ -27,7 +25,7 @@ data CmsgHdr = CmsgHdr { } deriving (Eq, Show) instance Storable CmsgHdr where - sizeOf _ = const #{size WSACMSGHDR} + sizeOf = const #{size WSACMSGHDR} alignment _ = #alignment WSACMSGHDR peek p = do diff --git a/Network/Socket/Win32/MsgHdr.hsc b/Network/Socket/Win32/MsgHdr.hsc index cf206098..9f26257c 100644 --- a/Network/Socket/Win32/MsgHdr.hsc +++ b/Network/Socket/Win32/MsgHdr.hsc @@ -1,15 +1,19 @@ {-# OPTIONS_GHC -funbox-strict-fields #-} +{-# LANGUAGE CPP #-} -- | Support module for the Windows 'WSASendMsg' system call. module Network.Socket.Win32.MsgHdr ( MsgHdr(..) ) where +#include "HsNet.h" + import Network.Socket.Imports import Network.Socket.Internal (zeroMemory) import Network.Socket.Win32.WSABuf -import System.Win32.Types +type DWORD = Word32 +type ULONG = Word32 -- The size of BufferLen is different on pre-vista compilers. -- But since those platforms are out of support anyway we ignore that. @@ -18,13 +22,13 @@ data MsgHdr sa = MsgHdr , msgNameLen :: !CInt , msgBuffer :: !(Ptr WSABuf) , msgBufferLen :: !DWORD - , msgCtr :: !(Ptr Word8) - , msgCtrLen :: !ULONG + , msgCtrl :: !(Ptr Word8) + , msgCtrlLen :: !ULONG , msgFlags :: !DWORD } instance Storable (MsgHdr sa) where - sizeOf _ = const #{size WSAMSG} + sizeOf = const #{size WSAMSG} alignment _ = #alignment WSAMSG peek p = do diff --git a/Network/Socket/Win32/WSABuf.hsc b/Network/Socket/Win32/WSABuf.hsc index cafe72b9..c5026434 100644 --- a/Network/Socket/Win32/WSABuf.hsc +++ b/Network/Socket/Win32/WSABuf.hsc @@ -12,7 +12,7 @@ import Foreign.Marshal.Array (allocaArray) import Network.Socket.Imports -import System.Win32.Types +type ULONG = Word32 data WSABuf = WSABuf { wsaBufPtr :: !(Ptr Word8) @@ -20,13 +20,13 @@ data WSABuf = WSABuf } instance Storable WSABuf where - sizeOf _ = const #{size WSABUF} + sizeOf = const #{size WSABUF} alignment _ = #alignment WSABUF peek p = do base <- (#peek WSABUF, buf) p len <- (#peek WSABUF, len) p - return $ IOVec base len + return $ WSABuf base len poke p iov = do (#poke WSABUF, buf) p (wsaBufPtr iov) diff --git a/include/HsNet.h b/include/HsNet.h index 32a05ed6..9c15b441 100644 --- a/include/HsNet.h +++ b/include/HsNet.h @@ -25,6 +25,8 @@ #ifdef _WIN32 # include # include +# include +# include "win32defs.h" # define IPV6_V6ONLY 27 #endif @@ -80,10 +82,10 @@ extern void* newAcceptParams(int sock, extern int acceptNewSock(void* d); extern int acceptDoProc(void* param); -extern struct LPCMSGHDR +extern struct LPWSACMSGHDR cmsg_firsthdr(LPWSAMSG mhdr); -extern LPCMSGHDR +extern LPWSACMSGHDR cmsg_nxthdr(LPWSAMSG mhdr, LPWSACMSGHDR cmsg); extern unsigned char * diff --git a/include/win32defs.h b/include/win32defs.h new file mode 100644 index 00000000..ea39a170 --- /dev/null +++ b/include/win32defs.h @@ -0,0 +1,117 @@ +#ifndef IP_OPTIONS +#define IP_OPTIONS 1 // Set/get IP options. +#endif +#ifndef IP_HDRINCL +#define IP_HDRINCL 2 // Header is included with data. +#endif +#ifndef IP_TOS +#define IP_TOS 3 // IP type of service. +#endif +#ifndef IP_TTL +#define IP_TTL 4 // IP TTL (hop limit). +#endif +#ifndef IP_MULTICAST_IF +#define IP_MULTICAST_IF 9 // IP multicast interface. +#endif +#ifndef IP_MULTICAST_TTL +#define IP_MULTICAST_TTL 10 // IP multicast TTL (hop limit). +#endif +#ifndef IP_MULTICAST_LOOP +#define IP_MULTICAST_LOOP 11 // IP multicast loopback. +#endif +#ifndef IP_ADD_MEMBERSHIP +#define IP_ADD_MEMBERSHIP 12 // Add an IP group membership. +#endif +#ifndef IP_DROP_MEMBERSHIP +#define IP_DROP_MEMBERSHIP 13 // Drop an IP group membership. +#endif +#ifndef IP_DONTFRAGMENT +#define IP_DONTFRAGMENT 14 // Don't fragment IP datagrams. +#endif +#ifndef IP_ADD_SOURCE_MEMBERSHIP +#define IP_ADD_SOURCE_MEMBERSHIP 15 // Join IP group/source. +#endif +#ifndef IP_DROP_SOURCE_MEMBERSHIP +#define IP_DROP_SOURCE_MEMBERSHIP 16 // Leave IP group/source. +#endif +#ifndef IP_BLOCK_SOURCE +#define IP_BLOCK_SOURCE 17 // Block IP group/source. +#endif +#ifndef IP_UNBLOCK_SOURCE +#define IP_UNBLOCK_SOURCE 18 // Unblock IP group/source. +#endif +#ifndef IP_PKTINFO +#define IP_PKTINFO 19 // Receive packet information. +#endif +#ifndef IP_HOPLIMIT +#define IP_HOPLIMIT 21 // Receive packet hop limit. +#endif +#ifndef IP_RECVTTL +#define IP_RECVTTL 21 // Receive packet Time To Live (TTL). +#endif +#ifndef IP_RECEIVE_BROADCAST +#define IP_RECEIVE_BROADCAST 22 // Allow/block broadcast reception. +#endif +#ifndef IP_RECVIF +#define IP_RECVIF 24 // Receive arrival interface. +#endif +#ifndef IP_RECVDSTADDR +#define IP_RECVDSTADDR 25 // Receive destination address. +#endif +#ifndef IP_IFLIST +#define IP_IFLIST 28 // Enable/Disable an interface list. +#endif +#ifndef IP_ADD_IFLIST +#define IP_ADD_IFLIST 29 // Add an interface list entry. +#endif +#ifndef IP_DEL_IFLIST +#define IP_DEL_IFLIST 30 // Delete an interface list entry. +#endif +#ifndef IP_UNICAST_IF +#define IP_UNICAST_IF 31 // IP unicast interface. +#endif +#ifndef IP_RTHDR +#define IP_RTHDR 32 // Set/get IPv6 routing header. +#endif +#ifndef IP_GET_IFLIST +#define IP_GET_IFLIST 33 // Get an interface list. +#endif +#ifndef IP_RECVRTHDR +#define IP_RECVRTHDR 38 // Receive the routing header. +#endif +#ifndef IP_TCLASS +#define IP_TCLASS 39 // Packet traffic class. +#endif +#ifndef IP_RECVTCLASS +#define IP_RECVTCLASS 40 // Receive packet traffic class. +#endif +#ifndef IP_RECVTOS +#define IP_RECVTOS 40 // Receive packet Type Of Service (TOS). +#endif +#ifndef IP_ORIGINAL_ARRIVAL_IF +#define IP_ORIGINAL_ARRIVAL_IF 47 // Original Arrival Interface Index. +#endif +#ifndef IP_ECN +#define IP_ECN 50 // Receive ECN codepoints in the IP header. +#endif +#ifndef IP_PKTINFO_EX +#define IP_PKTINFO_EX 51 // Receive extended packet information. +#endif +#ifndef IP_WFP_REDIRECT_RECORDS +#define IP_WFP_REDIRECT_RECORDS 60 // WFP's Connection Redirect Records. +#endif +#ifndef IP_WFP_REDIRECT_CONTEXT +#define IP_WFP_REDIRECT_CONTEXT 70 // WFP's Connection Redirect Context. +#endif +#ifndef IP_MTU_DISCOVER +#define IP_MTU_DISCOVER 71 // Set/get path MTU discover state. +#endif +#ifndef IP_MTU +#define IP_MTU 73 // Get path MTU. +#endif +#ifndef IP_NRT_INTERFACE +#define IP_NRT_INTERFACE 74 // Set NRT interface constraint (outbound). +#endif +#ifndef IP_RECVERR +#define IP_RECVERR 75 // Receive ICMP errors. +#endif \ No newline at end of file diff --git a/network.cabal b/network.cabal index ece83010..6955acc9 100644 --- a/network.cabal +++ b/network.cabal @@ -79,7 +79,6 @@ library Network.Socket.SockAddr Network.Socket.Syscall Network.Socket.Types - Network.Socket.Unix build-depends: base >= 4.7 && < 5, @@ -87,8 +86,8 @@ library deepseq include-dirs: include - includes: HsNet.h HsNetDef.h alignment.h - install-includes: HsNet.h HsNetDef.h + includes: HsNet.h HsNetDef.h alignment.h win32defs.h + install-includes: HsNet.h HsNetDef.h win32defs.h c-sources: cbits/HsNet.c ghc-options: -Wall -fwarn-tabs build-tools: hsc2hs @@ -102,8 +101,10 @@ library Network.Socket.Posix.CmsgHdr Network.Socket.Posix.IOVec Network.Socket.Posix.MsgHdr + Network.Socket.Unix c-sources: cbits/cmsg.c + if os(solaris) extra-libraries: nsl, socket From 9683add7b3e74acee56f32da8fd9091ad3fe261c Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Mon, 16 Mar 2020 08:12:56 +0000 Subject: [PATCH 41/48] Fix Win32 linking issues. now need to fixtest failures --- Network/Socket/Buffer.hsc | 4 +-- Network/Socket/ByteString/Internal.hs | 6 ++-- cbits/cmsg.c | 47 +++++++++++++++++++++++++-- include/HsNet.h | 16 ++++++++- network.cabal | 7 ++-- tests/Network/SocketSpec.hs | 4 +++ 6 files changed, 72 insertions(+), 12 deletions(-) diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 9efaced4..c45eb790 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -302,10 +302,10 @@ foreign import CALLCONV SAFE_ON_WIN "ioctlsocket" c_ioctlsocket :: CInt -> CLong -> Ptr CULong -> IO CInt foreign import CALLCONV SAFE_ON_WIN "WSAGetLastError" c_WSAGetLastError :: IO CInt -foreign import CALLCONV SAFE_ON_WIN "sendmsg" +foreign import CALLCONV SAFE_ON_WIN "WSASendMsg" -- fixme Handle for SOCKET, see #426 c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> DWORD -> LPDWORD -> Ptr () -> Ptr () -> IO CInt -foreign import CALLCONV SAFE_ON_WIN "recvmsg" +foreign import CALLCONV SAFE_ON_WIN "WSARecvMsg" c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> LPDWORD -> Ptr () -> Ptr () -> IO CInt #endif diff --git a/Network/Socket/ByteString/Internal.hs b/Network/Socket/ByteString/Internal.hs index 84c47ebc..3d789d97 100644 --- a/Network/Socket/ByteString/Internal.hs +++ b/Network/Socket/ByteString/Internal.hs @@ -62,10 +62,10 @@ foreign import ccall unsafe "recvmsg" c_recvmsg :: CInt -> Ptr (MsgHdr SockAddr) -> CInt -> IO CSsize #else -- fixme Handle for SOCKET, see #426 -foreign import CALLCONV SAFE_ON_WIN "wsasend" +foreign import CALLCONV SAFE_ON_WIN "WSASend" c_wsasend :: CInt -> Ptr WSABuf -> DWORD -> LPDWORD -> DWORD -> Ptr () -> Ptr () -> IO CInt -foreign import CALLCONV SAFE_ON_WIN "sendmsg" +foreign import CALLCONV SAFE_ON_WIN "WSASendMsg" c_sendmsg :: CInt -> Ptr (MsgHdr SockAddr) -> DWORD -> LPDWORD -> Ptr () -> Ptr () -> IO CInt -foreign import CALLCONV SAFE_ON_WIN "recvmsg" +foreign import CALLCONV SAFE_ON_WIN "WSARecvMsg" c_recvmsg :: CInt -> Ptr (MsgHdr SockAddr) -> LPDWORD -> Ptr () -> Ptr () -> IO CInt #endif diff --git a/cbits/cmsg.c b/cbits/cmsg.c index 748f1381..b971e2c9 100644 --- a/cbits/cmsg.c +++ b/cbits/cmsg.c @@ -3,11 +3,11 @@ #ifdef _WIN32 -struct LPCMSGHDR cmsg_firsthdr(LPWSAMSG mhdr) { +LPWSACMSGHDR cmsg_firsthdr(LPWSAMSG mhdr) { return (WSA_CMSG_FIRSTHDR(mhdr)); } -struct LPCMSGHDR cmsg_nxthdr(LPWSAMSG mhdr, LPWSACMSGHDR cmsg) { +LPWSACMSGHDR cmsg_nxthdr(LPWSAMSG mhdr, LPWSACMSGHDR cmsg) { return (WSA_CMSG_NXTHDR(mhdr, cmsg)); } @@ -22,6 +22,49 @@ unsigned int cmsg_space(unsigned int l) { unsigned int cmsg_len(unsigned int l) { return (WSA_CMSG_LEN(l)); } + +static LPFN_WSASENDMSG ptr_SendMsg; +static LPFN_WSARECVMSG ptr_RecvMsg; +/* GUIDS to lookup WSASend/RecvMsg */ +static GUID WSARecvMsgGUID = WSAID_WSARECVMSG; +static GUID WSASendMsgGUID = WSAID_WSASENDMSG; + +int WINAPI +WSASendMsg (SOCKET s, LPWSAMSG lpMsg, DWORD flags, + LPDWORD lpdwNumberOfBytesRecvd, LPWSAOVERLAPPED lpOverlapped, + LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine) { + + if (!ptr_SendMsg) { + DWORD len; + if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, + &WSASendMsgGUID, sizeof(WSASendMsgGUID), &ptr_SendMsg, + sizeof(ptr_SendMsg), &len, NULL, NULL) != 0) + return -1; + } + + return ptr_SendMsg (s, lpMsg, flags, lpdwNumberOfBytesRecvd, lpOverlapped, + lpCompletionRoutine); +} + +/** + * WSARecvMsg function + */ +int WINAPI +WSARecvMsg (SOCKET s, LPWSAMSG lpMsg, LPDWORD lpdwNumberOfBytesRecvd, + LPWSAOVERLAPPED lpOverlapped, + LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine) { + + if (!ptr_RecvMsg) { + DWORD len; + if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, + &WSARecvMsgGUID, sizeof(WSARecvMsgGUID), &ptr_RecvMsg, + sizeof(ptr_RecvMsg), &len, NULL, NULL) != 0) + return -1; + } + + return ptr_RecvMsg (s, lpMsg, lpdwNumberOfBytesRecvd, lpOverlapped, + lpCompletionRoutine); +} #else struct cmsghdr *cmsg_firsthdr(struct msghdr *mhdr) { return (CMSG_FIRSTHDR(mhdr)); diff --git a/include/HsNet.h b/include/HsNet.h index 9c15b441..bd90f870 100644 --- a/include/HsNet.h +++ b/include/HsNet.h @@ -82,7 +82,7 @@ extern void* newAcceptParams(int sock, extern int acceptNewSock(void* d); extern int acceptDoProc(void* param); -extern struct LPWSACMSGHDR +extern LPWSACMSGHDR cmsg_firsthdr(LPWSAMSG mhdr); extern LPWSACMSGHDR @@ -96,6 +96,20 @@ cmsg_space(unsigned int l); extern unsigned int cmsg_len(unsigned int l); + +/** + * WSASendMsg function + */ +extern WINAPI int +WSASendMsg (SOCKET, LPWSAMSG, DWORD, LPDWORD, + LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE); + +/** + * WSARecvMsg function + */ +extern WINAPI int +WSARecvMsg (SOCKET, LPWSAMSG, LPDWORD, + LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE); #else /* _WIN32 */ extern int sendFd(int sock, int outfd); diff --git a/network.cabal b/network.cabal index 6955acc9..95c0bb57 100644 --- a/network.cabal +++ b/network.cabal @@ -88,7 +88,7 @@ library include-dirs: include includes: HsNet.h HsNetDef.h alignment.h win32defs.h install-includes: HsNet.h HsNetDef.h win32defs.h - c-sources: cbits/HsNet.c + c-sources: cbits/HsNet.c cbits/cmsg.c ghc-options: -Wall -fwarn-tabs build-tools: hsc2hs @@ -102,8 +102,6 @@ library Network.Socket.Posix.IOVec Network.Socket.Posix.MsgHdr Network.Socket.Unix - c-sources: cbits/cmsg.c - if os(solaris) extra-libraries: nsl, socket @@ -116,10 +114,11 @@ library Network.Socket.Win32.WSABuf Network.Socket.Win32.MsgHdr c-sources: cbits/initWinSock.c, cbits/winSockErr.c, cbits/asyncAccept.c - extra-libraries: ws2_32, iphlpapi + extra-libraries: ws2_32, iphlpapi, mswsock -- See https://github.com/haskell/network/pull/362 if impl(ghc >= 7.10) cpp-options: -D_WIN32_WINNT=0x0600 + cc-options: -D_WIN32_WINNT=0x0600 test-suite spec default-language: Haskell2010 diff --git a/tests/Network/SocketSpec.hs b/tests/Network/SocketSpec.hs index c80af019..42d94912 100644 --- a/tests/Network/SocketSpec.hs +++ b/tests/Network/SocketSpec.hs @@ -126,6 +126,7 @@ spec = do -- check if an exception is not thrown. isSupportedSockAddr addr `shouldBe` True +#if !defined(mingw32_HOST_OS) when isUnixDomainSocketAvailable $ do context "unix sockets" $ do it "basic unix sockets end-to-end" $ do @@ -134,6 +135,7 @@ spec = do recv sock 1024 `shouldReturn` testMsg addr `shouldBe` (SockAddrUnix "") test . setClientAction client $ unixWithUnlink unixAddr server +#endif #ifdef linux_HOST_OS it "can end-to-end with an abstract socket" $ do @@ -152,6 +154,7 @@ spec = do bind sock (SockAddrUnix abstractAddress) `shouldThrow` anyErrorCall #endif +#if !defined(mingw32_HOST_OS) describe "socketPair" $ do it "can send and recieve bi-directionally" $ do (s1, s2) <- socketPair AF_UNIX Stream defaultProtocol @@ -206,6 +209,7 @@ spec = do cred1 <- getPeerCredential s cred1 `shouldBe` (Nothing,Nothing,Nothing) -} +#endif describe "gracefulClose" $ do it "does not send TCP RST back" $ do From fb8529f2bda52e23288612554135de9e9cbcbd0f Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Mon, 16 Mar 2020 08:21:48 +0000 Subject: [PATCH 42/48] Fix bindist missin header --- network.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network.cabal b/network.cabal index 95c0bb57..0c2f3024 100644 --- a/network.cabal +++ b/network.cabal @@ -87,7 +87,7 @@ library include-dirs: include includes: HsNet.h HsNetDef.h alignment.h win32defs.h - install-includes: HsNet.h HsNetDef.h win32defs.h + install-includes: HsNet.h HsNetDef.h alignment.h win32defs.h c-sources: cbits/HsNet.c cbits/cmsg.c ghc-options: -Wall -fwarn-tabs build-tools: hsc2hs From 507a2c0e0b3367e29975950cd22a4d4f0ba9e386 Mon Sep 17 00:00:00 2001 From: Tamar Christina Date: Sun, 5 Apr 2020 19:54:46 +0100 Subject: [PATCH 43/48] Finish windows implementation --- Network/Socket.hs | 1 + Network/Socket/Buffer.hsc | 10 +++++- Network/Socket/Internal.hs | 44 ++++++++++++++++++++++++-- Network/Socket/Options.hsc | 8 +++++ Network/Socket/Win32/Cmsg.hsc | 4 +-- Network/Socket/Win32/CmsgHdr.hsc | 17 ++++++---- Network/Socket/Win32/MsgHdr.hsc | 2 +- cbits/cmsg.c | 13 ++++++-- include/win32defs.h | 3 ++ tests/Network/Socket/ByteStringSpec.hs | 37 +++++++++++++--------- 10 files changed, 110 insertions(+), 29 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index e0f8f3fd..7c3f736d 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -140,6 +140,7 @@ module Network.Socket ,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo) , isSupportedSocketOption + , whenSupported , getSocketOption , setSocketOption , getSockOpt diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index c45eb790..16934709 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -267,9 +267,17 @@ recvBufMsg s bufsizs clen flags = do , msgBuffer = wsaBPtr , msgBufferLen = fromIntegral wsaBLen #endif +#if !defined(mingw32_HOST_OS) , msgCtrl = castPtr ctrlPtr +#else + , msgCtrl = if clen == 0 then nullPtr else castPtr ctrlPtr +#endif , msgCtrlLen = fromIntegral clen +#if !defined(mingw32_HOST_OS) , msgFlags = 0 +#else + , msgFlags = fromIntegral $ fromMsgFlag flags +#endif } _cflags = fromMsgFlag flags withFdSocket s $ \fd -> do @@ -280,7 +288,7 @@ recvBufMsg s bufsizs clen flags = do c_recvmsg fd msgHdrPtr _cflags #else alloca $ \len_ptr -> do - _ <- throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $ + _ <- throwSocketErrorWaitReadBut (== #{const WSAEMSGSIZE}) s "Network.Socket.Buffer.recvmg" $ c_recvmsg fd msgHdrPtr len_ptr nullPtr nullPtr peek len_ptr #endif diff --git a/Network/Socket/Internal.hs b/Network/Socket/Internal.hs index bfe07385..dd2f68d0 100644 --- a/Network/Socket/Internal.hs +++ b/Network/Socket/Internal.hs @@ -34,12 +34,14 @@ module Network.Socket.Internal , throwSocketErrorIfMinus1Retry , throwSocketErrorIfMinus1Retry_ , throwSocketErrorIfMinus1RetryMayBlock + , throwSocketErrorIfMinus1ButRetry -- ** Guards that wait and retry if the operation would block -- | These guards are based on 'throwSocketErrorIfMinus1RetryMayBlock'. -- They wait for socket readiness if the action fails with @EWOULDBLOCK@ -- or similar. , throwSocketErrorWaitRead + , throwSocketErrorWaitReadBut , throwSocketErrorWaitWrite -- * Initialization @@ -134,16 +136,37 @@ throwSocketErrorIfMinus1RetryMayBlock {-# SPECIALIZE throwSocketErrorIfMinus1RetryMayBlock :: String -> IO b -> IO CInt -> IO CInt #-} + +-- | Throw an 'IOError' corresponding to the current socket error if +-- the IO action returns a result of @-1@, but retries in case of an +-- interrupted operation. Checks for operations that would block and +-- executes an alternative action before retrying in that case. If the error +-- is one handled by the exempt filter then ignore it and return the errorcode. +throwSocketErrorIfMinus1RetryMayBlockBut + :: (Eq a, Num a) + => (CInt -> Bool) -- ^ exception exempt filter + -> String -- ^ textual description of the location + -> IO b -- ^ action to execute before retrying if an + -- immediate retry would block + -> IO a -- ^ the 'IO' operation to be executed + -> IO a + +{-# SPECIALIZE throwSocketErrorIfMinus1RetryMayBlock + :: String -> IO b -> IO CInt -> IO CInt #-} + #if defined(mingw32_HOST_OS) throwSocketErrorIfMinus1RetryMayBlock name _ act = throwSocketErrorIfMinus1Retry name act +throwSocketErrorIfMinus1RetryMayBlockBut exempt name _ act + = throwSocketErrorIfMinus1ButRetry exempt name act + throwSocketErrorIfMinus1_ name act = do _ <- throwSocketErrorIfMinus1Retry name act return () -throwSocketErrorIfMinus1Retry name act = do +throwSocketErrorIfMinus1ButRetry exempt name act = do r <- act if (r == -1) then do @@ -155,7 +178,9 @@ throwSocketErrorIfMinus1Retry name act = do then throwSocketError name else return r' else - throwSocketError name + if (exempt rc) + then return r + else throwSocketError name else return r throwSocketErrorCode name rc = do @@ -177,6 +202,9 @@ foreign import ccall unsafe "getWSErrorDescr" throwSocketErrorIfMinus1RetryMayBlock name on_block act = throwErrnoIfMinus1RetryMayBlock name act on_block +throwSocketErrorIfMinus1RetryMayBlockBut _exempt name on_block act = + throwErrnoIfMinus1RetryMayBlock name act on_block + throwSocketErrorIfMinus1Retry = throwErrnoIfMinus1Retry throwSocketErrorIfMinus1_ = throwErrnoIfMinus1_ @@ -188,6 +216,9 @@ throwSocketErrorCode loc errno = #endif +throwSocketErrorIfMinus1Retry + = throwSocketErrorIfMinus1ButRetry (const False) + -- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with -- @EWOULDBLOCK@ or similar, wait for the socket to be read-ready, -- and try again. @@ -196,6 +227,15 @@ throwSocketErrorWaitRead s name io = withFdSocket s $ \fd -> throwSocketErrorIfMinus1RetryMayBlock name (threadWaitRead $ fromIntegral fd) io +-- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with +-- @EWOULDBLOCK@ or similar, wait for the socket to be read-ready, +-- and try again. If it fails with the error the user was expecting then +-- ignore the error +throwSocketErrorWaitReadBut :: (Eq a, Num a) => (CInt -> Bool) -> Socket -> String -> IO a -> IO a +throwSocketErrorWaitReadBut exempt s name io = withFdSocket s $ \fd -> + throwSocketErrorIfMinus1RetryMayBlockBut exempt name + (threadWaitRead $ fromIntegral fd) io + -- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with -- @EWOULDBLOCK@ or similar, wait for the socket to be write-ready, -- and try again. diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 817f71c5..65511754 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -16,6 +16,7 @@ module Network.Socket.Options ( ,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo ,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo) , isSupportedSocketOption + , whenSupported , getSocketType , getSocketOption , setSocketOption @@ -289,6 +290,13 @@ instance Storable StructLinger where (#poke struct linger, l_linger) p linger #endif +-- | Executes the given action and ignoring the result only when the specified +-- socket option is valid. +whenSupported :: SocketOption -> IO a -> IO () +whenSupported s action + | isSupportedSocketOption s = action >> return () + | otherwise = return () + -- | Set a socket option that expects an Int value. -- There is currently no API to set e.g. the timeval socket options setSocketOption :: Socket diff --git a/Network/Socket/Win32/Cmsg.hsc b/Network/Socket/Win32/Cmsg.hsc index 6fd02d4e..531d9e70 100644 --- a/Network/Socket/Win32/Cmsg.hsc +++ b/Network/Socket/Win32/Cmsg.hsc @@ -41,11 +41,11 @@ pattern CmsgIdIPv6HopLimit = CmsgId (#const IPPROTO_IPV6) (#const IPV6_HOPLIMIT) -- | The identifier for 'IPv4TOS'. pattern CmsgIdIPv4TOS :: CmsgId -pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_RECVTOS) +pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_TOS) -- | The identifier for 'IPv6TClass'. pattern CmsgIdIPv6TClass :: CmsgId -pattern CmsgIdIPv6TClass = CmsgId (#const IPPROTO_IPV6) (#const IPV6_RECVTCLASS) +pattern CmsgIdIPv6TClass = CmsgId (#const IPPROTO_IPV6) (#const IPV6_TCLASS) -- | The identifier for 'IPv4PktInfo'. pattern CmsgIdIPv4PktInfo :: CmsgId diff --git a/Network/Socket/Win32/CmsgHdr.hsc b/Network/Socket/Win32/CmsgHdr.hsc index 467e452a..e1248a6d 100644 --- a/Network/Socket/Win32/CmsgHdr.hsc +++ b/Network/Socket/Win32/CmsgHdr.hsc @@ -72,16 +72,21 @@ parseCmsgs msgptr = do loop ptr build | ptr == nullPtr = return $ build [] | otherwise = do - cmsg <- fromCmsgHdr ptr - nextPtr <- c_cmsg_nxthdr msgptr ptr - loop nextPtr (build . (cmsg :)) - -fromCmsgHdr :: Ptr CmsgHdr -> IO Cmsg + val <- fromCmsgHdr ptr + case val of + Nothing -> return $ build [] + Just cmsg -> do + nextPtr <- c_cmsg_nxthdr msgptr ptr + loop nextPtr (build . (cmsg :)) + +fromCmsgHdr :: Ptr CmsgHdr -> IO (Maybe Cmsg) fromCmsgHdr ptr = do CmsgHdr len lvl typ <- peek ptr src <- c_cmsg_data ptr let siz = fromIntegral len - (src `minusPtr` ptr) - Cmsg (CmsgId lvl typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz) + if siz < 0 + then return Nothing + else Just . Cmsg (CmsgId lvl typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz) foreign import ccall unsafe "cmsg_firsthdr" c_cmsg_firsthdr :: Ptr (MsgHdr sa) -> IO (Ptr CmsgHdr) diff --git a/Network/Socket/Win32/MsgHdr.hsc b/Network/Socket/Win32/MsgHdr.hsc index 9f26257c..da5f6b01 100644 --- a/Network/Socket/Win32/MsgHdr.hsc +++ b/Network/Socket/Win32/MsgHdr.hsc @@ -25,7 +25,7 @@ data MsgHdr sa = MsgHdr , msgCtrl :: !(Ptr Word8) , msgCtrlLen :: !ULONG , msgFlags :: !DWORD - } + } deriving Show instance Storable (MsgHdr sa) where sizeOf = const #{size WSAMSG} diff --git a/cbits/cmsg.c b/cbits/cmsg.c index b971e2c9..c532f5ad 100644 --- a/cbits/cmsg.c +++ b/cbits/cmsg.c @@ -62,8 +62,17 @@ WSARecvMsg (SOCKET s, LPWSAMSG lpMsg, LPDWORD lpdwNumberOfBytesRecvd, return -1; } - return ptr_RecvMsg (s, lpMsg, lpdwNumberOfBytesRecvd, lpOverlapped, - lpCompletionRoutine); + int res = ptr_RecvMsg (s, lpMsg, lpdwNumberOfBytesRecvd, lpOverlapped, + lpCompletionRoutine); + + /* If the msg was truncated then this pointer can be garbage. */ + if (res == SOCKET_ERROR && GetLastError () == WSAEMSGSIZE) + { + lpMsg->Control.len = 0; + lpMsg->Control.buf = NULL; + } + + return res; } #else struct cmsghdr *cmsg_firsthdr(struct msghdr *mhdr) { diff --git a/include/win32defs.h b/include/win32defs.h index ea39a170..d9261be3 100644 --- a/include/win32defs.h +++ b/include/win32defs.h @@ -114,4 +114,7 @@ #endif #ifndef IP_RECVERR #define IP_RECVERR 75 // Receive ICMP errors. +#endif +#ifndef IPV6_TCLASS +#define IPV6_TCLASS 39 #endif \ No newline at end of file diff --git a/tests/Network/Socket/ByteStringSpec.hs b/tests/Network/Socket/ByteStringSpec.hs index 75af7fd5..502e87dd 100644 --- a/tests/Network/Socket/ByteStringSpec.hs +++ b/tests/Network/Socket/ByteStringSpec.hs @@ -229,14 +229,17 @@ spec = do it "receives control messages for IPv4" $ do let server sock = do - setSocketOption sock RecvIPv4TTL 1 - setSocketOption sock RecvIPv4TOS 1 - setSocketOption sock RecvIPv4PktInfo 1 + whenSupported RecvIPv4TTL $ setSocketOption sock RecvIPv4TTL 1 + whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1 + whenSupported RecvIPv4PktInfo $ setSocketOption sock RecvIPv4PktInfo 1 (_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty - ((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing - ((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing - ((lookupCmsg CmsgIdIPv4PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing + whenSupported RecvIPv4TTL $ + ((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing + whenSupported RecvIPv4TOS $ + ((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing + whenSupported RecvIPv4PktInfo $ + ((lookupCmsg CmsgIdIPv4PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" @@ -244,14 +247,18 @@ spec = do it "receives control messages for IPv6" $ do let server sock = do - setSocketOption sock RecvIPv6HopLimit 1 - setSocketOption sock RecvIPv6TClass 1 - setSocketOption sock RecvIPv6PktInfo 1 + whenSupported RecvIPv6HopLimit $ setSocketOption sock RecvIPv6HopLimit 1 + whenSupported RecvIPv6TClass $ setSocketOption sock RecvIPv6TClass 1 + whenSupported RecvIPv6PktInfo $ setSocketOption sock RecvIPv6PktInfo 1 (_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty - ((lookupCmsg CmsgIdIPv6HopLimit cmsgs >>= decodeCmsg) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing - ((lookupCmsg CmsgIdIPv6TClass cmsgs >>= decodeCmsg) :: Maybe IPv6TClass) `shouldNotBe` Nothing - ((lookupCmsg CmsgIdIPv6PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing + + whenSupported RecvIPv6HopLimit $ + ((lookupCmsg CmsgIdIPv6HopLimit cmsgs >>= decodeCmsg) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing + whenSupported RecvIPv6TClass $ + ((lookupCmsg CmsgIdIPv6TClass cmsgs >>= decodeCmsg) :: Maybe IPv6TClass) `shouldNotBe` Nothing + whenSupported RecvIPv6PktInfo $ + ((lookupCmsg CmsgIdIPv6PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing client sock addr = sendTo sock seg addr seg = C.pack "This is a test message" @@ -259,9 +266,9 @@ spec = do it "receives truncated control messages" $ do let server sock = do - setSocketOption sock RecvIPv4TTL 1 - setSocketOption sock RecvIPv4TOS 1 - setSocketOption sock RecvIPv4PktInfo 1 + whenSupported RecvIPv4TTL $ setSocketOption sock RecvIPv4TTL 1 + whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1 + whenSupported RecvIPv4PktInfo $ setSocketOption sock RecvIPv4PktInfo 1 (_, _, _, flags) <- recvMsg sock 1024 10 mempty flags .&. MSG_CTRUNC `shouldBe` MSG_CTRUNC From 7cc239a12db62d75fcb8f41e2597ac93f2130766 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 8 Apr 2020 10:26:12 +0900 Subject: [PATCH 44/48] fixing a gap between Unix and Windows. --- Network/Socket/Internal.hs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/Network/Socket/Internal.hs b/Network/Socket/Internal.hs index dd2f68d0..6d8ebb2e 100644 --- a/Network/Socket/Internal.hs +++ b/Network/Socket/Internal.hs @@ -34,8 +34,9 @@ module Network.Socket.Internal , throwSocketErrorIfMinus1Retry , throwSocketErrorIfMinus1Retry_ , throwSocketErrorIfMinus1RetryMayBlock +#if defined(mingw32_HOST_OS) , throwSocketErrorIfMinus1ButRetry - +#endif -- ** Guards that wait and retry if the operation would block -- | These guards are based on 'throwSocketErrorIfMinus1RetryMayBlock'. -- They wait for socket readiness if the action fails with @EWOULDBLOCK@ @@ -183,6 +184,9 @@ throwSocketErrorIfMinus1ButRetry exempt name act = do else throwSocketError name else return r +throwSocketErrorIfMinus1Retry + = throwSocketErrorIfMinus1ButRetry (const False) + throwSocketErrorCode name rc = do pstr <- c_getWSError rc str <- peekCString pstr @@ -216,9 +220,6 @@ throwSocketErrorCode loc errno = #endif -throwSocketErrorIfMinus1Retry - = throwSocketErrorIfMinus1ButRetry (const False) - -- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with -- @EWOULDBLOCK@ or similar, wait for the socket to be read-ready, -- and try again. From f536d352ad2c196bba321504b00ed676edae57b8 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 8 Apr 2020 10:57:27 +0900 Subject: [PATCH 45/48] improving docs and definitions as suggested by vdukhovni. --- Network/Socket/Buffer.hsc | 9 ++++++--- Network/Socket/Posix/Cmsg.hsc | 8 +++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 16934709..353b10c2 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -238,17 +238,20 @@ sendBufMsg s sa bufsizs cmsgs flags = do #endif return $ fromIntegral sz --- | Receive data from the socket using recvmsg(2). +-- | Receive data from the socket using recvmsg(2). The supplied +-- buffers are filled in order, with subsequent buffers used only +-- after all the preceding buffers are full. If the message is short +-- enough some of the supplied buffers may remain unused. recvBufMsg :: SocketAddress sa => Socket -- ^ Socket - -> [(Ptr Word8,Int)] -- ^ A list of a pair of buffer and its size. + -> [(Ptr Word8,Int)] -- ^ A list of (buffer, buffer-length) pairs. -- If the total length is not large enough, -- 'MSG_TRUNC' is returned -> Int -- ^ The buffer size for control messages. -- If the length is not large enough, -- 'MSG_CTRUNC' is returned -> MsgFlag -- ^ Message flags - -> IO (sa,Int,[Cmsg],MsgFlag) -- ^ Source address, received data, control messages and message flags + -> IO (sa,Int,[Cmsg],MsgFlag) -- ^ Source address, total bytes received, control messages and message flags recvBufMsg s bufsizs clen flags = do withNewSocketAddress $ \addrPtr addrSize -> allocaBytes clen $ \ctrlPtr -> diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index 71620951..a3f7d0a8 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -70,14 +70,12 @@ pattern CmsgIdFd = CmsgId (#const SOL_SOCKET) (#const SCM_RIGHTS) ---------------------------------------------------------------- --- | Looking up control message. The following shows an example usage: +-- | Locate a control message of the given type in a list of control +-- messages. The following shows an example usage: -- -- > (lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS lookupCmsg :: CmsgId -> [Cmsg] -> Maybe Cmsg -lookupCmsg _ [] = Nothing -lookupCmsg cid (cmsg:cmsgs) - | cmsgId cmsg == cid = Just cmsg - | otherwise = lookupCmsg cid cmsgs +lookupCmsg cid cmsgs = find (\cmsg -> cmsgId cmsg == cid) cmsgs -- | Filtering control message. filterCmsg :: CmsgId -> [Cmsg] -> [Cmsg] From 3aa7b18724b0765a861cf74b555facd2acf09c14 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 8 Apr 2020 10:59:09 +0900 Subject: [PATCH 46/48] using GHC 8.8.3. --- .travis.yml | 2 +- appveyor.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 21a0f34b..85c5491a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,7 +38,7 @@ matrix: - compiler: "ghc-8.6.5" # env: TEST=--disable-tests BENCH=--disable-benchmarks addons: {apt: {packages: [ghc-ppa-tools,cabal-install-2.4,ghc-8.6.5], sources: [hvr-ghc]}} - - compiler: "ghc-8.8.1" + - compiler: "ghc-8.8.3" # env: TEST=--disable-tests BENCH=--disable-benchmarks addons: {apt: {packages: [ghc-ppa-tools,cabal-install-3.0,ghc-8.8.1], sources: [hvr-ghc]}} - compiler: "ghc-head" diff --git a/appveyor.yml b/appveyor.yml index 951603fe..b47f1ee6 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -15,7 +15,7 @@ environment: - GHCVER: 8.2.2 - GHCVER: 8.4.4 - GHCVER: 8.6.3 - - GHCVER: 8.8.1 + - GHCVER: 8.8.3 platform: # - x86 # We may want to test x86 as well, but it would double the 23min build time. From 15193be394b43df7b82cd3ec8bade0fa2a41e06d Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 8 Apr 2020 13:38:31 +0900 Subject: [PATCH 47/48] improving docs as suggested by vdukhovni. --- Network/Socket/Buffer.hsc | 2 +- Network/Socket/ByteString/IO.hsc | 2 +- Network/Socket/Options.hsc | 4 ++-- Network/Socket/Posix/Cmsg.hsc | 4 +++- Network/Socket/SockAddr.hs | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Network/Socket/Buffer.hsc b/Network/Socket/Buffer.hsc index 353b10c2..f398fe2d 100644 --- a/Network/Socket/Buffer.hsc +++ b/Network/Socket/Buffer.hsc @@ -196,7 +196,7 @@ mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError InvalidArgument loc Nothing Nothing) "non-positive length" --- | Send data from the socket using sendmsg(2). +-- | Send data to the socket using sendmsg(2). sendBufMsg :: SocketAddress sa => Socket -- ^ Socket -> sa -- ^ Destination address diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 51898726..a45d469d 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -286,7 +286,7 @@ withWSABuffromBS cs f = do getBufsiz :: ByteString -> IO (Ptr Word8, Int) getBufsiz (PS fptr off len) = withForeignPtr fptr $ \ptr -> return (ptr `plusPtr` off, len) --- | Send data from the socket using sendmsg(2). +-- | Send data to the socket using sendmsg(2). sendMsg :: Socket -- ^ Socket -> SockAddr -- ^ Destination address -> [ByteString] -- ^ Data to be sent diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 65511754..e1ff45b5 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -290,8 +290,8 @@ instance Storable StructLinger where (#poke struct linger, l_linger) p linger #endif --- | Executes the given action and ignoring the result only when the specified --- socket option is valid. +-- | Execute the given action only when the specified socket option is +-- supported. Any return value is ignored. whenSupported :: SocketOption -> IO a -> IO () whenSupported s action | isSupportedSocketOption s = action >> return () diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index a3f7d0a8..ada81124 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -83,7 +83,9 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs ---------------------------------------------------------------- --- | A class to encode and decode control message. +-- | Control message type class. +-- Each control message type has a numeric 'CmsgId' and a 'Storable' +-- data representation. class Storable a => ControlMessage a where controlMessageId :: a -> CmsgId diff --git a/Network/Socket/SockAddr.hs b/Network/Socket/SockAddr.hs index f668849b..8468e2a9 100644 --- a/Network/Socket/SockAddr.hs +++ b/Network/Socket/SockAddr.hs @@ -74,7 +74,7 @@ sendBufTo = G.sendBufTo recvBufFrom :: Socket -> Ptr a -> Int -> IO (Int, SockAddr) recvBufFrom = G.recvBufFrom --- | Send data from the socket using sendmsg(2). +-- | Send data to the socket using sendmsg(2). sendBufMsg :: Socket -- ^ Socket -> SockAddr -- ^ Destination address -> [(Ptr Word8,Int)] -- ^ Data to be sent From f6b1f7c10a3ac026d368cd42b9333dad0172f0ba Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 8 Apr 2020 13:39:58 +0900 Subject: [PATCH 48/48] using GHC 8.6.5. --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index b47f1ee6..e68a3a83 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -14,7 +14,7 @@ environment: - GHCVER: 8.0.2 - GHCVER: 8.2.2 - GHCVER: 8.4.4 - - GHCVER: 8.6.3 + - GHCVER: 8.6.5 - GHCVER: 8.8.3 platform: