From 60b627a44c996dad67246c2d79620b929ab0c6d0 Mon Sep 17 00:00:00 2001 From: Nick Banks Date: Mon, 14 Jun 2021 13:21:23 -0700 Subject: [PATCH] Add Load Balancing Tests (#1707) --- src/core/binding.c | 84 +++++++++-- src/core/binding.h | 20 +++ src/core/connection.c | 8 +- src/core/library.c | 6 +- src/core/listener.c | 4 +- src/inc/msquic.h | 3 + src/inc/msquic.hpp | 230 +++++++++++++++++++---------- src/inc/msquicp.h | 18 +++ src/manifest/clog.sidecar | 188 +++++------------------ src/test/MsQuicTests.h | 11 +- src/test/bin/quic_gtest.cpp | 11 ++ src/test/bin/winkernel/control.cpp | 8 +- src/test/lib/HandshakeTest.cpp | 123 +++++++++++++++ src/test/lib/TestHelpers.h | 195 ++++++++++++++++++++++++ src/tools/spin/spinquic.cpp | 14 +- 15 files changed, 666 insertions(+), 257 deletions(-) diff --git a/src/core/binding.c b/src/core/binding.c index fc621da417..2c22c1ede6 100644 --- a/src/core/binding.c +++ b/src/core/binding.c @@ -125,14 +125,42 @@ QuicBindingInitialize( } #endif - Status = - CxPlatSocketCreateUdp( - MsQuicLib.Datapath, - LocalAddress, - RemoteAddress, - Binding, - 0, - &Binding->Socket); +#if QUIC_TEST_DATAPATH_HOOKS_ENABLED + QUIC_TEST_DATAPATH_HOOKS* Hooks = MsQuicLib.TestDatapathHooks; + if (Hooks != NULL) { + QUIC_ADDR RemoteAddressCopy; + if (RemoteAddress != NULL) { + RemoteAddressCopy = *RemoteAddress; + } + QUIC_ADDR LocalAddressCopy; + if (LocalAddress != NULL) { + LocalAddressCopy = *LocalAddress; + } + Hooks->Create( + RemoteAddress != NULL ? &RemoteAddressCopy : NULL, + LocalAddress != NULL ? &LocalAddressCopy : NULL); + + Status = + CxPlatSocketCreateUdp( + MsQuicLib.Datapath, + LocalAddress != NULL ? &LocalAddressCopy : NULL, + RemoteAddress != NULL ? &RemoteAddressCopy : NULL, + Binding, + 0, + &Binding->Socket); + } else { +#endif + Status = + CxPlatSocketCreateUdp( + MsQuicLib.Datapath, + LocalAddress, + RemoteAddress, + Binding, + 0, + &Binding->Socket); +#if QUIC_TEST_DATAPATH_HOOKS_ENABLED + } +#endif #ifdef QUIC_COMPARTMENT_ID if (RevertCompartmentId) { @@ -151,8 +179,8 @@ QuicBindingInitialize( } QUIC_ADDR DatapathLocalAddr, DatapathRemoteAddr; - CxPlatSocketGetLocalAddress(Binding->Socket, &DatapathLocalAddr); - CxPlatSocketGetRemoteAddress(Binding->Socket, &DatapathRemoteAddr); + QuicBindingGetLocalAddress(Binding, &DatapathLocalAddr); + QuicBindingGetRemoteAddress(Binding, &DatapathRemoteAddr); QuicTraceEvent( BindingCreated, "[bind][%p] Created, Udp=%p LocalAddr=%!ADDR! RemoteAddr=%!ADDR!", @@ -248,8 +276,8 @@ QuicBindingTraceRundown( // TODO - Trace datapath binding QUIC_ADDR DatapathLocalAddr, DatapathRemoteAddr; - CxPlatSocketGetLocalAddress(Binding->Socket, &DatapathLocalAddr); - CxPlatSocketGetRemoteAddress(Binding->Socket, &DatapathRemoteAddr); + QuicBindingGetLocalAddress(Binding, &DatapathLocalAddr); + QuicBindingGetRemoteAddress(Binding, &DatapathRemoteAddr); QuicTraceEvent( BindingRundown, "[bind][%p] Rundown, Udp=%p LocalAddr=%!ADDR! RemoteAddr=%!ADDR!", @@ -270,6 +298,38 @@ QuicBindingTraceRundown( CxPlatDispatchRwLockReleaseShared(&Binding->RwLock); } +_IRQL_requires_max_(DISPATCH_LEVEL) +void +QuicBindingGetLocalAddress( + _In_ QUIC_BINDING* Binding, + _Out_ QUIC_ADDR* Address + ) +{ + CxPlatSocketGetLocalAddress(Binding->Socket, Address); +#if QUIC_TEST_DATAPATH_HOOKS_ENABLED + QUIC_TEST_DATAPATH_HOOKS* Hooks = MsQuicLib.TestDatapathHooks; + if (Hooks != NULL) { + Hooks->GetLocalAddress(Address); + } +#endif +} + +_IRQL_requires_max_(DISPATCH_LEVEL) +void +QuicBindingGetRemoteAddress( + _In_ QUIC_BINDING* Binding, + _Out_ QUIC_ADDR* Address + ) +{ + CxPlatSocketGetRemoteAddress(Binding->Socket, Address); +#if QUIC_TEST_DATAPATH_HOOKS_ENABLED + QUIC_TEST_DATAPATH_HOOKS* Hooks = MsQuicLib.TestDatapathHooks; + if (Hooks != NULL) { + Hooks->GetRemoteAddress(Address); + } +#endif +} + // // Returns TRUE if there are any registered listeners on this binding. // diff --git a/src/core/binding.h b/src/core/binding.h index 2d63763bf3..c94105a0c2 100644 --- a/src/core/binding.h +++ b/src/core/binding.h @@ -289,6 +289,26 @@ QuicBindingTraceRundown( _In_ QUIC_BINDING* Binding ); +// +// Queries the local IP address of the binding. +// +_IRQL_requires_max_(DISPATCH_LEVEL) +void +QuicBindingGetLocalAddress( + _In_ QUIC_BINDING* Binding, + _Out_ QUIC_ADDR* Address + ); + +// +// Queries the remote IP address of the binding. +// +_IRQL_requires_max_(DISPATCH_LEVEL) +void +QuicBindingGetRemoteAddress( + _In_ QUIC_BINDING* Binding, + _Out_ QUIC_ADDR* Address + ); + // // Looks up the listener based on the ALPN list. Optionally, outputs the // first ALPN that matches. diff --git a/src/core/connection.c b/src/core/connection.c index d6c8867032..e5bfe3f2f8 100644 --- a/src/core/connection.c +++ b/src/core/connection.c @@ -1868,9 +1868,7 @@ QuicConnStart( } Connection->State.LocalAddressSet = TRUE; - CxPlatSocketGetLocalAddress( - Path->Binding->Socket, - &Path->LocalAddress); + QuicBindingGetLocalAddress(Path->Binding, &Path->LocalAddress); QuicTraceEvent( ConnLocalAddrAdded, "[conn][%p] New Local IP: %!ADDR!", @@ -5675,8 +5673,8 @@ QuicConnParamSet( Connection, CASTED_CLOG_BYTEARRAY(sizeof(Connection->Paths[0].LocalAddress), &Connection->Paths[0].LocalAddress)); - CxPlatSocketGetLocalAddress( - Connection->Paths[0].Binding->Socket, + QuicBindingGetLocalAddress( + Connection->Paths[0].Binding, &Connection->Paths[0].LocalAddress); QuicTraceEvent( diff --git a/src/core/library.c b/src/core/library.c index acab9603f2..c5109bfee8 100644 --- a/src/core/library.c +++ b/src/core/library.c @@ -1390,7 +1390,7 @@ QuicLibraryLookupBinding( #endif QUIC_ADDR BindingLocalAddr; - CxPlatSocketGetLocalAddress(Binding->Socket, &BindingLocalAddr); + QuicBindingGetLocalAddress(Binding, &BindingLocalAddr); if (!QuicAddrCompare(LocalAddress, &BindingLocalAddr)) { continue; @@ -1402,7 +1402,7 @@ QuicLibraryLookupBinding( } QUIC_ADDR BindingRemoteAddr; - CxPlatSocketGetRemoteAddress(Binding->Socket, &BindingRemoteAddr); + QuicBindingGetRemoteAddress(Binding, &BindingRemoteAddr); if (!QuicAddrCompare(RemoteAddress, &BindingRemoteAddr)) { continue; } @@ -1507,7 +1507,7 @@ QuicLibraryGetBinding( goto Exit; } - CxPlatSocketGetLocalAddress((*NewBinding)->Socket, &NewLocalAddress); + QuicBindingGetLocalAddress(*NewBinding, &NewLocalAddress); CxPlatDispatchLockAcquire(&MsQuicLib.DatapathLock); diff --git a/src/core/listener.c b/src/core/listener.c index 5ccf25f1b4..816d003310 100644 --- a/src/core/listener.c +++ b/src/core/listener.c @@ -293,9 +293,7 @@ MsQuicListenerStart( } if (PortUnspecified) { - CxPlatSocketGetLocalAddress( - Listener->Binding->Socket, - &BindingLocalAddress); + QuicBindingGetLocalAddress(Listener->Binding, &BindingLocalAddress); QuicAddrSetPort( &Listener->LocalAddress, QuicAddrGetPort(&BindingLocalAddress)); diff --git a/src/inc/msquic.h b/src/inc/msquic.h index de5dca7104..b13d291d00 100644 --- a/src/inc/msquic.h +++ b/src/inc/msquic.h @@ -885,6 +885,7 @@ typedef struct QUIC_CONNECTION_EVENT { union { struct { BOOLEAN SessionResumed; + _Field_range_(>, 0) uint8_t NegotiatedAlpnLength; _Field_size_(NegotiatedAlpnLength) const uint8_t* NegotiatedAlpn; @@ -934,7 +935,9 @@ typedef struct QUIC_CONNECTION_EVENT { const uint8_t* ResumptionState; } RESUMED; struct { + _Field_range_(>, 0) uint32_t ResumptionTicketLength; + _Field_size_(ResumptionTicketLength) const uint8_t* ResumptionTicket; } RESUMPTION_TICKET_RECEIVED; struct { diff --git a/src/inc/msquic.hpp b/src/inc/msquic.hpp index 7f870e5b80..b01091863e 100644 --- a/src/inc/msquic.hpp +++ b/src/inc/msquic.hpp @@ -31,6 +31,84 @@ Supported Platforms: #define CXPLAT_DBG_ASSERT(X) // no-op if not already defined #endif +#ifdef CX_PLATFORM_TYPE + +// +// Abstractions for platform specific types/interfaces +// + +struct CxPlatEvent { + CXPLAT_EVENT Handle; + CxPlatEvent() noexcept { CxPlatEventInitialize(&Handle, FALSE, FALSE); } + CxPlatEvent(bool ManualReset) noexcept { CxPlatEventInitialize(&Handle, ManualReset, FALSE); } + CxPlatEvent(CXPLAT_EVENT event) noexcept : Handle(event) { } + ~CxPlatEvent() noexcept { CxPlatEventUninitialize(Handle); } + CXPLAT_EVENT* operator &() noexcept { return &Handle; } + operator CXPLAT_EVENT() const noexcept { return Handle; } + void Set() { CxPlatEventSet(Handle); } + void Reset() { CxPlatEventReset(Handle); } + void WaitForever() { CxPlatEventWaitForever(Handle); } + bool WaitTimeout(uint32_t TimeoutMs) { return CxPlatEventWaitWithTimeout(Handle, TimeoutMs); } +}; + +#ifdef CXPLAT_HASH_MIN_SIZE + +struct HashTable { + bool Initialized; + CXPLAT_HASHTABLE Table; + HashTable() noexcept { Initialized = CxPlatHashtableInitializeEx(&Table, CXPLAT_HASH_MIN_SIZE); } + ~HashTable() noexcept { if (Initialized) { CxPlatHashtableUninitialize(&Table); } } + void Insert(CXPLAT_HASHTABLE_ENTRY* Entry) { CxPlatHashtableInsert(&Table, Entry, Entry->Signature, nullptr); } + void Remove(CXPLAT_HASHTABLE_ENTRY* Entry) { CxPlatHashtableRemove(&Table, Entry, nullptr); } + CXPLAT_HASHTABLE_ENTRY* Lookup(uint64_t Signature) { + CXPLAT_HASHTABLE_LOOKUP_CONTEXT LookupContext; + return CxPlatHashtableLookup(&Table, Signature, &LookupContext); + } + CXPLAT_HASHTABLE_ENTRY* LookupEx(uint64_t Signature, bool (*Equals)(CXPLAT_HASHTABLE_ENTRY* Entry, void* Context), void* Context) { + CXPLAT_HASHTABLE_LOOKUP_CONTEXT LookupContext; + CXPLAT_HASHTABLE_ENTRY* Entry = CxPlatHashtableLookup(&Table, Signature, &LookupContext); + while (Entry != NULL) { + if (Equals(Entry, Context)) return Entry; + Entry = CxPlatHashtableLookupNext(&Table, &LookupContext); + } + return NULL; + } +}; + +#endif // CXPLAT_HASH_MIN_SIZE + +#ifdef CXPLAT_FRE_ASSERT + +class CxPlatWatchdog { + CXPLAT_THREAD WatchdogThread; + CxPlatEvent ShutdownEvent {true}; + uint32_t TimeoutMs; + static CXPLAT_THREAD_CALLBACK(WatchdogThreadCallback, Context) { + auto This = (CxPlatWatchdog*)Context; + if (!This->ShutdownEvent.WaitTimeout(This->TimeoutMs)) { + CXPLAT_FRE_ASSERTMSG(FALSE, "Watchdog timeout fired!"); + } + CXPLAT_THREAD_RETURN(0); + } +public: + CxPlatWatchdog(uint32_t WatchdogTimeoutMs) : TimeoutMs(WatchdogTimeoutMs) { + CXPLAT_THREAD_CONFIG Config = { 0 }; + Config.Name = "cxplat_watchdog"; + Config.Callback = WatchdogThreadCallback; + Config.Context = this; + CXPLAT_FRE_ASSERT(QUIC_SUCCEEDED(CxPlatThreadCreate(&Config, &WatchdogThread))); + } + ~CxPlatWatchdog() { + ShutdownEvent.Set(); + CxPlatThreadWait(&WatchdogThread); + CxPlatThreadDelete(&WatchdogThread); + } +}; + +#endif // CXPLAT_FRE_ASSERT + +#endif // CX_PLATFORM_TYPE + struct QuicAddr { QUIC_ADDR SockAddr; QuicAddr() { @@ -563,6 +641,14 @@ struct MsQuicConnection { MsQuicConnectionCallback* Callback; void* Context; QUIC_STATUS InitStatus; + bool HandshakeComplete {false}; + bool HandshakeResumed {false}; + uint32_t ResumptionTicketLength {0}; + uint8_t* ResumptionTicket {nullptr}; +#ifdef CX_PLATFORM_TYPE + CxPlatEvent HandshakeCompleteEvent; + CxPlatEvent ResumptionTicketReceivedEvent; +#endif // CX_PLATFORM_TYPE MsQuicConnection( _In_ const MsQuicRegistration& Registration, @@ -600,6 +686,7 @@ struct MsQuicConnection { if (Handle) { MsQuic->ConnectionClose(Handle); } + delete[] ResumptionTicket; } void @@ -678,6 +765,28 @@ struct MsQuicConnection { return MsQuic->GetParam(Handle, Level, Param, BufferLength, Buffer); } + QUIC_STATUS + SetLocalAddr(_In_ const QuicAddr& Addr) noexcept { + return + MsQuic->SetParam( + Handle, + QUIC_PARAM_LEVEL_CONNECTION, + QUIC_PARAM_CONN_LOCAL_ADDRESS, + sizeof(Addr.SockAddr), + &Addr.SockAddr); + } + + QUIC_STATUS + SetResumptionTicket(_In_reads_(TicketLength) const uint8_t* Ticket, uint32_t TicketLength) noexcept { + return + MsQuic->SetParam( + Handle, + QUIC_PARAM_LEVEL_CONNECTION, + QUIC_PARAM_CONN_RESUMPTION_TICKET, + TicketLength, + Ticket); + } + QUIC_STATUS SetSettings(_In_ const MsQuicSettings& Settings) noexcept { const QUIC_SETTINGS* QSettings = &Settings; @@ -732,7 +841,28 @@ struct MsQuicConnection { if (Event->Type == QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED) { // // Not great beacuse it doesn't provide an application specific - // error code. If you expect to get streams, you should not be no-op + // error code. If you expect to get streams, you should not no-op + // the callbacks. + // + MsQuic->StreamClose(Event->PEER_STREAM_STARTED.Stream); + } + return QUIC_STATUS_SUCCESS; + } + + static + QUIC_STATUS + QUIC_API + SendResumptionCallback( + _In_ MsQuicConnection* Connection, + _In_opt_ void* /* Context */, + _Inout_ QUIC_CONNECTION_EVENT* Event + ) noexcept { + if (Event->Type == QUIC_CONNECTION_EVENT_CONNECTED) { + MsQuic->ConnectionSendResumptionTicket(*Connection, QUIC_SEND_RESUMPTION_FLAG_FINAL, 0, nullptr); + } else if (Event->Type == QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED) { + // + // Not great beacuse it doesn't provide an application specific + // error code. If you expect to get streams, you should not no-op // the callbacks. // MsQuic->StreamClose(Event->PEER_STREAM_STARTED.Stream); @@ -753,6 +883,23 @@ struct MsQuicConnection { _Inout_ QUIC_CONNECTION_EVENT* Event ) noexcept { CXPLAT_DBG_ASSERT(pThis); + if (Event->Type == QUIC_CONNECTION_EVENT_CONNECTED) { + pThis->HandshakeComplete = true; + pThis->HandshakeResumed = Event->CONNECTED.SessionResumed; +#ifdef CX_PLATFORM_TYPE + pThis->HandshakeCompleteEvent.Set(); +#endif // CX_PLATFORM_TYPE + } else if (Event->Type == QUIC_CONNECTION_EVENT_RESUMPTION_TICKET_RECEIVED && !pThis->ResumptionTicket) { + pThis->ResumptionTicketLength = Event->RESUMPTION_TICKET_RECEIVED.ResumptionTicketLength; + pThis->ResumptionTicket = new(std::nothrow) uint8_t[pThis->ResumptionTicketLength]; + if (pThis->ResumptionTicket) { + CXPLAT_DBG_ASSERT(pThis->ResumptionTicketLength != 0); + memcpy(pThis->ResumptionTicket, Event->RESUMPTION_TICKET_RECEIVED.ResumptionTicket, pThis->ResumptionTicketLength); +#ifdef CX_PLATFORM_TYPE + pThis->ResumptionTicketReceivedEvent.Set(); +#endif // CX_PLATFORM_TYPE + } + } auto DeleteOnExit = Event->Type == QUIC_CONNECTION_EVENT_SHUTDOWN_COMPLETE && pThis->CleanUpMode == CleanUpAutoDelete; @@ -768,6 +915,7 @@ struct MsQuicAutoAcceptListener : public MsQuicListener { const MsQuicConfiguration& Configuration; MsQuicConnectionCallback* ConnectionHandler; void* ConnectionContext; + uint32_t AcceptedConnectionCount {0}; MsQuicAutoAcceptListener( _In_ const MsQuicRegistration& Registration, @@ -805,6 +953,8 @@ struct MsQuicAutoAcceptListener : public MsQuicListener { // Connection->Handle = nullptr; delete Connection; + } else { + InterlockedIncrement((long*)&pThis->AcceptedConnectionCount); } } } @@ -992,81 +1142,3 @@ struct QuicBufferScope { operator QUIC_BUFFER* () noexcept { return Buffer; } ~QuicBufferScope() noexcept { if (Buffer) { delete[](uint8_t*) Buffer; } } }; - -#ifdef CX_PLATFORM_TYPE - -// -// Abstractions for platform specific types/interfaces -// - -struct CxPlatEvent { - CXPLAT_EVENT Handle; - CxPlatEvent() noexcept { CxPlatEventInitialize(&Handle, FALSE, FALSE); } - CxPlatEvent(bool ManualReset) noexcept { CxPlatEventInitialize(&Handle, ManualReset, FALSE); } - CxPlatEvent(CXPLAT_EVENT event) noexcept : Handle(event) { } - ~CxPlatEvent() noexcept { CxPlatEventUninitialize(Handle); } - CXPLAT_EVENT* operator &() noexcept { return &Handle; } - operator CXPLAT_EVENT() const noexcept { return Handle; } - void Set() { CxPlatEventSet(Handle); } - void Reset() { CxPlatEventReset(Handle); } - void WaitForever() { CxPlatEventWaitForever(Handle); } - bool WaitTimeout(uint32_t TimeoutMs) { return CxPlatEventWaitWithTimeout(Handle, TimeoutMs); } -}; - -#ifdef CXPLAT_HASH_MIN_SIZE - -struct HashTable { - bool Initialized; - CXPLAT_HASHTABLE Table; - HashTable() noexcept { Initialized = CxPlatHashtableInitializeEx(&Table, CXPLAT_HASH_MIN_SIZE); } - ~HashTable() noexcept { if (Initialized) { CxPlatHashtableUninitialize(&Table); } } - void Insert(CXPLAT_HASHTABLE_ENTRY* Entry) { CxPlatHashtableInsert(&Table, Entry, Entry->Signature, nullptr); } - void Remove(CXPLAT_HASHTABLE_ENTRY* Entry) { CxPlatHashtableRemove(&Table, Entry, nullptr); } - CXPLAT_HASHTABLE_ENTRY* Lookup(uint64_t Signature) { - CXPLAT_HASHTABLE_LOOKUP_CONTEXT LookupContext; - return CxPlatHashtableLookup(&Table, Signature, &LookupContext); - } - CXPLAT_HASHTABLE_ENTRY* LookupEx(uint64_t Signature, bool (*Equals)(CXPLAT_HASHTABLE_ENTRY* Entry, void* Context), void* Context) { - CXPLAT_HASHTABLE_LOOKUP_CONTEXT LookupContext; - CXPLAT_HASHTABLE_ENTRY* Entry = CxPlatHashtableLookup(&Table, Signature, &LookupContext); - while (Entry != NULL) { - if (Equals(Entry, Context)) return Entry; - Entry = CxPlatHashtableLookupNext(&Table, &LookupContext); - } - return NULL; - } -}; - -#endif // CXPLAT_HASH_MIN_SIZE - -#ifdef CXPLAT_FRE_ASSERT - -class CxPlatWatchdog { - CXPLAT_THREAD WatchdogThread; - CxPlatEvent ShutdownEvent {true}; - uint32_t TimeoutMs; - static CXPLAT_THREAD_CALLBACK(WatchdogThreadCallback, Context) { - auto This = (CxPlatWatchdog*)Context; - if (!This->ShutdownEvent.WaitTimeout(This->TimeoutMs)) { - CXPLAT_FRE_ASSERTMSG(FALSE, "Watchdog timeout fired!"); - } - CXPLAT_THREAD_RETURN(0); - } -public: - CxPlatWatchdog(uint32_t WatchdogTimeoutMs) : TimeoutMs(WatchdogTimeoutMs) { - CXPLAT_THREAD_CONFIG Config = { 0 }; - Config.Name = "cxplat_watchdog"; - Config.Callback = WatchdogThreadCallback; - Config.Context = this; - CXPLAT_FRE_ASSERT(QUIC_SUCCEEDED(CxPlatThreadCreate(&Config, &WatchdogThread))); - } - ~CxPlatWatchdog() { - ShutdownEvent.Set(); - CxPlatThreadWait(&WatchdogThread); - CxPlatThreadDelete(&WatchdogThread); - } -}; - -#endif // CXPLAT_FRE_ASSERT - -#endif // CX_PLATFORM_TYPE diff --git a/src/inc/msquicp.h b/src/inc/msquicp.h index 0ce9df798e..546ab54123 100644 --- a/src/inc/msquicp.h +++ b/src/inc/msquicp.h @@ -23,6 +23,21 @@ extern "C" { typedef struct CXPLAT_RECV_DATA CXPLAT_RECV_DATA; typedef struct CXPLAT_SEND_DATA CXPLAT_SEND_DATA; +typedef +_IRQL_requires_max_(DISPATCH_LEVEL) +void +(QUIC_API * QUIC_TEST_DATAPATH_CREATE_HOOK)( + _Inout_opt_ QUIC_ADDR* RemoteAddress, + _Inout_opt_ QUIC_ADDR* LocalAddress + ); + +typedef +_IRQL_requires_max_(DISPATCH_LEVEL) +void +(QUIC_API * QUIC_TEST_DATAPATH_GET_ADDRESS_HOOK)( + _Inout_ QUIC_ADDR* Address + ); + // // Returns TRUE to drop the packet. // @@ -46,6 +61,9 @@ BOOLEAN ); typedef struct QUIC_TEST_DATAPATH_HOOKS { + QUIC_TEST_DATAPATH_CREATE_HOOK Create; + QUIC_TEST_DATAPATH_GET_ADDRESS_HOOK GetLocalAddress; + QUIC_TEST_DATAPATH_GET_ADDRESS_HOOK GetRemoteAddress; QUIC_TEST_DATAPATH_RECEIVE_HOOK Receive; QUIC_TEST_DATAPATH_SEND_HOOK Send; } QUIC_TEST_DATAPATH_HOOKS; diff --git a/src/manifest/clog.sidecar b/src/manifest/clog.sidecar index d3a0dc8733..a10186c6cf 100644 --- a/src/manifest/clog.sidecar +++ b/src/manifest/clog.sidecar @@ -1054,10 +1054,10 @@ ], "macroName": "QuicTraceLogConnInfo" }, - "RttUpdated": { + "RttUpdatedMsg": { "ModuleProperites": {}, "TraceString": "[conn][%p] Updated Rtt=%u.%03u ms, Var=%u.%03u", - "UniqueId": "RttUpdated", + "UniqueId": "RttUpdatedMsg", "splitArgs": [ { "DefinationEncoding": "p", @@ -6100,10 +6100,10 @@ ], "macroName": "QuicTraceLogConnVerbose" }, - "RemoveSendFlags": { + "RemoveSendFlagsMsg": { "ModuleProperites": {}, "TraceString": "[conn][%p] Removing flags %x", - "UniqueId": "RemoveSendFlags", + "UniqueId": "RemoveSendFlagsMsg", "splitArgs": [ { "DefinationEncoding": "p", @@ -7304,10 +7304,10 @@ ], "macroName": "QuicTraceLogStreamVerbose" }, - "AckRange": { + "AckRangeMsg": { "ModuleProperites": {}, "TraceString": "[strm][%p] Received ack for %d bytes, offset=%llu, FF=0x%hx", - "UniqueId": "AckRange", + "UniqueId": "AckRangeMsg", "splitArgs": [ { "DefinationEncoding": "p", @@ -9061,10 +9061,10 @@ ], "macroName": "QuicTraceLogWarning" }, - "DatapathUnreachable": { + "DatapathUnreachableMsg": { "ModuleProperites": {}, "TraceString": "[sock][%p] Unreachable error from %!ADDR!", - "UniqueId": "DatapathUnreachable", + "UniqueId": "DatapathUnreachableMsg", "splitArgs": [ { "DefinationEncoding": "p", @@ -10334,6 +10334,22 @@ "splitArgs": [], "macroName": "QuicTraceLogVerbose" }, + "TestHookReplaceCreateSend": { + "ModuleProperites": {}, + "TraceString": "[test][hook] Create (remote) Addr :%hu => :%hu", + "UniqueId": "TestHookReplaceCreateSend", + "splitArgs": [ + { + "DefinationEncoding": "hu", + "MacroVariableName": "arg2" + }, + { + "DefinationEncoding": "hu", + "MacroVariableName": "arg3" + } + ], + "macroName": "QuicTraceLogVerbose" + }, "InteropTestStart": { "ModuleProperites": {}, "TraceString": "[ntrp] Test Start, Server: %s, Port: %hu, Tests: 0x%x.", @@ -10381,118 +10397,6 @@ } ], "macroName": "QuicTraceLogInfo" - }, - "RttUpdatedMsg": { - "ModuleProperites": {}, - "TraceString": "[conn][%p] Updated Rtt=%u.%03u ms, Var=%u.%03u", - "UniqueId": "RttUpdatedMsg", - "splitArgs": [ - { - "DefinationEncoding": "p", - "MacroVariableName": "arg1" - }, - { - "DefinationEncoding": "u", - "MacroVariableName": "arg3" - }, - { - "DefinationEncoding": "03u", - "MacroVariableName": "arg4" - }, - { - "DefinationEncoding": "u", - "MacroVariableName": "arg5" - }, - { - "DefinationEncoding": "03u", - "MacroVariableName": "arg6" - } - ], - "macroName": "QuicTraceLogConnVerbose" - }, - "RemoveSendFlagsMsg": { - "ModuleProperites": {}, - "TraceString": "[conn][%p] Removing flags %x", - "UniqueId": "RemoveSendFlagsMsg", - "splitArgs": [ - { - "DefinationEncoding": "p", - "MacroVariableName": "arg1" - }, - { - "DefinationEncoding": "x", - "MacroVariableName": "arg3" - } - ], - "macroName": "QuicTraceLogConnVerbose" - }, - "AckRangeMsg": { - "ModuleProperites": {}, - "TraceString": "[strm][%p] Received ack for %d bytes, offset=%llu, FF=0x%hx", - "UniqueId": "AckRangeMsg", - "splitArgs": [ - { - "DefinationEncoding": "p", - "MacroVariableName": "arg1" - }, - { - "DefinationEncoding": "d", - "MacroVariableName": "arg3" - }, - { - "DefinationEncoding": "llu", - "MacroVariableName": "arg4" - }, - { - "DefinationEncoding": "hx", - "MacroVariableName": "arg5" - } - ], - "macroName": "QuicTraceLogStreamVerbose" - }, - "TestSendIoctl": { - "ModuleProperites": {}, - "TraceString": "[test] Sending Write IOCTL %u with %u bytes.", - "UniqueId": "TestSendIoctl", - "splitArgs": [ - { - "DefinationEncoding": "u", - "MacroVariableName": "arg2" - }, - { - "DefinationEncoding": "u", - "MacroVariableName": "arg3" - } - ], - "macroName": "QuicTraceLogVerbose" - }, - "TestReadIoctl": { - "ModuleProperites": {}, - "TraceString": "[test] Sending Read IOCTL %u.", - "UniqueId": "TestReadIoctl", - "splitArgs": [ - { - "DefinationEncoding": "u", - "MacroVariableName": "arg2" - } - ], - "macroName": "QuicTraceLogVerbose" - }, - "DatapathUnreachableMsg": { - "ModuleProperites": {}, - "TraceString": "[sock][%p] Unreachable error from %!ADDR!", - "UniqueId": "DatapathUnreachableMsg", - "splitArgs": [ - { - "DefinationEncoding": "p", - "MacroVariableName": "arg2" - }, - { - "DefinationEncoding": "!ADDR!", - "MacroVariableName": "arg3" - } - ], - "macroName": "QuicTraceLogVerbose" } }, "Version": 1, @@ -11256,8 +11160,8 @@ "TraceID": "ApplySettings" }, { - "UniquenessHash": "ffeb66e8-2636-b98b-46a5-564068b26843", - "TraceID": "RttUpdated" + "UniquenessHash": "dc25a415-b386-47ec-0128-415c6f31795b", + "TraceID": "RttUpdatedMsg" }, { "UniquenessHash": "5bdd3273-8aaf-afec-f7e3-a031e8c0122c", @@ -12372,8 +12276,8 @@ "TraceID": "ScheduleSendFlags" }, { - "UniquenessHash": "93fcf6ee-a709-3273-c039-2c2b81572b48", - "TraceID": "RemoveSendFlags" + "UniquenessHash": "b4aa5fad-d1de-466c-0214-1bbda4b9eb95", + "TraceID": "RemoveSendFlagsMsg" }, { "UniquenessHash": "4a595158-b864-777b-d90a-dbb944e765ac", @@ -12712,8 +12616,8 @@ "TraceID": "RecoverRange" }, { - "UniquenessHash": "79ffdfe4-2d35-3e89-ccb1-9617b2c74c00", - "TraceID": "AckRange" + "UniquenessHash": "1f842f1b-027f-b5ec-e0ac-fc9491aeb629", + "TraceID": "AckRangeMsg" }, { "UniquenessHash": "0d9cddd5-17e7-d796-fa49-7f3450287fa6", @@ -13156,8 +13060,8 @@ "TraceID": "DatapathUroExceeded" }, { - "UniquenessHash": "37423b13-b9e3-7961-07d6-137252085127", - "TraceID": "DatapathUnreachable" + "UniquenessHash": "29acd049-f710-a71e-1271-698510cfc519", + "TraceID": "DatapathUnreachableMsg" }, { "UniquenessHash": "86f0cb85-ebae-d293-4103-fb695920b86c", @@ -13547,6 +13451,10 @@ "UniquenessHash": "78d1c2fe-fbce-b603-004e-f2026093416c", "TraceID": "TestHookDropLimitAddrSend" }, + { + "UniquenessHash": "be2459c1-0870-ccba-da9b-ae1ecc23ded8", + "TraceID": "TestHookReplaceCreateSend" + }, { "UniquenessHash": "e98e6411-dfab-d3a6-e7bd-d1782209846d", "TraceID": "InteropTestStart" @@ -13554,30 +13462,6 @@ { "UniquenessHash": "7368812e-a46d-0924-7643-009511288886", "TraceID": "InteropTestStop" - }, - { - "UniquenessHash": "dc25a415-b386-47ec-0128-415c6f31795b", - "TraceID": "RttUpdatedMsg" - }, - { - "UniquenessHash": "b4aa5fad-d1de-466c-0214-1bbda4b9eb95", - "TraceID": "RemoveSendFlagsMsg" - }, - { - "UniquenessHash": "1f842f1b-027f-b5ec-e0ac-fc9491aeb629", - "TraceID": "AckRangeMsg" - }, - { - "UniquenessHash": "28f77853-916b-5ceb-e928-405aff056b0a", - "TraceID": "TestSendIoctl" - }, - { - "UniquenessHash": "3c33979e-0db8-7562-7df1-41ef66cb75fa", - "TraceID": "TestReadIoctl" - }, - { - "UniquenessHash": "29acd049-f710-a71e-1271-698510cfc519", - "TraceID": "DatapathUnreachableMsg" } ] } diff --git a/src/test/MsQuicTests.h b/src/test/MsQuicTests.h index 0c28b0390e..4f78222806 100644 --- a/src/test/MsQuicTests.h +++ b/src/test/MsQuicTests.h @@ -161,6 +161,11 @@ QuicTestInvalidAlpnLengths( void ); +void +QuicTestLoadBalancedHandshake( + _In_ int Family + ); + // // Negative Handshake Tests // @@ -844,4 +849,8 @@ typedef struct { #define IOCTL_QUIC_RUN_MTU_DISCOVERY \ QUIC_CTL_CODE(68, METHOD_BUFFERED, FILE_WRITE_DATA) -#define QUIC_MAX_IOCTL_FUNC_CODE 68 +#define IOCTL_QUIC_RUN_LOAD_BALANCED_HANDSHAKE \ + QUIC_CTL_CODE(69, METHOD_BUFFERED, FILE_WRITE_DATA) + // int - Family + +#define QUIC_MAX_IOCTL_FUNC_CODE 69 diff --git a/src/test/bin/quic_gtest.cpp b/src/test/bin/quic_gtest.cpp index 144966725e..074fb519f4 100644 --- a/src/test/bin/quic_gtest.cpp +++ b/src/test/bin/quic_gtest.cpp @@ -980,6 +980,17 @@ TEST_P(WithFamilyArgs, ChangeMaxStreamIDs) { } } +#if QUIC_TEST_DATAPATH_HOOKS_ENABLED +TEST_P(WithFamilyArgs, LoadBalanced) { + TestLoggerT Logger("QuicTestLoadBalancedHandshake", GetParam()); + if (TestingKernelMode) { + ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_LOAD_BALANCED_HANDSHAKE, GetParam().Family)); + } else { + QuicTestLoadBalancedHandshake(GetParam().Family); + } +} +#endif // QUIC_TEST_DATAPATH_HOOKS_ENABLED + TEST_P(WithSendArgs1, Send) { TestLoggerT Logger("QuicTestConnectAndPing", GetParam()); if (TestingKernelMode) { diff --git a/src/test/bin/winkernel/control.cpp b/src/test/bin/winkernel/control.cpp index d143a0aa8a..dcac15d6c3 100644 --- a/src/test/bin/winkernel/control.cpp +++ b/src/test/bin/winkernel/control.cpp @@ -438,7 +438,8 @@ size_t QUIC_IOCTL_BUFFER_SIZES[] = 0, 0, 0, - sizeof(QUIC_RUN_MTU_DISCOVERY_PARAMS) + sizeof(QUIC_RUN_MTU_DISCOVERY_PARAMS), + sizeof(INT32) }; CXPLAT_STATIC_ASSERT( @@ -1058,6 +1059,11 @@ QuicTestCtlEvtIoDeviceControl( Params->MtuDiscoveryParams.RaiseMinimumMtu)); break; + case IOCTL_QUIC_RUN_LOAD_BALANCED_HANDSHAKE: + CXPLAT_FRE_ASSERT(Params != nullptr); + QuicTestCtlRun(QuicTestLoadBalancedHandshake(Params->Family)); + break; + default: Status = STATUS_NOT_IMPLEMENTED; break; diff --git a/src/test/lib/HandshakeTest.cpp b/src/test/lib/HandshakeTest.cpp index 1298f76e94..30c85a3a15 100644 --- a/src/test/lib/HandshakeTest.cpp +++ b/src/test/lib/HandshakeTest.cpp @@ -15,6 +15,9 @@ #endif QUIC_TEST_DATAPATH_HOOKS DatapathHooks::FuncTable = { + DatapathHooks::CreateCallback, + DatapathHooks::GetLocalAddressCallback, + DatapathHooks::GetRemoteAddressCallback, DatapathHooks::ReceiveCallback, DatapathHooks::SendCallback }; @@ -2567,3 +2570,123 @@ QuicTestConnectExpiredClientCertificate( } } } + +struct LoadBalancedServer { + QuicAddr PublicAddress; + QuicAddr* PrivateAddresses {nullptr}; + QUIC_TICKET_KEY_CONFIG KeyConfig; + MsQuicConfiguration** Configurations {nullptr}; + MsQuicAutoAcceptListener** Listeners {nullptr}; + uint32_t ListenerCount; + LoadBalancerHelper* LoadBalancer {nullptr}; + QUIC_STATUS InitStatus {QUIC_STATUS_INVALID_PARAMETER}; // Only hit in ListenerCount == 0 scenario + LoadBalancedServer( + _In_ const MsQuicRegistration& Registration, + _In_ QUIC_ADDRESS_FAMILY QuicAddrFamily = QUIC_ADDRESS_FAMILY_UNSPEC, + _In_ MsQuicConnectionCallback* ConnectionHandler = MsQuicConnection::NoOpCallback, + _In_ uint32_t ListenerCount = 2 + ) noexcept : + PublicAddress(QuicAddrFamily, (uint16_t)443), PrivateAddresses(new(std::nothrow) QuicAddr[ListenerCount]), + Configurations(new(std::nothrow) MsQuicConfiguration*[ListenerCount]), + Listeners(new(std::nothrow) MsQuicAutoAcceptListener*[ListenerCount]), ListenerCount(ListenerCount) { + CxPlatRandom(sizeof(KeyConfig), &KeyConfig); + KeyConfig.MaterialLength = sizeof(KeyConfig.Material); + CxPlatZeroMemory(Configurations, sizeof(MsQuicConfiguration*) * ListenerCount); + CxPlatZeroMemory(Listeners, sizeof(MsQuicAutoAcceptListener*) * ListenerCount); + MsQuicSettings Settings; + Settings.SetServerResumptionLevel(QUIC_SERVER_RESUME_AND_ZERORTT); + QuicAddrSetToLoopback(&PublicAddress.SockAddr); + for (uint32_t i = 0; i < ListenerCount; ++i) { + PrivateAddresses[i] = QuicAddr(QuicAddrFamily); + QuicAddrSetToLoopback(&PrivateAddresses[i].SockAddr); + Configurations[i] = new(std::nothrow) MsQuicConfiguration(Registration, "MsQuicTest", Settings, ServerSelfSignedCredConfig); + TEST_QUIC_SUCCEEDED(InitStatus = Configurations[i]->GetInitStatus()); + TEST_QUIC_SUCCEEDED(InitStatus = Configurations[i]->SetTicketKey(&KeyConfig)); + Listeners[i] = new(std::nothrow) MsQuicAutoAcceptListener(Registration, *Configurations[i], ConnectionHandler); + TEST_QUIC_SUCCEEDED(InitStatus = Listeners[i]->GetInitStatus()); + TEST_QUIC_SUCCEEDED(InitStatus = Listeners[i]->Start("MsQuicTest", &PrivateAddresses[i].SockAddr)); + TEST_QUIC_SUCCEEDED(InitStatus = Listeners[i]->GetLocalAddr(PrivateAddresses[i])); + } + LoadBalancer = new(std::nothrow) LoadBalancerHelper(PublicAddress.SockAddr, (QUIC_ADDR*)PrivateAddresses, ListenerCount); + } + ~LoadBalancedServer() noexcept { + delete LoadBalancer; + for (uint32_t i = 0; i < ListenerCount; ++i) { + delete Listeners[i]; + delete Configurations[i]; + } + delete[] Listeners; + delete[] Configurations; + delete[] PrivateAddresses; + } + QUIC_STATUS GetInitStatus() const noexcept { return InitStatus; } + void ValidateLoadBalancing() const noexcept { + for (uint32_t i = 0; i < ListenerCount; ++i) { + TEST_TRUE(Listeners[i]->AcceptedConnectionCount != 0); + } + } +}; + +void +QuicTestLoadBalancedHandshake( + _In_ int Family + ) +{ + MsQuicRegistration Registration(true); + TEST_QUIC_SUCCEEDED(Registration.GetInitStatus()); + + MsQuicConfiguration ClientConfiguration(Registration, "MsQuicTest", MsQuicCredentialConfig()); + TEST_QUIC_SUCCEEDED(ClientConfiguration.GetInitStatus()); + + QUIC_ADDRESS_FAMILY QuicAddrFamily = (Family == 4) ? QUIC_ADDRESS_FAMILY_INET : QUIC_ADDRESS_FAMILY_INET6; + LoadBalancedServer Listeners(Registration, QuicAddrFamily, MsQuicConnection::SendResumptionCallback, 3); + TEST_QUIC_SUCCEEDED(Listeners.GetInitStatus()); + + QuicAddr ConnLocalAddr(QuicAddrFamily, false); + uint32_t ResumptionTicketLength = 0; + uint8_t* ResumptionTicket = nullptr; + bool SchannelMode = false; // Only determined on first resumed connection. + ConnLocalAddr.SetPort(33667); // Randomly chosen! + for (uint32_t i = 0; i < 100; ++i) { + MsQuicConnection Connection(Registration); + TEST_QUIC_SUCCEEDED(Connection.GetInitStatus()); + bool TryingResumption = false; + if (ResumptionTicket) { + TEST_QUIC_SUCCEEDED(Connection.SetResumptionTicket(ResumptionTicket, ResumptionTicketLength)); + delete[] ResumptionTicket; + ResumptionTicket = nullptr; + TryingResumption = true; + } + TEST_QUIC_SUCCEEDED(Connection.SetLocalAddr(ConnLocalAddr)); // TODO - Put in loop in case addr is taken + TEST_QUIC_SUCCEEDED(Connection.StartLocalhost(ClientConfiguration, Listeners.PublicAddress)); + TEST_TRUE(Connection.HandshakeCompleteEvent.WaitTimeout(TestWaitTimeout)); + if (SchannelMode) { + // + // HACK: Schannel reuses tickets, so it always resumes. Also, no + // point in waiting for a ticket because it won't send it. + // + TEST_TRUE(Connection.HandshakeResumed); + + } else { + TEST_TRUE(Connection.HandshakeResumed == TryingResumption); + if (!Connection.ResumptionTicketReceivedEvent.WaitTimeout(TestWaitTimeout)) { + if (Connection.HandshakeResumed) { + SchannelMode = true; // Schannel doesn't send tickets on resumed connections. + ResumptionTicket = nullptr; + } else { + TEST_FAILURE("Timeout waiting for resumption ticket"); + return; + } + } else { + TEST_TRUE(Connection.ResumptionTicket != nullptr); + ResumptionTicketLength = Connection.ResumptionTicketLength; + ResumptionTicket = Connection.ResumptionTicket; + Connection.ResumptionTicket = nullptr; + } + } + Connection.Shutdown(0); // Best effort start peer shutdown + ConnLocalAddr.IncrementPort(); + } + delete[] ResumptionTicket; + Listeners.ValidateLoadBalancing(); +} diff --git a/src/test/lib/TestHelpers.h b/src/test/lib/TestHelpers.h index f625c71d56..614cfa25e8 100644 --- a/src/test/lib/TestHelpers.h +++ b/src/test/lib/TestHelpers.h @@ -14,6 +14,7 @@ #endif #include "msquic.hpp" +#include "quic_toeplitz.h" #define OLD_SUPPORTED_VERSION QUIC_VERSION_1_MS_H #define LATEST_SUPPORTED_VERSION QUIC_VERSION_LATEST_H @@ -154,6 +155,33 @@ struct DatapathHook DatapathHook() : Next(nullptr) { } + virtual ~DatapathHook() { } + + virtual + _IRQL_requires_max_(PASSIVE_LEVEL) + void + Create( + _Inout_opt_ QUIC_ADDR* /* RemoteAddress */, + _Inout_opt_ QUIC_ADDR* /* LocalAddress */ + ) { + } + + virtual + _IRQL_requires_max_(PASSIVE_LEVEL) + void + GetLocalAddress( + _Inout_ QUIC_ADDR* /* Address */ + ) { + } + + virtual + _IRQL_requires_max_(PASSIVE_LEVEL) + void + GetRemoteAddress( + _Inout_ QUIC_ADDR* /* Address */ + ) { + } + virtual _IRQL_requires_max_(DISPATCH_LEVEL) BOOLEAN @@ -182,6 +210,37 @@ class DatapathHooks DatapathHook* Hooks; CXPLAT_DISPATCH_LOCK Lock; + static + _IRQL_requires_max_(PASSIVE_LEVEL) + void + QUIC_API + CreateCallback( + _Inout_opt_ QUIC_ADDR* RemoteAddress, + _Inout_opt_ QUIC_ADDR* LocalAddress + ) { + return Instance->Create(RemoteAddress, LocalAddress); + } + + static + _IRQL_requires_max_(PASSIVE_LEVEL) + void + QUIC_API + GetLocalAddressCallback( + _Inout_ QUIC_ADDR* Address + ) { + return Instance->GetLocalAddress(Address); + } + + static + _IRQL_requires_max_(PASSIVE_LEVEL) + void + QUIC_API + GetRemoteAddressCallback( + _Inout_ QUIC_ADDR* Address + ) { + return Instance->GetRemoteAddress(Address); + } + static _IRQL_requires_max_(DISPATCH_LEVEL) BOOLEAN @@ -248,6 +307,46 @@ class DatapathHooks #endif } + void + Create( + _Inout_opt_ QUIC_ADDR* RemoteAddress, + _Inout_opt_ QUIC_ADDR* LocalAddress + ) { + CxPlatDispatchLockAcquire(&Lock); + DatapathHook* Iter = Hooks; + while (Iter) { + Iter->Create(RemoteAddress, LocalAddress); + Iter = Iter->Next; + } + CxPlatDispatchLockRelease(&Lock); + } + + void + GetLocalAddress( + _Inout_ QUIC_ADDR* Address + ) { + CxPlatDispatchLockAcquire(&Lock); + DatapathHook* Iter = Hooks; + while (Iter) { + Iter->GetLocalAddress(Address); + Iter = Iter->Next; + } + CxPlatDispatchLockRelease(&Lock); + } + + void + GetRemoteAddress( + _Inout_ QUIC_ADDR* Address + ) { + CxPlatDispatchLockAcquire(&Lock); + DatapathHook* Iter = Hooks; + while (Iter) { + Iter->GetRemoteAddress(Address); + Iter = Iter->Next; + } + CxPlatDispatchLockRelease(&Lock); + } + BOOLEAN Receive( _Inout_ struct CXPLAT_RECV_DATA* Datagram @@ -539,3 +638,99 @@ struct ReplaceAddressThenDropHelper : public DatapathHook return FALSE; } }; + +struct LoadBalancerHelper : public DatapathHook +{ + CXPLAT_TOEPLITZ_HASH Toeplitz; + QUIC_ADDR PublicAddress; + const QUIC_ADDR* PrivateAddresses; + uint32_t PrivateAddressesCount; + LoadBalancerHelper(const QUIC_ADDR& Public, const QUIC_ADDR* Private, uint32_t PrivateCount) : + PublicAddress(Public), PrivateAddresses(Private), PrivateAddressesCount(PrivateCount) { + CxPlatRandom(CXPLAT_TOEPLITZ_KEY_SIZE, &Toeplitz.HashKey); + CxPlatToeplitzHashInitialize(&Toeplitz); + DatapathHooks::Instance->AddHook(this); + } + ~LoadBalancerHelper() { + DatapathHooks::Instance->RemoveHook(this); + } + _IRQL_requires_max_(PASSIVE_LEVEL) + void + Create( + _Inout_opt_ QUIC_ADDR* RemoteAddress, + _Inout_opt_ QUIC_ADDR* LocalAddress + ) { + if (RemoteAddress && LocalAddress && + QuicAddrCompare(RemoteAddress, &PublicAddress)) { + *RemoteAddress = MapSendToPublic(LocalAddress); + QuicTraceLogVerbose( + TestHookReplaceCreateSend, + "[test][hook] Create (remote) Addr :%hu => :%hu", + QuicAddrGetPort(&PublicAddress), + QuicAddrGetPort(RemoteAddress)); + } + } + _IRQL_requires_max_(PASSIVE_LEVEL) + void + GetLocalAddress( + _Inout_ QUIC_ADDR* /* Address */ + ) { + } + _IRQL_requires_max_(PASSIVE_LEVEL) + void + GetRemoteAddress( + _Inout_ QUIC_ADDR* Address + ) { + for (uint32_t i = 0; i < PrivateAddressesCount; ++i) { + if (QuicAddrCompare( + Address, + &PrivateAddresses[i])) { + *Address = PublicAddress; + break; + } + } + } + _IRQL_requires_max_(DISPATCH_LEVEL) + BOOLEAN + Receive( + _Inout_ struct CXPLAT_RECV_DATA* Datagram + ) { + for (uint32_t i = 0; i < PrivateAddressesCount; ++i) { + if (QuicAddrCompare( + &Datagram->Tuple->RemoteAddress, + &PrivateAddresses[i])) { + Datagram->Tuple->RemoteAddress = PublicAddress; + QuicTraceLogVerbose( + TestHookReplaceAddrRecv, + "[test][hook] Recv Addr :%hu => :%hu", + QuicAddrGetPort(&PrivateAddresses[i]), + QuicAddrGetPort(&PublicAddress)); + break; + } + } + return FALSE; + } + _IRQL_requires_max_(PASSIVE_LEVEL) + BOOLEAN + Send( + _Inout_ QUIC_ADDR* RemoteAddress, + _Inout_opt_ QUIC_ADDR* LocalAddress, + _Inout_ struct CXPLAT_SEND_DATA* /* SendData */ + ) { + if (QuicAddrCompare(RemoteAddress, &PublicAddress)) { + *RemoteAddress = MapSendToPublic(LocalAddress); + QuicTraceLogVerbose( + TestHookReplaceAddrSend, + "[test][hook] Send Addr :%hu => :%hu", + QuicAddrGetPort(&PublicAddress), + QuicAddrGetPort(RemoteAddress)); + } + return FALSE; + } +private: + const QUIC_ADDR& MapSendToPublic(_In_ const QUIC_ADDR* SourceAddress) { + uint32_t Key = 0, Offset; + CxPlatToeplitzHashComputeAddr(&Toeplitz, SourceAddress, &Key, &Offset); + return PrivateAddresses[Key % PrivateAddressesCount]; + } +}; diff --git a/src/tools/spin/spinquic.cpp b/src/tools/spin/spinquic.cpp index 83a735e76b..d61490f828 100644 --- a/src/tools/spin/spinquic.cpp +++ b/src/tools/spin/spinquic.cpp @@ -766,6 +766,14 @@ CXPLAT_THREAD_CALLBACK(ClientSpin, Context) CXPLAT_THREAD_RETURN(0); } +void QUIC_API DatapathHookCreateCallback(_Inout_opt_ QUIC_ADDR* /* RemoteAddress */, _Inout_opt_ QUIC_ADDR* /* LocalAddress */) +{ +} + +void QUIC_API DatapathHookGetAddressCallback(_Inout_ QUIC_ADDR* /* Address */) +{ +} + BOOLEAN QUIC_API DatapathHookReceiveCallback(struct CXPLAT_RECV_DATA* /* Datagram */) { uint8_t RandomValue; @@ -779,7 +787,11 @@ BOOLEAN QUIC_API DatapathHookSendCallback(QUIC_ADDR* /* RemoteAddress */, QUIC_A } QUIC_TEST_DATAPATH_HOOKS DataPathHooks = { - DatapathHookReceiveCallback, DatapathHookSendCallback + DatapathHookCreateCallback, + DatapathHookGetAddressCallback, + DatapathHookGetAddressCallback, + DatapathHookReceiveCallback, + DatapathHookSendCallback }; void PrintHelpText(void)