Skip to content

Commit

Permalink
Allow a control message to contain multiple file descriptors (issue h…
Browse files Browse the repository at this point in the history
…askell#566)

ControlMessage has lost the Storable constraint, because it is not possible to implement Storable [Fd] because [Fd] is not fixed-size when encoded.
  • Loading branch information
Dretch committed Jul 29, 2023
1 parent 734a3e7 commit 82d70ce
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 22 deletions.
4 changes: 1 addition & 3 deletions Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,11 @@ module Network.Socket
,CmsgIdIPv6TClass
,CmsgIdIPv4PktInfo
,CmsgIdIPv6PktInfo
,CmsgIdFd
,CmsgIdFds
,UnsupportedCmsgId)
-- ** APIs for control message
, lookupCmsg
, filterCmsg
, decodeCmsg
, encodeCmsg
-- ** Class and types for control message
, ControlMessage(..)
, IPv4TTL(..)
Expand Down
2 changes: 1 addition & 1 deletion Network/Socket/ByteString/IO.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ sendManyWithFds s bss fds =
sendBufMsg s addr bufsizs cmsgs flags
where
addr = NullSockAddr
cmsgs = encodeCmsg <$> fds
cmsgs = encodeCmsg . (:[]) <$> fds
flags = mempty

-- ----------------------------------------------------------------------------
Expand Down
60 changes: 47 additions & 13 deletions Network/Socket/Posix/Cmsg.hsc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
Expand All @@ -15,6 +16,7 @@ module Network.Socket.Posix.Cmsg where

import Data.ByteString.Internal
import Foreign.ForeignPtr
import Foreign.Marshal.Array (peekArray, pokeArray)
import System.IO.Unsafe (unsafeDupablePerformIO)
import System.Posix.Types (Fd(..))

Expand Down Expand Up @@ -82,9 +84,9 @@ pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO)
pattern CmsgIdIPv6PktInfo = CmsgId (-1) (-1)
#endif

