Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass COM caller name to rest source in request header #3112

Merged
merged 3 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/AppInstallerCLICore/Workflows/WorkflowBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ namespace AppInstaller::CLI::Workflow
}
}

auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails> { return source.Open(progress); };
auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails>
{
source.SetCaller("winget-cli");
return source.Open(progress);
};
auto updateFailures = context.Reporter.ExecuteWithProgress(openFunction, true);

// We'll only report the source update failure as warning and continue
Expand Down Expand Up @@ -376,7 +380,11 @@ namespace AppInstaller::CLI::Workflow
// A well known predefined source should return a value.
THROW_HR_IF(E_UNEXPECTED, !source);

auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails> { return source.Open(progress); };
auto openFunction = [&](IProgressCallback& progress)->std::vector<Repository::SourceDetails>
{
source.SetCaller("winget-cli");
return source.Open(progress);
};
context.Reporter.ExecuteWithProgress(openFunction, true);
}
catch (...)
Expand Down
39 changes: 37 additions & 2 deletions src/AppInstallerCLITests/CustomHeader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TEST_CASE("RestClient_CustomHeader", "[RestSource][CustomHeader]")
std::optional<std::string> customHeader = "Testing custom header";
auto header = std::make_pair<>(CustomHeaderName, JSON::GetUtilityString(customHeader.value()));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, std::move(helper));
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}

Expand Down Expand Up @@ -130,6 +130,41 @@ TEST_CASE("RestSourceSearch_CustomHeaderExceedingSize", "[RestSource][CustomHead
auto header = std::make_pair<>(CustomHeaderName, JSON::GetUtilityString(customHeader));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sampleSearchResponse, header) };

REQUIRE_THROWS_HR(RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, std::move(helper)),
REQUIRE_THROWS_HR(RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), customHeader, {}, std::move(helper)),
APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH);
}

TEST_CASE("RestClient_CustomUserAgentHeader", "[RestSource][CustomHeader]")
{
utility::string_t sample = _XPLATSTR(
R"delimiter({
"Data" : {
"SourceIdentifier": "Source123",
"ServerSupportedVersions": [
"1.0.0",
"2.0.0"]
}})delimiter");

std::string testCaller = "TestCaller";
auto header = std::make_pair<>(web::http::header_names::user_agent, JSON::GetUtilityString(Runtime::GetUserAgent(testCaller)));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), {}, testCaller, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}

TEST_CASE("RestClient_DefaultUserAgentHeader", "[RestSource][CustomHeader]")
{
utility::string_t sample = _XPLATSTR(
R"delimiter({
"Data" : {
"SourceIdentifier": "Source123",
"ServerSupportedVersions": [
"1.0.0",
"2.0.0"]
}})delimiter");

auto header = std::make_pair<>(web::http::header_names::user_agent, JSON::GetUtilityString(Runtime::GetDefaultUserAgent()));
HttpClientHelper helper{ GetCustomHeaderVerificationHandler(web::http::status_codes::OK, sample, header) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string("https://restsource.com/api"), {}, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}
6 changes: 3 additions & 3 deletions src/AppInstallerCLITests/RestClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ TEST_CASE("RestClientCreate_UnsupportedVersion", "[RestSource]")
}})delimiter");

HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) };
REQUIRE_THROWS_HR(RestClient::Create("https://restsource.com/api", {}, std::move(helper)), APPINSTALLER_CLI_ERROR_UNSUPPORTED_RESTSOURCE);
REQUIRE_THROWS_HR(RestClient::Create("https://restsource.com/api", {}, {}, std::move(helper)), APPINSTALLER_CLI_ERROR_UNSUPPORTED_RESTSOURCE);
}

TEST_CASE("RestClientCreate_1.0_Success", "[RestSource]")
Expand All @@ -157,7 +157,7 @@ TEST_CASE("RestClientCreate_1.0_Success", "[RestSource]")
}})delimiter");

HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, std::move(helper));
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
}

Expand Down Expand Up @@ -194,7 +194,7 @@ TEST_CASE("RestClientCreate_1.1_Success", "[RestSource]")
}})delimiter");

