Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
quic: ensure callbacks of QuicSocket.connect() get called
Browse files Browse the repository at this point in the history
1. The callbacks of QuicSocket.connect() won't get called after
QuicSocket binding. To fix this issue, This PR calls
QuicSession[kReady]() directly when QuicSocket bound.

2. This PR also modify SocketAddress::Hash and SocketAddress::Compare
to accept struct values instead of pointers because SocketAddress might
get freed firstly which would cause the values aren't safe to use.
For example, the test added in this PR is likely to abort before this PR.
  • Loading branch information
sigma-zer0 committed Dec 9, 2019
1 parent 1c316ac commit 21377c1
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 27 deletions.
3 changes: 3 additions & 0 deletions lib/internal/quic/core.js
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,9 @@ class QuicSocket extends EventEmitter {
if (typeof callback === 'function')
session.on('ready', callback);

if (this.bound)
session[kReady]();

this[kMaybeBind](connectAfterBind.bind(
this,
session,
Expand Down
2 changes: 2 additions & 0 deletions src/node_crypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include "env.h"
#include "base_object.h"
// TODO do not included
#include "base_object-inl.h"
#include "util.h"

#include "v8.h"
Expand Down
2 changes: 1 addition & 1 deletion src/node_quic_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2050,7 +2050,7 @@ void QuicSession::RemoveFromSocket() {
socket_->DisassociateCID(QuicCID(&cid));

Debug(this, "Removed from the QuicSocket.");
socket_->RemoveSession(QuicCID(scid_), **GetRemoteAddress());
socket_->RemoveSession(QuicCID(scid_), GetRemoteAddress()->GetSockaddrStorage());
socket_.reset();
}

Expand Down
32 changes: 18 additions & 14 deletions src/node_quic_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ void QuicSocket::AddSession(
const QuicCID& cid,
BaseObjectPtr<QuicSession> session) {
sessions_[cid.ToStr()] = session;
IncrementSocketAddressCounter(**session->GetRemoteAddress());
IncrementSocketAddressCounter(session->GetRemoteAddress()->GetSockaddrStorage());
IncrementSocketStat(
1, &socket_stats_,
session->IsServer() ?
Expand Down Expand Up @@ -485,7 +485,7 @@ int QuicSocket::ReceiveStop() {
return udp_->RecvStop();
}

void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr* addr) {
void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr_storage* addr) {
sessions_.erase(cid.ToStr());
DecrementSocketAddressCounter(addr);
}
Expand Down Expand Up @@ -659,7 +659,7 @@ namespace {
void QuicSocket::SetValidatedAddress(const sockaddr* addr) {
if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) {
// Remove the oldest item if we've hit the LRU limit
validated_addrs_.push_back(addr_hash(addr));
validated_addrs_.push_back(addr_hash(*addr));
if (validated_addrs_.size() > MAX_VALIDATE_ADDRESS_LRU)
validated_addrs_.pop_front();
}
Expand All @@ -669,7 +669,7 @@ bool QuicSocket::IsValidatedAddress(const sockaddr* addr) const {
if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) {
auto res = std::find(std::begin(validated_addrs_),
std::end(validated_addrs_),
addr_hash(addr));
addr_hash(*addr));
return res != std::end(validated_addrs_);
}
return false;
Expand Down Expand Up @@ -721,9 +721,13 @@ BaseObjectPtr<QuicSession> QuicSocket::AcceptInitialPacket(
// Check to see if the number of connections for this peer has been exceeded.
// If the count has been exceeded, shutdown the connection immediately
// after the initial keys are installed.
if (GetCurrentSocketAddressCounter(addr) >= max_connections_per_host_) {
Debug(this, "Connection count for address exceeded");
initial_connection_close = NGTCP2_SERVER_BUSY;
{
sockaddr_storage storage;
memcpy(&storage, addr, SocketAddress::GetLength(addr));
if (GetCurrentSocketAddressCounter(&storage) >= max_connections_per_host_) {
Debug(this, "Connection count for address exceeded");
initial_connection_close = NGTCP2_SERVER_BUSY;
}
}

// QUIC has address validation built in to the handshake but allows for
Expand Down Expand Up @@ -782,22 +786,22 @@ BaseObjectPtr<QuicSession> QuicSocket::AcceptInitialPacket(
return session;
}

void QuicSocket::IncrementSocketAddressCounter(const sockaddr* addr) {
addr_counts_[addr]++;
void QuicSocket::IncrementSocketAddressCounter(const sockaddr_storage* addr) {
addr_counts_[*addr]++;
}

void QuicSocket::DecrementSocketAddressCounter(const sockaddr* addr) {
auto it = addr_counts_.find(addr);
void QuicSocket::DecrementSocketAddressCounter(const sockaddr_storage* addr) {
auto it = addr_counts_.find(*addr);
if (it == std::end(addr_counts_))
return;
it->second--;
// Remove the address if the counter reaches zero again.
if (it->second == 0)
addr_counts_.erase(addr);
addr_counts_.erase(*addr);
}

size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr* addr) {
auto it = addr_counts_.find(addr);
size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr_storage* addr) {
auto it = addr_counts_.find(*addr);
if (it == std::end(addr_counts_))
return 0;
return it->second;
Expand Down
10 changes: 5 additions & 5 deletions src/node_quic_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class QuicSocket : public AsyncWrap,
int ReceiveStop();
void RemoveSession(
const QuicCID& cid,
const sockaddr* addr);
const sockaddr_storage* addr);
void ReportSendError(
int error);
int SendPacket(
Expand Down Expand Up @@ -233,9 +233,9 @@ class QuicSocket : public AsyncWrap,
const struct sockaddr* addr,
unsigned int flags);

void IncrementSocketAddressCounter(const sockaddr* addr);
void DecrementSocketAddressCounter(const sockaddr* addr);
size_t GetCurrentSocketAddressCounter(const sockaddr* addr);
void IncrementSocketAddressCounter(const sockaddr_storage* addr);
void DecrementSocketAddressCounter(const sockaddr_storage* addr);
size_t GetCurrentSocketAddressCounter(const sockaddr_storage* addr);

void IncrementPendingCallbacks() { pending_callbacks_++; }
void DecrementPendingCallbacks() { pending_callbacks_--; }
Expand Down Expand Up @@ -315,7 +315,7 @@ class QuicSocket : public AsyncWrap,
// value reaches the value of max_connections_per_host_,
// attempts to create new connections will be ignored
// until the value falls back below the limit.
std::unordered_map<const sockaddr*, size_t, SocketAddress::Hash,
std::unordered_map<const sockaddr_storage, size_t, SocketAddress::Hash,
SocketAddress::Compare> addr_counts_;

// The validated_addrs_ vector is used as an LRU cache for
Expand Down
23 changes: 18 additions & 5 deletions src/node_sockaddr-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ inline void hash_combine(size_t* seed, const T& value, Args... rest) {
}
} // namespace

size_t SocketAddress::Hash::operator()(const sockaddr* addr) const {
static size_t GetHash(const sockaddr* addr) {
size_t hash = 0;
switch (addr->sa_family) {
case AF_INET: {
Expand All @@ -48,11 +48,20 @@ size_t SocketAddress::Hash::operator()(const sockaddr* addr) const {
return hash;
}

size_t SocketAddress::Hash::operator()(const sockaddr& addr) const {
return GetHash(&addr);
}

size_t SocketAddress::Hash::operator()(const sockaddr_storage& addr_storage) const {
const sockaddr* addr = reinterpret_cast<const sockaddr*>(&addr_storage);
return GetHash(addr);
}

bool SocketAddress::Compare::operator()(
const sockaddr* laddr,
const sockaddr* raddr) const {
CHECK(laddr->sa_family == AF_INET || laddr->sa_family == AF_INET6);
return memcmp(laddr, raddr, GetLength(laddr)) == 0;
const sockaddr_storage& laddr,
const sockaddr_storage& raddr) const {
CHECK(laddr.ss_family == AF_INET || laddr.ss_family == AF_INET6);
return memcmp(&laddr, &raddr, GetLength(&laddr)) == 0;
}

bool SocketAddress::is_numeric_host(const char* hostname) {
Expand Down Expand Up @@ -146,6 +155,10 @@ const sockaddr* SocketAddress::operator*() const {
return reinterpret_cast<const sockaddr*>(&address_);
}

const sockaddr_storage* SocketAddress::GetSockaddrStorage() const {
return &address_;
}

size_t SocketAddress::GetLength() const {
return GetLength(&address_);
}
Expand Down
7 changes: 5 additions & 2 deletions src/node_sockaddr.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ namespace node {
class SocketAddress {
public:
struct Hash {
inline size_t operator()(const sockaddr* addr) const;
inline size_t operator()(const sockaddr& addr) const;
inline size_t operator()(const sockaddr_storage& addr_storage) const;
};

struct Compare {
inline bool operator()(const sockaddr* laddr, const sockaddr* raddr) const;
inline bool operator()(const sockaddr_storage& laddr, const sockaddr_storage& raddr) const;
};

inline static bool is_numeric_host(const char* hostname);
Expand Down Expand Up @@ -56,6 +57,8 @@ class SocketAddress {

inline const sockaddr* operator*() const;

inline const sockaddr_storage* GetSockaddrStorage() const;

inline size_t GetLength() const;

inline int GetFamily() const;
Expand Down
72 changes: 72 additions & 0 deletions test/parallel/test-quic-client-connect-callback.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
'use strict';

const common = require('../common');
if (!common.hasQuic)
common.skip('missing quic');

const { createSocket } = require('quic');
const fixtures = require('../common/fixtures');
const Countdown = require('../common/countdown');
const key = fixtures.readKey('agent1-key.pem', 'binary');
const cert = fixtures.readKey('agent1-cert.pem', 'binary');
const ca = fixtures.readKey('ca1-cert.pem', 'binary');

const kServerName = 'agent2';
const kALPN = 'zzz';
const kIdleTimeout = 0;
const kConnections = 5;

// After QuicSocket bound, the callback of QuicSocket.connect()
// should still get called.
{
let client;
const server = createSocket({
port: 0,
});

server.listen({
key,
cert,
ca,
alpn: kALPN,
idleTimeout: kIdleTimeout,
});

const countdown = new Countdown(kConnections, () => {
client.close();
server.close();
});

server.on('ready', common.mustCall(() => {
const options = {
key,
cert,
ca,
address: common.localhostIPv4,
port: server.address.port,
servername: kServerName,
alpn: kALPN,
idleTimeout: kIdleTimeout,
};

client = createSocket({
port: 0,
});

const session = client.connect(options, common.mustCall(() => {
session.close(common.mustCall(() => {
// After a session being ready, the socket should have bound
// and we could start the test.
testConnections();
}));
}));

const testConnections = common.mustCall(() => {
for (let i = 0; i < kConnections; i += 1) {
client.connect(options, common.mustCall(() => {
countdown.dec();
}));
}
});
}));
}

0 comments on commit 21377c1

Please sign in to comment.