diff --git a/docs/api/ConnectionCertificateValidationComplete.md b/docs/api/ConnectionCertificateValidationComplete.md index 5f2cf99b7e..3ce71ae777 100644 --- a/docs/api/ConnectionCertificateValidationComplete.md +++ b/docs/api/ConnectionCertificateValidationComplete.md @@ -1,7 +1,7 @@ ConnectionCertificateValidationComplete function ====== -Uses the QUIC (client) handle to complete resumption ticket validation. This must be called after client app handles certificate validation and then return QUIC_STATUS_PENDING. +Uses the QUIC handle to complete certificate validation. This must be called after the app receives `QUIC_CONNECTION_EVENT_PEER_CERTIFICATE_RECEIVED` and returns QUIC_STATUS_PENDING. The app should complete certificate validation and call this before the idle timeout and disconnect timeouts occur. # Syntax @@ -23,7 +23,7 @@ The valid handle to an open connection object. `Result` -Ticket validation result. +Certificate validation result. # Return Value diff --git a/src/core/api.c b/src/core/api.c index 7c8d7caa51..f3aa0cb0b5 100644 --- a/src/core/api.c +++ b/src/core/api.c @@ -1756,11 +1756,6 @@ MsQuicConnectionCertificateValidationComplete( QUIC_CONN_VERIFY(Connection, !Connection->State.Freed); - if (QuicConnIsServer(Connection)) { - Status = QUIC_STATUS_INVALID_PARAMETER; - goto Error; - } - Oper = QuicOperationAlloc(Connection->Worker, QUIC_OPER_TYPE_API_CALL); if (Oper == NULL) { Status = QUIC_STATUS_OUT_OF_MEMORY; diff --git a/src/core/connection.c b/src/core/connection.c index c45c4a19c6..40b3d0adec 100644 --- a/src/core/connection.c +++ b/src/core/connection.c @@ -3158,6 +3158,9 @@ QuicConnPeerCertReceived( return FALSE; } if (Status == QUIC_STATUS_PENDING) { + // + // Don't set pending here because validation may have completed in the callback. + // QuicTraceLogConnInfo( CustomCertValidationPending, Connection, @@ -7314,7 +7317,6 @@ QuicConnProcessApiOperation( break; case QUIC_API_TYPE_CONN_COMPLETE_CERTIFICATE_VALIDATION: - CXPLAT_DBG_ASSERT(QuicConnIsClient(Connection)); QuicCryptoCustomCertValidationComplete( &Connection->Crypto, ApiCtx->CONN_COMPLETE_CERTIFICATE_VALIDATION.Result); diff --git a/src/core/crypto.c b/src/core/crypto.c index a6174f4f5b..97e318f43a 100644 --- a/src/core/crypto.c +++ b/src/core/crypto.c @@ -1368,6 +1368,7 @@ QuicCryptoProcessTlsCompletion( _In_ QUIC_CRYPTO* Crypto ) { + CXPLAT_DBG_ASSERT(!Crypto->TicketValidationPending && !Crypto->CertValidationPending); QUIC_CONNECTION* Connection = QuicCryptoGetConnection(Crypto); if (Crypto->ResultFlags & CXPLAT_TLS_RESULT_ERROR) { @@ -1556,6 +1557,7 @@ QuicCryptoProcessTlsCompletion( if (Crypto->ResultFlags & CXPLAT_TLS_RESULT_HANDSHAKE_COMPLETE) { CXPLAT_DBG_ASSERT(!(Crypto->ResultFlags & CXPLAT_TLS_RESULT_ERROR)); CXPLAT_TEL_ASSERT(!Connection->State.Connected); + CXPLAT_DBG_ASSERT(!Crypto->TicketValidationPending && !Crypto->CertValidationPending); QuicTraceEvent( ConnHandshakeComplete, diff --git a/src/test/MsQuicTests.h b/src/test/MsQuicTests.h index 1a40f6ff39..d0bd406c98 100644 --- a/src/test/MsQuicTests.h +++ b/src/test/MsQuicTests.h @@ -189,7 +189,13 @@ QuicTestFailedVersionNegotiation( #endif // QUIC_API_ENABLE_PREVIEW_FEATURES void -QuicTestCustomCertificateValidation( +QuicTestCustomServerCertificateValidation( + _In_ bool AcceptCert, + _In_ bool AsyncValidation + ); + +void +QuicTestCustomClientCertificateValidation( _In_ bool AcceptCert, _In_ bool AsyncValidation ); @@ -908,7 +914,7 @@ typedef struct { BOOLEAN AsyncValidation; } QUIC_RUN_CUSTOM_CERT_VALIDATION; -#define IOCTL_QUIC_RUN_CUSTOM_CERT_VALIDATION \ +#define IOCTL_QUIC_RUN_CUSTOM_SERVER_CERT_VALIDATION \ QUIC_CTL_CODE(47, METHOD_BUFFERED, FILE_WRITE_DATA) // QUIC_RUN_CUSTOM_CERT_VALIDATION @@ -1166,4 +1172,8 @@ typedef struct { QUIC_CTL_CODE(109, METHOD_BUFFERED, FILE_WRITE_DATA) // int - Family -#define QUIC_MAX_IOCTL_FUNC_CODE 109 +#define IOCTL_QUIC_RUN_CUSTOM_CLIENT_CERT_VALIDATION \ + QUIC_CTL_CODE(110, METHOD_BUFFERED, FILE_WRITE_DATA) + // QUIC_RUN_CUSTOM_CERT_VALIDATION + +#define QUIC_MAX_IOCTL_FUNC_CODE 110 diff --git a/src/test/bin/quic_gtest.cpp b/src/test/bin/quic_gtest.cpp index cc0ae5738e..646ec0d83f 100644 --- a/src/test/bin/quic_gtest.cpp +++ b/src/test/bin/quic_gtest.cpp @@ -966,16 +966,29 @@ TEST_P(WithFamilyArgs, FailedVersionNegotiation) { } #endif // QUIC_API_ENABLE_PREVIEW_FEATURES -TEST_P(WithHandshakeArgs5, CustomCertificateValidation) { - TestLoggerT Logger("QuicTestCustomCertificateValidation", GetParam()); +TEST_P(WithHandshakeArgs5, CustomServerCertificateValidation) { + TestLoggerT Logger("QuicTestCustomServerCertificateValidation", GetParam()); if (TestingKernelMode) { QUIC_RUN_CUSTOM_CERT_VALIDATION Params = { GetParam().AcceptCert, GetParam().AsyncValidation }; - ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_CUSTOM_CERT_VALIDATION, Params)); + ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_CUSTOM_SERVER_CERT_VALIDATION, Params)); } else { - QuicTestCustomCertificateValidation(GetParam().AcceptCert, GetParam().AsyncValidation); + QuicTestCustomServerCertificateValidation(GetParam().AcceptCert, GetParam().AsyncValidation); + } +} + +TEST_P(WithHandshakeArgs5, CustomClientCertificateValidation) { + TestLoggerT Logger("QuicTestCustomClientCertificateValidation", GetParam()); + if (TestingKernelMode) { + QUIC_RUN_CUSTOM_CERT_VALIDATION Params = { + GetParam().AcceptCert, + GetParam().AsyncValidation + }; + ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_CUSTOM_CLIENT_CERT_VALIDATION, Params)); + } else { + QuicTestCustomClientCertificateValidation(GetParam().AcceptCert, GetParam().AsyncValidation); } } diff --git a/src/test/bin/winkernel/control.cpp b/src/test/bin/winkernel/control.cpp index 051d39f6c6..36bcf5cf86 100644 --- a/src/test/bin/winkernel/control.cpp +++ b/src/test/bin/winkernel/control.cpp @@ -481,6 +481,7 @@ size_t QUIC_IOCTL_BUFFER_SIZES[] = sizeof(BOOLEAN), sizeof(INT32), sizeof(INT32), + sizeof(QUIC_RUN_CUSTOM_CERT_VALIDATION), }; CXPLAT_STATIC_ASSERT( @@ -920,10 +921,10 @@ QuicTestCtlEvtIoDeviceControl( QuicTestAckSendDelay(Params->Family)); break; - case IOCTL_QUIC_RUN_CUSTOM_CERT_VALIDATION: + case IOCTL_QUIC_RUN_CUSTOM_SERVER_CERT_VALIDATION: CXPLAT_FRE_ASSERT(Params != nullptr); QuicTestCtlRun( - QuicTestCustomCertificateValidation( + QuicTestCustomServerCertificateValidation( Params->CustomCertValidationParams.AcceptCert, Params->CustomCertValidationParams.AsyncValidation)); break; @@ -1330,6 +1331,14 @@ QuicTestCtlEvtIoDeviceControl( QuicTestCtlRun(QuicTestHandshakeSpecificLossPatterns(Params->Family)); break; + case IOCTL_QUIC_RUN_CUSTOM_CLIENT_CERT_VALIDATION: + CXPLAT_FRE_ASSERT(Params != nullptr); + QuicTestCtlRun( + QuicTestCustomClientCertificateValidation( + Params->CustomCertValidationParams.AcceptCert, + Params->CustomCertValidationParams.AsyncValidation)); + break; + default: Status = STATUS_NOT_IMPLEMENTED; break; diff --git a/src/test/lib/HandshakeTest.cpp b/src/test/lib/HandshakeTest.cpp index 4d6b12eb27..de4ace48db 100644 --- a/src/test/lib/HandshakeTest.cpp +++ b/src/test/lib/HandshakeTest.cpp @@ -111,8 +111,12 @@ ListenerAcceptConnection( ) { ServerAcceptContext* AcceptContext = (ServerAcceptContext*)Listener->Context; - *AcceptContext->NewConnection = new(std::nothrow) TestConnection(ConnectionHandle); + *AcceptContext->NewConnection = new(std::nothrow) TestConnection(ConnectionHandle, (NEW_STREAM_CALLBACK_HANDLER)AcceptContext->NewStreamHandler); (*AcceptContext->NewConnection)->SetExpectedCustomTicketValidationResult(AcceptContext->ExpectedCustomTicketValidationResult); + (*AcceptContext->NewConnection)->SetAsyncCustomValidationResult(AcceptContext->AsyncCustomCertValidation); + if (AcceptContext->IsCustomCertValidationResultSet) { + (*AcceptContext->NewConnection)->SetExpectedCustomValidationResult(AcceptContext->CustomCertValidationResult); + } if (*AcceptContext->NewConnection == nullptr || !(*AcceptContext->NewConnection)->IsValid()) { TEST_FAILURE("Failed to accept new TestConnection."); delete *AcceptContext->NewConnection; @@ -744,7 +748,7 @@ QuicTestConnectAndIdle( } void -QuicTestCustomCertificateValidation( +QuicTestCustomServerCertificateValidation( _In_ bool AcceptCert, _In_ bool AsyncValidation ) @@ -820,6 +824,122 @@ QuicTestCustomCertificateValidation( } } +void +NoOpStreamShutdownCallback( + _In_ TestStream* Stream + ) +{ + UNREFERENCED_PARAMETER(Stream); +} + +void +NewStreamCallbackTestFail( + _In_ TestConnection* Connection, + _In_ HQUIC StreamHandle, + _In_ QUIC_STREAM_OPEN_FLAGS Flags + ) +{ + UNREFERENCED_PARAMETER(Connection); + UNREFERENCED_PARAMETER(Flags); + MsQuic->StreamClose(StreamHandle); + TEST_FAILURE("Unexpected new Stream received"); +} + +void +QuicTestCustomClientCertificateValidation( + _In_ bool AcceptCert, + _In_ bool AsyncValidation + ) +{ + MsQuicRegistration Registration; + TEST_TRUE(Registration.IsValid()); + + MsQuicAlpn Alpn("MsQuicTest"); + + MsQuicSettings Settings; + Settings.SetPeerBidiStreamCount(1); + Settings.SetIdleTimeoutMs(3000); + + MsQuicConfiguration ServerConfiguration(Registration, Alpn, Settings, ServerSelfSignedCredConfigClientAuth); + TEST_TRUE(ServerConfiguration.IsValid()); + + MsQuicConfiguration ClientConfiguration(Registration, Alpn, Settings, ClientCertCredConfig); + TEST_TRUE(ClientConfiguration.IsValid()); + + { + TestListener Listener(Registration, ListenerAcceptConnection, ServerConfiguration); + TEST_TRUE(Listener.IsValid()); + TEST_QUIC_SUCCEEDED(Listener.Start(Alpn)); + + QuicAddr ServerLocalAddr; + TEST_QUIC_SUCCEEDED(Listener.GetLocalAddr(ServerLocalAddr)); + + { + UniquePtr Server; + ServerAcceptContext ServerAcceptCtx((TestConnection**)&Server); + if (!AcceptCert) { + ServerAcceptCtx.ExpectedTransportCloseStatus = QUIC_STATUS_BAD_CERTIFICATE; + ServerAcceptCtx.NewStreamHandler = (void*)NewStreamCallbackTestFail; + } + ServerAcceptCtx.AsyncCustomCertValidation = AsyncValidation; + if (!AsyncValidation) { + ServerAcceptCtx.IsCustomCertValidationResultSet = true; + ServerAcceptCtx.CustomCertValidationResult = AcceptCert; + } + ServerAcceptCtx.AddExpectedClientCertValidationResult(QUIC_STATUS_CERT_UNTRUSTED_ROOT); + Listener.Context = &ServerAcceptCtx; + + { + TestConnection Client(Registration); + TEST_TRUE(Client.IsValid()); + + if (!AcceptCert) { + Client.SetExpectedTransportCloseStatus(QUIC_STATUS_BAD_CERTIFICATE); + } + + UniquePtr ClientStream( + TestStream::FromConnectionHandle( + Client.GetConnection(), + NoOpStreamShutdownCallback, + QUIC_STREAM_OPEN_FLAG_NONE)); + + TEST_QUIC_SUCCEEDED(ClientStream->Start(QUIC_STREAM_START_FLAG_IMMEDIATE)); + + TEST_QUIC_SUCCEEDED( + Client.Start( + ClientConfiguration, + QUIC_ADDRESS_FAMILY_UNSPEC, + QUIC_TEST_LOOPBACK_FOR_AF( + QuicAddrGetFamily(&ServerLocalAddr.SockAddr)), + ServerLocalAddr.GetPort())); + + if (!CxPlatEventWaitWithTimeout(ServerAcceptCtx.NewConnectionReady, TestWaitTimeout)) { + TEST_FAILURE("Timed out waiting for server accept."); + } + + if (AsyncValidation) { + CxPlatSleep(2000); + TEST_QUIC_SUCCEEDED(Server->SetCustomValidationResult(AcceptCert)); + } + + if (!Client.WaitForConnectionComplete()) { + return; + } + + if (AcceptCert) { // Server will be deleted on reject case, so can't validate. + TEST_NOT_EQUAL(nullptr, Server); + if (!Server->WaitForConnectionComplete()) { + return; + } + TEST_TRUE(Server->GetIsConnected()); + } + // In all cases, the client "connects", but in the rejection case, it gets disconnected. + TEST_TRUE(Client.GetIsConnected()); + } + } + } +} + void QuicTestConnectUnreachable( _In_ int Family diff --git a/src/test/lib/TestHelpers.h b/src/test/lib/TestHelpers.h index 195274f2c5..0e2aa36e75 100644 --- a/src/test/lib/TestHelpers.h +++ b/src/test/lib/TestHelpers.h @@ -79,6 +79,7 @@ class TestConnection; struct ServerAcceptContext { CXPLAT_EVENT NewConnectionReady; TestConnection** NewConnection; + void* NewStreamHandler{nullptr}; QUIC_STATUS ExpectedTransportCloseStatus{QUIC_STATUS_SUCCESS}; QUIC_STATUS ExpectedClientCertValidationResult[2]{}; uint32_t ExpectedClientCertValidationResultCount{0}; @@ -86,6 +87,9 @@ struct ServerAcceptContext { QUIC_PRIVATE_TRANSPORT_PARAMETER* TestTP{nullptr}; bool AsyncCustomTicketValidation{false}; QUIC_STATUS ExpectedCustomTicketValidationResult{QUIC_STATUS_SUCCESS}; + bool AsyncCustomCertValidation{false}; + bool IsCustomCertValidationResultSet{false}; + bool CustomCertValidationResult{false}; ServerAcceptContext(TestConnection** _NewConnection) : NewConnection(_NewConnection) { CxPlatEventInitialize(&NewConnectionReady, TRUE, FALSE);