Skip to content

Commit

Permalink
Use transport module for GCE environment check.
Browse files Browse the repository at this point in the history
  • Loading branch information
dhermes committed Aug 15, 2016
1 parent 9c1ece5 commit 0053fb4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 53 deletions.
20 changes: 9 additions & 11 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@
GCE_METADATA_TIMEOUT = 3

_SERVER_SOFTWARE = 'SERVER_SOFTWARE'
_GCE_METADATA_HOST = '169.254.169.254'
_METADATA_FLAVOR_HEADER = 'Metadata-Flavor'
_GCE_METADATA_URI = 'http://169.254.169.254'
_METADATA_FLAVOR_HEADER = 'metadata-flavor' # lowercase header
_DESIRED_METADATA_FLAVOR = 'Google'

# Expose utcnow() at module level to allow for
Expand Down Expand Up @@ -999,21 +999,19 @@ def _detect_gce_environment():
# could lead to false negatives in the event that we are on GCE, but
# the metadata resolution was particularly slow. The latter case is
# "unlikely".
connection = six.moves.http_client.HTTPConnection(
_GCE_METADATA_HOST, timeout=GCE_METADATA_TIMEOUT)

http = transport.get_http_object(timeout=GCE_METADATA_TIMEOUT)
headers = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR}
try:
headers = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR}
connection.request('GET', '/', headers=headers)
response = connection.getresponse()
response, _ = transport.request(
http, _GCE_METADATA_URI, headers=headers)
if response.status == http_client.OK:
return (response.getheader(_METADATA_FLAVOR_HEADER) ==
return (response.get(_METADATA_FLAVOR_HEADER) ==
_DESIRED_METADATA_FLAVOR)
else:
return False
except socket.error: # socket.timeout or socket.error(64, 'Host is down')
logger.info('Timeout attempting to reach GCE metadata service.')
return False
finally:
connection.close()


def _in_gae_environment():
Expand Down
80 changes: 38 additions & 42 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,67 +358,48 @@ def test_environment_caching(self):
# is cached.
self.assertTrue(client._in_gae_environment())

def _environment_check_gce_helper(self, status_ok=True, socket_error=False,
def _environment_check_gce_helper(self, status_ok=True,
server_software=''):
response = mock.Mock()
if status_ok:
response.status = http_client.OK
response.getheader = mock.Mock(
name='getheader',
return_value=client._DESIRED_METADATA_FLAVOR)
headers = {
'status': http_client.OK,
client._METADATA_FLAVOR_HEADER: (
client._DESIRED_METADATA_FLAVOR),
}
else:
response.status = http_client.NOT_FOUND

connection = mock.Mock()
connection.getresponse = mock.Mock(name='getresponse',
return_value=response)
if socket_error:
connection.getresponse.side_effect = socket.error()
headers = {'status': http_client.NOT_FOUND}

http = http_mock.HttpMock(headers=headers)
with mock.patch('oauth2client.client.os') as os_module:
os_module.environ = {client._SERVER_SOFTWARE: server_software}
with mock.patch('oauth2client.client.six') as six_module:
http_client_module = six_module.moves.http_client
http_client_module.HTTPConnection = mock.Mock(
name='HTTPConnection', return_value=connection)

with mock.patch('oauth2client.transport.get_http_object',
return_value=http) as new_http:
if server_software == '':
self.assertFalse(client._in_gae_environment())
else:
self.assertTrue(client._in_gae_environment())

if status_ok and not socket_error and server_software == '':
if status_ok and server_software == '':
self.assertTrue(client._in_gce_environment())
else:
self.assertFalse(client._in_gce_environment())

# Verify mocks.
if server_software == '':
http_client_module.HTTPConnection.assert_called_once_with(
client._GCE_METADATA_HOST,
new_http.assert_called_once_with(
timeout=client.GCE_METADATA_TIMEOUT)
connection.getresponse.assert_called_once_with()
# Remaining calls are not "getresponse"
headers = {
self.assertEqual(http.requests, 1)
self.assertEqual(http.uri, client._GCE_METADATA_URI)
self.assertEqual(http.method, 'GET')
self.assertIsNone(http.body)
request_headers = {
client._METADATA_FLAVOR_HEADER: (
client._DESIRED_METADATA_FLAVOR),
}
self.assertEqual(connection.method_calls, [
mock.call.request('GET', '/',
headers=headers),
mock.call.close(),
])
self.assertEqual(response.method_calls, [])
if status_ok and not socket_error:
response.getheader.assert_called_once_with(
client._METADATA_FLAVOR_HEADER)
self.assertEqual(http.headers, request_headers)
else:
self.assertEqual(
http_client_module.HTTPConnection.mock_calls, [])
self.assertEqual(connection.getresponse.mock_calls, [])
# Remaining calls are not "getresponse"
self.assertEqual(connection.method_calls, [])
self.assertEqual(response.method_calls, [])
self.assertEqual(response.getheader.mock_calls, [])
new_http.assert_not_called()
self.assertEqual(http.requests, 0)

def test_environment_check_gce_production(self):
self._environment_check_gce_helper(status_ok=True)
Expand All @@ -427,8 +408,23 @@ def test_environment_check_gce_prod_with_working_gae_imports(self):
with mock_module_import('google.appengine'):
self._environment_check_gce_helper(status_ok=True)

def test_environment_check_gce_timeout(self):
self._environment_check_gce_helper(socket_error=True)
@mock.patch('oauth2client.client.os.environ',
new={client._SERVER_SOFTWARE: ''})
@mock.patch('oauth2client.transport.get_http_object',
return_value=object())
@mock.patch('oauth2client.transport.request',
side_effect=socket.timeout())
def test_environment_check_gce_timeout(self, mock_request, new_http):
self.assertFalse(client._in_gae_environment())
self.assertFalse(client._in_gce_environment())

# Verify mocks.
new_http.assert_called_once_with(timeout=client.GCE_METADATA_TIMEOUT)
headers = {
client._METADATA_FLAVOR_HEADER: client._DESIRED_METADATA_FLAVOR,
}
mock_request.assert_called_once_with(
new_http.return_value, client._GCE_METADATA_URI, headers=headers)

def test_environ_check_gae_module_unknown(self):
with mock_module_import('google.appengine'):
Expand Down

0 comments on commit 0053fb4

Please sign in to comment.