From 21377c19d6755aca52be53bc0242596a9e599dcf Mon Sep 17 00:00:00 2001 From: ouyangyadong Date: Mon, 9 Dec 2019 15:31:49 +0800 Subject: [PATCH] quic: ensure callbacks of QuicSocket.connect() get called 1. The callbacks of QuicSocket.connect() won't get called after QuicSocket binding. To fix this issue, This PR calls QuicSession[kReady]() directly when QuicSocket bound. 2. This PR also modify SocketAddress::Hash and SocketAddress::Compare to accept struct values instead of pointers because SocketAddress might get freed firstly which would cause the values aren't safe to use. For example, the test added in this PR is likely to abort before this PR. --- lib/internal/quic/core.js | 3 + src/node_crypto.h | 2 + src/node_quic_session.cc | 2 +- src/node_quic_socket.cc | 32 +++++---- src/node_quic_socket.h | 10 +-- src/node_sockaddr-inl.h | 23 ++++-- src/node_sockaddr.h | 7 +- .../test-quic-client-connect-callback.js | 72 +++++++++++++++++++ 8 files changed, 124 insertions(+), 27 deletions(-) create mode 100644 test/parallel/test-quic-client-connect-callback.js diff --git a/lib/internal/quic/core.js b/lib/internal/quic/core.js index cb9a0c6eb9..5dfbf810f4 100644 --- a/lib/internal/quic/core.js +++ b/lib/internal/quic/core.js @@ -1063,6 +1063,9 @@ class QuicSocket extends EventEmitter { if (typeof callback === 'function') session.on('ready', callback); + if (this.bound) + session[kReady](); + this[kMaybeBind](connectAfterBind.bind( this, session, diff --git a/src/node_crypto.h b/src/node_crypto.h index 40a234bd4a..f9df6a26ef 100644 --- a/src/node_crypto.h +++ b/src/node_crypto.h @@ -29,6 +29,8 @@ #include "env.h" #include "base_object.h" +// TODO do not included +#include "base_object-inl.h" #include "util.h" #include "v8.h" diff --git a/src/node_quic_session.cc b/src/node_quic_session.cc index 8fe1451a67..d8bdd89fe9 100644 --- a/src/node_quic_session.cc +++ b/src/node_quic_session.cc @@ -2050,7 +2050,7 @@ void QuicSession::RemoveFromSocket() { socket_->DisassociateCID(QuicCID(&cid)); Debug(this, "Removed from the QuicSocket."); - socket_->RemoveSession(QuicCID(scid_), **GetRemoteAddress()); + socket_->RemoveSession(QuicCID(scid_), GetRemoteAddress()->GetSockaddrStorage()); socket_.reset(); } diff --git a/src/node_quic_socket.cc b/src/node_quic_socket.cc index c6de5f33f6..8773eb13a6 100644 --- a/src/node_quic_socket.cc +++ b/src/node_quic_socket.cc @@ -247,7 +247,7 @@ void QuicSocket::AddSession( const QuicCID& cid, BaseObjectPtr session) { sessions_[cid.ToStr()] = session; - IncrementSocketAddressCounter(**session->GetRemoteAddress()); + IncrementSocketAddressCounter(session->GetRemoteAddress()->GetSockaddrStorage()); IncrementSocketStat( 1, &socket_stats_, session->IsServer() ? @@ -485,7 +485,7 @@ int QuicSocket::ReceiveStop() { return udp_->RecvStop(); } -void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr* addr) { +void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr_storage* addr) { sessions_.erase(cid.ToStr()); DecrementSocketAddressCounter(addr); } @@ -659,7 +659,7 @@ namespace { void QuicSocket::SetValidatedAddress(const sockaddr* addr) { if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) { // Remove the oldest item if we've hit the LRU limit - validated_addrs_.push_back(addr_hash(addr)); + validated_addrs_.push_back(addr_hash(*addr)); if (validated_addrs_.size() > MAX_VALIDATE_ADDRESS_LRU) validated_addrs_.pop_front(); } @@ -669,7 +669,7 @@ bool QuicSocket::IsValidatedAddress(const sockaddr* addr) const { if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) { auto res = std::find(std::begin(validated_addrs_), std::end(validated_addrs_), - addr_hash(addr)); + addr_hash(*addr)); return res != std::end(validated_addrs_); } return false; @@ -721,9 +721,13 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( // Check to see if the number of connections for this peer has been exceeded. // If the count has been exceeded, shutdown the connection immediately // after the initial keys are installed. - if (GetCurrentSocketAddressCounter(addr) >= max_connections_per_host_) { - Debug(this, "Connection count for address exceeded"); - initial_connection_close = NGTCP2_SERVER_BUSY; + { + sockaddr_storage storage; + memcpy(&storage, addr, SocketAddress::GetLength(addr)); + if (GetCurrentSocketAddressCounter(&storage) >= max_connections_per_host_) { + Debug(this, "Connection count for address exceeded"); + initial_connection_close = NGTCP2_SERVER_BUSY; + } } // QUIC has address validation built in to the handshake but allows for @@ -782,22 +786,22 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( return session; } -void QuicSocket::IncrementSocketAddressCounter(const sockaddr* addr) { - addr_counts_[addr]++; +void QuicSocket::IncrementSocketAddressCounter(const sockaddr_storage* addr) { + addr_counts_[*addr]++; } -void QuicSocket::DecrementSocketAddressCounter(const sockaddr* addr) { - auto it = addr_counts_.find(addr); +void QuicSocket::DecrementSocketAddressCounter(const sockaddr_storage* addr) { + auto it = addr_counts_.find(*addr); if (it == std::end(addr_counts_)) return; it->second--; // Remove the address if the counter reaches zero again. if (it->second == 0) - addr_counts_.erase(addr); + addr_counts_.erase(*addr); } -size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr* addr) { - auto it = addr_counts_.find(addr); +size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr_storage* addr) { + auto it = addr_counts_.find(*addr); if (it == std::end(addr_counts_)) return 0; return it->second; diff --git a/src/node_quic_socket.h b/src/node_quic_socket.h index 4550490101..cd24c36d4d 100644 --- a/src/node_quic_socket.h +++ b/src/node_quic_socket.h @@ -133,7 +133,7 @@ class QuicSocket : public AsyncWrap, int ReceiveStop(); void RemoveSession( const QuicCID& cid, - const sockaddr* addr); + const sockaddr_storage* addr); void ReportSendError( int error); int SendPacket( @@ -233,9 +233,9 @@ class QuicSocket : public AsyncWrap, const struct sockaddr* addr, unsigned int flags); - void IncrementSocketAddressCounter(const sockaddr* addr); - void DecrementSocketAddressCounter(const sockaddr* addr); - size_t GetCurrentSocketAddressCounter(const sockaddr* addr); + void IncrementSocketAddressCounter(const sockaddr_storage* addr); + void DecrementSocketAddressCounter(const sockaddr_storage* addr); + size_t GetCurrentSocketAddressCounter(const sockaddr_storage* addr); void IncrementPendingCallbacks() { pending_callbacks_++; } void DecrementPendingCallbacks() { pending_callbacks_--; } @@ -315,7 +315,7 @@ class QuicSocket : public AsyncWrap, // value reaches the value of max_connections_per_host_, // attempts to create new connections will be ignored // until the value falls back below the limit. - std::unordered_map addr_counts_; // The validated_addrs_ vector is used as an LRU cache for diff --git a/src/node_sockaddr-inl.h b/src/node_sockaddr-inl.h index 24dd9a9e9a..e680c0cf49 100644 --- a/src/node_sockaddr-inl.h +++ b/src/node_sockaddr-inl.h @@ -25,7 +25,7 @@ inline void hash_combine(size_t* seed, const T& value, Args... rest) { } } // namespace -size_t SocketAddress::Hash::operator()(const sockaddr* addr) const { +static size_t GetHash(const sockaddr* addr) { size_t hash = 0; switch (addr->sa_family) { case AF_INET: { @@ -48,11 +48,20 @@ size_t SocketAddress::Hash::operator()(const sockaddr* addr) const { return hash; } +size_t SocketAddress::Hash::operator()(const sockaddr& addr) const { + return GetHash(&addr); +} + +size_t SocketAddress::Hash::operator()(const sockaddr_storage& addr_storage) const { + const sockaddr* addr = reinterpret_cast(&addr_storage); + return GetHash(addr); +} + bool SocketAddress::Compare::operator()( - const sockaddr* laddr, - const sockaddr* raddr) const { - CHECK(laddr->sa_family == AF_INET || laddr->sa_family == AF_INET6); - return memcmp(laddr, raddr, GetLength(laddr)) == 0; + const sockaddr_storage& laddr, + const sockaddr_storage& raddr) const { + CHECK(laddr.ss_family == AF_INET || laddr.ss_family == AF_INET6); + return memcmp(&laddr, &raddr, GetLength(&laddr)) == 0; } bool SocketAddress::is_numeric_host(const char* hostname) { @@ -146,6 +155,10 @@ const sockaddr* SocketAddress::operator*() const { return reinterpret_cast(&address_); } +const sockaddr_storage* SocketAddress::GetSockaddrStorage() const { + return &address_; +} + size_t SocketAddress::GetLength() const { return GetLength(&address_); } diff --git a/src/node_sockaddr.h b/src/node_sockaddr.h index 1ac5cbbdd4..86e13f5eb0 100644 --- a/src/node_sockaddr.h +++ b/src/node_sockaddr.h @@ -15,11 +15,12 @@ namespace node { class SocketAddress { public: struct Hash { - inline size_t operator()(const sockaddr* addr) const; + inline size_t operator()(const sockaddr& addr) const; + inline size_t operator()(const sockaddr_storage& addr_storage) const; }; struct Compare { - inline bool operator()(const sockaddr* laddr, const sockaddr* raddr) const; + inline bool operator()(const sockaddr_storage& laddr, const sockaddr_storage& raddr) const; }; inline static bool is_numeric_host(const char* hostname); @@ -56,6 +57,8 @@ class SocketAddress { inline const sockaddr* operator*() const; + inline const sockaddr_storage* GetSockaddrStorage() const; + inline size_t GetLength() const; inline int GetFamily() const; diff --git a/test/parallel/test-quic-client-connect-callback.js b/test/parallel/test-quic-client-connect-callback.js new file mode 100644 index 0000000000..ed2f7e877c --- /dev/null +++ b/test/parallel/test-quic-client-connect-callback.js @@ -0,0 +1,72 @@ +'use strict'; + +const common = require('../common'); +if (!common.hasQuic) + common.skip('missing quic'); + +const { createSocket } = require('quic'); +const fixtures = require('../common/fixtures'); +const Countdown = require('../common/countdown'); +const key = fixtures.readKey('agent1-key.pem', 'binary'); +const cert = fixtures.readKey('agent1-cert.pem', 'binary'); +const ca = fixtures.readKey('ca1-cert.pem', 'binary'); + +const kServerName = 'agent2'; +const kALPN = 'zzz'; +const kIdleTimeout = 0; +const kConnections = 5; + +// After QuicSocket bound, the callback of QuicSocket.connect() +// should still get called. +{ + let client; + const server = createSocket({ + port: 0, + }); + + server.listen({ + key, + cert, + ca, + alpn: kALPN, + idleTimeout: kIdleTimeout, + }); + + const countdown = new Countdown(kConnections, () => { + client.close(); + server.close(); + }); + + server.on('ready', common.mustCall(() => { + const options = { + key, + cert, + ca, + address: common.localhostIPv4, + port: server.address.port, + servername: kServerName, + alpn: kALPN, + idleTimeout: kIdleTimeout, + }; + + client = createSocket({ + port: 0, + }); + + const session = client.connect(options, common.mustCall(() => { + session.close(common.mustCall(() => { + // After a session being ready, the socket should have bound + // and we could start the test. + testConnections(); + })); + })); + + const testConnections = common.mustCall(() => { + for (let i = 0; i < kConnections; i += 1) { + client.connect(options, common.mustCall(() => { + countdown.dec(); + })); + } + }); + })); +}