diff --git a/src/UnitTests/UTConn/test_conn.cpp b/src/UnitTests/UTConn/test_conn.cpp index 792bd2db4..e8535732b 100644 --- a/src/UnitTests/UTConn/test_conn.cpp +++ b/src/UnitTests/UTConn/test_conn.cpp @@ -123,6 +123,12 @@ TEST(TestDecodeHex, All_possible_hex) { EXPECT_EQ(expected, OktaCredentialsProvider::DecodeHex(hex_encoded)); } +TEST(TestGetUserAgent, Success) { + TSCommunication conn; + std::string expected = "ts-odbc." TIMESTREAMDRIVERVERSION " [ut_conn]"; + EXPECT_EQ(expected, conn.GetUserAgent()); +} + // TODO: enable gmock and mock the response from timestream //class TestTSConnConnectDBStart : public testing::Test { // protected: diff --git a/src/odfesqlodbc/mylog.c b/src/odfesqlodbc/mylog.c index 7530b55f1..34a56ffa8 100644 --- a/src/odfesqlodbc/mylog.c +++ b/src/odfesqlodbc/mylog.c @@ -27,6 +27,10 @@ #include "es_odbc.h" #include "misc.h" +#ifdef __APPLE__ +#include +#endif + #ifndef WIN32 #include #include @@ -124,6 +128,13 @@ const char *GetExeProgramName() { if (GetModuleFileName(NULL, pathname, sizeof(pathname)) > 0) _splitpath(pathname, NULL, NULL, exename, NULL); +#elif __APPLE__ + char pathname[PATH_MAX + 1]; + uint32_t size = PATH_MAX + 1; + + if (_NSGetExecutablePath(pathname, &size) == 0) { + STRCPY_FIXED(exename, po_basename(pathname)); + } #else CSTR flist[] = {"/proc/self/exe", "/proc/curproc/file", "/proc/curproc/exe"}; diff --git a/src/odfesqlodbc/ts_communication.cpp b/src/odfesqlodbc/ts_communication.cpp index e689affc5..ba1cf5c5e 100644 --- a/src/odfesqlodbc/ts_communication.cpp +++ b/src/odfesqlodbc/ts_communication.cpp @@ -34,6 +34,8 @@ // clang-format on namespace { + const Aws::String UA_ID_PREFIX = Aws::String("ts-odbc."); + typedef std::function< std::unique_ptr< Aws::TimestreamQuery::TimestreamQueryClient >( const runtime_options& options, const Aws::Client::ClientConfiguration& config) > QueryClientCreator; @@ -109,6 +111,7 @@ bool TSCommunication::Validate(const runtime_options& options) { std::unique_ptr< Aws::TimestreamQuery::TimestreamQueryClient > TSCommunication::CreateQueryClient(const runtime_options& options) { Aws::Client::ClientConfiguration config; + config.userAgent = GetUserAgent(); if (!options.auth.end_point_override.empty()) { config.endpointOverride = options.auth.end_point_override; } else { @@ -205,6 +208,14 @@ void TSCommunication::StopResultRetrieval(StatementClass* stmt) { } } +Aws::String TSCommunication::GetUserAgent() { + Aws::String program_name(GetExeProgramName()); + Aws::String name_suffix = " [" + program_name + "]"; + Aws::String msg = "Name of the application using the driver: " + name_suffix; + LogMsg(LOG_INFO, msg.c_str()); + return UA_ID_PREFIX + GetVersion() + name_suffix; +} + /** * Context class for Aws::Client::AsyncCallerContext * Only for execution diff --git a/src/odfesqlodbc/ts_communication.h b/src/odfesqlodbc/ts_communication.h index 603a04b0b..2e990c264 100644 --- a/src/odfesqlodbc/ts_communication.h +++ b/src/odfesqlodbc/ts_communication.h @@ -81,6 +81,12 @@ class TSCommunication : public Communication { */ virtual void StopResultRetrieval(StatementClass* stmt) override; + /** + * Get the user agent for Aws::Client::ClientConfiguration. + * @return the user agent. + */ + Aws::String GetUserAgent(); + private: /** * Create Timestream Query Client