From 5454867160f7f0d5aca8a75ad702ef1c13134011 Mon Sep 17 00:00:00 2001 From: elibixby Date: Tue, 28 Jun 2016 12:59:18 -0700 Subject: [PATCH] Populate scopes for gce.AppAssertionCredentials (#524) * Populate Scopes for gce.AppAssertionCredentials * _retrieve_scopes -> _retrieve_info * Add note about credentials being initially invalid --- oauth2client/contrib/_metadata.py | 19 +++--- oauth2client/contrib/gce.py | 97 ++++++++++++++++-------------- tests/contrib/test_gce.py | 99 ++++++++++++++++++------------- tests/contrib/test_metadata.py | 10 ++-- 4 files changed, 126 insertions(+), 99 deletions(-) diff --git a/oauth2client/contrib/_metadata.py b/oauth2client/contrib/_metadata.py index 9987da753..2995f6523 100644 --- a/oauth2client/contrib/_metadata.py +++ b/oauth2client/contrib/_metadata.py @@ -33,7 +33,7 @@ METADATA_HEADERS = {'Metadata-Flavor': 'Google'} -def get(path, http_request=None, root=METADATA_ROOT, recursive=None): +def get(http_request, path, root=METADATA_ROOT, recursive=None): """Fetch a resource from the metadata server. Args: @@ -53,9 +53,6 @@ def get(path, http_request=None, root=METADATA_ROOT, recursive=None): Raises: httplib2.Httplib2Error if an error corrured while retrieving metadata. """ - if not http_request: - http_request = httplib2.Http().request - url = urlparse.urljoin(root, path) url = util._add_query_parameter(url, 'recursive', recursive) @@ -76,7 +73,7 @@ def get(path, http_request=None, root=METADATA_ROOT, recursive=None): 'metadata service. Response:\n{1}'.format(url, response)) -def get_service_account_info(service_account='default', http_request=None): +def get_service_account_info(http_request, service_account='default'): """Get information about a service account from the metadata server. Args: @@ -97,12 +94,12 @@ def get_service_account_info(service_account='default', http_request=None): } """ return get( - 'instance/service-accounts/{0}'.format(service_account), - recursive=True, - http_request=http_request) + http_request, + 'instance/service-accounts/{0}/'.format(service_account), + recursive=True) -def get_token(service_account='default', http_request=None): +def get_token(http_request, service_account='default'): """Fetch an oauth token for the Args: @@ -119,8 +116,8 @@ def get_token(service_account='default', http_request=None): that indicates when the access token will expire. """ token_json = get( - 'instance/service-accounts/{0}/token'.format(service_account), - http_request=http_request) + http_request, + 'instance/service-accounts/{0}/token'.format(service_account)) token_expiry = _UTCNOW() + datetime.timedelta( seconds=token_json['expires_in']) return token_json['access_token'], token_expiry diff --git a/oauth2client/contrib/gce.py b/oauth2client/contrib/gce.py index 2aad4dbea..b495e44ac 100644 --- a/oauth2client/contrib/gce.py +++ b/oauth2client/contrib/gce.py @@ -17,14 +17,11 @@ Utilities for making it easier to use OAuth 2.0 on Google Compute Engine. """ -import json import logging import warnings import httplib2 -from oauth2client._helpers import _from_bytes -from oauth2client import util from oauth2client.client import AssertionCredentials from oauth2client.client import HttpAccessTokenRefreshError from oauth2client.contrib import _metadata @@ -53,36 +50,72 @@ class AppAssertionCredentials(AssertionCredentials): This credential does not require a flow to instantiate because it represents a two legged flow, and therefore has all of the required information to generate and refresh its own access tokens. + + Note that :attr:`service_account_email` and :attr:`scopes` + will both return None until the credentials have been refreshed. + To check whether credentials have previously been refreshed use + :attr:`invalid`. """ - @util.positional(2) - def __init__(self, scope='', **kwargs): + def __init__(self, email=None, *args, **kwargs): """Constructor for AppAssertionCredentials Args: - scope: string or iterable of strings, scope(s) of the credentials - being requested. Using this argument will have no effect on - the actual scopes for tokens requested. These scopes are - set at VM instance creation time and won't change. + email: an email that specifies the service account to use. + Only necessary if using custom service accounts + (see https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#createdefaultserviceaccount). """ - if scope: + if 'scopes' in kwargs: warnings.warn(_SCOPES_WARNING) - # This is just provided for backwards compatibility, but is not - # used by this class. - self.scope = util.scopes_to_string(scope) - self.kwargs = kwargs + kwargs['scopes'] = None # Assertion type is no longer used, but still in the # parent class signature. - super(AppAssertionCredentials, self).__init__(None) + super(AppAssertionCredentials, self).__init__(None, *args, **kwargs) - # Cache until Metadata Server supports Cache-Control Header - self._service_account_email = None + self.service_account_email = email + self.scopes = None + self.invalid = True @classmethod def from_json(cls, json_data): - data = json.loads(_from_bytes(json_data)) - return AppAssertionCredentials(data['scope']) + raise NotImplementedError( + 'Cannot serialize credentials for GCE service accounts.') + + def to_json(self): + raise NotImplementedError( + 'Cannot serialize credentials for GCE service accounts.') + + def retrieve_scopes(self, http): + """Retrieves the canonical list of scopes for this access token. + + Overrides client.Credentials.retrieve_scopes. Fetches scopes info + from the metadata server. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + + Returns: + A set of strings containing the canonical list of scopes. + """ + self._retrieve_info(http.request) + return self.scopes + + def _retrieve_info(self, http_request): + """Validates invalid service accounts by retrieving service account info. + + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + request to the metadata server + """ + if self.invalid: + info = _metadata.get_service_account_info( + http_request, service_account=self.service_account_email or 'default') + self.invalid = False + self.service_account_email = info['email'] + self.scopes = info['scopes'] def _refresh(self, http_request): """Refreshes the access_token. @@ -98,8 +131,9 @@ def _refresh(self, http_request): HttpAccessTokenRefreshError: When the refresh fails. """ try: + self._retrieve_info(http_request) self.access_token, self.token_expiry = _metadata.get_token( - http_request=http_request) + http_request, service_account=self.service_account_email) except httplib2.HttpLib2Error as e: raise HttpAccessTokenRefreshError(str(e)) @@ -111,9 +145,6 @@ def serialization_data(self): def create_scoped_required(self): return False - def create_scoped(self, scopes): - return AppAssertionCredentials(scopes, **self.kwargs) - def sign_blob(self, blob): """Cryptographically sign a blob (of bytes). @@ -129,23 +160,3 @@ def sign_blob(self, blob): """ raise NotImplementedError( 'Compute Engine service accounts cannot sign blobs') - - @property - def service_account_email(self): - """Get the email for the current service account. - - Uses the Google Compute Engine metadata service to retrieve the email - of the default service account. - - Returns: - string, The email associated with the Google Compute Engine - service account. - - Raises: - AttributeError, if the email can not be retrieved from the Google - Compute Engine metadata service. - """ - if self._service_account_email is None: - self._service_account_email = ( - _metadata.get_service_account_info()['email']) - return self._service_account_email diff --git a/tests/contrib/test_gce.py b/tests/contrib/test_gce.py index 4da0341fc..4757a7d24 100644 --- a/tests/contrib/test_gce.py +++ b/tests/contrib/test_gce.py @@ -15,14 +15,13 @@ """Unit tests for oauth2client.contrib.gce.""" import datetime +import httplib2 import json import mock from six.moves import http_client -from six.moves import urllib import unittest2 -from oauth2client.client import Credentials from oauth2client.client import save_to_well_known_file from oauth2client.client import HttpAccessTokenRefreshError from oauth2client.contrib.gce import _SCOPES_WARNING @@ -31,44 +30,60 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' +SERVICE_ACCOUNT_INFO = { + 'scopes': ['a', 'b'], + 'email': 'a@example.com', + 'aliases': ['default'] +} class AppAssertionCredentialsTests(unittest2.TestCase): def test_constructor(self): - credentials = AppAssertionCredentials(foo='bar') - self.assertEqual(credentials.scope, '') - self.assertEqual(credentials.kwargs, {'foo': 'bar'}) - self.assertEqual(credentials.assertion_type, None) + credentials = AppAssertionCredentials() + self.assertIsNone(credentials.assertion_type, None) + self.assertIsNone(credentials.service_account_email) + self.assertIsNone(credentials.scopes) + self.assertTrue(credentials.invalid) @mock.patch('warnings.warn') def test_constructor_with_scopes(self, warn_mock): scope = 'http://example.com/a http://example.com/b' scopes = scope.split() - credentials = AppAssertionCredentials(scope=scopes, foo='bar') - self.assertEqual(credentials.scope, scope) - self.assertEqual(credentials.kwargs, {'foo': 'bar'}) + credentials = AppAssertionCredentials(scopes=scopes) + self.assertEqual(credentials.scopes, None) self.assertEqual(credentials.assertion_type, None) warn_mock.assert_called_once_with(_SCOPES_WARNING) - def test_to_json_and_from_json(self): + def test_to_json(self): credentials = AppAssertionCredentials() - json = credentials.to_json() - credentials_from_json = Credentials.new_from_json(json) - self.assertEqual(credentials.access_token, - credentials_from_json.access_token) + with self.assertRaises(NotImplementedError): + credentials.to_json() + + def test_from_json(self): + with self.assertRaises(NotImplementedError): + AppAssertionCredentials.from_json({}) @mock.patch('oauth2client.contrib._metadata.get_token', side_effect=[('A', datetime.datetime.min), ('B', datetime.datetime.max)]) - def test_refresh_token(self, metadata): + @mock.patch('oauth2client.contrib._metadata.get_service_account_info', + return_value=SERVICE_ACCOUNT_INFO) + def test_refresh_token(self, get_info, get_token): + http_request = mock.MagicMock() + http_mock = mock.MagicMock(request=http_request) credentials = AppAssertionCredentials() + credentials.invalid = False + credentials.service_account_email = 'a@example.com' self.assertIsNone(credentials.access_token) - credentials.get_access_token() + credentials.get_access_token(http=http_mock) self.assertEqual(credentials.access_token, 'A') self.assertTrue(credentials.access_token_expired) - credentials.get_access_token() + get_token.assert_called_with(http_request, service_account='a@example.com') + credentials.get_access_token(http=http_mock) self.assertEqual(credentials.access_token, 'B') self.assertFalse(credentials.access_token_expired) + get_token.assert_called_with(http_request, service_account='a@example.com') + get_info.assert_not_called() def test_refresh_token_failed_fetch(self): http_request = request_mock( @@ -77,46 +92,50 @@ def test_refresh_token_failed_fetch(self): json.dumps({'access_token': 'a', 'expires_in': 100}) ) credentials = AppAssertionCredentials() - + credentials.invalid = False + credentials.service_account_email = 'a@example.com' with self.assertRaises(HttpAccessTokenRefreshError): - credentials._refresh(http_request=http_request) + credentials._refresh(http_request) def test_serialization_data(self): credentials = AppAssertionCredentials() self.assertRaises(NotImplementedError, getattr, credentials, 'serialization_data') - def test_create_scoped_required_without_scopes(self): + def test_create_scoped_required(self): credentials = AppAssertionCredentials() self.assertFalse(credentials.create_scoped_required()) - @mock.patch('warnings.warn') - def test_create_scoped_required_with_scopes(self, warn_mock): - credentials = AppAssertionCredentials(['dummy_scope']) - self.assertFalse(credentials.create_scoped_required()) - warn_mock.assert_called_once_with(_SCOPES_WARNING) - - @mock.patch('warnings.warn') - def test_create_scoped(self, warn_mock): - credentials = AppAssertionCredentials() - new_credentials = credentials.create_scoped(['dummy_scope']) - self.assertNotEqual(credentials, new_credentials) - self.assertTrue(isinstance(new_credentials, AppAssertionCredentials)) - self.assertEqual('dummy_scope', new_credentials.scope) - warn_mock.assert_called_once_with(_SCOPES_WARNING) - def test_sign_blob_not_implemented(self): credentials = AppAssertionCredentials([]) with self.assertRaises(NotImplementedError): credentials.sign_blob(b'blob') @mock.patch('oauth2client.contrib._metadata.get_service_account_info', - return_value={'email': 'a@example.com'}) - def test_service_account_email(self, metadata): + return_value=SERVICE_ACCOUNT_INFO) + def test_retrieve_scopes(self, metadata): + http_request = mock.MagicMock() + http_mock = mock.MagicMock(request=http_request) credentials = AppAssertionCredentials() - # Assert that service account isn't pre-fetched - metadata.assert_not_called() - self.assertEqual(credentials.service_account_email, 'a@example.com') + self.assertTrue(credentials.invalid) + self.assertIsNone(credentials.scopes) + scopes = credentials.retrieve_scopes(http_mock) + self.assertEqual(scopes, SERVICE_ACCOUNT_INFO['scopes']) + self.assertFalse(credentials.invalid) + credentials.retrieve_scopes(http_mock) + # Assert scopes weren't refetched + metadata.assert_called_once_with(http_request, service_account='default') + + @mock.patch('oauth2client.contrib._metadata.get_service_account_info', + side_effect=httplib2.HttpLib2Error('No Such Email')) + def test_retrieve_scopes_bad_email(self, metadata): + http_request = mock.MagicMock() + http_mock = mock.MagicMock(request=http_request) + credentials = AppAssertionCredentials(email='b@example.com') + with self.assertRaises(httplib2.HttpLib2Error): + credentials.retrieve_scopes(http_mock) + + metadata.assert_called_once_with(http_request, service_account='b@example.com') def test_save_to_well_known_file(self): import os diff --git a/tests/contrib/test_metadata.py b/tests/contrib/test_metadata.py index 4e48387c4..8c6b973ad 100644 --- a/tests/contrib/test_metadata.py +++ b/tests/contrib/test_metadata.py @@ -45,7 +45,7 @@ def test_get_success_json(self): http_request = request_mock( http_client.OK, 'application/json', json.dumps(DATA)) self.assertEqual( - _metadata.get(PATH, http_request=http_request), + _metadata.get(http_request, PATH), DATA ) http_request.assert_called_once_with(EXPECTED_URL, **EXPECTED_KWARGS) @@ -54,7 +54,7 @@ def test_get_success_string(self): http_request = request_mock( http_client.OK, 'text/html', '

Hello World!

') self.assertEqual( - _metadata.get(PATH, http_request=http_request), + _metadata.get(http_request, PATH), '

Hello World!

' ) http_request.assert_called_once_with(EXPECTED_URL, **EXPECTED_KWARGS) @@ -63,7 +63,7 @@ def test_get_failure(self): http_request = request_mock( http_client.NOT_FOUND, 'text/html', '

Error

') with self.assertRaises(httplib2.HttpLib2Error): - _metadata.get(PATH, http_request=http_request) + _metadata.get(http_request, PATH) http_request.assert_called_once_with(EXPECTED_URL, **EXPECTED_KWARGS) @@ -89,9 +89,9 @@ def test_get_token_success(self, now): def test_service_account_info(self): http_request = request_mock( http_client.OK, 'application/json', json.dumps(DATA)) - info = _metadata.get_service_account_info(http_request=http_request) + info = _metadata.get_service_account_info(http_request) self.assertEqual(info, DATA) http_request.assert_called_once_with( - EXPECTED_URL+'?recursive=True', + EXPECTED_URL+'/?recursive=True', **EXPECTED_KWARGS )