HttpClientHelper helper{ GetTestRestRequestHandler(web::http::status_codes::OK, sample) };
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, std::move(helper));
RestClient client = RestClient::Create(utility::conversions::to_utf8string(TestRestUri), {}, {}, std::move(helper));
REQUIRE(client.GetSourceIdentifier() == "Source123");
auto information = client.GetSourceInformation();
REQUIRE(information.SourceAgreementsIdentifier == "agreementV1");
Expand Down
3 changes: 3 additions & 0 deletions src/AppInstallerCommonCore/Public/AppInstallerRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,7 @@ namespace AppInstaller::Runtime

// Gets the default user agent string for the Windows Package Manager.
Utility::LocIndString GetDefaultUserAgent();

// Gets the user agent string from passed in caller for the Windows Package Manager.
Utility::LocIndString GetUserAgent(std::string_view caller);
}
12 changes: 11 additions & 1 deletion src/AppInstallerCommonCore/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,17 @@ namespace AppInstaller::Runtime
{
std::ostringstream strstr;
strstr <<
"winget-cli"
"winget-cli" <<
" WindowsPackageManager/" << GetClientVersion() <<
" DesktopAppInstaller/" << GetPackageVersion();
return Utility::LocIndString{ strstr.str() };
}

Utility::LocIndString GetUserAgent(std::string_view caller)
{
std::ostringstream strstr;
strstr <<
caller <<
" WindowsPackageManager/" << GetClientVersion() <<
" DesktopAppInstaller/" << GetPackageVersion();
return Utility::LocIndString{ strstr.str() };
Expand Down
5 changes: 4 additions & 1 deletion src/AppInstallerRepositoryCore/ISource.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ namespace AppInstaller::Repository
virtual SourceInformation GetInformation() { return {}; }

// Set custom header. Returns false if custom header is not supported.
virtual bool SetCustomHeader(std::optional<std::string> header) { UNREFERENCED_PARAMETER(header); return false; }
virtual bool SetCustomHeader(std::optional<std::string>) { return false; }

// Set caller.
virtual void SetCaller(std::string) {}

// Opens the source. This function should throw upon open failure rather than returning an empty pointer.
virtual std::shared_ptr<ISource> Open(IProgressCallback& progress) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ namespace AppInstaller::Repository
// Set custom header.
bool SetCustomHeader(std::optional<std::string> header);

// Set caller.
void SetCaller(std::string caller);

// Execute a search on the source.
SearchResult Search(const SearchRequest& request) const;

Expand Down
8 changes: 8 additions & 0 deletions src/AppInstallerRepositoryCore/RepositorySource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,14 @@ namespace AppInstaller::Repository
return m_sourceReferences[0]->SetCustomHeader(header);
}

void Source::SetCaller(std::string caller)
{
for (auto& sourceReference : m_sourceReferences)
{
sourceReference->SetCaller(caller);
}
}

SearchResult Source::Search(const SearchRequest& request) const
{
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_source);
Expand Down
30 changes: 21 additions & 9 deletions src/AppInstallerRepositoryCore/Rest/RestClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,30 @@ namespace AppInstaller::Repository::Rest
constexpr size_t WindowsPackageManagerHeaderMaxLength = 1024;

namespace {
std::unordered_map<utility::string_t, utility::string_t> GetHeaders(std::optional<std::string> customHeader)
std::unordered_map<utility::string_t, utility::string_t> GetHeaders(std::optional<std::string> customHeader, std::string_view caller)
{
if (!customHeader)
std::unordered_map<utility::string_t, utility::string_t> headers;

if (customHeader)
{
AICLI_LOG(Repo, Verbose, << "Custom header not found.");
return {};
AICLI_LOG(Repo, Verbose, << "Custom header found: " << customHeader.value());
THROW_HR_IF(APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH, customHeader.value().size() > WindowsPackageManagerHeaderMaxLength);
headers.emplace(JSON::GetUtilityString(WindowsPackageManagerHeader), JSON::GetUtilityString(customHeader.value()));
}

THROW_HR_IF(APPINSTALLER_CLI_ERROR_CUSTOMHEADER_EXCEEDS_MAXLENGTH, customHeader.value().size() > WindowsPackageManagerHeaderMaxLength);
if (!caller.empty())
{
AICLI_LOG(Repo, Verbose, << "User agent caller found: " << caller);
std::wstring userAgentWide = JSON::GetUtilityString(Runtime::GetUserAgent(caller));
try
{
// Replace user profile if the caller binary is under user profile.
userAgentWide = Utility::ReplaceWhileCopying(userAgentWide, Runtime::GetPathTo(Runtime::PathName::UserProfile).wstring(), L"%USERPROFILE%");
}
CATCH_LOG();
headers.emplace(web::http::header_names::user_agent, userAgentWide);
}

std::unordered_map<utility::string_t, utility::string_t> headers;
headers.emplace(JSON::GetUtilityString(WindowsPackageManagerHeader), JSON::GetUtilityString(customHeader.value()));
return headers;
}
}
Expand Down Expand Up @@ -139,12 +151,12 @@ namespace AppInstaller::Repository::Rest
THROW_HR(APPINSTALLER_CLI_ERROR_RESTSOURCE_INVALID_VERSION);
}

