Skip to content

Commit

Permalink
Merge pull request #554 from kazu-yamamoto/sockpair
Browse files Browse the repository at this point in the history
Sockpair
  • Loading branch information
kazu-yamamoto authored May 22, 2023
2 parents 34f74b4 + 0c2df5c commit e7165d0
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 48 deletions.
4 changes: 0 additions & 4 deletions Network/Socket/Info.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,7 @@ unpackBits ((k,v):xs) r
-- SockAddr

instance Show SockAddr where
#if defined(DOMAIN_SOCKET_SUPPORT)
showsPrec _ (SockAddrUnix str) = showString str
#else
showsPrec _ SockAddrUnix{} = error "showsPrec: not supported"
#endif
showsPrec _ (SockAddrInet port ha)
= showHostAddress ha
. showString ":"
Expand Down
18 changes: 0 additions & 18 deletions Network/Socket/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ import GHC.IO (IO (..))

import qualified Text.Read as P

#if defined(DOMAIN_SOCKET_SUPPORT)
import Foreign.Marshal.Array
#endif

import Network.Socket.Imports

Expand Down Expand Up @@ -1075,11 +1073,7 @@ isSupportedSockAddr :: SockAddr -> Bool
isSupportedSockAddr addr = case addr of
SockAddrInet{} -> True
SockAddrInet6{} -> True
#if defined(DOMAIN_SOCKET_SUPPORT)
SockAddrUnix{} -> True
#else
SockAddrUnix{} -> False
#endif

