Skip to content

Commit

Permalink
feat: Introduce the functionality to override token_uri in credentials (
Browse files Browse the repository at this point in the history
#1159)

* feat: Introduce the functionality to override token_uri in credentials

* update rt
  • Loading branch information
sai-sunder-s authored Oct 11, 2022
1 parent 75326e3 commit 73bc7e9
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 4 deletions.
27 changes: 26 additions & 1 deletion google/auth/compute_engine/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def with_scopes(self, scopes, default_scopes=None):
_DEFAULT_TOKEN_URI = "https://www.googleapis.com/oauth2/v4/token"


class IDTokenCredentials(credentials.CredentialsWithQuotaProject, credentials.Signing):
class IDTokenCredentials(
credentials.CredentialsWithQuotaProject,
credentials.Signing,
credentials.CredentialsWithTokenUri,
):
"""Open ID Connect ID Token-based service account credentials.
These credentials relies on the default service account of a GCE instance.
Expand Down Expand Up @@ -302,6 +306,27 @@ def with_quota_project(self, quota_project_id):
quota_project_id=quota_project_id,
)

@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
def with_token_uri(self, token_uri):

# since the signer is already instantiated,
# the request is not needed
if self._use_metadata_identity_endpoint:
raise ValueError(
"If use_metadata_identity_endpoint is set, token_uri" " must not be set"
)
else:
return self.__class__(
None,
service_account_email=self._service_account_email,
token_uri=token_uri,
target_audience=self._target_audience,
additional_claims=self._additional_claims.copy(),
signer=self.signer,
use_metadata_identity_endpoint=False,
quota_project_id=self.quota_project_id,
)

def _make_authorization_grant_assertion(self):
"""Create the OAuth 2.0 assertion.
This assertion is used during the OAuth 2.0 grant to acquire an
Expand Down
15 changes: 15 additions & 0 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,21 @@ def with_quota_project(self, quota_project_id):
raise NotImplementedError("This credential does not support quota project.")


class CredentialsWithTokenUri(Credentials):
"""Abstract base for credentials supporting ``with_token_uri`` factory"""

def with_token_uri(self, token_uri):
"""Returns a copy of these credentials with a modified token uri.
Args:
token_uri (str): The uri to use for fetching/exchanging tokens
Returns:
google.oauth2.credentials.Credentials: A new credentials instance.
"""
raise NotImplementedError("This credential does not use token uri.")


class AnonymousCredentials(Credentials):
"""Credentials that do not provide any authentication information.
Expand Down
26 changes: 25 additions & 1 deletion google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@


@six.add_metaclass(abc.ABCMeta)
class Credentials(credentials.Scoped, credentials.CredentialsWithQuotaProject):
class Credentials(
credentials.Scoped,
credentials.CredentialsWithQuotaProject,
credentials.CredentialsWithTokenUri,
):
"""Base class for all external account credentials.
This is used to instantiate Credentials for exchanging external account
Expand Down Expand Up @@ -382,6 +386,26 @@ def with_quota_project(self, quota_project_id):
d.pop("workforce_pool_user_project")
return self.__class__(**d)

@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
def with_token_uri(self, token_uri):
d = dict(
audience=self._audience,
subject_token_type=self._subject_token_type,
token_url=token_uri,
credential_source=self._credential_source,
service_account_impersonation_url=self._service_account_impersonation_url,
service_account_impersonation_options=self._service_account_impersonation_options,
client_id=self._client_id,
client_secret=self._client_secret,
quota_project_id=self._quota_project_id,
scopes=self._scopes,
default_scopes=self._default_scopes,
workforce_pool_user_project=self._workforce_pool_user_project,
)
if not self.is_workforce_pool:
d.pop("workforce_pool_user_project")
return self.__class__(**d)

def _initialize_impersonated_credentials(self):
"""Generates an impersonated credentials.
Expand Down
17 changes: 17 additions & 0 deletions google/oauth2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,23 @@ def with_quota_project(self, quota_project_id):
enable_reauth_refresh=self._enable_reauth_refresh,
)

@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
def with_token_uri(self, token_uri):

return self.__class__(
self.token,
refresh_token=self.refresh_token,
id_token=self.id_token,
token_uri=token_uri,
client_id=self.client_id,
client_secret=self.client_secret,
scopes=self.scopes,
default_scopes=self.default_scopes,
quota_project_id=self.quota_project_id,
rapt_token=self.rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
)

@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
scopes = self._scopes if self._scopes is not None else self._default_scopes
Expand Down
38 changes: 36 additions & 2 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@


class Credentials(
credentials.Signing, credentials.Scoped, credentials.CredentialsWithQuotaProject
credentials.Signing,
credentials.Scoped,
credentials.CredentialsWithQuotaProject,
credentials.CredentialsWithTokenUri,
):
"""Service account credentials
Expand Down Expand Up @@ -364,6 +367,22 @@ def with_quota_project(self, quota_project_id):
always_use_jwt_access=self._always_use_jwt_access,
)

@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
def with_token_uri(self, token_uri):

return self.__class__(
self._signer,
service_account_email=self._service_account_email,
default_scopes=self._default_scopes,
scopes=self._scopes,
token_uri=token_uri,
subject=self._subject,
project_id=self._project_id,
quota_project_id=self._quota_project_id,
additional_claims=self._additional_claims.copy(),
always_use_jwt_access=self._always_use_jwt_access,
)