RestClient RestClient::Create(const std::string& restApi, std::optional<std::string> customHeader, const HttpClientHelper& helper)
RestClient RestClient::Create(const std::string& restApi, std::optional<std::string> customHeader, std::string_view caller, const HttpClientHelper& helper)
{
utility::string_t restEndpoint = RestHelper::GetRestAPIBaseUri(restApi);
THROW_HR_IF(APPINSTALLER_CLI_ERROR_RESTSOURCE_INVALID_URL, !RestHelper::IsValidUri(restEndpoint));

auto headers = GetHeaders(customHeader);
auto headers = GetHeaders(customHeader, caller);

IRestClient::Information information = GetInformation(restEndpoint, headers, helper);
std::optional<Version> latestCommonVersion = GetLatestCommonVersion(information.ServerSupportedVersions, WingetSupportedContracts);
Expand Down
2 changes: 1 addition & 1 deletion src/AppInstallerRepositoryCore/Rest/RestClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace AppInstaller::Repository::Rest

static std::unique_ptr<Schema::IRestClient> GetSupportedInterface(const std::string& restApi, const std::unordered_map<utility::string_t, utility::string_t>& additionalHeaders, const Schema::IRestClient::Information& information, const AppInstaller::Utility::Version& version);

static RestClient Create(const std::string& restApi, std::optional<std::string> customHeader, const Schema::HttpClientHelper& helper = {});
static RestClient Create(const std::string& restApi, std::optional<std::string> customHeader, std::string_view caller, const Schema::HttpClientHelper& helper = {});
private:
RestClient(std::unique_ptr<Schema::IRestClient> supportedInterface, std::string sourceIdentifier);

Expand Down
10 changes: 8 additions & 2 deletions src/AppInstallerRepositoryCore/Rest/RestSourceFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@ namespace AppInstaller::Repository::Rest
return true;
}

void SetCaller(std::string caller) override
{
m_caller = std::move(caller);
}

std::shared_ptr<ISource> Open(IProgressCallback&) override
{
Initialize();
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_httpClientHelper);
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_caller, m_httpClientHelper);
return std::make_shared<RestSource>(m_details, m_information, std::move(restClient));
}

Expand All @@ -51,7 +56,7 @@ namespace AppInstaller::Repository::Rest
[&]()
{
m_httpClientHelper.SetPinningConfiguration(m_details.CertificatePinningConfiguration);
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_httpClientHelper);
RestClient restClient = RestClient::Create(m_details.Arg, m_customHeader, m_caller, m_httpClientHelper);

m_details.Identifier = restClient.GetSourceIdentifier();

Expand All @@ -73,6 +78,7 @@ namespace AppInstaller::Repository::Rest
Schema::HttpClientHelper m_httpClientHelper;
SourceInformation m_information;
std::optional<std::string> m_customHeader;
std::string m_caller;
std::once_flag m_initializeFlag;
};

Expand Down
18 changes: 18 additions & 0 deletions src/Microsoft.Management.Deployment/Helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ using namespace std::string_view_literals;

