Skip to content

Commit

Permalink
Add the endpoint construction logic
Browse files Browse the repository at this point in the history
  • Loading branch information
waahm7 committed May 6, 2024
1 parent 1db0c06 commit b94275c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 18 deletions.
21 changes: 20 additions & 1 deletion include/aws/auth/credentials.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,26 @@ struct aws_credentials_provider_sts_options {
*/
const struct aws_http_proxy_options *http_proxy_options;

/**
* (Optional)
* Use a cached config file profile collection (~/.aws/config). You can also pass a merged profile collection which
* contains both config file and credentials file.
* TODO: decide should I add override the path, we probably need it for tests.
*/
struct aws_profile_collection *profile_collection_cached;

/*
* (Optional)
* Override of what profile to use, if not set, 'default' will be used.
*/
struct aws_byte_cursor profile_name_override;

/*
* (Optional)
* Override path to the profile config file (~/.aws/config by default)
*/
struct aws_byte_cursor config_file_name_override;

struct aws_credentials_provider_shutdown_options shutdown_options;

/* For mocking, leave NULL otherwise */
Expand Down Expand Up @@ -555,7 +575,6 @@ struct aws_credentials_provider_chain_default_options {
* Use a cached merged profile collection. A merge collection has both config file
* (~/.aws/config) and credentials file based profile collection (~/.aws/credentials) using
* `aws_profile_collection_new_from_merge`.
* If this option is provided, `config_file_name_override` and `credentials_file_name_override` will be ignored.
*/
struct aws_profile_collection *profile_collection_cached;

Expand Down
59 changes: 45 additions & 14 deletions source/credentials_provider_sts.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <aws/auth/signing.h>
#include <aws/auth/signing_config.h>
#include <aws/auth/signing_result.h>
#include <aws/common/xml_parser.h>

#include <aws/common/clock.h>
#include <aws/common/string.h>
#include <aws/common/xml_parser.h>
#include <aws/sdkutils/aws_profile.h>

#include <aws/http/connection.h>
#include <aws/http/connection_manager.h>
Expand Down Expand Up @@ -40,11 +40,6 @@ static int s_sts_xml_on_AssumeRoleResponse_child(struct aws_xml_node *, void *);
static int s_sts_xml_on_AssumeRoleResult_child(struct aws_xml_node *, void *);
static int s_sts_xml_on_Credentials_child(struct aws_xml_node *, void *);

static struct aws_http_header s_host_header = {
.name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("host"),
.value = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("sts.amazonaws.com"),
};

static struct aws_http_header s_content_type_header = {
.name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("content-type"),
.value = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("application/x-www-form-urlencoded"),
Expand All @@ -53,7 +48,7 @@ static struct aws_http_header s_content_type_header = {
static struct aws_byte_cursor s_content_length = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("content-length");
static struct aws_byte_cursor s_path = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("/");
static struct aws_byte_cursor s_signing_region = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("us-east-1");
static struct aws_byte_cursor s_service_name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("sts");
AWS_STATIC_STRING_FROM_LITERAL(s_sts_service_name, "sts");
static const int s_max_retries = 8;

const uint16_t aws_sts_assume_role_default_duration_secs = 900;
Expand All @@ -62,6 +57,7 @@ struct aws_credentials_provider_sts_impl {
struct aws_http_connection_manager *connection_manager;
struct aws_string *assume_role_profile;
struct aws_string *role_session_name;
struct aws_string *endpoint;
uint16_t duration_seconds;
struct aws_credentials_provider *provider;
struct aws_credentials_provider_shutdown_options source_shutdown_options;
Expand Down Expand Up @@ -468,12 +464,17 @@ static void s_start_make_request(
struct aws_credentials_provider *provider,
struct sts_creds_provider_user_data *provider_user_data) {
provider_user_data->message = aws_http_message_new_request(provider->allocator);
struct aws_credentials_provider_sts_impl *impl = provider->impl;

if (!provider_user_data->message) {
goto error;
}
struct aws_http_header host_header = {
.name = aws_byte_cursor_from_c_str("host"),
.value = aws_byte_cursor_from_string(impl->endpoint),
};

if (aws_http_message_add_header(provider_user_data->message, s_host_header)) {
if (aws_http_message_add_header(provider_user_data->message, host_header)) {
goto error;
}

Expand Down Expand Up @@ -526,16 +527,14 @@ static void s_start_make_request(
goto error;
}

struct aws_credentials_provider_sts_impl *impl = provider->impl;

provider_user_data->signing_config.algorithm = AWS_SIGNING_ALGORITHM_V4;
provider_user_data->signing_config.signature_type = AWS_ST_HTTP_REQUEST_HEADERS;
provider_user_data->signing_config.signed_body_header = AWS_SBHT_NONE;
provider_user_data->signing_config.config_type = AWS_SIGNING_CONFIG_AWS;
provider_user_data->signing_config.credentials_provider = impl->provider;
aws_date_time_init_now(&provider_user_data->signing_config.date);
provider_user_data->signing_config.region = s_signing_region;
provider_user_data->signing_config.service = s_service_name;
provider_user_data->signing_config.service = aws_byte_cursor_from_string(s_sts_service_name);
provider_user_data->signing_config.flags.use_double_uri_encode = false;

if (aws_sign_request_aws(
Expand Down Expand Up @@ -647,6 +646,7 @@ static void s_on_credentials_provider_shutdown(void *user_data) {

aws_string_destroy(impl->role_session_name);
aws_string_destroy(impl->assume_role_profile);
aws_string_destroy(impl->endpoint);

aws_mem_release(provider->allocator, provider);
}
Expand Down Expand Up @@ -769,9 +769,40 @@ struct aws_credentials_provider *aws_credentials_provider_new_sts(
impl->provider = options->creds_provider;
aws_credentials_provider_acquire(impl->provider);

/*
* resolve the endpoint
*/
struct aws_profile_collection *profile_collection = NULL;
if (options->profile_collection_cached) {
profile_collection = aws_profile_collection_acquire(options->profile_collection_cached);
} else {
profile_collection =
aws_load_profile_collection_from_config_file(allocator, options->config_file_name_override);
}
struct aws_string *profile_name = NULL;
profile_name = aws_get_profile_name(allocator, &options->profile_name_override);

const struct aws_profile *profile = aws_profile_collection_get_profile(profile_collection, profile_name);

struct aws_string *region = aws_credentials_provider_resolve_region(allocator, profile);
if (region != NULL) {
struct aws_byte_buf endpoint;
AWS_ZERO_STRUCT(endpoint);
// construct endpoint
if (aws_credentials_provider_construct_endpoint(allocator, &endpoint, region, s_sts_service_name)) {
goto cleanup_provider;
}
impl->endpoint = aws_string_new_from_buf(allocator, &endpoint);
aws_byte_buf_clean_up(&endpoint);
} else {
// global endpoint
impl->endpoint = aws_string_new_from_c_str(allocator, "sts.amazonaws.com");
}
struct aws_byte_cursor endpoint_cursor = aws_byte_cursor_from_string(impl->endpoint);

aws_tls_connection_options_init_from_ctx(&tls_connection_options, options->tls_ctx);

if (aws_tls_connection_options_set_server_name(&tls_connection_options, allocator, &s_host_header.value)) {
if (aws_tls_connection_options_set_server_name(&tls_connection_options, allocator, &endpoint_cursor)) {
AWS_LOGF_ERROR(
AWS_LS_AUTH_CREDENTIALS_PROVIDER,
"(id=%p): failed to create a tls connection options with error %s",
Expand All @@ -788,7 +819,7 @@ struct aws_credentials_provider *aws_credentials_provider_new_sts(

struct aws_http_connection_manager_options connection_manager_options = {
.bootstrap = options->bootstrap,
.host = s_host_header.value,
.host = endpoint_cursor,
.initial_window_size = SIZE_MAX,
.max_connections = 2,
.port = 443,
Expand Down
6 changes: 3 additions & 3 deletions source/credentials_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

#include <aws/common/clock.h>
#include <aws/common/date_time.h>
#include <aws/common/environment.h>
#include <aws/common/json.h>
#include <aws/common/string.h>
#include <aws/common/uuid.h>
#include <aws/common/environment.h>
#include <aws/http/connection.h>
#include <aws/http/request_response.h>
#include <aws/http/status_code.h>
Expand Down Expand Up @@ -416,15 +416,15 @@ struct aws_string *aws_credentials_provider_resolve_region(
const struct aws_profile *profile) {
AWS_PRECONDITION(allocator);
AWS_PRECONDITION(profile);

/* check environment variable */
struct aws_string *region = NULL;
aws_get_environment_value(allocator, s_region_env, &region);

if (region != NULL && region->len > 0) {
return region;
}

/* check the config file */
const struct aws_profile_property *property = aws_profile_get_property(profile, s_region_config);
if (property) {
Expand Down

0 comments on commit b94275c

Please sign in to comment.