Skip to content

Commit

Permalink
Pass COM caller name to rest source in request header (#3112)
Browse files Browse the repository at this point in the history
  • Loading branch information
yao-msft authored Mar 29, 2023
1 parent 0f9554b commit dddc094
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 39 deletions.
12 changes: 10 additions & 2 deletions src/AppInstallerCLICore/Workflows/WorkflowBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ namespace AppInstaller::CLI::Workflow
}
}

auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails> { return source.Open(progress); };
auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails>
{
source.SetCaller("winget-cli");
return source.Open(progress);
};
auto updateFailures = context.Reporter.ExecuteWithProgress(openFunction, true);

// We'll only report the source update failure as warning and continue
Expand Down Expand Up @@ -376,7 +380,11 @@ namespace AppInstaller::CLI::Workflow
// A well known predefined source should return a value.
THROW_HR_IF(E_UNEXPECTED, !source);

auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails> { return source.Open(progress); };
auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails>
{
source.SetCaller("winget-cli");
return source.Open(progress);
};
context.Reporter.ExecuteWithProgress(openFunction, true);
}
catch (...)
Expand Down
39 changes: 37 additions & 2 deletions src/AppInstallerCLITests/CustomHeader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TEST_CASE("RestClient_CustomHeader", "[RestSource][CustomHeader]")
std::optional<std::string> customHeader = "Testing custom header";
auto header = std::make_pair<>(CustomHeaderName, JSON::GetUtilityString(customHeader.value()));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, std::move(helper));
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}

Expand Down Expand Up @@ -130,6 +130,41 @@ TEST_CASE("RestSourceSearch_CustomHeaderExceedingSize", "[RestSource][CustomHead
auto header = std::make_pair<>(CustomHeaderName, JSON::GetUtilityString(customHeader));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sampleSearchResponse, header) };

REQUIRE_THROWS_HR(RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, std::move(helper)),
REQUIRE_THROWS_HR(RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, {}, std::move(helper)),
APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH);
}

TEST_CASE("RestClient_CustomUserAgentHeader", "[RestSource][CustomHeader]")
{
utility::string_t sample = _XPLATSTR(
R"delimiter({
"Data" : {
"SourceIdentifier": "Source123",
"ServerSupportedVersions": [
"1.0.0",
"2.0.0"]
}})delimiter");

std::string testCaller = "TestCaller";
auto header = std::make_pair<>(web::http::header_names::user_agent, JSON::GetUtilityString(Runtime::GetUserAgent(testCaller)));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), {}, testCaller, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}

TEST_CASE("RestClient_DefaultUserAgentHeader", "[RestSource][CustomHeader]")
{
utility::string_t sample = _XPLATSTR(
R"delimiter({
"Data" : {
"SourceIdentifier": "Source123",
"ServerSupportedVersions": [
"1.0.0",
"2.0.0"]
}})delimiter");

auto header = std::make_pair<>(web::http::header_names::user_agent, JSON::GetUtilityString(Runtime::GetDefaultUserAgent()));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), {}, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}
6 changes: 3 additions & 3 deletions src/AppInstallerCLITests/RestClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ TEST_CASE("RestClientCreate_UnsupportedVersion", "[RestSource]")
}})delimiter");

HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) };
REQUIRE_THROWS_HR(RestClient::Create("https://restsource.com/api", {}, std::move(helper)), APPINSTALLER_CLI_ERROR_UNSUPPORTED_RESTSOURCE);
REQUIRE_THROWS_HR(RestClient::Create("https://restsource.com/api", {}, {}, std::move(helper)), APPINSTALLER_CLI_ERROR_UNSUPPORTED_RESTSOURCE);
}

TEST_CASE("RestClientCreate_1.0_Success", "[RestSource]")
Expand All @@ -157,7 +157,7 @@ TEST_CASE("RestClientCreate_1.0_Success", "[RestSource]")
}})delimiter");

HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, std::move(helper));
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}

Expand Down Expand Up @@ -194,7 +194,7 @@ TEST_CASE("RestClientCreate_1.1_Success", "[RestSource]")
}})delimiter");

HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, std::move(helper));
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
auto information = client.GetSourceInformation();
REQUIRE(information.SourceAgreementsIdentifier == "agreementV1");
Expand Down
3 changes: 3 additions & 0 deletions src/AppInstallerCommonCore/Public/AppInstallerRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,7 @@ namespace AppInstaller::Runtime

// Gets the default user agent string for the Windows Package Manager.
Utility::LocIndString GetDefaultUserAgent();

// Gets the user agent string from passed in caller for the Windows Package Manager.
Utility::LocIndString GetUserAgent(std::string_view caller);
}
12 changes: 11 additions & 1 deletion src/AppInstallerCommonCore/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,17 @@ namespace AppInstaller::Runtime
{
std::ostringstream strstr;
strstr <<
"winget-cli"
"winget-cli" <<
" WindowsPackageManager/" << GetClientVersion() <<
" DesktopAppInstaller/" << GetPackageVersion();
return Utility::LocIndString{ strstr.str() };
}

Utility::LocIndString GetUserAgent(std::string_view caller)
{
std::ostringstream strstr;
strstr <<
caller <<
" WindowsPackageManager/" << GetClientVersion() <<
" DesktopAppInstaller/" << GetPackageVersion();
return Utility::LocIndString{ strstr.str() };
Expand Down
5 changes: 4 additions & 1 deletion src/AppInstallerRepositoryCore/ISource.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ namespace AppInstaller::Repository
virtual SourceInformation GetInformation() { return {}; }

// Set custom header. Returns false if custom header is not supported.
virtual bool SetCustomHeader(std::optional<std::string> header) { UNREFERENCED_PARAMETER(header); return false; }
virtual bool SetCustomHeader(std::optional<std::string>) { return false; }

// Set caller.
virtual void SetCaller(std::string) {}

// Opens the source. This function should throw upon open failure rather than returning an empty pointer.
virtual std::shared_ptr<ISource> Open(IProgressCallback& progress) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ namespace AppInstaller::Repository
// Set custom header.
bool SetCustomHeader(std::optional<std::string> header);

// Set caller.
void SetCaller(std::string caller);

// Execute a search on the source.
SearchResult Search(const SearchRequest& request) const;

Expand Down
8 changes: 8 additions & 0 deletions src/AppInstallerRepositoryCore/RepositorySource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,14 @@ namespace AppInstaller::Repository
return m_sourceReferences[0]->SetCustomHeader(header);
}

void Source::SetCaller(std::string caller)
{
for (auto& sourceReference : m_sourceReferences)
{
sourceReference->SetCaller(caller);
}
}

SearchResult Source::Search(const SearchRequest& request) const
{
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_source);
Expand Down
30 changes: 21 additions & 9 deletions src/AppInstallerRepositoryCore/Rest/RestClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,30 @@ namespace AppInstaller::Repository::Rest
constexpr size_t WindowsPackageManagerHeaderMaxLength = 1024;

namespace {
std::unordered_map<utility::string_t, utility::string_t> GetHeaders(std::optional<std::string> customHeader)
std::unordered_map<utility::string_t, utility::string_t> GetHeaders(std::optional<std::string> customHeader, std::string_view caller)
{
if (!customHeader)
std::unordered_map<utility::string_t, utility::string_t> headers;

if (customHeader)
{
AICLI_LOG(Repo, Verbose, << "Custom header not found.");
return {};
AICLI_LOG(Repo, Verbose, << "Custom header found: " << customHeader.value());
THROW_HR_IF(APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH, customHeader.value().size() > WindowsPackageManagerHeaderMaxLength);
headers.emplace(JSON::GetUtilityString(WindowsPackageManagerHeader), JSON::GetUtilityString(customHeader.value()));
}

THROW_HR_IF(APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH, customHeader.value().size() > WindowsPackageManagerHeaderMaxLength);
if (!caller.empty())
{
AICLI_LOG(Repo, Verbose, << "User agent caller found: " << caller);
std::wstring userAgentWide = JSON::GetUtilityString(Runtime::GetUserAgent(caller));
try
{
// Replace user profile if the caller binary is under user profile.
userAgentWide = Utility::ReplaceWhileCopying(userAgentWide, Runtime::GetPathTo(Runtime::PathName::UserProfile).wstring(), L"%USERPROFILE%");
}
CATCH_LOG();
headers.emplace(web::http::header_names::user_agent, userAgentWide);
}

std::unordered_map<utility::string_t, utility::string_t> headers;
headers.emplace(JSON::GetUtilityString(WindowsPackageManagerHeader), JSON::GetUtilityString(customHeader.value()));
return headers;
}
}
Expand Down Expand Up @@ -139,12 +151,12 @@ namespace AppInstaller::Repository::Rest
THROW_HR(APPINSTALLER_CLI_ERROR_RESTSOURCE_INVALID_VERSION);
}

