Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth2: fix token session subdomain #35845

Merged
merged 15 commits into from
Aug 31, 2024
56 changes: 39 additions & 17 deletions source/extensions/filters/http/oauth2/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,25 @@ Http::Utility::QueryParamsMulti buildAutorizationQueryParams(
return query_params;
}

std::string encodeHmacHexBase64(const std::vector<uint8_t>& secret, absl::string_view host,
std::string encodeHmacHexBase64(const std::vector<uint8_t>& secret, absl::string_view domain,
absl::string_view expires, absl::string_view token = "",
absl::string_view id_token = "",
absl::string_view refresh_token = "") {
auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
const auto hmac_payload =
absl::StrJoin({host, expires, token, id_token, refresh_token}, HmacPayloadSeparator);
absl::StrJoin({domain, expires, token, id_token, refresh_token}, HmacPayloadSeparator);
std::string encoded_hmac;
absl::Base64Escape(Hex::encode(crypto_util.getSha256Hmac(secret, hmac_payload)), &encoded_hmac);
return encoded_hmac;
}

std::string encodeHmacBase64(const std::vector<uint8_t>& secret, absl::string_view host,
std::string encodeHmacBase64(const std::vector<uint8_t>& secret, absl::string_view domain,
absl::string_view expires, absl::string_view token = "",
absl::string_view id_token = "",
absl::string_view refresh_token = "") {
auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
const auto hmac_payload =
absl::StrJoin({host, expires, token, id_token, refresh_token}, HmacPayloadSeparator);
absl::StrJoin({domain, expires, token, id_token, refresh_token}, HmacPayloadSeparator);

std::string base64_encoded_hmac;
std::vector<uint8_t> hmac_result = crypto_util.getSha256Hmac(secret, hmac_payload);
Expand All @@ -163,10 +163,10 @@ std::string encodeHmacBase64(const std::vector<uint8_t>& secret, absl::string_vi
return base64_encoded_hmac;
}

std::string encodeHmac(const std::vector<uint8_t>& secret, absl::string_view host,
std::string encodeHmac(const std::vector<uint8_t>& secret, absl::string_view domain,
absl::string_view expires, absl::string_view token = "",
absl::string_view id_token = "", absl::string_view refresh_token = "") {
return encodeHmacBase64(secret, host, expires, token, id_token, refresh_token);
return encodeHmacBase64(secret, domain, expires, token, id_token, refresh_token);
}

} // namespace
Expand Down Expand Up @@ -235,9 +235,14 @@ void OAuth2CookieValidator::setParams(const Http::RequestHeaderMap& headers,
bool OAuth2CookieValidator::canUpdateTokenByRefreshToken() const { return !refresh_token_.empty(); }

bool OAuth2CookieValidator::hmacIsValid() const {
return (
(encodeHmacBase64(secret_, host_, expires_, token_, id_token_, refresh_token_) == hmac_) ||
(encodeHmacHexBase64(secret_, host_, expires_, token_, id_token_, refresh_token_) == hmac_));
absl::string_view cookie_domain = host_;
if (!cookie_domain_.empty()) {
cookie_domain = cookie_domain_;
}
return ((encodeHmacBase64(secret_, cookie_domain, expires_, token_, id_token_, refresh_token_) ==
hmac_) ||
(encodeHmacHexBase64(secret_, cookie_domain, expires_, token_, id_token_,
refresh_token_) == hmac_));
}

bool OAuth2CookieValidator::timestampIsValid() const {
Expand All @@ -254,7 +259,8 @@ bool OAuth2CookieValidator::isValid() const { return hmacIsValid() && timestampI

OAuth2Filter::OAuth2Filter(FilterConfigSharedPtr config,
std::unique_ptr<OAuth2Client>&& oauth_client, TimeSource& time_source)
: validator_(std::make_shared<OAuth2CookieValidator>(time_source, config->cookieNames())),
: validator_(std::make_shared<OAuth2CookieValidator>(time_source, config->cookieNames(),
config->cookieDomain())),
oauth_client_(std::move(oauth_client)), config_(std::move(config)),
time_source_(time_source) {

Expand Down Expand Up @@ -500,18 +506,28 @@ Http::FilterHeadersStatus OAuth2Filter::signOutUser(const Http::RequestHeaderMap
{{Http::Headers::get().Status, std::to_string(enumToInt(Http::Code::Found))}})};

const std::string new_path = absl::StrCat(headers.getSchemeValue(), "://", host_, "/");

std::string cookie_domain;
if (!config_->cookieDomain().empty()) {
cookie_domain = fmt::format(CookieDomainFormatString, config_->cookieDomain());
}

response_headers->addReferenceKey(
Http::Headers::get().SetCookie,
fmt::format(CookieDeleteFormatString, config_->cookieNames().oauth_hmac_));
absl::StrCat(fmt::format(CookieDeleteFormatString, config_->cookieNames().oauth_hmac_),
cookie_domain));
response_headers->addReferenceKey(
Http::Headers::get().SetCookie,
fmt::format(CookieDeleteFormatString, config_->cookieNames().bearer_token_));
absl::StrCat(fmt::format(CookieDeleteFormatString, config_->cookieNames().bearer_token_),
cookie_domain));
response_headers->addReferenceKey(
Http::Headers::get().SetCookie,
fmt::format(CookieDeleteFormatString, config_->cookieNames().id_token_));
absl::StrCat(fmt::format(CookieDeleteFormatString, config_->cookieNames().id_token_),
cookie_domain));
response_headers->addReferenceKey(
Http::Headers::get().SetCookie,
fmt::format(CookieDeleteFormatString, config_->cookieNames().refresh_token_));
absl::StrCat(fmt::format(CookieDeleteFormatString, config_->cookieNames().refresh_token_),
cookie_domain));
response_headers->setLocation(new_path);
decoder_callbacks_->encodeHeaders(std::move(response_headers), true, SIGN_OUT);

Expand Down Expand Up @@ -542,11 +558,17 @@ std::string OAuth2Filter::getEncodedToken() const {
auto token_secret = config_->tokenSecret();
std::vector<uint8_t> token_secret_vec(token_secret.begin(), token_secret.end());
std::string encoded_token;

absl::string_view domain = host_;
if (!config_->cookieDomain().empty()) {
domain = config_->cookieDomain();
}

if (config_->forwardBearerToken()) {
encoded_token =
encodeHmac(token_secret_vec, host_, new_expires_, access_token_, id_token_, refresh_token_);
encoded_token = encodeHmac(token_secret_vec, domain, new_expires_, access_token_, id_token_,
refresh_token_);
} else {
encoded_token = encodeHmac(token_secret_vec, host_, new_expires_);
encoded_token = encodeHmac(token_secret_vec, domain, new_expires_);
}
return encoded_token;
}
Expand Down
6 changes: 4 additions & 2 deletions source/extensions/filters/http/oauth2/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ class CookieValidator {

class OAuth2CookieValidator : public CookieValidator {
public:
explicit OAuth2CookieValidator(TimeSource& time_source, const CookieNames& cookie_names)
: time_source_(time_source), cookie_names_(cookie_names) {}
explicit OAuth2CookieValidator(TimeSource& time_source, const CookieNames& cookie_names,
const std::string& cookie_domain)
: time_source_(time_source), cookie_names_(cookie_names), cookie_domain_(cookie_domain) {}

const std::string& token() const override { return token_; }
const std::string& refreshToken() const override { return refresh_token_; }
Expand All @@ -226,6 +227,7 @@ class OAuth2CookieValidator : public CookieValidator {
absl::string_view host_;
TimeSource& time_source_;
const CookieNames cookie_names_;
const std::string cookie_domain_;
};

/**
Expand Down
50 changes: 41 additions & 9 deletions test/extensions/filters/http/oauth2/filter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class OAuth2Test : public testing::TestWithParam<int> {
}

// Validates the behavior of the cookie validator.
void expectValidCookies(const CookieNames& cookie_names) {
void expectValidCookies(const CookieNames& cookie_names, const std::string& cookie_domain) {
// Set SystemTime to a fixed point so we get consistent HMAC encodings between test runs.
test_time_.setSystemTime(SystemTime(std::chrono::seconds(0)));

Expand All @@ -188,7 +188,8 @@ class OAuth2Test : public testing::TestWithParam<int> {
absl::StrCat(cookie_names.oauth_hmac_, "=dCu0otMcLoaGF73jrT+R8rGA0pnWyMgNf4+GivGrHEI=")},
};

auto cookie_validator = std::make_shared<OAuth2CookieValidator>(test_time_, cookie_names);
auto cookie_validator =
std::make_shared<OAuth2CookieValidator>(test_time_, cookie_names, cookie_domain);
EXPECT_EQ(cookie_validator->token(), "");
EXPECT_EQ(cookie_validator->refreshToken(), "");
cookie_validator->setParams(request_headers, "mock-secret");
Expand Down Expand Up @@ -881,13 +882,44 @@ TEST_F(OAuth2Test, AjaxDoesNotRedirect) {
// Validates the behavior of the cookie validator.
TEST_F(OAuth2Test, CookieValidator) {
expectValidCookies(
CookieNames{"BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"});
CookieNames{"BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"}, "");
}

// Validates the behavior of the cookie validator with custom cookie names.
TEST_F(OAuth2Test, CookieValidatorWithCustomNames) {
expectValidCookies(CookieNames{"CustomBearerToken", "CustomOauthHMAC", "CustomOauthExpires",
"CustomIdToken", "CustomRefreshToken"});
"CustomIdToken", "CustomRefreshToken"},
"");
}

// Validates the behavior of the cookie validator with custom cookie domain.
TEST_F(OAuth2Test, CookieValidatorWithCookieDomain) {
test_time_.setSystemTime(SystemTime(std::chrono::seconds(0)));
auto cookie_names =
CookieNames{"BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"};
const auto expires_at_s = DateUtil::nowToSeconds(test_time_.timeSystem()) + 5;

Http::TestRequestHeaderMapImpl request_headers{
{Http::Headers::get().Host.get(), "traffic.example.com"},
{Http::Headers::get().Path.get(), "/anypath"},
{Http::Headers::get().Method.get(), Http::Headers::get().MethodValues.Get},
{Http::Headers::get().Cookie.get(),
fmt::format("{}={}", cookie_names.oauth_expires_, expires_at_s)},
{Http::Headers::get().Cookie.get(), absl::StrCat(cookie_names.bearer_token_, "=xyztoken")},
{Http::Headers::get().Cookie.get(),
absl::StrCat(cookie_names.oauth_hmac_, "=zgWoFFmB6rbPHQQYQj35H+Fz+GYZgUrh/C48y0WHWRM=")},
};

auto cookie_validator =
std::make_shared<OAuth2CookieValidator>(test_time_, cookie_names, "example.com");

EXPECT_EQ(cookie_validator->token(), "");
EXPECT_EQ(cookie_validator->refreshToken(), "");
cookie_validator->setParams(request_headers, "mock-secret");

EXPECT_TRUE(cookie_validator->hmacIsValid());
EXPECT_TRUE(cookie_validator->timestampIsValid());
EXPECT_TRUE(cookie_validator->isValid());
}

// Validates the behavior of the cookie validator when the combination of some fields could be same.
Expand All @@ -909,7 +941,7 @@ TEST_F(OAuth2Test, CookieValidatorSame) {
absl::StrCat(cookie_names.oauth_hmac_, "=MSq8mkNQGdXx2LKGlLHMwSIj8rLZRnrHE6EWvvTUFx0=")},
};

auto cookie_validator = std::make_shared<OAuth2CookieValidator>(test_time_, cookie_names);
auto cookie_validator = std::make_shared<OAuth2CookieValidator>(test_time_, cookie_names, "");
EXPECT_EQ(cookie_validator->token(), "");
cookie_validator->setParams(request_headers, "mock-secret");

Expand Down Expand Up @@ -966,7 +998,7 @@ TEST_F(OAuth2Test, CookieValidatorInvalidExpiresAt) {

auto cookie_validator = std::make_shared<OAuth2CookieValidator>(
test_time_,
CookieNames{"BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"});
CookieNames{"BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"}, "");
cookie_validator->setParams(request_headers, "mock-secret");

EXPECT_TRUE(cookie_validator->hmacIsValid());
Expand All @@ -986,7 +1018,7 @@ TEST_F(OAuth2Test, CookieValidatorCanUpdateToken) {

auto cookie_validator = std::make_shared<OAuth2CookieValidator>(
test_time_,
CookieNames("BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"));
CookieNames("BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"), "");
cookie_validator->setParams(request_headers, "mock-secret");

EXPECT_TRUE(cookie_validator->canUpdateTokenByRefreshToken());
Expand Down Expand Up @@ -1398,7 +1430,7 @@ TEST_F(OAuth2Test, OAuthTestFullFlowPostWithCookieDomain) {
Http::TestRequestHeaderMapImpl second_response_headers{
{Http::Headers::get().Status.get(), "302"},
{Http::Headers::get().SetCookie.get(),
"OauthHMAC=fV62OgLipChTQQC3UFgDp+l5sCiSb3zt7nCoJiVivWw=;"
"OauthHMAC=aPoIhN7QYMrYc9nTGCCWgd3rJpZIEdjOtxPDdmVDS6E=;"
"domain=example.com;path=/;Max-Age=;secure;HttpOnly"},
{Http::Headers::get().SetCookie.get(),
"OauthExpires=;domain=example.com;path=/;Max-Age=;secure;HttpOnly"},
Expand Down Expand Up @@ -1988,7 +2020,7 @@ TEST_F(OAuth2Test, CookieValidatorInTransition) {

auto cookie_validator = std::make_shared<OAuth2CookieValidator>(
test_time_,
CookieNames{"BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"});
CookieNames{"BearerToken", "OauthHMAC", "OauthExpires", "IdToken", "RefreshToken"}, "");
cookie_validator->setParams(request_headers_base64only, "mock-secret");
EXPECT_TRUE(cookie_validator->hmacIsValid());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ name: oauth
Http::Headers::get().Cookie,
absl::StrCat(default_cookie_names_.refresh_token_, "=", refreshToken));

OAuth2CookieValidator validator{api_->timeSource(), default_cookie_names_};
OAuth2CookieValidator validator{api_->timeSource(), default_cookie_names_, ""};
validator.setParams(validate_headers, std::string(hmac_secret));
return validator.isValid();
}
Expand Down