Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Commit

Permalink
Populate scopes for gce.AppAssertionCredentials (#524)
Browse files Browse the repository at this point in the history
* Populate Scopes for gce.AppAssertionCredentials
* _retrieve_scopes -> _retrieve_info
* Add note about credentials being initially invalid
  • Loading branch information
elibixby authored and Jon Wayne Parrott committed Jun 28, 2016
1 parent c82816c commit 5454867
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 99 deletions.
19 changes: 8 additions & 11 deletions oauth2client/contrib/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
METADATA_HEADERS = {'Metadata-Flavor': 'Google'}


def get(path, http_request=None, root=METADATA_ROOT, recursive=None):
def get(http_request, path, root=METADATA_ROOT, recursive=None):
"""Fetch a resource from the metadata server.
Args:
Expand All @@ -53,9 +53,6 @@ def get(path, http_request=None, root=METADATA_ROOT, recursive=None):
Raises:
httplib2.Httplib2Error if an error corrured while retrieving metadata.
"""
if not http_request:
http_request = httplib2.Http().request

url = urlparse.urljoin(root, path)
url = util._add_query_parameter(url, 'recursive', recursive)

Expand All @@ -76,7 +73,7 @@ def get(path, http_request=None, root=METADATA_ROOT, recursive=None):
'metadata service. Response:\n{1}'.format(url, response))


def get_service_account_info(service_account='default', http_request=None):
def get_service_account_info(http_request, service_account='default'):
"""Get information about a service account from the metadata server.
Args:
Expand All @@ -97,12 +94,12 @@ def get_service_account_info(service_account='default', http_request=None):
}
"""
return get(
'instance/service-accounts/{0}'.format(service_account),
recursive=True,
http_request=http_request)
http_request,
'instance/service-accounts/{0}/'.format(service_account),
recursive=True)


def get_token(service_account='default', http_request=None):
def get_token(http_request, service_account='default'):
"""Fetch an oauth token for the
Args:
Expand All @@ -119,8 +116,8 @@ def get_token(service_account='default', http_request=None):
that indicates when the access token will expire.
"""
token_json = get(
'instance/service-accounts/{0}/token'.format(service_account),
http_request=http_request)
http_request,
'instance/service-accounts/{0}/token'.format(service_account))
token_expiry = _UTCNOW() + datetime.timedelta(
seconds=token_json['expires_in'])
return token_json['access_token'], token_expiry
97 changes: 54 additions & 43 deletions oauth2client/contrib/gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
Utilities for making it easier to use OAuth 2.0 on Google Compute Engine.
"""

import json
import logging
import warnings

import httplib2

from oauth2client._helpers import _from_bytes
from oauth2client import util
from oauth2client.client import AssertionCredentials
from oauth2client.client import HttpAccessTokenRefreshError
from oauth2client.contrib import _metadata
Expand Down Expand Up @@ -53,36 +50,72 @@ class AppAssertionCredentials(AssertionCredentials):
This credential does not require a flow to instantiate because it
represents a two legged flow, and therefore has all of the required
information to generate and refresh its own access tokens.
Note that :attr:`service_account_email` and :attr:`scopes`
will both return None until the credentials have been refreshed.
To check whether credentials have previously been refreshed use
:attr:`invalid`.
"""

@util.positional(2)
def __init__(self, scope='', **kwargs):
def __init__(self, email=None, *args, **kwargs):
"""Constructor for AppAssertionCredentials
Args:
scope: string or iterable of strings, scope(s) of the credentials
being requested. Using this argument will have no effect on
the actual scopes for tokens requested. These scopes are
set at VM instance creation time and won't change.
email: an email that specifies the service account to use.
Only necessary if using custom service accounts
(see https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#createdefaultserviceaccount).
"""
if scope:
if 'scopes' in kwargs:
warnings.warn(_SCOPES_WARNING)
# This is just provided for backwards compatibility, but is not
# used by this class.
self.scope = util.scopes_to_string(scope)
self.kwargs = kwargs
kwargs['scopes'] = None

# Assertion type is no longer used, but still in the
# parent class signature.
super(AppAssertionCredentials, self).__init__(None)
super(AppAssertionCredentials, self).__init__(None, *args, **kwargs)

# Cache until Metadata Server supports Cache-Control Header
self._service_account_email = None
self.service_account_email = email
self.scopes = None
self.invalid = True

@classmethod
def from_json(cls, json_data):
data = json.loads(_from_bytes(json_data))
return AppAssertionCredentials(data['scope'])
raise NotImplementedError(
'Cannot serialize credentials for GCE service accounts.')