RestClient RestClient::Create(const std::string& restApi, std::optional<std::string> customHeader, const HttpClientHelper& helper)
RestClient RestClient::Create(const std::string& restApi, std::optional<std::string> customHeader, std::string_view caller, const HttpClientHelper& helper)
{
utility::string_t restEndpoint = RestHelper::GetRestAPIBaseUri(restApi);
THROW_HR_IF(APPINSTALLER_CLI_ERROR_RESTSOURCE_INVALID_URL, !RestHelper::IsValidUri(restEndpoint));

auto headers = GetHeaders(customHeader);
auto headers = GetHeaders(customHeader, caller);

IRestClient::Information information = GetInformation(restEndpoint, headers, helper);
std::optional<Version> latestCommonVersion = GetLatestCommonVersion(information.ServerSupportedVersions, WingetSupportedContracts);
Expand Down
2 changes: 1 addition & 1 deletion src/AppInstallerRepositoryCore/Rest/RestClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace AppInstaller::Repository::Rest

static std::unique_ptr<Schema::IRestClient> GetSupportedInterface(const std::string& restApi, const std::unordered_map<utility::string_t, utility::string_t>& additionalHeaders, const Schema::IRestClient::Information& information, const AppInstaller::Utility::Version& version);

static RestClient Create(const std::string& restApi, std::optional<std::string> customHeader, const Schema::HttpClientHelper& helper = {});
static RestClient Create(const std::string& restApi, std::optional<std::string> customHeader, std::string_view caller, const Schema::HttpClientHelper& helper = {});
private:
RestClient(std::unique_ptr<Schema::IRestClient> supportedInterface, std::string sourceIdentifier);

Expand Down
10 changes: 8 additions & 2 deletions src/AppInstallerRepositoryCore/Rest/RestSourceFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@ namespace AppInstaller::Repository::Rest
return true;
}

void SetCaller(std::string caller) override
{
m_caller = std::move(caller);
}

std::shared_ptr<ISource> Open(IProgressCallback&) override
{
Initialize();
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_httpClientHelper);
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_caller, m_httpClientHelper);
return std::make_shared<RestSource>(m_details, m_information, std::move(restClient));
}

Expand All @@ -51,7 +56,7 @@ namespace AppInstaller::Repository::Rest
[&]()
{
m_httpClientHelper.SetPinningConfiguration(m_details.CertificatePinningConfiguration);
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_httpClientHelper);
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_caller, m_httpClientHelper);

m_details.Identifier = restClient.GetSourceIdentifier();

Expand All @@ -73,6 +78,7 @@ namespace AppInstaller::Repository::Rest
Schema::HttpClientHelper m_httpClientHelper;
SourceInformation m_information;
std::optional<std::string> m_customHeader;
std::string m_caller;
std::once_flag m_initializeFlag;
};

Expand Down
18 changes: 18 additions & 0 deletions src/Microsoft.Management.Deployment/Helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ using namespace std::string_view_literals;

