diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py index 59b48dae6..e97fabea9 100644 --- a/google/auth/compute_engine/credentials.py +++ b/google/auth/compute_engine/credentials.py @@ -154,7 +154,11 @@ def with_scopes(self, scopes, default_scopes=None): _DEFAULT_TOKEN_URI = "https://www.googleapis.com/oauth2/v4/token" -class IDTokenCredentials(credentials.CredentialsWithQuotaProject, credentials.Signing): +class IDTokenCredentials( + credentials.CredentialsWithQuotaProject, + credentials.Signing, + credentials.CredentialsWithTokenUri, +): """Open ID Connect ID Token-based service account credentials. These credentials relies on the default service account of a GCE instance. @@ -302,6 +306,27 @@ def with_quota_project(self, quota_project_id): quota_project_id=quota_project_id, ) + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + + # since the signer is already instantiated, + # the request is not needed + if self._use_metadata_identity_endpoint: + raise ValueError( + "If use_metadata_identity_endpoint is set, token_uri" " must not be set" + ) + else: + return self.__class__( + None, + service_account_email=self._service_account_email, + token_uri=token_uri, + target_audience=self._target_audience, + additional_claims=self._additional_claims.copy(), + signer=self.signer, + use_metadata_identity_endpoint=False, + quota_project_id=self.quota_project_id, + ) + def _make_authorization_grant_assertion(self): """Create the OAuth 2.0 assertion. This assertion is used during the OAuth 2.0 grant to acquire an diff --git a/google/auth/credentials.py b/google/auth/credentials.py index 004fde9c2..2735892d4 100644 --- a/google/auth/credentials.py +++ b/google/auth/credentials.py @@ -150,6 +150,21 @@ def with_quota_project(self, quota_project_id): raise NotImplementedError("This credential does not support quota project.") +class CredentialsWithTokenUri(Credentials): + """Abstract base for credentials supporting ``with_token_uri`` factory""" + + def with_token_uri(self, token_uri): + """Returns a copy of these credentials with a modified token uri. + + Args: + token_uri (str): The uri to use for fetching/exchanging tokens + + Returns: + google.oauth2.credentials.Credentials: A new credentials instance. + """ + raise NotImplementedError("This credential does not use token uri.") + + class AnonymousCredentials(Credentials): """Credentials that do not provide any authentication information. diff --git a/google/auth/external_account.py b/google/auth/external_account.py index eb216fb72..c1ba5efa0 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -55,7 +55,11 @@ @six.add_metaclass(abc.ABCMeta) -class Credentials(credentials.Scoped, credentials.CredentialsWithQuotaProject): +class Credentials( + credentials.Scoped, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithTokenUri, +): """Base class for all external account credentials. This is used to instantiate Credentials for exchanging external account @@ -382,6 +386,26 @@ def with_quota_project(self, quota_project_id): d.pop("workforce_pool_user_project") return self.__class__(**d) + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + d = dict( + audience=self._audience, + subject_token_type=self._subject_token_type, + token_url=token_uri, + credential_source=self._credential_source, + service_account_impersonation_url=self._service_account_impersonation_url, + service_account_impersonation_options=self._service_account_impersonation_options, + client_id=self._client_id, + client_secret=self._client_secret, + quota_project_id=self._quota_project_id, + scopes=self._scopes, + default_scopes=self._default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, + ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + return self.__class__(**d) + def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index 4cc502fea..8f1c3dda4 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -254,6 +254,23 @@ def with_quota_project(self, quota_project_id): enable_reauth_refresh=self._enable_reauth_refresh, ) + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + + return self.__class__( + self.token, + refresh_token=self.refresh_token, + id_token=self.id_token, + token_uri=token_uri, + client_id=self.client_id, + client_secret=self.client_secret, + scopes=self.scopes, + default_scopes=self.default_scopes, + quota_project_id=self.quota_project_id, + rapt_token=self.rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, + ) + @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): scopes = self._scopes if self._scopes is not None else self._default_scopes diff --git a/google/oauth2/service_account.py b/google/oauth2/service_account.py index 5c4f340fa..0989750db 100644 --- a/google/oauth2/service_account.py +++ b/google/oauth2/service_account.py @@ -84,7 +84,10 @@ class Credentials( - credentials.Signing, credentials.Scoped, credentials.CredentialsWithQuotaProject + credentials.Signing, + credentials.Scoped, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithTokenUri, ): """Service account credentials @@ -364,6 +367,22 @@ def with_quota_project(self, quota_project_id): always_use_jwt_access=self._always_use_jwt_access, ) + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + + return self.__class__( + self._signer, + service_account_email=self._service_account_email, + default_scopes=self._default_scopes, + scopes=self._scopes, + token_uri=token_uri, + subject=self._subject, + project_id=self._project_id, + quota_project_id=self._quota_project_id, + additional_claims=self._additional_claims.copy(), + always_use_jwt_access=self._always_use_jwt_access, + ) + def _make_authorization_grant_assertion(self): """Create the OAuth 2.0 assertion. @@ -455,7 +474,11 @@ def signer_email(self): return self._service_account_email -class IDTokenCredentials(credentials.Signing, credentials.CredentialsWithQuotaProject): +class IDTokenCredentials( + credentials.Signing, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithTokenUri, +): """Open ID Connect ID Token-based service account credentials. These credentials are largely similar to :class:`.Credentials`, but instead @@ -627,6 +650,17 @@ def with_quota_project(self, quota_project_id): quota_project_id=quota_project_id, ) + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + return self.__class__( + self._signer, + service_account_email=self._service_account_email, + token_uri=token_uri, + target_audience=self._target_audience, + additional_claims=self._additional_claims.copy(), + quota_project_id=self._quota_project_id, + ) + def _make_authorization_grant_assertion(self): """Create the OAuth 2.0 assertion. diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index 78f7254cb..37bd82904 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index ebce176e8..6a2f8cc20 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -483,6 +483,50 @@ def test_with_quota_project(self, sign, get, utcnow): # Check that the signer have been initialized with a Request object assert isinstance(self.credentials._signer._request, transport.Request) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.utcfromtimestamp(0), + ) + @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) + @mock.patch("google.auth.iam.Signer.sign", autospec=True) + def test_with_token_uri(self, sign, get, utcnow): + get.side_effect = [ + {"email": "service-account@example.com", "scopes": ["one", "two"]} + ] + sign.side_effect = [b"signature"] + + request = mock.create_autospec(transport.Request, instance=True) + self.credentials = credentials.IDTokenCredentials( + request=request, + target_audience="https://audience.com", + token_uri="http://xyz.com", + ) + assert self.credentials._token_uri == "http://xyz.com" + creds_with_token_uri = self.credentials.with_token_uri("http://abc.com") + assert creds_with_token_uri._token_uri == "http://abc.com" + + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.utcfromtimestamp(0), + ) + @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) + @mock.patch("google.auth.iam.Signer.sign", autospec=True) + def test_with_token_uri_exception(self, sign, get, utcnow): + get.side_effect = [ + {"email": "service-account@example.com", "scopes": ["one", "two"]} + ] + sign.side_effect = [b"signature"] + + request = mock.create_autospec(transport.Request, instance=True) + self.credentials = credentials.IDTokenCredentials( + request=request, + target_audience="https://audience.com", + use_metadata_identity_endpoint=True, + ) + assert self.credentials._token_uri is None + with pytest.raises(ValueError): + self.credentials.with_token_uri("http://abc.com") + @responses.activate def test_with_quota_project_integration(self): """ Test that it is possible to refresh credentials diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index e5f71def0..c8301078d 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -701,6 +701,18 @@ def test_with_quota_project(self): creds.apply(headers) assert "x-goog-user-project" in headers + def test_with_token_uri(self): + info = AUTH_USER_INFO.copy() + + creds = credentials.Credentials.from_authorized_user_info(info) + new_token_uri = "https://oauth2-eu.googleapis.com/token" + + assert creds._token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + + creds_with_new_token_uri = creds.with_token_uri(new_token_uri) + + assert creds_with_new_token_uri._token_uri == new_token_uri + def test_from_authorized_user_info(self): info = AUTH_USER_INFO.copy() diff --git a/tests/oauth2/test_service_account.py b/tests/oauth2/test_service_account.py index 1d1438485..4bd194b35 100644 --- a/tests/oauth2/test_service_account.py +++ b/tests/oauth2/test_service_account.py @@ -155,6 +155,13 @@ def test_with_quota_project(self): new_credentials.apply(hdrs, token="tok") assert "x-goog-user-project" in hdrs + def test_with_token_uri(self): + credentials = self.make_credentials() + new_token_uri = "https://example2.com/oauth2/token" + assert credentials._token_uri == self.TOKEN_URI + creds_with_new_token_uri = credentials.with_token_uri(new_token_uri) + assert creds_with_new_token_uri._token_uri == new_token_uri + def test__with_always_use_jwt_access(self): credentials = self.make_credentials() assert not credentials._always_use_jwt_access @@ -464,6 +471,13 @@ def test_with_quota_project(self): new_credentials = credentials.with_quota_project("project-foo") assert new_credentials._quota_project_id == "project-foo" + def test_with_token_uri(self): + credentials = self.make_credentials() + new_token_uri = "https://example2.com/oauth2/token" + assert credentials._token_uri == self.TOKEN_URI + creds_with_new_token_uri = credentials.with_token_uri(new_token_uri) + assert creds_with_new_token_uri._token_uri == new_token_uri + def test__make_authorization_grant_assertion(self): credentials = self.make_credentials() token = credentials._make_authorization_grant_assertion() diff --git a/tests/test_external_account.py b/tests/test_external_account.py index 920aa34ea..468152e05 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -542,6 +542,33 @@ def test_with_scopes_full_options_propagated(self): workforce_pool_user_project=None, ) + def test_with_token_uri(self): + credentials = self.make_credentials() + new_token_uri = "https://eu-sts.googleapis.com/v1/token" + + assert credentials._token_url == self.TOKEN_URL + + creds_with_new_token_uri = credentials.with_token_uri(new_token_uri) + + assert creds_with_new_token_uri._token_url == new_token_uri + + def test_with_token_uri_workforce_pool(self): + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + new_token_uri = "https://eu-sts.googleapis.com/v1/token" + + assert credentials._token_url == self.TOKEN_URL + + creds_with_new_token_uri = credentials.with_token_uri(new_token_uri) + + assert creds_with_new_token_uri._token_url == new_token_uri + assert ( + creds_with_new_token_uri.info.get("workforce_pool_user_project") + == self.WORKFORCE_POOL_USER_PROJECT + ) + def test_with_quota_project(self): credentials = self.make_credentials()