instance SocketAddress SockAddr where
sizeOfSocketAddress = sizeOfSockAddr
Expand All @@ -1098,7 +1092,6 @@ type CSaFamily = (#type sa_family_t)
-- 'SockAddr'. This function differs from 'Foreign.Storable.sizeOf'
-- in that the value of the argument /is/ used.
sizeOfSockAddr :: SockAddr -> Int
#if defined(DOMAIN_SOCKET_SUPPORT)
# ifdef linux_HOST_OS
-- http://man7.org/linux/man-pages/man7/unix.7.html says:
-- "an abstract socket address is distinguished (from a
Expand All @@ -1118,9 +1111,6 @@ sizeOfSockAddr (SockAddrUnix path) =
# else
sizeOfSockAddr SockAddrUnix{} = #const sizeof(struct sockaddr_un)
# endif
#else
sizeOfSockAddr SockAddrUnix{} = error "sizeOfSockAddr: not supported"
#endif
sizeOfSockAddr SockAddrInet{} = #const sizeof(struct sockaddr_in)
sizeOfSockAddr SockAddrInet6{} = #const sizeof(struct sockaddr_in6)

Expand All @@ -1135,10 +1125,8 @@ withSockAddr addr f = do
-- structure, and attempting to do so could overflow the allocated storage
-- space. This constant holds the maximum allowable path length.
--
#if defined(DOMAIN_SOCKET_SUPPORT)
unixPathMax :: Int
unixPathMax = #const sizeof(((struct sockaddr_un *)NULL)->sun_path)
#endif

-- We can't write an instance of 'Storable' for 'SockAddr' because
-- @sockaddr@ is a sum type of variable size but
Expand All @@ -1149,7 +1137,6 @@ unixPathMax = #const sizeof(((struct sockaddr_un *)NULL)->sun_path)

-- | Write the given 'SockAddr' to the given memory location.
pokeSockAddr :: Ptr a -> SockAddr -> IO ()
#if defined(DOMAIN_SOCKET_SUPPORT)
pokeSockAddr p sa@(SockAddrUnix path) = do
when (length path > unixPathMax) $ error
$ "pokeSockAddr: path is too long in SockAddrUnix " <> show path
Expand All @@ -1162,9 +1149,6 @@ pokeSockAddr p sa@(SockAddrUnix path) = do
let pathC = map castCharToCChar path
-- the buffer is already filled with nulls.
pokeArray ((#ptr struct sockaddr_un, sun_path) p) pathC
#else
pokeSockAddr _ SockAddrUnix{} = error "pokeSockAddr: not supported"
#endif
pokeSockAddr p (SockAddrInet port addr) = do
zeroMemory p (#const sizeof(struct sockaddr_in))
#if defined(HAVE_STRUCT_SOCKADDR_SA_LEN)
Expand All @@ -1189,11 +1173,9 @@ peekSockAddr :: Ptr SockAddr -> IO SockAddr
peekSockAddr p = do
family <- (#peek struct sockaddr, sa_family) p
case family :: CSaFamily of
#if defined(DOMAIN_SOCKET_SUPPORT)
(#const AF_UNIX) -> do
str <- peekCAString ((#ptr struct sockaddr_un, sun_path) p)
return (SockAddrUnix str)
#endif
(#const AF_INET) -> do
addr <- (#peek struct sockaddr_in, sin_addr) p
port <- (#peek struct sockaddr_in, sin_port) p
Expand Down
51 changes: 27 additions & 24 deletions Network/Socket/Unix.hsc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}

#include "HsNet.h"
##include "HsNetDef.h"
Expand All @@ -13,30 +14,32 @@ module Network.Socket.Unix (
, getPeerEid
) where

import System.Posix.Types (Fd(..))

import Foreign.Marshal.Alloc (allocaBytes)
import Network.Socket.Buffer
import Network.Socket.Fcntl
import Network.Socket.Imports
import Network.Socket.Types
import System.Posix.Types (Fd(..))

#if defined(mingw32_HOST_OS)
import Network.Socket.Syscall
import Network.Socket.Win32.Cmsg
import System.Directory
import System.IO
import System.IO.Temp
#else
import Foreign.Marshal.Array (peekArray)
import Network.Socket.Internal
import Network.Socket.Posix.Cmsg
#endif
import Network.Socket.Types

#if defined(HAVE_GETPEEREID)
import System.IO.Error (catchIOError)
#endif
#ifdef HAVE_GETPEEREID
import Foreign.Marshal.Alloc (alloca)
#endif
#ifdef DOMAIN_SOCKET_SUPPORT
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Marshal.Array (peekArray)

import Network.Socket.Fcntl
import Network.Socket.Internal
#endif
#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
import Network.Socket.Options
#endif
Expand Down Expand Up @@ -126,11 +129,7 @@ getPeerEid _ = return (0, 0)
--
-- Since 2.7.0.0.
isUnixDomainSocketAvailable :: Bool
#if defined(DOMAIN_SOCKET_SUPPORT)
isUnixDomainSocketAvailable = True
#else
isUnixDomainSocketAvailable = False
#endif

data NullSockAddr = NullSockAddr

Expand All @@ -143,33 +142,25 @@ instance SocketAddress NullSockAddr where
-- Use this function in the case where 'isUnixDomainSocketAvailable' is
-- 'True'.
sendFd :: Socket -> CInt -> IO ()
#if defined(DOMAIN_SOCKET_SUPPORT)
sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do
let cmsg = encodeCmsg $ Fd outfd
sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty
where
dummyBufSize = 1
#else
sendFd _ _ = error "Network.Socket.sendFd"
#endif

-- | Receive a file descriptor over a UNIX-domain socket. Note that the resulting
-- file descriptor may have to be put into non-blocking mode in order to be
-- used safely. See 'setNonBlockIfNeeded'.
-- Use this function in the case where 'isUnixDomainSocketAvailable' is
-- 'True'.
recvFd :: Socket -> IO CInt
#if defined(DOMAIN_SOCKET_SUPPORT)
recvFd s = allocaBytes dummyBufSize $ \buf -> do
(NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty
case (lookupCmsg CmsgIdFd cmsgs >>= decodeCmsg) :: Maybe Fd of
Nothing -> return (-1)
Just (Fd fd) -> return fd
where
dummyBufSize = 16
#else
recvFd _ = error "Network.Socket.recvFd"
#endif

-- | Build a pair of connected socket objects.
-- For portability, use this function in the case
Expand All @@ -179,7 +170,21 @@ socketPair :: Family -- Family Name (usually AF_UNIX)
-> SocketType -- Socket Type (usually Stream)
-> ProtocolNumber -- Protocol Number
-> IO (Socket, Socket) -- unnamed and connected.
#if defined(DOMAIN_SOCKET_SUPPORT)
#if defined(mingw32_HOST_OS)
socketPair _ _ _ = withSystemTempFile "temp-for-pair" $ \file hdl -> do
hClose hdl
removeFile file
listenSock <- socket AF_UNIX Stream defaultProtocol
bind listenSock $ SockAddrUnix file
listen listenSock 10
clientSock <- socket AF_UNIX Stream defaultProtocol
connect clientSock $ SockAddrUnix file
(serverSock, _ :: SockAddr) <- accept listenSock
close listenSock
withFdSocket clientSock setNonBlockIfNeeded
withFdSocket serverSock setNonBlockIfNeeded
return (clientSock, serverSock)
#else
socketPair family stype protocol =
allocaBytes (2 * sizeOf (1 :: CInt)) $ \ fdArr -> do
let c_stype = packSocketType stype
Expand All @@ -194,6 +199,4 @@ socketPair family stype protocol =

foreign import ccall unsafe "socketpair"
c_socketpair :: CInt -> CInt -> CInt -> Ptr CInt -> IO CInt
#else
socketPair _ _ _ = error "Network.Socket.socketPair"
#endif
2 changes: 0 additions & 2 deletions include/HsNetDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#undef PACKAGE_TARNAME
#undef PACKAGE_VERSION

#define DOMAIN_SOCKET_SUPPORT 1

#if defined(HAVE_STRUCT_UCRED) && HAVE_DECL_SO_PEERCRED
# define HAVE_STRUCT_UCRED_SO_PEERCRED 1
#else
Expand Down
3 changes: 3 additions & 0 deletions network.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ library
cpp-options: -D_WIN32_WINNT=0x0600
cc-options: -D_WIN32_WINNT=0x0600

build-depends:
temporary

test-suite spec
type: exitcode-stdio-1.0
main-is: Spec.hs
Expand Down

0 comments on commit e7165d0

Please sign in to comment.