Skip to content

Commit

Permalink
Merge PR haskell#306.
Browse files Browse the repository at this point in the history
  • Loading branch information
kazu-yamamoto committed Feb 16, 2018
2 parents b9f3463 + b80dbeb commit a4e86b5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 22 deletions.
50 changes: 28 additions & 22 deletions Network/Socket/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -814,13 +814,20 @@ class SocketAddress sa where
peekSocketAddress :: Ptr sa -> IO sa
pokeSocketAddress :: Ptr a -> sa -> IO ()

-- sizeof(struct sockaddr_storage) which has enough space to contain
-- sockaddr_in, sockaddr_in6 and sockaddr_un.
sockaddrStorageLen :: Int
sockaddrStorageLen = 128

withSocketAddress :: SocketAddress sa => sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress addr f = do
let sz = sizeOfSocketAddress addr
allocaBytes sz $ \p -> pokeSocketAddress p addr >> f (castPtr p) sz

withNewSocketAddress :: SocketAddress sa => (Ptr sa -> Int -> IO a) -> IO a
withNewSocketAddress f = allocaBytes 128 $ \ptr -> f ptr 128
withNewSocketAddress f = allocaBytes sockaddrStorageLen $ \ptr -> do
zeroMemory ptr $ fromIntegral sockaddrStorageLen
f ptr sockaddrStorageLen