-- | The identifier for 'Fd'.
pattern CmsgIdFd :: CmsgId
pattern CmsgIdFd = CmsgId (#const SOL_SOCKET) (#const SCM_RIGHTS)
-- | The identifier for 'Fds'.
pattern CmsgIdFds :: CmsgId
pattern CmsgIdFds = CmsgId (#const SOL_SOCKET) (#const SCM_RIGHTS)

----------------------------------------------------------------

Expand All @@ -102,13 +104,15 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
----------------------------------------------------------------

-- | Control message type class.
-- Each control message type has a numeric 'CmsgId' and a 'Storable'
-- data representation.
class Storable a => ControlMessage a where
-- Each control message type has a numeric 'CmsgId' and encode
-- and decode functions.
class ControlMessage a where
controlMessageId :: CmsgId
encodeCmsg :: a -> Cmsg
decodeCmsg :: Cmsg -> Maybe a

encodeCmsg :: forall a . ControlMessage a => a -> Cmsg
encodeCmsg x = unsafeDupablePerformIO $ do
encodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => a -> Cmsg
encodeStorableCmsg x = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
poke p x
Expand All @@ -117,8 +121,8 @@ encodeCmsg x = unsafeDupablePerformIO $ do
where
siz = sizeOf x

decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeCmsg (Cmsg cmsid (PS fptr off len))
decodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeStorableCmsg (Cmsg cmsid (PS fptr off len))
| cid /= cmsid = Nothing
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
Expand All @@ -139,6 +143,8 @@ newtype IPv4TTL = IPv4TTL CInt deriving (Eq, Show, Storable)

instance ControlMessage IPv4TTL where
controlMessageId = CmsgIdIPv4TTL
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -147,6 +153,8 @@ newtype IPv6HopLimit = IPv6HopLimit CInt deriving (Eq, Show, Storable)

instance ControlMessage IPv6HopLimit where
controlMessageId = CmsgIdIPv6HopLimit
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -155,6 +163,8 @@ newtype IPv4TOS = IPv4TOS CChar deriving (Eq, Show, Storable)

instance ControlMessage IPv4TOS where
controlMessageId = CmsgIdIPv4TOS
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -163,6 +173,8 @@ newtype IPv6TClass = IPv6TClass CInt deriving (Eq, Show, Storable)

instance ControlMessage IPv6TClass where
controlMessageId = CmsgIdIPv6TClass
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -174,6 +186,8 @@ instance Show IPv4PktInfo where

instance ControlMessage IPv4PktInfo where
controlMessageId = CmsgIdIPv4PktInfo
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

instance Storable IPv4PktInfo where
#if defined (IP_PKTINFO)
Expand Down Expand Up @@ -205,6 +219,8 @@ instance Show IPv6PktInfo where

instance ControlMessage IPv6PktInfo where
controlMessageId = CmsgIdIPv6PktInfo
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

instance Storable IPv6PktInfo where
#if defined (IPV6_PKTINFO)
Expand All @@ -226,8 +242,26 @@ instance Storable IPv6PktInfo where

----------------------------------------------------------------

instance ControlMessage Fd where
controlMessageId = CmsgIdFd
instance ControlMessage [Fd] where
controlMessageId = CmsgIdFds

encodeCmsg fds = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
pokeArray p fds
return $ Cmsg CmsgIdFds bs
where
siz = sizeOf (undefined :: Fd) * length fds

decodeCmsg (Cmsg cmsid (PS fptr off len))
| cmsid /= CmsgIdFds = Nothing
| otherwise =
unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
let p = castPtr (p0 `plusPtr` off)
numFds = len `div` sizeOfFd
Just <$> peekArray numFds p
where
sizeOfFd = sizeOf (undefined :: Fd)

cmsgIdBijection :: Bijection CmsgId String
cmsgIdBijection =
Expand All @@ -238,7 +272,7 @@ cmsgIdBijection =
, (CmsgIdIPv6TClass, "CmsgIdIPv6TClass")
, (CmsgIdIPv4PktInfo, "CmsgIdIPv4PktInfo")
, (CmsgIdIPv6PktInfo, "CmsgIdIPv6PktInfo")
, (CmsgIdFd, "CmsgIdFd")
, (CmsgIdFds, "CmsgIdFds")
]

instance Show CmsgId where
Expand Down
8 changes: 4 additions & 4 deletions Network/Socket/Unix.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ isUnixDomainSocketAvailable = True
-- This function does not work on Windows.
sendFd :: Socket -> CInt -> IO ()
sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do
let cmsg = encodeCmsg $ Fd outfd
let cmsg = encodeCmsg [Fd outfd]
sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty
where
dummyBufSize = 1
Expand All @@ -149,9 +149,9 @@ sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do
recvFd :: Socket -> IO CInt
recvFd s = allocaBytes dummyBufSize $ \buf -> do
(NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty
case (lookupCmsg CmsgIdFd cmsgs >>= decodeCmsg) :: Maybe Fd of
Nothing -> return (-1)
Just (Fd fd) -> return fd
case (lookupCmsg CmsgIdFds cmsgs >>= decodeCmsg) :: Maybe [Fd] of
Just (Fd fd : _) -> return fd
_ -> return (-1)
where
dummyBufSize = 16

Expand Down
10 changes: 9 additions & 1 deletion tests/Network/SocketSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Network.Test.Common
import System.Mem (performGC)
import System.IO.Error (tryIOError)
import System.IO.Temp (withSystemTempDirectory)
import System.Posix.Types (Fd(..))
import Foreign.C.Types ()

import Test.Hspec
Expand Down Expand Up @@ -378,6 +379,10 @@ spec = do
let msgid = CmsgId (-300) (-300) in
show msgid `shouldBe` "CmsgId (-300) (-300)"

describe "bijective encodeCmsg-decodeCmsg roundtrip equality" $ do
it "holds for [Fd]" $ forAll genFds $
\x -> (decodeCmsg . encodeCmsg $ x) == Just (x :: [Fd])

describe "bijective read-show roundtrip equality" $ do
it "holds for Family" $ forAll familyGen $
\x -> (read . show $ x) == (x :: Family)
Expand Down Expand Up @@ -415,6 +420,9 @@ sockoptGen = biasedGen (\g -> SockOpt <$> g <*> g) sockoptPatterns arbitrary
cmsgidGen :: Gen CmsgId
cmsgidGen = biasedGen (\g -> CmsgId <$> g <*> g) cmsgidPatterns arbitrary

genFds :: Gen [Fd]
genFds = listOf (Fd <$> arbitrary)

-- pruned lists of pattern synonym values for each type to generate values from

familyPatterns :: [Family]
Expand Down Expand Up @@ -462,5 +470,5 @@ cmsgidPatterns = nub
, CmsgIdIPv6TClass
, CmsgIdIPv4PktInfo
, CmsgIdIPv6PktInfo
, CmsgIdFd
, CmsgIdFds
]

0 comments on commit 82d70ce

Please sign in to comment.