def to_json(self):
raise NotImplementedError(
'Cannot serialize credentials for GCE service accounts.')

def retrieve_scopes(self, http):
"""Retrieves the canonical list of scopes for this access token.
Overrides client.Credentials.retrieve_scopes. Fetches scopes info
from the metadata server.
Args:
http: httplib2.Http, an http object to be used to make the refresh
request.
Returns:
A set of strings containing the canonical list of scopes.
"""
self._retrieve_info(http.request)
return self.scopes

def _retrieve_info(self, http_request):
"""Validates invalid service accounts by retrieving service account info.
Args:
http_request: callable, a callable that matches the method
signature of httplib2.Http.request, used to make the
request to the metadata server
"""
if self.invalid:
info = _metadata.get_service_account_info(
http_request, service_account=self.service_account_email or 'default')
self.invalid = False
self.service_account_email = info['email']
self.scopes = info['scopes']

def _refresh(self, http_request):
"""Refreshes the access_token.
Expand All @@ -98,8 +131,9 @@ def _refresh(self, http_request):
HttpAccessTokenRefreshError: When the refresh fails.
"""
try:
self._retrieve_info(http_request)
self.access_token, self.token_expiry = _metadata.get_token(
http_request=http_request)
http_request, service_account=self.service_account_email)
except httplib2.HttpLib2Error as e:
raise HttpAccessTokenRefreshError(str(e))

Expand All @@ -111,9 +145,6 @@ def serialization_data(self):
def create_scoped_required(self):
return False

def create_scoped(self, scopes):
return AppAssertionCredentials(scopes, **self.kwargs)

def sign_blob(self, blob):
"""Cryptographically sign a blob (of bytes).
Expand All @@ -129,23 +160,3 @@ def sign_blob(self, blob):
"""
raise NotImplementedError(
'Compute Engine service accounts cannot sign blobs')

@property
def service_account_email(self):
"""Get the email for the current service account.
Uses the Google Compute Engine metadata service to retrieve the email
of the default service account.
Returns:
string, The email associated with the Google Compute Engine
service account.
Raises:
AttributeError, if the email can not be retrieved from the Google
Compute Engine metadata service.
"""
if self._service_account_email is None:
self._service_account_email = (
_metadata.get_service_account_info()['email'])
return self._service_account_email
99 changes: 59 additions & 40 deletions tests/contrib/test_gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
"""Unit tests for oauth2client.contrib.gce."""

import datetime
import httplib2
import json

import mock
from six.moves import http_client
from six.moves import urllib
import unittest2

from oauth2client.client import Credentials
from oauth2client.client import save_to_well_known_file
from oauth2client.client import HttpAccessTokenRefreshError
from oauth2client.contrib.gce import _SCOPES_WARNING
Expand All @@ -31,44 +30,60 @@

__author__ = '[email protected] (Joe Gregorio)'

SERVICE_ACCOUNT_INFO = {
'scopes': ['a', 'b'],
'email': '[email protected]',
'aliases': ['default']
}

class AppAssertionCredentialsTests(unittest2.TestCase):

def test_constructor(self):
credentials = AppAssertionCredentials(foo='bar')
self.assertEqual(credentials.scope, '')
self.assertEqual(credentials.kwargs, {'foo': 'bar'})
self.assertEqual(credentials.assertion_type, None)
credentials = AppAssertionCredentials()
self.assertIsNone(credentials.assertion_type, None)
self.assertIsNone(credentials.service_account_email)
self.assertIsNone(credentials.scopes)
self.assertTrue(credentials.invalid)

@mock.patch('warnings.warn')
def test_constructor_with_scopes(self, warn_mock):
scope = 'http://example.com/a http://example.com/b'
scopes = scope.split()
credentials = AppAssertionCredentials(scope=scopes, foo='bar')
self.assertEqual(credentials.scope, scope)
self.assertEqual(credentials.kwargs, {'foo': 'bar'})
credentials = AppAssertionCredentials(scopes=scopes)
self.assertEqual(credentials.scopes, None)
self.assertEqual(credentials.assertion_type, None)
warn_mock.assert_called_once_with(_SCOPES_WARNING)

def test_to_json_and_from_json(self):
def test_to_json(self):
credentials = AppAssertionCredentials()
json = credentials.to_json()
credentials_from_json = Credentials.new_from_json(json)
self.assertEqual(credentials.access_token,
credentials_from_json.access_token)
with self.assertRaises(NotImplementedError):
credentials.to_json()