------------------------------------------------------------------------
-- Socket addresses
Expand Down Expand Up @@ -893,14 +900,11 @@ type CSaFamily = (#type sa_family_t)
-- in that the value of the argument /is/ used.
sizeOfSockAddr :: SockAddr -> Int
#if defined(DOMAIN_SOCKET_SUPPORT)
sizeOfSockAddr (SockAddrUnix path) =
case path of
'\0':_ -> (#const sizeof(sa_family_t)) + length path
_ -> #const sizeof(struct sockaddr_un)
sizeOfSockAddr SockAddrUnix{} = #const sizeof(struct sockaddr_un)
#else
sizeOfSockAddr SockAddrUnix{} = error "sizeOfSockAddr: not supported"
sizeOfSockAddr SockAddrUnix{} = error "sizeOfSockAddr: not supported"
#endif
sizeOfSockAddr SockAddrInet{} = #const sizeof(struct sockaddr_in)
sizeOfSockAddr SockAddrInet{} = #const sizeof(struct sockaddr_in)
sizeOfSockAddr SockAddrInet6{} = #const sizeof(struct sockaddr_in6)

-- | Use a 'SockAddr' with a function requiring a pointer to a
Expand All @@ -910,6 +914,17 @@ withSockAddr addr f = do
let sz = sizeOfSockAddr addr
allocaBytes sz $ \p -> pokeSockAddr p addr >> f (castPtr p) sz

-- We cannot bind sun_paths longer than than the space in the sockaddr_un
-- structure, and attempting to do so could overflow the allocated storage
-- space. This constant holds the maximum allowable path length.
--
unixPathMax :: Int
#if defined(DOMAIN_SOCKET_SUPPORT)
unixPathMax = #const sizeof(((struct sockaddr_un *)NULL)->sun_path)
#else
unixPathMax = 0
#endif

-- We can't write an instance of 'Storable' for 'SockAddr' because
-- @sockaddr@ is a sum type of variable size but
-- 'Foreign.Storable.sizeOf' is required to be constant.
Expand All @@ -920,38 +935,29 @@ withSockAddr addr f = do
-- | Write the given 'SockAddr' to the given memory location.
pokeSockAddr :: Ptr a -> SockAddr -> IO ()
#if defined(DOMAIN_SOCKET_SUPPORT)
pokeSockAddr p (SockAddrUnix path) = do
# if defined(darwin_HOST_OS)
zeroMemory p (#const sizeof(struct sockaddr_un))
# else
case path of
('\0':_) -> zeroMemory p (#const sizeof(struct sockaddr_un))
_ -> return ()
# endif
pokeSockAddr p sa@(SockAddrUnix path) = do
when (length path > unixPathMax) $ error "pokeSockAddr: path is too long"
zeroMemory p $ fromIntegral $ sizeOfSockAddr sa
# if defined(HAVE_STRUCT_SOCKADDR_SA_LEN)
(#poke struct sockaddr_un, sun_len) p ((#const sizeof(struct sockaddr_un)) :: Word8)
# endif
(#poke struct sockaddr_un, sun_family) p ((#const AF_UNIX) :: CSaFamily)
let pathC = map castCharToCChar path
poker = case path of ('\0':_) -> pokeArray; _ -> pokeArray0 0
poker ((#ptr struct sockaddr_un, sun_path) p) pathC
-- the buffer is already filled with nulls.
pokeArray ((#ptr struct sockaddr_un, sun_path) p) pathC
#else
pokeSockAddr _ SockAddrUnix{} = error "pokeSockAddr: not supported"
#endif
pokeSockAddr p (SockAddrInet (PortNum port) addr) = do
#if defined(darwin_HOST_OS)
zeroMemory p (#const sizeof(struct sockaddr_in))
#endif
#if defined(HAVE_STRUCT_SOCKADDR_SA_LEN)
(#poke struct sockaddr_in, sin_len) p ((#const sizeof(struct sockaddr_in)) :: Word8)
#endif
(#poke struct sockaddr_in, sin_family) p ((#const AF_INET) :: CSaFamily)
(#poke struct sockaddr_in, sin_port) p port
(#poke struct sockaddr_in, sin_addr) p addr
pokeSockAddr p (SockAddrInet6 (PortNum port) flow addr scope) = do
# if defined(darwin_HOST_OS)
zeroMemory p (#const sizeof(struct sockaddr_in6))
# endif
# if defined(HAVE_STRUCT_SOCKADDR_SA_LEN)
(#poke struct sockaddr_in6, sin6_len) p ((#const sizeof(struct sockaddr_in6)) :: Word8)
# endif
Expand All @@ -968,7 +974,7 @@ peekSockAddr p = do
case family :: CSaFamily of
#if defined(DOMAIN_SOCKET_SUPPORT)
(#const AF_UNIX) -> do
str <- peekCString ((#ptr struct sockaddr_un, sun_path) p)
str <- peekCAString ((#ptr struct sockaddr_un, sun_path) p)
return (SockAddrUnix str)
#endif
(#const AF_INET) -> do
Expand Down
1 change: 1 addition & 0 deletions network.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ test-suite spec
build-depends:
base < 5,
bytestring,
directory,
HUnit,
network,
hspec
Expand Down
45 changes: 45 additions & 0 deletions tests/SimpleSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C
import Network.Socket
import Network.Socket.ByteString
import System.Directory
import System.Timeout (timeout)

import Test.Hspec
Expand Down Expand Up @@ -147,6 +148,15 @@ spec = do
`shouldBe` (0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334)
#endif

describe "unix sockets" $ do
it "basic unix sockets end-to-end" $ do
when isUnixDomainSocketAvailable $ do
let client sock = send sock testMsg
server (sock, addr) = do
recv sock 1024 `shouldReturn` testMsg
addr `shouldBe` (SockAddrUnix "")
unixTest client server

------------------------------------------------------------------------

serverAddr :: String
Expand All @@ -155,9 +165,44 @@ serverAddr = "127.0.0.1"
testMsg :: ByteString
testMsg = "This is a test message."

unixAddr :: String
unixAddr = "/tmp/network-test"

------------------------------------------------------------------------
-- Test helpers

-- | Establish a connection between client and server and then run
-- 'clientAct' and 'serverAct', in different threads. Both actions
-- get passed a connected 'Socket', used for communicating between
-- client and server. 'unixTest' makes sure that the 'Socket' is
-- closed after the actions have run.
unixTest :: (Socket -> IO a) -> ((Socket, SockAddr) -> IO b) -> IO ()
unixTest clientAct serverAct = do
test clientSetup clientAct serverSetup server
where
clientSetup = do
sock <- socket AF_UNIX Stream defaultProtocol
connect sock (SockAddrUnix unixAddr)
return sock

serverSetup = do
sock <- socket AF_UNIX Stream defaultProtocol
unlink unixAddr -- just in case
bind sock (SockAddrUnix unixAddr)
listen sock 1
return sock

server sock = E.bracket (accept sock) (killClientSock . fst) serverAct

unlink file = do
exist <- doesFileExist file
when exist $ removeFile file

killClientSock sock = do
shutdown sock ShutdownBoth
close sock
unlink unixAddr

-- | Establish a connection between client and server and then run
-- 'clientAct' and 'serverAct', in different threads. Both actions
-- get passed a connected 'Socket', used for communicating between
Expand Down

0 comments on commit a4e86b5

Please sign in to comment.