From 0053fb49ec33b576a283deebaeec9dcec4f6e914 Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Wed, 10 Aug 2016 18:43:38 -0700 Subject: [PATCH] Use transport module for GCE environment check. Fixes #599. --- oauth2client/client.py | 20 +++++------ tests/test_client.py | 80 ++++++++++++++++++++---------------------- 2 files changed, 47 insertions(+), 53 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index 2d1f6e833..42afeaf08 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -110,8 +110,8 @@ GCE_METADATA_TIMEOUT = 3 _SERVER_SOFTWARE = 'SERVER_SOFTWARE' -_GCE_METADATA_HOST = '169.254.169.254' -_METADATA_FLAVOR_HEADER = 'Metadata-Flavor' +_GCE_METADATA_URI = 'http://169.254.169.254' +_METADATA_FLAVOR_HEADER = 'metadata-flavor' # lowercase header _DESIRED_METADATA_FLAVOR = 'Google' # Expose utcnow() at module level to allow for @@ -999,21 +999,19 @@ def _detect_gce_environment(): # could lead to false negatives in the event that we are on GCE, but # the metadata resolution was particularly slow. The latter case is # "unlikely". - connection = six.moves.http_client.HTTPConnection( - _GCE_METADATA_HOST, timeout=GCE_METADATA_TIMEOUT) - + http = transport.get_http_object(timeout=GCE_METADATA_TIMEOUT) + headers = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR} try: - headers = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR} - connection.request('GET', '/', headers=headers) - response = connection.getresponse() + response, _ = transport.request( + http, _GCE_METADATA_URI, headers=headers) if response.status == http_client.OK: - return (response.getheader(_METADATA_FLAVOR_HEADER) == + return (response.get(_METADATA_FLAVOR_HEADER) == _DESIRED_METADATA_FLAVOR) + else: + return False except socket.error: # socket.timeout or socket.error(64, 'Host is down') logger.info('Timeout attempting to reach GCE metadata service.') return False - finally: - connection.close() def _in_gae_environment(): diff --git a/tests/test_client.py b/tests/test_client.py index 27f24d87b..78871ed3c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -358,67 +358,48 @@ def test_environment_caching(self): # is cached. self.assertTrue(client._in_gae_environment()) - def _environment_check_gce_helper(self, status_ok=True, socket_error=False, + def _environment_check_gce_helper(self, status_ok=True, server_software=''): - response = mock.Mock() if status_ok: - response.status = http_client.OK - response.getheader = mock.Mock( - name='getheader', - return_value=client._DESIRED_METADATA_FLAVOR) + headers = { + 'status': http_client.OK, + client._METADATA_FLAVOR_HEADER: ( + client._DESIRED_METADATA_FLAVOR), + } else: - response.status = http_client.NOT_FOUND - - connection = mock.Mock() - connection.getresponse = mock.Mock(name='getresponse', - return_value=response) - if socket_error: - connection.getresponse.side_effect = socket.error() + headers = {'status': http_client.NOT_FOUND} + http = http_mock.HttpMock(headers=headers) with mock.patch('oauth2client.client.os') as os_module: os_module.environ = {client._SERVER_SOFTWARE: server_software} - with mock.patch('oauth2client.client.six') as six_module: - http_client_module = six_module.moves.http_client - http_client_module.HTTPConnection = mock.Mock( - name='HTTPConnection', return_value=connection) - + with mock.patch('oauth2client.transport.get_http_object', + return_value=http) as new_http: if server_software == '': self.assertFalse(client._in_gae_environment()) else: self.assertTrue(client._in_gae_environment()) - if status_ok and not socket_error and server_software == '': + if status_ok and server_software == '': self.assertTrue(client._in_gce_environment()) else: self.assertFalse(client._in_gce_environment()) + # Verify mocks. if server_software == '': - http_client_module.HTTPConnection.assert_called_once_with( - client._GCE_METADATA_HOST, + new_http.assert_called_once_with( timeout=client.GCE_METADATA_TIMEOUT) - connection.getresponse.assert_called_once_with() - # Remaining calls are not "getresponse" - headers = { + self.assertEqual(http.requests, 1) + self.assertEqual(http.uri, client._GCE_METADATA_URI) + self.assertEqual(http.method, 'GET') + self.assertIsNone(http.body) + request_headers = { client._METADATA_FLAVOR_HEADER: ( client._DESIRED_METADATA_FLAVOR), } - self.assertEqual(connection.method_calls, [ - mock.call.request('GET', '/', - headers=headers), - mock.call.close(), - ]) - self.assertEqual(response.method_calls, []) - if status_ok and not socket_error: - response.getheader.assert_called_once_with( - client._METADATA_FLAVOR_HEADER) + self.assertEqual(http.headers, request_headers) else: - self.assertEqual( - http_client_module.HTTPConnection.mock_calls, []) - self.assertEqual(connection.getresponse.mock_calls, []) - # Remaining calls are not "getresponse" - self.assertEqual(connection.method_calls, []) - self.assertEqual(response.method_calls, []) - self.assertEqual(response.getheader.mock_calls, []) + new_http.assert_not_called() + self.assertEqual(http.requests, 0) def test_environment_check_gce_production(self): self._environment_check_gce_helper(status_ok=True) @@ -427,8 +408,23 @@ def test_environment_check_gce_prod_with_working_gae_imports(self): with mock_module_import('google.appengine'): self._environment_check_gce_helper(status_ok=True) - def test_environment_check_gce_timeout(self): - self._environment_check_gce_helper(socket_error=True) + @mock.patch('oauth2client.client.os.environ', + new={client._SERVER_SOFTWARE: ''}) + @mock.patch('oauth2client.transport.get_http_object', + return_value=object()) + @mock.patch('oauth2client.transport.request', + side_effect=socket.timeout()) + def test_environment_check_gce_timeout(self, mock_request, new_http): + self.assertFalse(client._in_gae_environment()) + self.assertFalse(client._in_gce_environment()) + + # Verify mocks. + new_http.assert_called_once_with(timeout=client.GCE_METADATA_TIMEOUT) + headers = { + client._METADATA_FLAVOR_HEADER: client._DESIRED_METADATA_FLAVOR, + } + mock_request.assert_called_once_with( + new_http.return_value, client._GCE_METADATA_URI, headers=headers) def test_environ_check_gae_module_unknown(self): with mock_module_import('google.appengine'):