Skip to content

Commit

Permalink
Change port from uint16_t to uint32_t, to support VSOCK (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
graebm authored Dec 30, 2023
1 parent df64f57 commit 749c87e
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 74 deletions.
4 changes: 2 additions & 2 deletions include/aws/io/channel_bootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ struct aws_server_bootstrap {
struct aws_socket_channel_bootstrap_options {
struct aws_client_bootstrap *bootstrap;
const char *host_name;
uint16_t port;
uint32_t port;
const struct aws_socket_options *socket_options;
const struct aws_tls_connection_options *tls_options;
aws_client_bootstrap_on_channel_event_fn *creation_callback;
Expand Down Expand Up @@ -208,7 +208,7 @@ struct aws_socket_channel_bootstrap_options {
struct aws_server_socket_channel_bootstrap_options {
struct aws_server_bootstrap *bootstrap;
const char *host_name;
uint16_t port;
uint32_t port;
const struct aws_socket_options *socket_options;
const struct aws_tls_connection_options *tls_options;
aws_server_bootstrap_on_accept_channel_setup_fn *incoming_callback;
Expand Down
18 changes: 17 additions & 1 deletion include/aws/io/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ typedef void(aws_socket_on_readable_fn)(struct aws_socket *socket, int error_cod
#endif
struct aws_socket_endpoint {
char address[AWS_ADDRESS_MAX_LEN];
uint16_t port;
uint32_t port;
};

struct aws_socket {
Expand Down Expand Up @@ -302,6 +302,22 @@ AWS_IO_API int aws_socket_get_error(struct aws_socket *socket);
*/
AWS_IO_API bool aws_socket_is_open(struct aws_socket *socket);

/**
* Raises AWS_IO_SOCKET_INVALID_ADDRESS and logs an error if connecting to this port is illegal.
* For example, port must be in range 1-65535 to connect with IPv4.
* These port values would fail eventually in aws_socket_connect(),
* but you can use this function to validate earlier.
*/
AWS_IO_API int aws_socket_validate_port_for_connect(uint32_t port, enum aws_socket_domain domain);

/**
* Raises AWS_IO_SOCKET_INVALID_ADDRESS and logs an error if binding to this port is illegal.
* For example, port must in range 0-65535 to bind with IPv4.
* These port values would fail eventually in aws_socket_bind(),
* but you can use this function to validate earlier.
*/
AWS_IO_API int aws_socket_validate_port_for_bind(uint32_t port, enum aws_socket_domain domain);

/**
* Assigns a random address (UUID) for use with AWS_SOCKET_LOCAL (Unix Domain Sockets).
* For use in internal tests only.
Expand Down
12 changes: 6 additions & 6 deletions source/channel_bootstrap.c
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct client_connection_args {
aws_client_bootstrap_on_channel_event_fn *shutdown_callback;
struct client_channel_data channel_data;
struct aws_socket_options outgoing_options;
uint16_t outgoing_port;
uint32_t outgoing_port;
struct aws_string *host_name;
void *user_data;
uint8_t addresses_count;
Expand Down Expand Up @@ -764,14 +764,14 @@ int aws_client_bootstrap_new_socket_channel(struct aws_socket_channel_bootstrap_
}

const char *host_name = options->host_name;
uint16_t port = options->port;
uint32_t port = options->port;

AWS_LOGF_TRACE(
AWS_LS_IO_CHANNEL_BOOTSTRAP,
"id=%p: attempting to initialize a new client channel to %s:%d",
"id=%p: attempting to initialize a new client channel to %s:%u",
(void *)bootstrap,
host_name,
(int)port);
port);

aws_ref_count_init(
&client_connection_args->ref_count,
Expand Down Expand Up @@ -1363,10 +1363,10 @@ struct aws_socket *aws_server_bootstrap_new_socket_listener(
AWS_LOGF_INFO(
AWS_LS_IO_CHANNEL_BOOTSTRAP,
"id=%p: attempting to initialize a new "
"server socket listener for %s:%d",
"server socket listener for %s:%u",
(void *)bootstrap_options->bootstrap,
bootstrap_options->host_name,
(int)bootstrap_options->port);
bootstrap_options->port);

aws_ref_count_init(
&server_connection_args->ref_count,
Expand Down
53 changes: 25 additions & 28 deletions source/posix/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,7 @@ static int s_update_local_endpoint(struct aws_socket *socket) {
} else if (address.ss_family == AF_VSOCK) {
struct sockaddr_vm *s = (struct sockaddr_vm *)&address;

/* VSOCK port is 32bit, but aws_socket_endpoint.port is only 16bit.
* Hopefully this isn't an issue, since users can only pass in 16bit values.
* But if it becomes an issue, we'll need to make aws_socket_endpoint more flexible */
if (s->svm_port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: aws_socket_endpoint can't deal with VSOCK port > UINT16_MAX",
(void *)socket,
socket->io_handle.data.fd);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
tmp_endpoint.port = (uint16_t)s->svm_port;
tmp_endpoint.port = s->svm_port;

snprintf(tmp_endpoint.address, sizeof(tmp_endpoint.address), "%" PRIu32, s->svm_cid);
return AWS_OP_SUCCESS;
Expand Down Expand Up @@ -642,18 +631,22 @@ int aws_socket_connect(
return AWS_OP_ERR;
}

if (aws_socket_validate_port_for_connect(remote_endpoint->port, socket->options.domain)) {
return AWS_OP_ERR;
}

struct socket_address address;
AWS_ZERO_STRUCT(address);
socklen_t sock_size = 0;
int pton_err = 1;
if (socket->options.domain == AWS_SOCKET_IPV4) {
pton_err = inet_pton(AF_INET, remote_endpoint->address, &address.sock_addr_types.addr_in.sin_addr);
address.sock_addr_types.addr_in.sin_port = htons(remote_endpoint->port);
address.sock_addr_types.addr_in.sin_port = htons((uint16_t)remote_endpoint->port);
address.sock_addr_types.addr_in.sin_family = AF_INET;
sock_size = sizeof(address.sock_addr_types.addr_in);
} else if (socket->options.domain == AWS_SOCKET_IPV6) {
pton_err = inet_pton(AF_INET6, remote_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr);
address.sock_addr_types.addr_in6.sin6_port = htons(remote_endpoint->port);
address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)remote_endpoint->port);
address.sock_addr_types.addr_in6.sin6_family = AF_INET6;
sock_size = sizeof(address.sock_addr_types.addr_in6);
} else if (socket->options.domain == AWS_SOCKET_LOCAL) {
Expand All @@ -664,7 +657,7 @@ int aws_socket_connect(
} else if (socket->options.domain == AWS_SOCKET_VSOCK) {
pton_err = parse_cid(remote_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid);
address.sock_addr_types.vm_addr.svm_family = AF_VSOCK;
address.sock_addr_types.vm_addr.svm_port = (unsigned int)remote_endpoint->port;
address.sock_addr_types.vm_addr.svm_port = remote_endpoint->port;
sock_size = sizeof(address.sock_addr_types.vm_addr);
#endif
} else {
Expand All @@ -676,21 +669,21 @@ int aws_socket_connect(
int errno_value = errno; /* Always cache errno before potential side-effect */
AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: failed to parse address %s:%d.",
"id=%p fd=%d: failed to parse address %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
remote_endpoint->address,
(int)remote_endpoint->port);
remote_endpoint->port);
return aws_raise_error(s_convert_pton_error(pton_err, errno_value));
}

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: connecting to endpoint %s:%d.",
"id=%p fd=%d: connecting to endpoint %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
remote_endpoint->address,
(int)remote_endpoint->port);
remote_endpoint->port);

socket->state = CONNECTING;
socket->remote_endpoint = *remote_endpoint;
Expand Down Expand Up @@ -806,26 +799,30 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
return AWS_OP_ERR;
}

if (aws_socket_validate_port_for_bind(local_endpoint->port, socket->options.domain)) {
return AWS_OP_ERR;
}

AWS_LOGF_INFO(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: binding to %s:%d.",
"id=%p fd=%d: binding to %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
local_endpoint->address,
(int)local_endpoint->port);
local_endpoint->port);

struct socket_address address;
AWS_ZERO_STRUCT(address);
socklen_t sock_size = 0;
int pton_err = 1;
if (socket->options.domain == AWS_SOCKET_IPV4) {
pton_err = inet_pton(AF_INET, local_endpoint->address, &address.sock_addr_types.addr_in.sin_addr);
address.sock_addr_types.addr_in.sin_port = htons(local_endpoint->port);
address.sock_addr_types.addr_in.sin_port = htons((uint16_t)local_endpoint->port);
address.sock_addr_types.addr_in.sin_family = AF_INET;
sock_size = sizeof(address.sock_addr_types.addr_in);
} else if (socket->options.domain == AWS_SOCKET_IPV6) {
pton_err = inet_pton(AF_INET6, local_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr);
address.sock_addr_types.addr_in6.sin6_port = htons(local_endpoint->port);
address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)local_endpoint->port);
address.sock_addr_types.addr_in6.sin6_family = AF_INET6;
sock_size = sizeof(address.sock_addr_types.addr_in6);
} else if (socket->options.domain == AWS_SOCKET_LOCAL) {
Expand All @@ -836,7 +833,7 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
} else if (socket->options.domain == AWS_SOCKET_VSOCK) {
pton_err = parse_cid(local_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid);
address.sock_addr_types.vm_addr.svm_family = AF_VSOCK;
address.sock_addr_types.vm_addr.svm_port = (unsigned int)local_endpoint->port;
address.sock_addr_types.vm_addr.svm_port = local_endpoint->port;
sock_size = sizeof(address.sock_addr_types.vm_addr);
#endif
} else {
Expand All @@ -848,11 +845,11 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
int errno_value = errno; /* Always cache errno before potential side-effect */
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: failed to parse address %s:%d.",
"id=%p fd=%d: failed to parse address %s:%u.",
(void *)socket,
socket->io_handle.data.fd,
local_endpoint->address,
(int)local_endpoint->port);
local_endpoint->port);
return aws_raise_error(s_convert_pton_error(pton_err, errno_value));
}

