Skip to content

Commit

Permalink
Implement socket endpoints, in particular reading a string descriptio…
Browse files Browse the repository at this point in the history
…n of them

Useful for 3rd-party networking applications that might want to pass around
service specifiers without worrying whether these are IP addresses, DNS names,
or UNIX-domain socket paths.

Previously, there was no data type to encapsulate these options together. In
particular, getAddrInfo had to be used to resolve DNS names into a SockAddr
before calling connect/bind, but it could not deal with UNIX domain sockets.
The new function sockNameToAddr takes this role, transparently converting DNS
names and passing through non-DNS-names unaltered, so that it can be used
uniformly without worrying about the specific type of input name/address.
  • Loading branch information
infinity0 committed May 27, 2020
1 parent 29c11bf commit f29aa06
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 2 deletions.
7 changes: 7 additions & 0 deletions Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ module Network.Socket
, Socket
, socket
, openSocket
, socketFromEndpoint
, withFdSocket
, unsafeFdSocket
, touchSocket
Expand Down Expand Up @@ -182,8 +183,14 @@ module Network.Socket
-- ** Protocol number
, ProtocolNumber
, defaultProtocol
-- * Basic socket endpoint type
, SockEndpoint(..)
, readSockEndpoint
, showSockEndpoint
, resolveEndpoint
-- * Basic socket address type
, SockAddr(..)
, sockAddrFamily
, isSupportedSockAddr
, getPeerName
, getSocketName
Expand Down
71 changes: 69 additions & 2 deletions Network/Socket/Info.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

module Network.Socket.Info where

import Control.Exception (try, IOException)
import Foreign.Marshal.Alloc (alloca, allocaBytes)
import Foreign.Marshal.Utils (maybeWith, with)
import GHC.IO.Exception (IOErrorType(NoSuchThing))
import System.IO.Error (ioeSetErrorString, mkIOError)
import System.IO.Unsafe (unsafePerformIO)
import Text.Read (readEither)

import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Syscall
import Network.Socket.Syscall (socket)
import Network.Socket.Types

-----------------------------------------------------------------------------
Expand Down Expand Up @@ -467,10 +470,74 @@ showHostAddress6 ha6@(a1, a2, a3, a4)
scanl (\c i -> if i == 0 then c - 1 else 0) 0 fields `zip` [0..]

-----------------------------------------------------------------------------

-- | A utility function to open a socket with `AddrInfo`.
-- This is a just wrapper for the following code:
--
-- > \addr -> socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
openSocket :: AddrInfo -> IO Socket
openSocket addr = socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)

-----------------------------------------------------------------------------
-- SockEndpoint

-- | Read a string representing a socket endpoint.
readSockEndpoint :: PortNumber -> String -> Either String SockEndpoint
readSockEndpoint defPort hostport = case hostport of
'/':_ -> Right $ EndpointByAddr $ SockAddrUnix hostport
'[':tl -> case span ((/=) ']') tl of
(_, []) -> Left $ "unterminated IPv6 address: " <> hostport
(ipv6, _:port) -> case readAddr ipv6 of
Nothing -> Left $ "invalid IPv6 address: " <> ipv6
Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port
_ -> case span ((/=) ':') hostport of
(host, port) -> case readAddr host of
Nothing -> EndpointByName host <$> readPort port
Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port
where
readPort "" = Right defPort
readPort ":" = Right defPort
readPort (':':port) = case readEither port of
Right p -> Right p
Left _ -> Left $ "bad port: " <> port
readPort x = Left $ "bad port: " <> x
hints = Just $ defaultHints { addrFlags = [AI_NUMERICHOST] }
readAddr host = case unsafePerformIO (try (getAddrInfo hints (Just host) Nothing)) of
Left e -> Nothing where _ = e :: IOException
Right r -> Just (addrAddress (head r))
sockAddrPort h p = case h of
SockAddrInet _ a -> SockAddrInet p a
SockAddrInet6 _ f a s -> SockAddrInet6 p f a s
x -> x

