Skip to content

Commit

Permalink
AT-765 Refactor the connection function (#64)
Browse files Browse the repository at this point in the history
* Refactored connect functions into smaller functions
  • Loading branch information
jerrytfleung authored May 19, 2021
1 parent 389dc1e commit 6130455
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 49 deletions.
16 changes: 16 additions & 0 deletions src/odfesqlodbc/aad_credentials_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ std::string AADCredentialsProvider::GetAADAccessToken() const {

Aws::String AADCredentialsProvider::GetSAMLAssertion() {
std::string access_token = GetAADAccessToken();
// Microsoft Azure AD doesn't send tail padding to us,
// the size of the access_token may not be a multiple of 4.
// While AWS::Utils::Base64 is expecting the size of the encoded is a multiple of 4.
// we need to pad ourself for the AWS::Utils::Base64 decoder
auto mod = access_token.size() % 4;
switch (mod) {
case 1:
access_token += "===";
break;
case 2:
access_token += "==";
break;
case 3:
access_token += "=";
break;
}
// Base64URL decode
auto decode_buffer = BASE64_URL.Decode(access_token);
auto size = decode_buffer.GetLength();
Expand Down
118 changes: 76 additions & 42 deletions src/odfesqlodbc/ts_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,60 @@
#include <aws/timestream-query/model/QueryRequest.h>
// clang-format on

namespace {
typedef std::function< std::unique_ptr< Aws::TimestreamQuery::TimestreamQueryClient >(
const runtime_options& options,
const Aws::Client::ClientConfiguration& config) > QueryClientCreator;

QueryClientCreator profile =
[](const runtime_options& options,
const Aws::Client::ClientConfiguration& config) {
if (!options.auth.profile_name.empty()) {
auto cp = std::make_shared<
Aws::Auth::ProfileConfigFileAWSCredentialsProvider >(
options.auth.profile_name.c_str());
return std::make_unique<
Aws::TimestreamQuery::TimestreamQueryClient >(cp, config);
} else {
return std::make_unique<
Aws::TimestreamQuery::TimestreamQueryClient >(config);
}
};

QueryClientCreator iam =
[](const runtime_options& options,
const Aws::Client::ClientConfiguration& config) {
Aws::Auth::AWSCredentials credentials(
options.auth.uid, options.auth.pwd, options.auth.session_token);
return std::make_unique< Aws::TimestreamQuery::TimestreamQueryClient >(
credentials, config);
};

QueryClientCreator aad =
[](const runtime_options& options,
const Aws::Client::ClientConfiguration& config) {
auto credential_provider =
std::make_unique< AADCredentialsProvider >(options.auth);
return std::make_unique< Aws::TimestreamQuery::TimestreamQueryClient >(credential_provider->GetAWSCredentials(), config);
};

QueryClientCreator okta =
[](const runtime_options& options,
const Aws::Client::ClientConfiguration& config) {
auto credential_provider =
std::make_unique< OktaCredentialsProvider >(options.auth);
return std::make_unique< Aws::TimestreamQuery::TimestreamQueryClient >(
credential_provider->GetAWSCredentials(), config);
};

std::unordered_map< std::string, QueryClientCreator > creators = {
{AUTHTYPE_AWS_PROFILE, profile},
{AUTHTYPE_IAM, iam},
{AUTHTYPE_AAD, aad},
{AUTHTYPE_OKTA, okta},
};
};

bool TSCommunication::Validate(const runtime_options& options) {
if (options.auth.region.empty() && options.auth.end_point_override.empty()) {
throw std::invalid_argument("Both region and end point cannot be empty.");
Expand All @@ -53,7 +107,7 @@ bool TSCommunication::Validate(const runtime_options& options) {
return true;
}

bool TSCommunication::Connect(const runtime_options& options) {
std::unique_ptr< Aws::TimestreamQuery::TimestreamQueryClient > TSCommunication::CreateQueryClient(const runtime_options& options) {
Aws::Client::ClientConfiguration config;
if (!options.auth.end_point_override.empty()) {
config.endpointOverride = options.auth.end_point_override;
Expand Down Expand Up @@ -83,41 +137,24 @@ bool TSCommunication::Connect(const runtime_options& options) {
config.maxConnections = max_connections;
}
if (!options.conn.max_retry_count_client.empty()) {
long max_retry_count_client = std::stol(options.conn.max_retry_count_client);
long max_retry_count_client =
std::stol(options.conn.max_retry_count_client);
if (max_retry_count_client < 0) {
throw std::invalid_argument("Max retry count client cannot be negative.");
}
config.retryStrategy = std::make_shared< Aws::Client::DefaultRetryStrategy >(max_retry_count_client);
}
if (options.auth.auth_type == AUTHTYPE_AWS_PROFILE) {
if (!options.auth.profile_name.empty()) {
auto cp = std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>(options.auth.profile_name.c_str());
m_client =
std::make_unique< Aws::TimestreamQuery::TimestreamQueryClient >(cp, config);
} else {
m_client =
std::make_unique< Aws::TimestreamQuery::TimestreamQueryClient >(config);
throw std::invalid_argument(
"Max retry count client cannot be negative.");
}
} else if (options.auth.auth_type == AUTHTYPE_IAM) {
Aws::Auth::AWSCredentials credentials(options.auth.uid,
options.auth.pwd, options.auth.session_token);
m_client =
std::make_unique< Aws::TimestreamQuery::TimestreamQueryClient >(
credentials, config);
} else if (options.auth.auth_type == AUTHTYPE_AAD) {
m_client = CreateQueryClientWithIdp(
std::make_unique< AADCredentialsProvider >(options.auth), config);
} else if (options.auth.auth_type == AUTHTYPE_OKTA) {
m_client = CreateQueryClientWithIdp(
std::make_unique< OktaCredentialsProvider >(options.auth), config);
} else {
throw std::runtime_error("Unknown auth type: " + options.auth.auth_type);
config.retryStrategy =
std::make_shared< Aws::Client::DefaultRetryStrategy >(
max_retry_count_client);
}

if (m_client == nullptr) {
throw std::runtime_error("Unable to create TimestreamQueryClient.");
if (creators.find(options.auth.auth_type) == creators.end()) {
throw std::runtime_error("Unknown auth type: "
+ options.auth.auth_type);
}
return creators[options.auth.auth_type](options, config);
}

bool TSCommunication::TestQueryClient() {
Aws::TimestreamQuery::Model::QueryRequest req;
req.SetQueryString("select 1");
auto outcome = m_client->Query(req);
Expand All @@ -131,6 +168,14 @@ bool TSCommunication::Connect(const runtime_options& options) {
return true;
}

bool TSCommunication::Connect(const runtime_options& options) {
m_client = CreateQueryClient(options);
if (m_client == nullptr) {
throw std::runtime_error("Unable to create TimestreamQueryClient.");
}
return TestQueryClient();
}

void TSCommunication::Disconnect() {
LogMsg(LOG_DEBUG, "Disconnecting Timestream connection.");
if (m_client) {
Expand Down Expand Up @@ -160,17 +205,6 @@ void TSCommunication::StopResultRetrieval(StatementClass* stmt) {
}
}

std::unique_ptr< Aws::TimestreamQuery::TimestreamQueryClient >
TSCommunication::CreateQueryClientWithIdp(
std::unique_ptr< SAMLCredentialsProvider > cp,
const Aws::Client::ClientConfiguration& config) {
auto credentials = cp->GetAWSCredentials();
auto query_client =
std::make_unique< Aws::TimestreamQuery::TimestreamQueryClient >(
credentials, config);
return query_client;
}

/**
* Context class for Aws::Client::AsyncCallerContext
* Only for execution
Expand Down
17 changes: 10 additions & 7 deletions src/odfesqlodbc/ts_communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,19 @@ class TSCommunication : public Communication {

private:
/**
* Create an unique_ptr of TimestreamQueryClient using the given credentials
* provider and configs
* @param cp std::unique_ptr<SAMLCredentialsProvider>
* @param auth const authentication_options&
* @param config const Aws::Client::ClientConfiguration&
* Create Timestream Query Client
* @param options const runtime_options&
* @return std::unique_ptr< Aws::TimestreamQuery::TimestreamQueryClient >
*/
std::unique_ptr< Aws::TimestreamQuery::TimestreamQueryClient >
CreateQueryClientWithIdp(std::unique_ptr< SAMLCredentialsProvider > cp,
const Aws::Client::ClientConfiguration& config);
CreateQueryClient(const runtime_options& options);

/**
* Test Query Client using "Select 1"
* @return bool
*/
bool TestQueryClient();

/**
* Timestream query client
*/
Expand Down

0 comments on commit 6130455

Please sign in to comment.