namespace winrt::Microsoft::Management::Deployment::implementation
{
namespace
{
static std::optional<std::string> s_callerName;
static wil::srwlock s_callerNameLock;
}

void SetComCallerName(std::string name)
{
auto lock = s_callerNameLock.lock_exclusive();
s_callerName.emplace(std::move(name));
}

std::string GetComCallerName(std::string defaultNameIfNotSet)
{
auto lock = s_callerNameLock.lock_shared();
return s_callerName.has_value() ? s_callerName.value() : defaultNameIfNotSet;
}

std::pair<HRESULT, DWORD> GetCallerProcessId()
{
RPC_STATUS rpcStatus = RPC_S_OK;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.Management.Deployment/Helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace winrt::Microsoft::Management::Deployment::implementation
{
void SetComCallerName(std::string name);
std::string GetComCallerName(std::string defaultNameIfNotSet);

// Enable custom code to run before creating any object through the factory.
// Currently that means requiring the overall WinGet policy to be enabled.
template <typename TCppWinRTClass>
Expand Down
30 changes: 30 additions & 0 deletions src/Microsoft.Management.Deployment/PackageCatalogReference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,38 @@
#include <wil\cppwinrt_wrl.h>
#include <winget/GroupPolicy.h>
#include <AppInstallerErrors.h>
#include <AppInstallerStrings.h>
#include <Helpers.h>

namespace winrt::Microsoft::Management::Deployment::implementation
{
namespace
{
std::string GetCallerName()
{
// See if caller name is set by caller
static auto callerName = GetComCallerName("");

// Get process string
if (callerName.empty())
{
try
{
auto [hrGetCallerId, callerProcessId] = GetCallerProcessId();
THROW_IF_FAILED(hrGetCallerId);
callerName = AppInstaller::Utility::ConvertToUTF8(TryGetCallerProcessInfo(callerProcessId));
}
CATCH_LOG();
}

if (callerName.empty())
{
callerName = "UnknownComCaller";
}

return callerName;
}
}
void PackageCatalogReference::Initialize(winrt::Microsoft::Management::Deployment::PackageCatalogInfo packageCatalogInfo, ::AppInstaller::Repository::Source sourceReference)
{
m_info = packageCatalogInfo;
Expand Down Expand Up @@ -77,6 +105,7 @@ namespace winrt::Microsoft::Management::Deployment::implementation
auto catalog = m_compositePackageCatalogOptions.Catalogs().GetAt(i);
winrt::Microsoft::Management::Deployment::implementation::PackageCatalogReference* catalogImpl = get_self<winrt::Microsoft::Management::Deployment::implementation::PackageCatalogReference>(catalog);
auto copy = catalogImpl->m_sourceReference;
copy.SetCaller(GetCallerName());
copy.Open(progress);
remoteSources.emplace_back(std::move(copy));
}
Expand Down Expand Up @@ -112,6 +141,7 @@ namespace winrt::Microsoft::Management::Deployment::implementation
else
{
source = m_sourceReference;
source.SetCaller(GetCallerName());
source.Open(progress);
}

Expand Down
20 changes: 2 additions & 18 deletions src/Microsoft.Management.Deployment/PackageManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,6 @@ using namespace ::AppInstaller::CLI::Execution;

namespace winrt::Microsoft::Management::Deployment::implementation
{
namespace
{
static std::optional<std::string> s_callerName;
static wil::srwlock s_callerNameLock;
}

void SetComCallerName(std::string name)
{
auto lock = s_callerNameLock.lock_exclusive();
s_callerName.emplace(std::move(name));
}

winrt::Windows::Foundation::Collections::IVectorView<winrt::Microsoft::Management::Deployment::PackageCatalogReference> PackageManager::GetPackageCatalogs()
{
Windows::Foundation::Collections::IVector<Microsoft::Management::Deployment::PackageCatalogReference> catalogs{ winrt::single_threaded_vector<Microsoft::Management::Deployment::PackageCatalogReference>() };
Expand Down Expand Up @@ -454,12 +442,8 @@ namespace winrt::Microsoft::Management::Deployment::implementation
{
std::unique_ptr<COMContext> context = std::make_unique<COMContext>();
hstring correlationData = (options) ? options.CorrelationData() : L"";
std::string callerName;
{
auto lock = s_callerNameLock.lock_shared();
callerName = s_callerName.has_value() ? s_callerName.value() : AppInstaller::Utility::ConvertToUTF8(callerProcessInfoString);
}
context->SetContextLoggers(correlationData, callerName);

context->SetContextLoggers(correlationData, GetComCallerName(AppInstaller::Utility::ConvertToUTF8(callerProcessInfoString)));

// Convert the options to arguments for the installer.
if constexpr (std::is_same_v<TOptions, winrt::Microsoft::Management::Deployment::InstallOptions>)
Expand Down

0 comments on commit dddc094

Please sign in to comment.