namespace winrt::Microsoft::Management::Deployment::implementation
{
namespace
{
static std::optional<std::string> s_callerName;
static wil::srwlock s_callerNameLock;
}

void SetComCallerName(std::string name)
{
auto lock = s_callerNameLock.lock_exclusive();
s_callerName.emplace(std::move(name));
}

std::string GetComCallerName(std::string defaultNameIfNotSet)
{
auto lock = s_callerNameLock.lock_shared();
return s_callerName.has_value() ? s_callerName.value() : defaultNameIfNotSet;
}

std::pair<HRESULT, DWORD> GetCallerProcessId()
{
RPC_STATUS rpcStatus = RPC_S_OK;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.Management.Deployment/Helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace winrt::Microsoft::Management::Deployment::implementation
{
void SetComCallerName(std::string name);
std::string GetComCallerName(std::string defaultNameIfNotSet);

// Enable custom code to run before creating any object through the factory.
// Currently that means requiring the overall WinGet policy to be enabled.
template <typename TCppWinRTClass>
Expand Down
30 changes: 30 additions & 0 deletions src/Microsoft.Management.Deployment/PackageCatalogReference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,38 @@
#include <wil\cppwinrt_wrl.h>
#include <winget/GroupPolicy.h>
#include <AppInstallerErrors.h>
#include <AppInstallerStrings.h>
#include <Helpers.h>

namespace winrt::Microsoft::Management::Deployment::implementation
{
namespace
{
std::string GetCallerName()
{
// See if caller name is set by caller
static auto callerName = GetComCallerName("");

// Get process string
if (callerName.empty())
{
try
{
auto [hrGetCallerId, callerProcessId] = GetCallerProcessId();
THROW_IF_FAILED(hrGetCallerId);
callerName = AppInstaller::Utility::ConvertToUTF8(TryGetCallerProcessInfo(callerProcessId));
}
CATCH_LOG();
}

if (callerName.empty())
{
callerName = "UnknownComCaller";
}

return callerName;
}
}
void PackageCatalogReference::Initialize(winrt::Microsoft::Management::Deployment::PackageCatalogInfo packageCatalogInfo, ::AppInstaller::Repository::Source sourceReference)
{
m_info = packageCatalogInfo;
Expand Down Expand Up @@ -77,6 +105,7 @@ namespace winrt::Microsoft::Management::Deployment::implementation
auto catalog = m_compositePackageCatalogOptions.Catalogs().GetAt(i);
winrt::Microsoft::Management::Deployment::implementation::PackageCatalogReference* catalogImpl = get_self<winrt::Microsoft::Management::Deployment::implementation::PackageCatalogReference>(catalog);
auto copy = catalogImpl->m_sourceReference;
copy.SetCaller(GetCallerName());
copy.Open(progress);
remoteSources.emplace_back(std::move(copy));
}
Expand Down Expand Up @@ -112,6 +141,7 @@ namespace winrt::Microsoft::Management::Deployment::implementation
else
{
source = m_sourceReference;
source.SetCaller(GetCallerName());
source.Open(progress);
}

Expand Down
20 changes: 2 additions & 18 deletions src/Microsoft.Management.Deployment/PackageManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,6 @@ using namespace ::AppInstaller::CLI::Execution;

namespace winrt::Microsoft::Management::Deployment::implementation
{
namespace
{
static std::optional<std::string> s_callerName;
static wil::srwlock s_callerNameLock;
}

void SetComCallerName(std::string name)
{
auto lock = s_callerNameLock.lock_exclusive();
s_callerName.emplace(std::move(name));
}

winrt::Windows::Foundation::Collections::IVectorView<winrt::Microsoft::Management::Deployment::PackageCatalogReference> PackageManager::GetPackageCatalogs()
{
Windows::Foundation::Collections::IVector<Microsoft::Management::Deployment::PackageCatalogReference> catalogs{ winrt::single_threaded_vector<Microsoft::Management::Deployment::PackageCatalogReference>() };
Expand Down Expand Up @@ -454,12 +442,8 @@ namespace winrt::Microsoft::Management::Deployment::implementation
{
std::unique_ptr<COMContext> context = std::make_unique<COMContext>();
hstring correlationData = (options) ? options.CorrelationData() : L"";
std::string callerName;
{
auto lock = s_callerNameLock.lock_shared();
callerName = s_callerName.has_value() ? s_callerName.value() : AppInstaller::Utility::ConvertToUTF8(callerProcessInfoString);
}
context->SetContextLoggers(correlationData, callerName);

context->SetContextLoggers(correlationData, GetComCallerName(AppInstaller::Utility::ConvertToUTF8(callerProcessInfoString)));

// Convert the options to arguments for the installer.
if constexpr (std::is_same_v<TOptions, winrt::Microsoft::Management::Deployment::InstallOptions>)
Expand Down