diff --git a/src/AppInstallerCLICore/Workflows/WorkflowBase.cpp b/src/AppInstallerCLICore/Workflows/WorkflowBase.cpp index 7e9884877b..4544bf7930 100644 --- a/src/AppInstallerCLICore/Workflows/WorkflowBase.cpp +++ b/src/AppInstallerCLICore/Workflows/WorkflowBase.cpp @@ -101,7 +101,11 @@ namespace AppInstaller::CLI::Workflow } } - auto openFunction = [&](IProgressCallback& progress)->std::vector { return source.Open(progress); }; + auto openFunction = [&](IProgressCallback& progress)->std::vector + { + source.SetCaller("winget-cli"); + return source.Open(progress); + }; auto updateFailures = context.Reporter.ExecuteWithProgress(openFunction, true); // We'll only report the source update failure as warning and continue @@ -376,7 +380,11 @@ namespace AppInstaller::CLI::Workflow // A well known predefined source should return a value. THROW_HR_IF(E_UNEXPECTED, !source); - auto openFunction = [&](IProgressCallback& progress)->std::vector { return source.Open(progress); }; + auto openFunction = [&](IProgressCallback& progress)->std::vector + { + source.SetCaller("winget-cli"); + return source.Open(progress); + }; context.Reporter.ExecuteWithProgress(openFunction, true); } catch (...) diff --git a/src/AppInstallerCLITests/CustomHeader.cpp b/src/AppInstallerCLITests/CustomHeader.cpp index bc644e3ff8..4644bd7ee2 100644 --- a/src/AppInstallerCLITests/CustomHeader.cpp +++ b/src/AppInstallerCLITests/CustomHeader.cpp @@ -81,7 +81,7 @@ TEST_CASE("RestClient_CustomHeader", "[RestSource][CustomHeader]") std::optional customHeader = "Testing custom header"; auto header = std::make_pair<>(CustomHeaderName, JSON::GetUtilityString(customHeader.value())); HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) }; - RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, std::move(helper)); + RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, {}, std::move(helper)); REQUIRE(client.GetSourceIdentifier() == "Source123"); } @@ -130,6 +130,41 @@ TEST_CASE("RestSourceSearch_CustomHeaderExceedingSize", "[RestSource][CustomHead auto header = std::make_pair<>(CustomHeaderName, JSON::GetUtilityString(customHeader)); HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sampleSearchResponse, header) }; - REQUIRE_THROWS_HR(RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, std::move(helper)), + REQUIRE_THROWS_HR(RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, {}, std::move(helper)), APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH); } + +TEST_CASE("RestClient_CustomUserAgentHeader", "[RestSource][CustomHeader]") +{ + utility::string_t sample = _XPLATSTR( + R"delimiter({ + "Data" : { + "SourceIdentifier": "Source123", + "ServerSupportedVersions": [ + "1.0.0", + "2.0.0"] + }})delimiter"); + + std::string testCaller = "TestCaller"; + auto header = std::make_pair<>(web::http::header_names::user_agent, JSON::GetUtilityString(Runtime::GetUserAgent(testCaller))); + HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) }; + RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), {}, testCaller, std::move(helper)); + REQUIRE(client.GetSourceIdentifier() == "Source123"); +} + +TEST_CASE("RestClient_DefaultUserAgentHeader", "[RestSource][CustomHeader]") +{ + utility::string_t sample = _XPLATSTR( + R"delimiter({ + "Data" : { + "SourceIdentifier": "Source123", + "ServerSupportedVersions": [ + "1.0.0", + "2.0.0"] + }})delimiter"); + + auto header = std::make_pair<>(web::http::header_names::user_agent, JSON::GetUtilityString(Runtime::GetDefaultUserAgent())); + HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) }; + RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), {}, {}, std::move(helper)); + REQUIRE(client.GetSourceIdentifier() == "Source123"); +} \ No newline at end of file diff --git a/src/AppInstallerCLITests/RestClient.cpp b/src/AppInstallerCLITests/RestClient.cpp index d15178c6d4..882b235e2d 100644 --- a/src/AppInstallerCLITests/RestClient.cpp +++ b/src/AppInstallerCLITests/RestClient.cpp @@ -142,7 +142,7 @@ TEST_CASE("RestClientCreate_UnsupportedVersion", "[RestSource]") }})delimiter"); HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) }; - REQUIRE_THROWS_HR(RestClient::Create("https://restsource.com/api", {}, std::move(helper)), APPINSTALLER_CLI_ERROR_UNSUPPORTED_RESTSOURCE); + REQUIRE_THROWS_HR(RestClient::Create("https://restsource.com/api", {}, {}, std::move(helper)), APPINSTALLER_CLI_ERROR_UNSUPPORTED_RESTSOURCE); } TEST_CASE("RestClientCreate_1.0_Success", "[RestSource]") @@ -157,7 +157,7 @@ TEST_CASE("RestClientCreate_1.0_Success", "[RestSource]") }})delimiter"); HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) }; - RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, std::move(helper)); + RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, {}, std::move(helper)); REQUIRE(client.GetSourceIdentifier() == "Source123"); } @@ -194,7 +194,7 @@ TEST_CASE("RestClientCreate_1.1_Success", "[RestSource]") }})delimiter"); HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) }; - RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, std::move(helper)); + RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, {}, std::move(helper)); REQUIRE(client.GetSourceIdentifier() == "Source123"); auto information = client.GetSourceInformation(); REQUIRE(information.SourceAgreementsIdentifier == "agreementV1"); diff --git a/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h b/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h index 0929f66b20..8683b567ad 100644 --- a/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h +++ b/src/AppInstallerCommonCore/Public/AppInstallerRuntime.h @@ -138,4 +138,7 @@ namespace AppInstaller::Runtime // Gets the default user agent string for the Windows Package Manager. Utility::LocIndString GetDefaultUserAgent(); + + // Gets the user agent string from passed in caller for the Windows Package Manager. + Utility::LocIndString GetUserAgent(std::string_view caller); } diff --git a/src/AppInstallerCommonCore/Runtime.cpp b/src/AppInstallerCommonCore/Runtime.cpp index 69e21f5df4..2347c719c6 100644 --- a/src/AppInstallerCommonCore/Runtime.cpp +++ b/src/AppInstallerCommonCore/Runtime.cpp @@ -775,7 +775,17 @@ namespace AppInstaller::Runtime { std::ostringstream strstr; strstr << - "winget-cli" + "winget-cli" << + " WindowsPackageManager/" << GetClientVersion() << + " DesktopAppInstaller/" << GetPackageVersion(); + return Utility::LocIndString{ strstr.str() }; + } + + Utility::LocIndString GetUserAgent(std::string_view caller) + { + std::ostringstream strstr; + strstr << + caller << " WindowsPackageManager/" << GetClientVersion() << " DesktopAppInstaller/" << GetPackageVersion(); return Utility::LocIndString{ strstr.str() }; diff --git a/src/AppInstallerRepositoryCore/ISource.h b/src/AppInstallerRepositoryCore/ISource.h index 2749b26e99..df91ed5f8e 100644 --- a/src/AppInstallerRepositoryCore/ISource.h +++ b/src/AppInstallerRepositoryCore/ISource.h @@ -42,7 +42,10 @@ namespace AppInstaller::Repository virtual SourceInformation GetInformation() { return {}; } // Set custom header. Returns false if custom header is not supported. - virtual bool SetCustomHeader(std::optional header) { UNREFERENCED_PARAMETER(header); return false; } + virtual bool SetCustomHeader(std::optional) { return false; } + + // Set caller. + virtual void SetCaller(std::string) {} // Opens the source. This function should throw upon open failure rather than returning an empty pointer. virtual std::shared_ptr Open(IProgressCallback& progress) = 0; diff --git a/src/AppInstallerRepositoryCore/Public/winget/RepositorySource.h b/src/AppInstallerRepositoryCore/Public/winget/RepositorySource.h index f8dbde67dc..c565efac02 100644 --- a/src/AppInstallerRepositoryCore/Public/winget/RepositorySource.h +++ b/src/AppInstallerRepositoryCore/Public/winget/RepositorySource.h @@ -216,6 +216,9 @@ namespace AppInstaller::Repository // Set custom header. bool SetCustomHeader(std::optional header); + // Set caller. + void SetCaller(std::string caller); + // Execute a search on the source. SearchResult Search(const SearchRequest& request) const; diff --git a/src/AppInstallerRepositoryCore/RepositorySource.cpp b/src/AppInstallerRepositoryCore/RepositorySource.cpp index d1474e6a65..6c25dcbb16 100644 --- a/src/AppInstallerRepositoryCore/RepositorySource.cpp +++ b/src/AppInstallerRepositoryCore/RepositorySource.cpp @@ -427,6 +427,14 @@ namespace AppInstaller::Repository return m_sourceReferences[0]->SetCustomHeader(header); } + void Source::SetCaller(std::string caller) + { + for (auto& sourceReference : m_sourceReferences) + { + sourceReference->SetCaller(caller); + } + } + SearchResult Source::Search(const SearchRequest& request) const { THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_source); diff --git a/src/AppInstallerRepositoryCore/Rest/RestClient.cpp b/src/AppInstallerRepositoryCore/Rest/RestClient.cpp index 7f2e9af901..665b929bfa 100644 --- a/src/AppInstallerRepositoryCore/Rest/RestClient.cpp +++ b/src/AppInstallerRepositoryCore/Rest/RestClient.cpp @@ -24,18 +24,30 @@ namespace AppInstaller::Repository::Rest constexpr size_t WindowsPackageManagerHeaderMaxLength = 1024; namespace { - std::unordered_map GetHeaders(std::optional customHeader) + std::unordered_map GetHeaders(std::optional customHeader, std::string_view caller) { - if (!customHeader) + std::unordered_map headers; + + if (customHeader) { - AICLI_LOG(Repo, Verbose, << "Custom header not found."); - return {}; + AICLI_LOG(Repo, Verbose, << "Custom header found: " << customHeader.value()); + THROW_HR_IF(APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH, customHeader.value().size() > WindowsPackageManagerHeaderMaxLength); + headers.emplace(JSON::GetUtilityString(WindowsPackageManagerHeader), JSON::GetUtilityString(customHeader.value())); } - THROW_HR_IF(APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH, customHeader.value().size() > WindowsPackageManagerHeaderMaxLength); + if (!caller.empty()) + { + AICLI_LOG(Repo, Verbose, << "User agent caller found: " << caller); + std::wstring userAgentWide = JSON::GetUtilityString(Runtime::GetUserAgent(caller)); + try + { + // Replace user profile if the caller binary is under user profile. + userAgentWide = Utility::ReplaceWhileCopying(userAgentWide, Runtime::GetPathTo(Runtime::PathName::UserProfile).wstring(), L"%USERPROFILE%"); + } + CATCH_LOG(); + headers.emplace(web::http::header_names::user_agent, userAgentWide); + } - std::unordered_map headers; - headers.emplace(JSON::GetUtilityString(WindowsPackageManagerHeader), JSON::GetUtilityString(customHeader.value())); return headers; } } @@ -139,12 +151,12 @@ namespace AppInstaller::Repository::Rest THROW_HR(APPINSTALLER_CLI_ERROR_RESTSOURCE_INVALID_VERSION); } - RestClient RestClient::Create(const std::string& restApi, std::optional customHeader, const HttpClientHelper& helper) + RestClient RestClient::Create(const std::string& restApi, std::optional customHeader, std::string_view caller, const HttpClientHelper& helper) { utility::string_t restEndpoint = RestHelper::GetRestAPIBaseUri(restApi); THROW_HR_IF(APPINSTALLER_CLI_ERROR_RESTSOURCE_INVALID_URL, !RestHelper::IsValidUri(restEndpoint)); - auto headers = GetHeaders(customHeader); + auto headers = GetHeaders(customHeader, caller); IRestClient::Information information = GetInformation(restEndpoint, headers, helper); std::optional latestCommonVersion = GetLatestCommonVersion(information.ServerSupportedVersions, WingetSupportedContracts); diff --git a/src/AppInstallerRepositoryCore/Rest/RestClient.h b/src/AppInstallerRepositoryCore/Rest/RestClient.h index af296f11bb..d30511a6e3 100644 --- a/src/AppInstallerRepositoryCore/Rest/RestClient.h +++ b/src/AppInstallerRepositoryCore/Rest/RestClient.h @@ -33,7 +33,7 @@ namespace AppInstaller::Repository::Rest static std::unique_ptr GetSupportedInterface(const std::string& restApi, const std::unordered_map& additionalHeaders, const Schema::IRestClient::Information& information, const AppInstaller::Utility::Version& version); - static RestClient Create(const std::string& restApi, std::optional customHeader, const Schema::HttpClientHelper& helper = {}); + static RestClient Create(const std::string& restApi, std::optional customHeader, std::string_view caller, const Schema::HttpClientHelper& helper = {}); private: RestClient(std::unique_ptr supportedInterface, std::string sourceIdentifier); diff --git a/src/AppInstallerRepositoryCore/Rest/RestSourceFactory.cpp b/src/AppInstallerRepositoryCore/Rest/RestSourceFactory.cpp index 1d97bf59da..c3f7320239 100644 --- a/src/AppInstallerRepositoryCore/Rest/RestSourceFactory.cpp +++ b/src/AppInstallerRepositoryCore/Rest/RestSourceFactory.cpp @@ -37,10 +37,15 @@ namespace AppInstaller::Repository::Rest return true; } + void SetCaller(std::string caller) override + { + m_caller = std::move(caller); + } + std::shared_ptr Open(IProgressCallback&) override { Initialize(); - RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_httpClientHelper); + RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_caller, m_httpClientHelper); return std::make_shared(m_details, m_information, std::move(restClient)); } @@ -51,7 +56,7 @@ namespace AppInstaller::Repository::Rest [&]() { m_httpClientHelper.SetPinningConfiguration(m_details.CertificatePinningConfiguration); - RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_httpClientHelper); + RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_caller, m_httpClientHelper); m_details.Identifier = restClient.GetSourceIdentifier(); @@ -73,6 +78,7 @@ namespace AppInstaller::Repository::Rest Schema::HttpClientHelper m_httpClientHelper; SourceInformation m_information; std::optional m_customHeader; + std::string m_caller; std::once_flag m_initializeFlag; }; diff --git a/src/Microsoft.Management.Deployment/Helpers.cpp b/src/Microsoft.Management.Deployment/Helpers.cpp index 5c228082b6..8adfb3fc36 100644 --- a/src/Microsoft.Management.Deployment/Helpers.cpp +++ b/src/Microsoft.Management.Deployment/Helpers.cpp @@ -13,6 +13,24 @@ using namespace std::string_view_literals; namespace winrt::Microsoft::Management::Deployment::implementation { + namespace + { + static std::optional s_callerName; + static wil::srwlock s_callerNameLock; + } + + void SetComCallerName(std::string name) + { + auto lock = s_callerNameLock.lock_exclusive(); + s_callerName.emplace(std::move(name)); + } + + std::string GetComCallerName(std::string defaultNameIfNotSet) + { + auto lock = s_callerNameLock.lock_shared(); + return s_callerName.has_value() ? s_callerName.value() : defaultNameIfNotSet; + } + std::pair GetCallerProcessId() { RPC_STATUS rpcStatus = RPC_S_OK; diff --git a/src/Microsoft.Management.Deployment/Helpers.h b/src/Microsoft.Management.Deployment/Helpers.h index 2cf38ba7c4..bd9c1df36a 100644 --- a/src/Microsoft.Management.Deployment/Helpers.h +++ b/src/Microsoft.Management.Deployment/Helpers.h @@ -7,6 +7,9 @@ namespace winrt::Microsoft::Management::Deployment::implementation { + void SetComCallerName(std::string name); + std::string GetComCallerName(std::string defaultNameIfNotSet); + // Enable custom code to run before creating any object through the factory. // Currently that means requiring the overall WinGet policy to be enabled. template diff --git a/src/Microsoft.Management.Deployment/PackageCatalogReference.cpp b/src/Microsoft.Management.Deployment/PackageCatalogReference.cpp index cb7fa46f5d..11031cc5c7 100644 --- a/src/Microsoft.Management.Deployment/PackageCatalogReference.cpp +++ b/src/Microsoft.Management.Deployment/PackageCatalogReference.cpp @@ -14,10 +14,38 @@ #include #include #include +#include #include namespace winrt::Microsoft::Management::Deployment::implementation { + namespace + { + std::string GetCallerName() + { + // See if caller name is set by caller + static auto callerName = GetComCallerName(""); + + // Get process string + if (callerName.empty()) + { + try + { + auto [hrGetCallerId, callerProcessId] = GetCallerProcessId(); + THROW_IF_FAILED(hrGetCallerId); + callerName = AppInstaller::Utility::ConvertToUTF8(TryGetCallerProcessInfo(callerProcessId)); + } + CATCH_LOG(); + } + + if (callerName.empty()) + { + callerName = "UnknownComCaller"; + } + + return callerName; + } + } void PackageCatalogReference::Initialize(winrt::Microsoft::Management::Deployment::PackageCatalogInfo packageCatalogInfo, ::AppInstaller::Repository::Source sourceReference) { m_info = packageCatalogInfo; @@ -77,6 +105,7 @@ namespace winrt::Microsoft::Management::Deployment::implementation auto catalog = m_compositePackageCatalogOptions.Catalogs().GetAt(i); winrt::Microsoft::Management::Deployment::implementation::PackageCatalogReference* catalogImpl = get_self(catalog); auto copy = catalogImpl->m_sourceReference; + copy.SetCaller(GetCallerName()); copy.Open(progress); remoteSources.emplace_back(std::move(copy)); } @@ -112,6 +141,7 @@ namespace winrt::Microsoft::Management::Deployment::implementation else { source = m_sourceReference; + source.SetCaller(GetCallerName()); source.Open(progress); } diff --git a/src/Microsoft.Management.Deployment/PackageManager.cpp b/src/Microsoft.Management.Deployment/PackageManager.cpp index 451da773a9..abde649745 100644 --- a/src/Microsoft.Management.Deployment/PackageManager.cpp +++ b/src/Microsoft.Management.Deployment/PackageManager.cpp @@ -38,18 +38,6 @@ using namespace ::AppInstaller::CLI::Execution; namespace winrt::Microsoft::Management::Deployment::implementation { - namespace - { - static std::optional s_callerName; - static wil::srwlock s_callerNameLock; - } - - void SetComCallerName(std::string name) - { - auto lock = s_callerNameLock.lock_exclusive(); - s_callerName.emplace(std::move(name)); - } - winrt::Windows::Foundation::Collections::IVectorView PackageManager::GetPackageCatalogs() { Windows::Foundation::Collections::IVector catalogs{ winrt::single_threaded_vector() }; @@ -454,12 +442,8 @@ namespace winrt::Microsoft::Management::Deployment::implementation { std::unique_ptr context = std::make_unique(); hstring correlationData = (options) ? options.CorrelationData() : L""; - std::string callerName; - { - auto lock = s_callerNameLock.lock_shared(); - callerName = s_callerName.has_value() ? s_callerName.value() : AppInstaller::Utility::ConvertToUTF8(callerProcessInfoString); - } - context->SetContextLoggers(correlationData, callerName); + + context->SetContextLoggers(correlationData, GetComCallerName(AppInstaller::Utility::ConvertToUTF8(callerProcessInfoString))); // Convert the options to arguments for the installer. if constexpr (std::is_same_v)