showSockEndpoint :: SockEndpoint -> String
showSockEndpoint n = case n of
EndpointByName h p -> h <> ":" <> show p
EndpointByAddr a -> show a

-- | Resolve a socket endpoint into a list of socket addresses.
-- The result is always non-empty; Haskell throws an exception if name
-- resolution fails.
resolveEndpoint :: SockEndpoint -> IO [SockAddr]
resolveEndpoint name = case name of
EndpointByAddr a -> pure [a]
EndpointByName host port -> fmap addrAddress <$> getAddrInfo hints (Just host) (Just (show port))
where
hints = Just $ defaultHints { addrSocketType = Stream }
-- prevents duplicates, otherwise getAddrInfo returns all socket types

-- | Shortcut for creating a socket from a socket endpoint.
--
-- >>> import Network.Socket
-- >>> let Right sn = readSockEndpoint 0 "0.0.0.0:0"
-- >>> (s, a) <- socketFromEndpoint sn head Stream defaultProtocol
-- >>> bind s a
socketFromEndpoint
:: SockEndpoint
-> ([SockAddr] -> SockAddr)
-> SocketType
-> ProtocolNumber
-> IO (Socket, SockAddr)
socketFromEndpoint end select stype protocol = do
a <- select <$> resolveEndpoint end
s <- socket (sockAddrFamily a) stype protocol
pure (s, a)
25 changes: 25 additions & 0 deletions Network/Socket/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ module Network.Socket.Types (
, withNewSocketAddress

-- * Socket address type
, SockEndpoint(..)
, SockAddr(..)
, sockAddrFamily
, isSupportedSockAddr
, HostAddress
, hostAddressToTuple
Expand Down Expand Up @@ -1041,6 +1043,23 @@ type FlowInfo = Word32
-- | Scope identifier.
type ScopeID = Word32

-- | Socket endpoints.
--
-- A wrapper around socket addresses that also accommodates the
-- popular usage of specifying them by name, e.g. "example.com:80".
-- We don't support service names here (string aliases for port
-- numbers) because they also imply a particular socket type, which
-- is outside of the scope of this data type.
--
-- This roughly corresponds to the "authority" part of a URI, as
-- defined here: https://tools.ietf.org/html/rfc3986#section-3.2
--
-- See also 'Network.Socket.socketFromEndpoint'.
data SockEndpoint
= EndpointByName !String !PortNumber
| EndpointByAddr !SockAddr
deriving (Eq, Ord)

-- | Socket addresses.
-- The existence of a constructor does not necessarily imply that
-- that socket address type is supported on your system: see
Expand All @@ -1064,6 +1083,12 @@ instance NFData SockAddr where
rnf (SockAddrInet6 _ _ _ _) = ()
rnf (SockAddrUnix str) = rnf str

sockAddrFamily :: SockAddr -> Family
sockAddrFamily addr = case addr of
SockAddrInet _ _ -> AF_INET
SockAddrInet6 _ _ _ _ -> AF_INET6
SockAddrUnix _ -> AF_UNIX

-- | Is the socket address type supported on this system?
isSupportedSockAddr :: SockAddr -> Bool
isSupportedSockAddr addr = case addr of
Expand Down
21 changes: 21 additions & 0 deletions tests/Network/SocketSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ spec = do
-- check if an exception is not thrown.
isSupportedSockAddr addr `shouldBe` True

it "endpoints API, IPv4" $ do
let Right end = readSockEndpoint 0 "127.0.0.1:6001"
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
bind sock addr
listen sock 1
close sock

it "endpoints API, IPv6" $ do
let Right end = readSockEndpoint 0 "[::1]:6001"
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
bind sock addr
listen sock 1
close sock

it "endpoints API, DNS" $ do
let Right end = readSockEndpoint 0 "localhost:6001"
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
bind sock addr
listen sock 1
close sock

#if !defined(mingw32_HOST_OS)
when isUnixDomainSocketAvailable $ do
context "unix sockets" $ do
Expand Down

0 comments on commit f29aa06

Please sign in to comment.