Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving sendMsg and recvMsg #445

Merged
merged 5 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ matrix:
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-2.4,ghc-8.6.5], sources: [hvr-ghc]}}
- compiler: "ghc-8.8.3"
# env: TEST=--disable-tests BENCH=--disable-benchmarks
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-3.0,ghc-8.8.1], sources: [hvr-ghc]}}
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-3.0,ghc-8.8.3], sources: [hvr-ghc]}}
- compiler: "ghc-head"
# env: TEST=--disable-tests BENCH=--disable-benchmarks
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-head,ghc-head], sources: [hvr-ghc]}}
Expand Down
33 changes: 19 additions & 14 deletions Network/Socket/Posix/Cmsg.hsc
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Network.Socket.Posix.Cmsg where

Expand Down Expand Up @@ -87,24 +89,27 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
-- Each control message type has a numeric 'CmsgId' and a 'Storable'
-- data representation.
class Storable a => ControlMessage a where
controlMessageId :: a -> CmsgId
controlMessageId :: CmsgId

encodeCmsg :: ControlMessage a => a -> Cmsg
encodeCmsg :: forall a . ControlMessage a => a -> Cmsg
encodeCmsg x = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
poke p x
return $ Cmsg (controlMessageId x) bs
let cmsid = controlMessageId @a
return $ Cmsg cmsid bs
where
siz = sizeOf x

decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a
decodeCmsg (Cmsg _ (PS fptr off len))
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeCmsg (Cmsg cmsid (PS fptr off len))
| cid /= cmsid = Nothing
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
let p = castPtr (p0 `plusPtr` off)
Just <$> peek p
where
cid = controlMessageId @a
siz = sizeOf (undefined :: a)

----------------------------------------------------------------
Expand All @@ -117,31 +122,31 @@ newtype IPv4TTL = IPv4TTL CInt deriving (Eq, Show, Storable)
#endif

instance ControlMessage IPv4TTL where
controlMessageId _ = CmsgIdIPv4TTL
controlMessageId = CmsgIdIPv4TTL

----------------------------------------------------------------

-- | Hop limit of IPv6.
newtype IPv6HopLimit = IPv6HopLimit CInt deriving (Eq, Show, Storable)

instance ControlMessage IPv6HopLimit where
controlMessageId _ = CmsgIdIPv6HopLimit
controlMessageId = CmsgIdIPv6HopLimit

----------------------------------------------------------------

-- | TOS of IPv4.
newtype IPv4TOS = IPv4TOS CChar deriving (Eq, Show, Storable)

instance ControlMessage IPv4TOS where
controlMessageId _ = CmsgIdIPv4TOS
controlMessageId = CmsgIdIPv4TOS

----------------------------------------------------------------

-- | Traffic class of IPv6.
newtype IPv6TClass = IPv6TClass CInt deriving (Eq, Show, Storable)

instance ControlMessage IPv6TClass where
controlMessageId _ = CmsgIdIPv6TClass
controlMessageId = CmsgIdIPv6TClass

----------------------------------------------------------------

Expand All @@ -152,7 +157,7 @@ instance Show IPv4PktInfo where
show (IPv4PktInfo n sa ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple sa) ++ " " ++ show (hostAddressToTuple ha)

instance ControlMessage IPv4PktInfo where
controlMessageId _ = CmsgIdIPv4PktInfo
controlMessageId = CmsgIdIPv4PktInfo

instance Storable IPv4PktInfo where
sizeOf _ = (#size struct in_pktinfo)
Expand All @@ -176,7 +181,7 @@ instance Show IPv6PktInfo where
show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6)

instance ControlMessage IPv6PktInfo where
controlMessageId _ = CmsgIdIPv6PktInfo
controlMessageId = CmsgIdIPv6PktInfo

instance Storable IPv6PktInfo where
sizeOf _ = (#size struct in6_pktinfo)
Expand All @@ -192,4 +197,4 @@ instance Storable IPv6PktInfo where
----------------------------------------------------------------

instance ControlMessage Fd where
controlMessageId _ = CmsgIdFd
controlMessageId = CmsgIdFd