def _make_authorization_grant_assertion(self):
"""Create the OAuth 2.0 assertion.
Expand Down Expand Up @@ -455,7 +474,11 @@ def signer_email(self):
return self._service_account_email


class IDTokenCredentials(credentials.Signing, credentials.CredentialsWithQuotaProject):
class IDTokenCredentials(
credentials.Signing,
credentials.CredentialsWithQuotaProject,
credentials.CredentialsWithTokenUri,
):
"""Open ID Connect ID Token-based service account credentials.
These credentials are largely similar to :class:`.Credentials`, but instead
Expand Down Expand Up @@ -627,6 +650,17 @@ def with_quota_project(self, quota_project_id):
quota_project_id=quota_project_id,
)

@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
def with_token_uri(self, token_uri):
return self.__class__(
self._signer,
service_account_email=self._service_account_email,
token_uri=token_uri,
target_audience=self._target_audience,
additional_claims=self._additional_claims.copy(),
quota_project_id=self._quota_project_id,
)

def _make_authorization_grant_assertion(self):
"""Create the OAuth 2.0 assertion.
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
44 changes: 44 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,50 @@ def test_with_quota_project(self, sign, get, utcnow):
# Check that the signer have been initialized with a Request object
assert isinstance(self.credentials._signer._request, transport.Request)

@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.utcfromtimestamp(0),
)
@mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
@mock.patch("google.auth.iam.Signer.sign", autospec=True)
def test_with_token_uri(self, sign, get, utcnow):
get.side_effect = [
{"email": "[email protected]", "scopes": ["one", "two"]}
]
sign.side_effect = [b"signature"]

request = mock.create_autospec(transport.Request, instance=True)
self.credentials = credentials.IDTokenCredentials(
request=request,
target_audience="https://audience.com",
token_uri="http://xyz.com",
)
assert self.credentials._token_uri == "http://xyz.com"
creds_with_token_uri = self.credentials.with_token_uri("http://abc.com")
assert creds_with_token_uri._token_uri == "http://abc.com"

@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.utcfromtimestamp(0),
)
@mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
@mock.patch("google.auth.iam.Signer.sign", autospec=True)
def test_with_token_uri_exception(self, sign, get, utcnow):
get.side_effect = [
{"email": "[email protected]", "scopes": ["one", "two"]}
]
sign.side_effect = [b"signature"]

request = mock.create_autospec(transport.Request, instance=True)
self.credentials = credentials.IDTokenCredentials(
request=request,
target_audience="https://audience.com",
use_metadata_identity_endpoint=True,
)
assert self.credentials._token_uri is None
with pytest.raises(ValueError):
self.credentials.with_token_uri("http://abc.com")

@responses.activate
def test_with_quota_project_integration(self):
""" Test that it is possible to refresh credentials
Expand Down
12 changes: 12 additions & 0 deletions tests/oauth2/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,18 @@ def test_with_quota_project(self):
creds.apply(headers)
assert "x-goog-user-project" in headers

def test_with_token_uri(self):
info = AUTH_USER_INFO.copy()

creds = credentials.Credentials.from_authorized_user_info(info)
new_token_uri = "https://oauth2-eu.googleapis.com/token"

assert creds._token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT

creds_with_new_token_uri = creds.with_token_uri(new_token_uri)

assert creds_with_new_token_uri._token_uri == new_token_uri

def test_from_authorized_user_info(self):
info = AUTH_USER_INFO.copy()

Expand Down
14 changes: 14 additions & 0 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def test_with_quota_project(self):
new_credentials.apply(hdrs, token="tok")
assert "x-goog-user-project" in hdrs

def test_with_token_uri(self):
credentials = self.make_credentials()
new_token_uri = "https://example2.com/oauth2/token"
assert credentials._token_uri == self.TOKEN_URI
creds_with_new_token_uri = credentials.with_token_uri(new_token_uri)
assert creds_with_new_token_uri._token_uri == new_token_uri

def test__with_always_use_jwt_access(self):
credentials = self.make_credentials()
assert not credentials._always_use_jwt_access
Expand Down Expand Up @@ -464,6 +471,13 @@ def test_with_quota_project(self):
new_credentials = credentials.with_quota_project("project-foo")
assert new_credentials._quota_project_id == "project-foo"

def test_with_token_uri(self):
credentials = self.make_credentials()
new_token_uri = "https://example2.com/oauth2/token"
assert credentials._token_uri == self.TOKEN_URI
creds_with_new_token_uri = credentials.with_token_uri(new_token_uri)
assert creds_with_new_token_uri._token_uri == new_token_uri

def test__make_authorization_grant_assertion(self):
credentials = self.make_credentials()
token = credentials._make_authorization_grant_assertion()
Expand Down
27 changes: 27 additions & 0 deletions tests/test_external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,33 @@ def test_with_scopes_full_options_propagated(self):
workforce_pool_user_project=None,
)

def test_with_token_uri(self):
credentials = self.make_credentials()
new_token_uri = "https://eu-sts.googleapis.com/v1/token"

assert credentials._token_url == self.TOKEN_URL

creds_with_new_token_uri = credentials.with_token_uri(new_token_uri)

assert creds_with_new_token_uri._token_url == new_token_uri

def test_with_token_uri_workforce_pool(self):
credentials = self.make_workforce_pool_credentials(
workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
)

new_token_uri = "https://eu-sts.googleapis.com/v1/token"

assert credentials._token_url == self.TOKEN_URL

creds_with_new_token_uri = credentials.with_token_uri(new_token_uri)

assert creds_with_new_token_uri._token_url == new_token_uri
assert (
creds_with_new_token_uri.info.get("workforce_pool_user_project")
== self.WORKFORCE_POOL_USER_PROJECT
)

def test_with_quota_project(self):
credentials = self.make_credentials()

Expand Down

0 comments on commit 73bc7e9

Please sign in to comment.