Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for CPP-928, and trusted certs #493

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/cassandra.h
Original file line number Diff line number Diff line change
Expand Up @@ -4602,6 +4602,12 @@ cass_ssl_add_trusted_cert_n(CassSsl* ssl,
* common name or one of its subject alternative names. This implies the
* certificate is also present. Hostname resolution must also be enabled.
*
* Notes:
* - CASS_SSL_VERIFY_PEER_IDENTITY and CASS_SSL_VERIFY_PEER_IDENTITY_DNS are
* mutually exclusive options.
* - The certificate Common Name is only checked against the IP address or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The certificate Common Name is only checked against the IP address or hostname if there are no Subject Alternative Names in the certificate.

Maybe this is a bug. If it's present in a SAN or the CN the it should probably be valid? I'd need to dig into that more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From ssl_openssl_impl.cpp

  static Result match(X509* cert, const Address& address) {
    Result result = match_subject_alt_names_ipadd(cert, address);
    if (result == NO_SAN_PRESENT) {
      result = match_common_name_ipaddr(cert, address.hostname_or_address());
    }
    return result;
  }

  static Result match_dns(X509* cert, const String& hostname) {
    Result result = match_subject_alt_names_dns(cert, hostname);
    if (result == NO_SAN_PRESENT) {
      result = match_common_name_dns(cert, hostname);
    }
    return result;
  }

Those functions only return NO_SAN_PRESENT if X509_get_ext_d2i returns NULL, otherwise it then checks the SAN stack and returns either NO_MATCH, INVALID_CERT or MATCH.

* hostname if there are no Subject Alternative Names in the certificate.
*
* <b>Default:</b> CASS_SSL_VERIFY_PEER_CERT
*
* @public @memberof CassSsl
Expand Down
5 changes: 3 additions & 2 deletions src/address.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ Address::Address(const uint8_t* address, uint8_t address_length, int port)
}
}

Address::Address(const struct sockaddr* addr)
: family_(UNRESOLVED)
Address::Address(const struct sockaddr* addr, const String& server_name)
: server_name_(server_name)
, family_(UNRESOLVED)
, port_(0) {
if (addr->sa_family == AF_INET) {
const struct sockaddr_in* addr_in = reinterpret_cast<const struct sockaddr_in*>(addr);
Expand Down
2 changes: 1 addition & 1 deletion src/address.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Address : public Allocated {
Address(const Address& other, const String& server_name);
Address(const String& hostname_or_address, int port, const String& server_name = String());
Address(const uint8_t* address, uint8_t address_length, int port);
Address(const struct sockaddr* addr);
Address(const struct sockaddr* addr, const String& server_name);

bool equals(const Address& other, bool with_port = true) const;

Expand Down
5 changes: 3 additions & 2 deletions src/client_insights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ class StartupMessageHandler : public RefCounted<StartupMessageHandler> {
new MultiResolver(bind_callback(&StartupMessageHandler::on_resolve, this)));
}
resolver->resolve(connection_->loop(), contact_point.hostname_or_address(), port,
config_.resolve_timeout_ms());
config_.resolve_timeout_ms(), contact_point.server_name());
}
}

