diff --git a/Network/Socket/Types.hsc b/Network/Socket/Types.hsc index a6a27d03..4b0e800b 100644 --- a/Network/Socket/Types.hsc +++ b/Network/Socket/Types.hsc @@ -814,13 +814,20 @@ class SocketAddress sa where peekSocketAddress :: Ptr sa -> IO sa pokeSocketAddress :: Ptr a -> sa -> IO () +-- sizeof(struct sockaddr_storage) which has enough space to contain +-- sockaddr_in, sockaddr_in6 and sockaddr_un. +sockaddrStorageLen :: Int +sockaddrStorageLen = 128 + withSocketAddress :: SocketAddress sa => sa -> (Ptr sa -> Int -> IO a) -> IO a withSocketAddress addr f = do let sz = sizeOfSocketAddress addr allocaBytes sz $ \p -> pokeSocketAddress p addr >> f (castPtr p) sz withNewSocketAddress :: SocketAddress sa => (Ptr sa -> Int -> IO a) -> IO a -withNewSocketAddress f = allocaBytes 128 $ \ptr -> f ptr 128 +withNewSocketAddress f = allocaBytes sockaddrStorageLen $ \ptr -> do + zeroMemory ptr $ fromIntegral sockaddrStorageLen + f ptr sockaddrStorageLen ------------------------------------------------------------------------ -- Socket addresses @@ -893,14 +900,11 @@ type CSaFamily = (#type sa_family_t) -- in that the value of the argument /is/ used. sizeOfSockAddr :: SockAddr -> Int #if defined(DOMAIN_SOCKET_SUPPORT) -sizeOfSockAddr (SockAddrUnix path) = - case path of - '\0':_ -> (#const sizeof(sa_family_t)) + length path - _ -> #const sizeof(struct sockaddr_un) +sizeOfSockAddr SockAddrUnix{} = #const sizeof(struct sockaddr_un) #else -sizeOfSockAddr SockAddrUnix{} = error "sizeOfSockAddr: not supported" +sizeOfSockAddr SockAddrUnix{} = error "sizeOfSockAddr: not supported" #endif -sizeOfSockAddr SockAddrInet{} = #const sizeof(struct sockaddr_in) +sizeOfSockAddr SockAddrInet{} = #const sizeof(struct sockaddr_in) sizeOfSockAddr SockAddrInet6{} = #const sizeof(struct sockaddr_in6) -- | Use a 'SockAddr' with a function requiring a pointer to a @@ -910,6 +914,17 @@ withSockAddr addr f = do let sz = sizeOfSockAddr addr allocaBytes sz $ \p -> pokeSockAddr p addr >> f (castPtr p) sz +-- We cannot bind sun_paths longer than than the space in the sockaddr_un +-- structure, and attempting to do so could overflow the allocated storage +-- space. This constant holds the maximum allowable path length. +-- +unixPathMax :: Int +#if defined(DOMAIN_SOCKET_SUPPORT) +unixPathMax = #const sizeof(((struct sockaddr_un *)NULL)->sun_path) +#else +unixPathMax = 0 +#endif + -- We can't write an instance of 'Storable' for 'SockAddr' because -- @sockaddr@ is a sum type of variable size but -- 'Foreign.Storable.sizeOf' is required to be constant. @@ -920,28 +935,21 @@ withSockAddr addr f = do -- | Write the given 'SockAddr' to the given memory location. pokeSockAddr :: Ptr a -> SockAddr -> IO () #if defined(DOMAIN_SOCKET_SUPPORT) -pokeSockAddr p (SockAddrUnix path) = do -# if defined(darwin_HOST_OS) - zeroMemory p (#const sizeof(struct sockaddr_un)) -# else - case path of - ('\0':_) -> zeroMemory p (#const sizeof(struct sockaddr_un)) - _ -> return () -# endif +pokeSockAddr p sa@(SockAddrUnix path) = do + when (length path > unixPathMax) $ error "pokeSockAddr: path is too long" + zeroMemory p $ fromIntegral $ sizeOfSockAddr sa # if defined(HAVE_STRUCT_SOCKADDR_SA_LEN) (#poke struct sockaddr_un, sun_len) p ((#const sizeof(struct sockaddr_un)) :: Word8) # endif (#poke struct sockaddr_un, sun_family) p ((#const AF_UNIX) :: CSaFamily) let pathC = map castCharToCChar path - poker = case path of ('\0':_) -> pokeArray; _ -> pokeArray0 0 - poker ((#ptr struct sockaddr_un, sun_path) p) pathC + -- the buffer is already filled with nulls. + pokeArray ((#ptr struct sockaddr_un, sun_path) p) pathC #else pokeSockAddr _ SockAddrUnix{} = error "pokeSockAddr: not supported" #endif pokeSockAddr p (SockAddrInet (PortNum port) addr) = do -#if defined(darwin_HOST_OS) zeroMemory p (#const sizeof(struct sockaddr_in)) -#endif #if defined(HAVE_STRUCT_SOCKADDR_SA_LEN) (#poke struct sockaddr_in, sin_len) p ((#const sizeof(struct sockaddr_in)) :: Word8) #endif @@ -949,9 +957,7 @@ pokeSockAddr p (SockAddrInet (PortNum port) addr) = do (#poke struct sockaddr_in, sin_port) p port (#poke struct sockaddr_in, sin_addr) p addr pokeSockAddr p (SockAddrInet6 (PortNum port) flow addr scope) = do -# if defined(darwin_HOST_OS) zeroMemory p (#const sizeof(struct sockaddr_in6)) -# endif # if defined(HAVE_STRUCT_SOCKADDR_SA_LEN) (#poke struct sockaddr_in6, sin6_len) p ((#const sizeof(struct sockaddr_in6)) :: Word8) # endif @@ -968,7 +974,7 @@ peekSockAddr p = do case family :: CSaFamily of #if defined(DOMAIN_SOCKET_SUPPORT) (#const AF_UNIX) -> do - str <- peekCString ((#ptr struct sockaddr_un, sun_path) p) + str <- peekCAString ((#ptr struct sockaddr_un, sun_path) p) return (SockAddrUnix str) #endif (#const AF_INET) -> do diff --git a/network.cabal b/network.cabal index a1c9894d..9b0cec5b 100644 --- a/network.cabal +++ b/network.cabal @@ -97,6 +97,7 @@ test-suite spec build-depends: base < 5, bytestring, + directory, HUnit, network, hspec diff --git a/tests/SimpleSpec.hs b/tests/SimpleSpec.hs index c54502f6..3bb8836a 100644 --- a/tests/SimpleSpec.hs +++ b/tests/SimpleSpec.hs @@ -13,6 +13,7 @@ import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as C import Network.Socket import Network.Socket.ByteString +import System.Directory import System.Timeout (timeout) import Test.Hspec @@ -147,6 +148,15 @@ spec = do `shouldBe` (0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334) #endif + describe "unix sockets" $ do + it "basic unix sockets end-to-end" $ do + when isUnixDomainSocketAvailable $ do + let client sock = send sock testMsg + server (sock, addr) = do + recv sock 1024 `shouldReturn` testMsg + addr `shouldBe` (SockAddrUnix "") + unixTest client server + ------------------------------------------------------------------------ serverAddr :: String @@ -155,9 +165,44 @@ serverAddr = "127.0.0.1" testMsg :: ByteString testMsg = "This is a test message." +unixAddr :: String +unixAddr = "/tmp/network-test" + ------------------------------------------------------------------------ -- Test helpers +-- | Establish a connection between client and server and then run +-- 'clientAct' and 'serverAct', in different threads. Both actions +-- get passed a connected 'Socket', used for communicating between +-- client and server. 'unixTest' makes sure that the 'Socket' is +-- closed after the actions have run. +unixTest :: (Socket -> IO a) -> ((Socket, SockAddr) -> IO b) -> IO () +unixTest clientAct serverAct = do + test clientSetup clientAct serverSetup server + where + clientSetup = do + sock <- socket AF_UNIX Stream defaultProtocol + connect sock (SockAddrUnix unixAddr) + return sock + + serverSetup = do + sock <- socket AF_UNIX Stream defaultProtocol + unlink unixAddr -- just in case + bind sock (SockAddrUnix unixAddr) + listen sock 1 + return sock + + server sock = E.bracket (accept sock) (killClientSock . fst) serverAct + + unlink file = do + exist <- doesFileExist file + when exist $ removeFile file + + killClientSock sock = do + shutdown sock ShutdownBoth + close sock + unlink unixAddr + -- | Establish a connection between client and server and then run -- 'clientAct' and 'serverAct', in different threads. Both actions -- get passed a connected 'Socket', used for communicating between