From 82d70ce3007db674a2d7d326b1d6e581e6efccdf Mon Sep 17 00:00:00 2001 From: Gareth Daniel Smith Date: Sat, 29 Jul 2023 17:22:31 +0100 Subject: [PATCH] Allow a control message to contain multiple file descriptors (issue #566) ControlMessage has lost the Storable constraint, because it is not possible to implement Storable [Fd] because [Fd] is not fixed-size when encoded. --- Network/Socket.hs | 4 +-- Network/Socket/ByteString/IO.hsc | 2 +- Network/Socket/Posix/Cmsg.hsc | 60 +++++++++++++++++++++++++------- Network/Socket/Unix.hsc | 8 ++--- tests/Network/SocketSpec.hs | 10 +++++- 5 files changed, 62 insertions(+), 22 deletions(-) diff --git a/Network/Socket.hs b/Network/Socket.hs index 5a90f9df..cf003a5a 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -245,13 +245,11 @@ module Network.Socket ,CmsgIdIPv6TClass ,CmsgIdIPv4PktInfo ,CmsgIdIPv6PktInfo - ,CmsgIdFd + ,CmsgIdFds ,UnsupportedCmsgId) -- ** APIs for control message , lookupCmsg , filterCmsg - , decodeCmsg - , encodeCmsg -- ** Class and types for control message , ControlMessage(..) , IPv4TTL(..) diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 30cd0c98..d4fcbca0 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -231,7 +231,7 @@ sendManyWithFds s bss fds = sendBufMsg s addr bufsizs cmsgs flags where addr = NullSockAddr - cmsgs = encodeCmsg <$> fds + cmsgs = encodeCmsg . (:[]) <$> fds flags = mempty -- ---------------------------------------------------------------------------- diff --git a/Network/Socket/Posix/Cmsg.hsc b/Network/Socket/Posix/Cmsg.hsc index fed4e7bc..868ca6ba 100644 --- a/Network/Socket/Posix/Cmsg.hsc +++ b/Network/Socket/Posix/Cmsg.hsc @@ -1,5 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} @@ -15,6 +16,7 @@ module Network.Socket.Posix.Cmsg where import Data.ByteString.Internal import Foreign.ForeignPtr +import Foreign.Marshal.Array (peekArray, pokeArray) import System.IO.Unsafe (unsafeDupablePerformIO) import System.Posix.Types (Fd(..)) @@ -82,9 +84,9 @@ pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO) pattern CmsgIdIPv6PktInfo = CmsgId (-1) (-1) #endif --- | The identifier for 'Fd'. -pattern CmsgIdFd :: CmsgId -pattern CmsgIdFd = CmsgId (#const SOL_SOCKET) (#const SCM_RIGHTS) +-- | The identifier for 'Fds'. +pattern CmsgIdFds :: CmsgId +pattern CmsgIdFds = CmsgId (#const SOL_SOCKET) (#const SCM_RIGHTS) ---------------------------------------------------------------- @@ -102,13 +104,15 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs ---------------------------------------------------------------- -- | Control message type class. --- Each control message type has a numeric 'CmsgId' and a 'Storable' --- data representation. -class Storable a => ControlMessage a where +-- Each control message type has a numeric 'CmsgId' and encode +-- and decode functions. +class ControlMessage a where controlMessageId :: CmsgId + encodeCmsg :: a -> Cmsg + decodeCmsg :: Cmsg -> Maybe a -encodeCmsg :: forall a . ControlMessage a => a -> Cmsg -encodeCmsg x = unsafeDupablePerformIO $ do +encodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => a -> Cmsg +encodeStorableCmsg x = unsafeDupablePerformIO $ do bs <- create siz $ \p0 -> do let p = castPtr p0 poke p x @@ -117,8 +121,8 @@ encodeCmsg x = unsafeDupablePerformIO $ do where siz = sizeOf x -decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a -decodeCmsg (Cmsg cmsid (PS fptr off len)) +decodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a +decodeStorableCmsg (Cmsg cmsid (PS fptr off len)) | cid /= cmsid = Nothing | len < siz = Nothing | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do @@ -139,6 +143,8 @@ newtype IPv4TTL = IPv4TTL CInt deriving (Eq, Show, Storable) instance ControlMessage IPv4TTL where controlMessageId = CmsgIdIPv4TTL + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg ---------------------------------------------------------------- @@ -147,6 +153,8 @@ newtype IPv6HopLimit = IPv6HopLimit CInt deriving (Eq, Show, Storable) instance ControlMessage IPv6HopLimit where controlMessageId = CmsgIdIPv6HopLimit + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg ---------------------------------------------------------------- @@ -155,6 +163,8 @@ newtype IPv4TOS = IPv4TOS CChar deriving (Eq, Show, Storable) instance ControlMessage IPv4TOS where controlMessageId = CmsgIdIPv4TOS + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg ---------------------------------------------------------------- @@ -163,6 +173,8 @@ newtype IPv6TClass = IPv6TClass CInt deriving (Eq, Show, Storable) instance ControlMessage IPv6TClass where controlMessageId = CmsgIdIPv6TClass + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg ---------------------------------------------------------------- @@ -174,6 +186,8 @@ instance Show IPv4PktInfo where instance ControlMessage IPv4PktInfo where controlMessageId = CmsgIdIPv4PktInfo + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg instance Storable IPv4PktInfo where #if defined (IP_PKTINFO) @@ -205,6 +219,8 @@ instance Show IPv6PktInfo where instance ControlMessage IPv6PktInfo where controlMessageId = CmsgIdIPv6PktInfo + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg instance Storable IPv6PktInfo where #if defined (IPV6_PKTINFO) @@ -226,8 +242,26 @@ instance Storable IPv6PktInfo where ---------------------------------------------------------------- -instance ControlMessage Fd where - controlMessageId = CmsgIdFd +instance ControlMessage [Fd] where + controlMessageId = CmsgIdFds + + encodeCmsg fds = unsafeDupablePerformIO $ do + bs <- create siz $ \p0 -> do + let p = castPtr p0 + pokeArray p fds + return $ Cmsg CmsgIdFds bs + where + siz = sizeOf (undefined :: Fd) * length fds + + decodeCmsg (Cmsg cmsid (PS fptr off len)) + | cmsid /= CmsgIdFds = Nothing + | otherwise = + unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do + let p = castPtr (p0 `plusPtr` off) + numFds = len `div` sizeOfFd + Just <$> peekArray numFds p + where + sizeOfFd = sizeOf (undefined :: Fd) cmsgIdBijection :: Bijection CmsgId String cmsgIdBijection = @@ -238,7 +272,7 @@ cmsgIdBijection = , (CmsgIdIPv6TClass, "CmsgIdIPv6TClass") , (CmsgIdIPv4PktInfo, "CmsgIdIPv4PktInfo") , (CmsgIdIPv6PktInfo, "CmsgIdIPv6PktInfo") - , (CmsgIdFd, "CmsgIdFd") + , (CmsgIdFds, "CmsgIdFds") ] instance Show CmsgId where diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index 709fb1c8..2562f117 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -137,7 +137,7 @@ isUnixDomainSocketAvailable = True -- This function does not work on Windows. sendFd :: Socket -> CInt -> IO () sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do - let cmsg = encodeCmsg $ Fd outfd + let cmsg = encodeCmsg [Fd outfd] sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty where dummyBufSize = 1 @@ -149,9 +149,9 @@ sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do recvFd :: Socket -> IO CInt recvFd s = allocaBytes dummyBufSize $ \buf -> do (NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty - case (lookupCmsg CmsgIdFd cmsgs >>= decodeCmsg) :: Maybe Fd of - Nothing -> return (-1) - Just (Fd fd) -> return fd + case (lookupCmsg CmsgIdFds cmsgs >>= decodeCmsg) :: Maybe [Fd] of + Just (Fd fd : _) -> return fd + _ -> return (-1) where dummyBufSize = 16 diff --git a/tests/Network/SocketSpec.hs b/tests/Network/SocketSpec.hs index b5c05988..30294050 100644 --- a/tests/Network/SocketSpec.hs +++ b/tests/Network/SocketSpec.hs @@ -14,6 +14,7 @@ import Network.Test.Common import System.Mem (performGC) import System.IO.Error (tryIOError) import System.IO.Temp (withSystemTempDirectory) +import System.Posix.Types (Fd(..)) import Foreign.C.Types () import Test.Hspec @@ -378,6 +379,10 @@ spec = do let msgid = CmsgId (-300) (-300) in show msgid `shouldBe` "CmsgId (-300) (-300)" + describe "bijective encodeCmsg-decodeCmsg roundtrip equality" $ do + it "holds for [Fd]" $ forAll genFds $ + \x -> (decodeCmsg . encodeCmsg $ x) == Just (x :: [Fd]) + describe "bijective read-show roundtrip equality" $ do it "holds for Family" $ forAll familyGen $ \x -> (read . show $ x) == (x :: Family) @@ -415,6 +420,9 @@ sockoptGen = biasedGen (\g -> SockOpt <$> g <*> g) sockoptPatterns arbitrary cmsgidGen :: Gen CmsgId cmsgidGen = biasedGen (\g -> CmsgId <$> g <*> g) cmsgidPatterns arbitrary +genFds :: Gen [Fd] +genFds = listOf (Fd <$> arbitrary) + -- pruned lists of pattern synonym values for each type to generate values from familyPatterns :: [Family] @@ -462,5 +470,5 @@ cmsgidPatterns = nub , CmsgIdIPv6TClass , CmsgIdIPv4PktInfo , CmsgIdIPv6PktInfo - , CmsgIdFd + , CmsgIdFds ]