Expand Down Expand Up @@ -668,7 +668,8 @@ class StartupMessageHandler : public RefCounted<StartupMessageHandler> {
Address::SocketStorage name;
int namelen = sizeof(name);
if (uv_tcp_getsockname(tcp, name.addr(), &namelen) == 0) {
Address address(name.addr());
// Pass a blank server name as this is a temporary address.
Address address(name.addr(), String());
if (address.is_valid_and_resolved()) {
return address.to_string();
}
Expand Down
3 changes: 2 additions & 1 deletion src/cluster_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ CassError cass_cluster_set_contact_points_n(CassCluster* cluster, const char* co
explode(String(contact_points, contact_points_length), exploded);
for (Vector<String>::const_iterator it = exploded.begin(), end = exploded.end(); it != end;
++it) {
cluster->config().contact_points().push_back(Address(*it, -1));
// Treat the address string as the server name.
cluster->config().contact_points().push_back(Address(*it, -1, *it));
}
}
return CASS_OK;
Expand Down
4 changes: 2 additions & 2 deletions src/cluster_metadata_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ class DefaultClusterMetadataResolver : public ClusterMetadataResolver {
int port = it->port() <= 0 ? port_ : it->port();

if (it->is_resolved()) {
resolved_contact_points_.push_back(Address(it->hostname_or_address(), port));
resolved_contact_points_.push_back(Address(it->hostname_or_address(), port, it->server_name()));
} else {
if (!resolver_) {
resolver_.reset(
new MultiResolver(bind_callback(&DefaultClusterMetadataResolver::on_resolve, this)));
}
resolver_->resolve(loop, it->hostname_or_address(), port, resolve_timeout_ms_);
resolver_->resolve(loop, it->hostname_or_address(), port, resolve_timeout_ms_, it->server_name());
}
}

Expand Down
11 changes: 7 additions & 4 deletions src/resolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ class Resolver : public RefCounted<Resolver> {
SUCCESS
};

Resolver(const String& hostname, int port, const Callback& callback)
Resolver(const String& hostname, int port, const Callback& callback, const String& server_name)
: hostname_(hostname)
, server_name_(server_name)
, port_(port)
, status_(NEW)
, callback_(callback) {
Expand Down Expand Up @@ -139,7 +140,7 @@ class Resolver : public RefCounted<Resolver> {
bool init_addresses(struct addrinfo* res) {
bool status = false;
do {
Address address(res->ai_addr);
Address address(res->ai_addr, server_name_);
if (address.is_valid_and_resolved()) {
addresses_.push_back(address);
status = true;
Expand All @@ -153,6 +154,7 @@ class Resolver : public RefCounted<Resolver> {
uv_getaddrinfo_t req_;
Timer timer_;
String hostname_;
String server_name_;
int port_;
Status status_;
int uv_status_;
Expand All @@ -175,10 +177,11 @@ class MultiResolver : public RefCounted<MultiResolver> {
const Resolver::Vec& resolvers() { return resolvers_; }

void resolve(uv_loop_t* loop, const String& host, int port, uint64_t timeout,
struct addrinfo* hints = NULL) {
const String& server_name, struct addrinfo* hints = NULL) {
inc_ref();
Resolver::Ptr resolver(
new Resolver(host, port, bind_callback(&MultiResolver::on_resolve, this)));
new Resolver(host, port, bind_callback(&MultiResolver::on_resolve, this),
server_name));
resolver->resolve(loop, timeout, hints);
resolvers_.push_back(resolver);
remaining_++;
Expand Down
4 changes: 2 additions & 2 deletions src/socket_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ void SocketConnector::connect(uv_loop_t* loop) {
hostname_ = address_.hostname_or_address();

resolver_.reset(new Resolver(hostname_, address_.port(),
bind_callback(&SocketConnector::on_resolve, this)));
bind_callback(&SocketConnector::on_resolve, this),
address_.server_name()));
resolver_->resolve(loop, settings_.resolve_timeout_ms);
} else {
resolved_address_ = address_;

if (settings_.hostname_resolution_enabled) { // Run hostname resolution then connect.
name_resolver_.reset(
new NameResolver(address_, bind_callback(&SocketConnector::on_name_resolve, this)));
Expand Down
45 changes: 23 additions & 22 deletions src/ssl/ssl_openssl_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,22 +228,6 @@ static int SSL_CTX_use_certificate_chain_bio(SSL_CTX* ctx, BIO* in) {
return ret;
}

static X509* load_cert(const char* cert, size_t cert_size) {
BIO* bio = BIO_new_mem_buf(const_cast<char*>(cert), cert_size);
if (bio == NULL) {
return NULL;
}

X509* x509 = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL);
if (x509 == NULL) {
ssl_log_errors("Unable to load certificate");
}

BIO_free_all(bio);

return x509;
}

static EVP_PKEY* load_key(const char* key, size_t key_size, const char* password) {
BIO* bio = BIO_new_mem_buf(const_cast<char*>(key), key_size);
if (bio == NULL) {
Expand Down Expand Up @@ -489,8 +473,8 @@ void OpenSslSession::verify() {
return;
}
} else if (verify_flags_ &
CASS_SSL_VERIFY_PEER_IDENTITY_DNS) { // Match using hostnames (including wildcards)
switch (OpenSslVerifyIdentity::match_dns(peer_cert, hostname_)) {
CASS_SSL_VERIFY_PEER_IDENTITY_DNS) { // Match using the server name (including wildcards)
switch (OpenSslVerifyIdentity::match_dns(peer_cert, sni_server_name_)) {
case OpenSslVerifyIdentity::MATCH:
// Success
break;
Expand Down Expand Up @@ -556,13 +540,30 @@ SslSession* OpenSslContext::create_session(const Address& address, const String&
}

CassError OpenSslContext::add_trusted_cert(const char* cert, size_t cert_length) {
X509* x509 = load_cert(cert, cert_length);
if (x509 == NULL) {
BIO* bio = BIO_new_mem_buf(const_cast<char*>(cert), cert_length);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on loading the whole cert chain. Awesome!

if (bio == NULL) {
return CASS_ERROR_SSL_INVALID_CERT;
}

X509_STORE_add_cert(trusted_store_, x509);
X509_free(x509);
int num_certs = 0;

// Iterate over the bio, reading out as many certificates as possible.
for (X509* cert = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL);
cert != NULL;
cert = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL))
{
X509_STORE_add_cert(trusted_store_, cert);
X509_free(cert);
num_certs++;
}

BIO_free_all(bio);

// If no certificates were read from the bio, that is an error.
if (num_certs == 0) {
ssl_log_errors("Unable to load certificate(s)");
return CASS_ERROR_SSL_INVALID_CERT;
}

return CASS_OK;
}
Expand Down
6 changes: 4 additions & 2 deletions tests/src/unit/tests/test_address.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
#include <gtest/gtest.h>

#include "address.hpp"
#include "string.hpp"

using datastax::internal::core::Address;
using datastax::internal::core::AddressSet;
using datastax::String;

TEST(AddressUnitTest, FromString) {
EXPECT_TRUE(Address("127.0.0.1", 9042).is_resolved());
Expand Down Expand Up @@ -64,14 +66,14 @@ TEST(AddressUnitTest, CompareIPv6) {
TEST(AddressUnitTest, ToSockAddrIPv4) {
Address expected("127.0.0.1", 9042);
Address::SocketStorage storage;
Address actual(expected.to_sockaddr(&storage));
Address actual(expected.to_sockaddr(&storage), String());
EXPECT_EQ(expected, actual);
}

TEST(AddressUnitTest, ToSockAddrIPv6) {
Address expected("::1", 9042);
Address::SocketStorage storage;
Address actual(expected.to_sockaddr(&storage));
Address actual(expected.to_sockaddr(&storage), String());
EXPECT_EQ(expected, actual);
}

Expand Down
27 changes: 14 additions & 13 deletions tests/src/unit/tests/test_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class ResolverUnitTest : public LoopTest {
: status_(Resolver::NEW) {}

Resolver::Ptr create(const String& hostname, int port = 9042) {
// Use the hostname as the TLS server name.
return Resolver::Ptr(
new Resolver(hostname, port, bind_callback(&ResolverUnitTest::on_resolve, this)));
new Resolver(hostname, port, bind_callback(&ResolverUnitTest::on_resolve, this), hostname));
}

MultiResolver::Ptr create_multi() {
Expand Down Expand Up @@ -108,9 +109,9 @@ TEST_F(ResolverUnitTest, Cancel) {

TEST_F(ResolverUnitTest, Multi) {
MultiResolver::Ptr resolver(create_multi());
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
run_loop();
ASSERT_EQ(3u, resolvers().size());
for (Resolver::Vec::const_iterator it = resolvers().begin(), end = resolvers().end(); end != it;
Expand All @@ -130,9 +131,9 @@ TEST_F(ResolverUnitTest, MultiTimeout) {
starve_thread_pool(200);

// Use shortest possible timeout for all requests
resolver->resolve(loop(), "localhost", 9042, 1);
resolver->resolve(loop(), "localhost", 9042, 1);
resolver->resolve(loop(), "localhost", 9042, 1);
resolver->resolve(loop(), "localhost", 9042, 1, "localhost");
resolver->resolve(loop(), "localhost", 9042, 1, "localhost");
resolver->resolve(loop(), "localhost", 9042, 1, "localhost");

run_loop();
ASSERT_EQ(3u, resolvers().size());
Expand All @@ -145,9 +146,9 @@ TEST_F(ResolverUnitTest, MultiTimeout) {

TEST_F(ResolverUnitTest, MultiInvalid) {
MultiResolver::Ptr resolver(create_multi());
resolver->resolve(loop(), "doesnotexist1.dne", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "doesnotexist2.dne", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "doesnotexist3.dne", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "doesnotexist1.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist1.dne");
resolver->resolve(loop(), "doesnotexist2.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist2.dne");
resolver->resolve(loop(), "doesnotexist3.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist3.dne");
run_loop();
ASSERT_EQ(3u, resolvers().size());
for (Resolver::Vec::const_iterator it = resolvers().begin(), end = resolvers().end(); end != it;
Expand All @@ -159,9 +160,9 @@ TEST_F(ResolverUnitTest, MultiInvalid) {

TEST_F(ResolverUnitTest, MultiCancel) {
MultiResolver::Ptr resolver(create_multi());
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
resolver->cancel();
run_loop();
ASSERT_EQ(3u, resolvers().size());
Expand Down
3 changes: 2 additions & 1 deletion tests/src/unit/tests/test_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ class SocketUnitTest : public LoopTest {
} else {
bool match = false;
do {
Address address(res->ai_addr);
// Use a blank server name as it's not needed here.
Address address(res->ai_addr, String());
if (address.is_valid_and_resolved() && address == Address(DNS_IP_ADDRESS, 8888)) {
match = true;
break;
Expand Down