Expand Down Expand Up @@ -882,7 +879,7 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: successfully bound to %s:%d",
"id=%p fd=%d: successfully bound to %s:%u",
(void *)socket,
socket->io_handle.data.fd,
socket->local_endpoint.address,
Expand Down Expand Up @@ -996,7 +993,7 @@ static void s_socket_accept_event(

new_sock->local_endpoint = socket->local_endpoint;
new_sock->state = CONNECTED_READ | CONNECTED_WRITE;
uint16_t port = 0;
uint32_t port = 0;

/* get the info on the incoming socket's address */
if (in_addr.ss_family == AF_INET) {
Expand Down
75 changes: 75 additions & 0 deletions source/socket_shared.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/io/socket.h>

#include <aws/io/logging.h>

/* common validation for connect() and bind() */
static int s_socket_validate_port_for_domain(uint32_t port, enum aws_socket_domain domain) {
switch (domain) {
case AWS_SOCKET_IPV4:
case AWS_SOCKET_IPV6:
if (port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"Invalid port=%u for %s. Cannot exceed 65535",
port,
domain == AWS_SOCKET_IPV4 ? "IPv4" : "IPv6");
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
break;

case AWS_SOCKET_LOCAL:
/* port is ignored */
break;

case AWS_SOCKET_VSOCK:
/* any 32bit port is legal */
break;

default:
AWS_LOGF_ERROR(AWS_LS_IO_SOCKET, "Cannot validate port for unknown domain=%d", domain);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
return AWS_OP_SUCCESS;
}

int aws_socket_validate_port_for_connect(uint32_t port, enum aws_socket_domain domain) {
if (s_socket_validate_port_for_domain(port, domain)) {
return AWS_OP_ERR;
}

/* additional validation */
switch (domain) {
case AWS_SOCKET_IPV4:
case AWS_SOCKET_IPV6:
if (port == 0) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"Invalid port=%u for %s connections. Must use 1-65535",
port,
domain == AWS_SOCKET_IPV4 ? "IPv4" : "IPv6");
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
break;

case AWS_SOCKET_VSOCK:
if (port == (uint32_t)-1) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET, "Invalid port for VSOCK connections. Cannot use VMADDR_PORT_ANY (-1U).");
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
break;

default:
/* no extra validation */
break;
}
return AWS_OP_SUCCESS;
}

int aws_socket_validate_port_for_bind(uint32_t port, enum aws_socket_domain domain) {
return s_socket_validate_port_for_domain(port, domain);
}
Loading

0 comments on commit 749c87e

Please sign in to comment.