def test_from_json(self):
with self.assertRaises(NotImplementedError):
AppAssertionCredentials.from_json({})

@mock.patch('oauth2client.contrib._metadata.get_token',
side_effect=[('A', datetime.datetime.min),
('B', datetime.datetime.max)])
def test_refresh_token(self, metadata):
@mock.patch('oauth2client.contrib._metadata.get_service_account_info',
return_value=SERVICE_ACCOUNT_INFO)
def test_refresh_token(self, get_info, get_token):
http_request = mock.MagicMock()
http_mock = mock.MagicMock(request=http_request)
credentials = AppAssertionCredentials()
credentials.invalid = False
credentials.service_account_email = '[email protected]'
self.assertIsNone(credentials.access_token)
credentials.get_access_token()
credentials.get_access_token(http=http_mock)
self.assertEqual(credentials.access_token, 'A')
self.assertTrue(credentials.access_token_expired)
credentials.get_access_token()
get_token.assert_called_with(http_request, service_account='[email protected]')
credentials.get_access_token(http=http_mock)
self.assertEqual(credentials.access_token, 'B')
self.assertFalse(credentials.access_token_expired)
get_token.assert_called_with(http_request, service_account='[email protected]')
get_info.assert_not_called()

def test_refresh_token_failed_fetch(self):
http_request = request_mock(
Expand All @@ -77,46 +92,50 @@ def test_refresh_token_failed_fetch(self):
json.dumps({'access_token': 'a', 'expires_in': 100})
)
credentials = AppAssertionCredentials()

credentials.invalid = False
credentials.service_account_email = '[email protected]'
with self.assertRaises(HttpAccessTokenRefreshError):
credentials._refresh(http_request=http_request)
credentials._refresh(http_request)

def test_serialization_data(self):
credentials = AppAssertionCredentials()
self.assertRaises(NotImplementedError, getattr,
credentials, 'serialization_data')

def test_create_scoped_required_without_scopes(self):
def test_create_scoped_required(self):
credentials = AppAssertionCredentials()
self.assertFalse(credentials.create_scoped_required())

@mock.patch('warnings.warn')
def test_create_scoped_required_with_scopes(self, warn_mock):
credentials = AppAssertionCredentials(['dummy_scope'])
self.assertFalse(credentials.create_scoped_required())
warn_mock.assert_called_once_with(_SCOPES_WARNING)

@mock.patch('warnings.warn')
def test_create_scoped(self, warn_mock):
credentials = AppAssertionCredentials()
new_credentials = credentials.create_scoped(['dummy_scope'])
self.assertNotEqual(credentials, new_credentials)
self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
self.assertEqual('dummy_scope', new_credentials.scope)
warn_mock.assert_called_once_with(_SCOPES_WARNING)

def test_sign_blob_not_implemented(self):
credentials = AppAssertionCredentials([])
with self.assertRaises(NotImplementedError):
credentials.sign_blob(b'blob')

@mock.patch('oauth2client.contrib._metadata.get_service_account_info',
return_value={'email': '[email protected]'})
def test_service_account_email(self, metadata):
return_value=SERVICE_ACCOUNT_INFO)
def test_retrieve_scopes(self, metadata):
http_request = mock.MagicMock()
http_mock = mock.MagicMock(request=http_request)
credentials = AppAssertionCredentials()
# Assert that service account isn't pre-fetched
metadata.assert_not_called()
self.assertEqual(credentials.service_account_email, '[email protected]')
self.assertTrue(credentials.invalid)
self.assertIsNone(credentials.scopes)
scopes = credentials.retrieve_scopes(http_mock)
self.assertEqual(scopes, SERVICE_ACCOUNT_INFO['scopes'])
self.assertFalse(credentials.invalid)
credentials.retrieve_scopes(http_mock)
# Assert scopes weren't refetched
metadata.assert_called_once_with(http_request, service_account='default')

@mock.patch('oauth2client.contrib._metadata.get_service_account_info',
side_effect=httplib2.HttpLib2Error('No Such Email'))
def test_retrieve_scopes_bad_email(self, metadata):
http_request = mock.MagicMock()
http_mock = mock.MagicMock(request=http_request)
credentials = AppAssertionCredentials(email='[email protected]')
with self.assertRaises(httplib2.HttpLib2Error):
credentials.retrieve_scopes(http_mock)

metadata.assert_called_once_with(http_request, service_account='[email protected]')

def test_save_to_well_known_file(self):
import os
Expand Down
Loading

0 comments on commit 5454867

Please sign in to comment.