Skip to content

Commit

Permalink
Add cancellable and AddressFamily-specific name resolution. (#33420)
Browse files Browse the repository at this point in the history
Add AddressFamily-specific name resolution and cancellation support for Windows. Resolves #939
  • Loading branch information
scalablecory authored Oct 26, 2020
1 parent 2b2955e commit b0be1ab
Show file tree
Hide file tree
Showing 14 changed files with 478 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ internal unsafe struct HostEntry
}

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetHostEntryForName")]
internal static extern unsafe int GetHostEntryForName(string address, HostEntry* entry);
internal static extern unsafe int GetHostEntryForName(string address, System.Net.Sockets.AddressFamily family, HostEntry* entry);

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FreeHostEntry")]
internal static extern unsafe void FreeHostEntry(HostEntry* entry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ internal static partial class Interop
{
internal static partial class Winsock
{
internal const int WSA_INVALID_HANDLE = 6;
internal const int WSA_E_CANCELLED = 10111;

internal const string GetAddrInfoExCancelFunctionName = "GetAddrInfoExCancel";

internal const int NS_ALL = 0;
Expand All @@ -28,6 +31,9 @@ internal static extern unsafe int GetAddrInfoExW(
[In] delegate* unmanaged<int, int, NativeOverlapped*, void> lpCompletionRoutine,
[Out] IntPtr* lpNameHandle);

[DllImport(Libraries.Ws2_32, ExactSpelling = true)]
internal static extern unsafe int GetAddrInfoExCancel([In] IntPtr* lpHandle);

[DllImport(Libraries.Ws2_32, ExactSpelling = true)]
internal static extern unsafe void FreeAddrInfoExW(AddressInfoEx* pAddrInfo);

Expand Down
159 changes: 82 additions & 77 deletions src/libraries/Native/Unix/System.Native/pal_networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,80 @@ c_static_assert(offsetof(IOVector, Count) == offsetof(iovec, iov_len));

#define Min(left,right) (((left) < (right)) ? (left) : (right))

static bool TryConvertAddressFamilyPlatformToPal(sa_family_t platformAddressFamily, int32_t* palAddressFamily)
{
assert(palAddressFamily != NULL);

switch (platformAddressFamily)
{
case AF_UNSPEC:
*palAddressFamily = AddressFamily_AF_UNSPEC;
return true;

case AF_UNIX:
*palAddressFamily = AddressFamily_AF_UNIX;
return true;

case AF_INET:
*palAddressFamily = AddressFamily_AF_INET;
return true;

case AF_INET6:
*palAddressFamily = AddressFamily_AF_INET6;
return true;
#ifdef AF_PACKET
case AF_PACKET:
*palAddressFamily = AddressFamily_AF_PACKET;
return true;
#endif
#ifdef AF_CAN
case AF_CAN:
*palAddressFamily = AddressFamily_AF_CAN;
return true;
#endif
default:
*palAddressFamily = platformAddressFamily;
return false;
}
}

static bool TryConvertAddressFamilyPalToPlatform(int32_t palAddressFamily, sa_family_t* platformAddressFamily)
{
assert(platformAddressFamily != NULL);

switch (palAddressFamily)
{
case AddressFamily_AF_UNSPEC:
*platformAddressFamily = AF_UNSPEC;
return true;

case AddressFamily_AF_UNIX:
*platformAddressFamily = AF_UNIX;
return true;

case AddressFamily_AF_INET:
*platformAddressFamily = AF_INET;
return true;

case AddressFamily_AF_INET6:
*platformAddressFamily = AF_INET6;
return true;
#ifdef AF_PACKET
case AddressFamily_AF_PACKET:
*platformAddressFamily = AF_PACKET;
return true;
#endif
#ifdef AF_CAN
case AddressFamily_AF_CAN:
*platformAddressFamily = AF_CAN;
return true;
#endif
default:
*platformAddressFamily = (sa_family_t)palAddressFamily;
return false;
}
}

static void ConvertByteArrayToIn6Addr(struct in6_addr* addr, const uint8_t* buffer, int32_t bufferLength)
{
assert(bufferLength == NUM_BYTES_IN_IPV6_ADDRESS);
Expand Down Expand Up @@ -261,7 +335,7 @@ static int32_t CopySockAddrToIPAddress(sockaddr* addr, sa_family_t family, IPAdd
return -1;
}

int32_t SystemNative_GetHostEntryForName(const uint8_t* address, HostEntry* entry)
int32_t SystemNative_GetHostEntryForName(const uint8_t* address, int32_t addressFamily, HostEntry* entry)
{
if (address == NULL || entry == NULL)
{
Expand All @@ -275,11 +349,16 @@ int32_t SystemNative_GetHostEntryForName(const uint8_t* address, HostEntry* entr
struct ifaddrs* addrs = NULL;
#endif

// Get all address families and the canonical name
sa_family_t platformFamily;
if (!TryConvertAddressFamilyPalToPlatform(addressFamily, &platformFamily))
{
return GetAddrInfoErrorFlags_EAI_FAMILY;
}

struct addrinfo hint;
memset(&hint, 0, sizeof(struct addrinfo));
hint.ai_family = AF_UNSPEC;
hint.ai_flags = AI_CANONNAME;
hint.ai_family = platformFamily;

int result = getaddrinfo((const char*)address, NULL, &hint, &info);
if (result != 0)
Expand Down Expand Up @@ -593,80 +672,6 @@ int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int
return Error_SUCCESS;
}

static bool TryConvertAddressFamilyPlatformToPal(sa_family_t platformAddressFamily, int32_t* palAddressFamily)
{
assert(palAddressFamily != NULL);

switch (platformAddressFamily)
{
case AF_UNSPEC:
*palAddressFamily = AddressFamily_AF_UNSPEC;
return true;

case AF_UNIX:
*palAddressFamily = AddressFamily_AF_UNIX;
return true;

case AF_INET:
*palAddressFamily = AddressFamily_AF_INET;
return true;

case AF_INET6:
*palAddressFamily = AddressFamily_AF_INET6;
return true;
#ifdef AF_PACKET
case AF_PACKET:
*palAddressFamily = AddressFamily_AF_PACKET;
return true;
#endif
#ifdef AF_CAN
case AF_CAN:
*palAddressFamily = AddressFamily_AF_CAN;
return true;
#endif
default:
*palAddressFamily = platformAddressFamily;
return false;
}
}

static bool TryConvertAddressFamilyPalToPlatform(int32_t palAddressFamily, sa_family_t* platformAddressFamily)
{
assert(platformAddressFamily != NULL);

switch (palAddressFamily)
{
case AddressFamily_AF_UNSPEC:
*platformAddressFamily = AF_UNSPEC;
return true;

case AddressFamily_AF_UNIX:
*platformAddressFamily = AF_UNIX;
return true;

case AddressFamily_AF_INET:
*platformAddressFamily = AF_INET;
return true;

case AddressFamily_AF_INET6:
*platformAddressFamily = AF_INET6;
return true;
#ifdef AF_PACKET
case AddressFamily_AF_PACKET:
*platformAddressFamily = AF_PACKET;
return true;
#endif
#ifdef AF_CAN
case AddressFamily_AF_CAN:
*platformAddressFamily = AF_CAN;
return true;
#endif
default:
*platformAddressFamily = (sa_family_t)palAddressFamily;
return false;
}
}

int32_t SystemNative_GetAddressFamily(const uint8_t* socketAddress, int32_t socketAddressLen, int32_t* addressFamily)
{
if (socketAddress == NULL || addressFamily == NULL || socketAddressLen < 0)
Expand Down
2 changes: 1 addition & 1 deletion src/libraries/Native/Unix/System.Native/pal_networking.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ typedef struct
uint32_t Padding; // Pad out to 8-byte alignment
} SocketEvent;

PALEXPORT int32_t SystemNative_GetHostEntryForName(const uint8_t* address, HostEntry* entry);
PALEXPORT int32_t SystemNative_GetHostEntryForName(const uint8_t* address, int32_t addressFamily, HostEntry* entry);

PALEXPORT void SystemNative_FreeHostEntry(HostEntry* entry);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private static async Task<HttpResponseMessage> SendWithNtAuthAsync(HttpRequestMe
}
else
{
IPHostEntry result = await Dns.GetHostEntryAsync(authUri.IdnHost).ConfigureAwait(false);
IPHostEntry result = await Dns.GetHostEntryAsync(authUri.IdnHost, cancellationToken).ConfigureAwait(false);
hostName = result.HostName;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ public static partial class Dns
[System.ObsoleteAttribute("EndResolve is obsoleted for this type, please use EndGetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
public static System.Net.IPHostEntry EndResolve(System.IAsyncResult asyncResult) { throw null; }
public static System.Net.IPAddress[] GetHostAddresses(string hostNameOrAddress) { throw null; }
public static System.Net.IPAddress[] GetHostAddresses(string hostNameOrAddress, System.Net.Sockets.AddressFamily family) { throw null; }
public static System.Threading.Tasks.Task<System.Net.IPAddress[]> GetHostAddressesAsync(string hostNameOrAddress) { throw null; }
public static System.Threading.Tasks.Task<System.Net.IPAddress[]> GetHostAddressesAsync(string hostNameOrAddress, System.Net.Sockets.AddressFamily family, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public static System.Threading.Tasks.Task<System.Net.IPAddress[]> GetHostAddressesAsync(string hostNameOrAddress, System.Threading.CancellationToken cancellationToken) { throw null; }
[System.ObsoleteAttribute("GetHostByAddress is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
public static System.Net.IPHostEntry GetHostByAddress(System.Net.IPAddress address) { throw null; }
[System.ObsoleteAttribute("GetHostByAddress is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
Expand All @@ -31,8 +34,11 @@ public static partial class Dns
public static System.Net.IPHostEntry GetHostByName(string hostName) { throw null; }
public static System.Net.IPHostEntry GetHostEntry(System.Net.IPAddress address) { throw null; }
public static System.Net.IPHostEntry GetHostEntry(string hostNameOrAddress) { throw null; }
public static System.Net.IPHostEntry GetHostEntry(string hostNameOrAddress, System.Net.Sockets.AddressFamily family) { throw null; }
public static System.Threading.Tasks.Task<System.Net.IPHostEntry> GetHostEntryAsync(System.Net.IPAddress address) { throw null; }
public static System.Threading.Tasks.Task<System.Net.IPHostEntry> GetHostEntryAsync(string hostNameOrAddress) { throw null; }
public static System.Threading.Tasks.Task<System.Net.IPHostEntry> GetHostEntryAsync(string hostNameOrAddress, System.Net.Sockets.AddressFamily family, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public static System.Threading.Tasks.Task<System.Net.IPHostEntry> GetHostEntryAsync(string hostNameOrAddress, System.Threading.CancellationToken cancellationToken) { throw null; }
public static string GetHostName() { throw null; }
[System.ObsoleteAttribute("Resolve is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")]
public static System.Net.IPHostEntry Resolve(string hostName) { throw null; }
Expand Down
Loading

0 comments on commit b0be1ab

Please sign in to comment.