Skip to content
This repository was archived by the owner on Jan 18, 2025. It is now read-only.

Use transport module for GCE environment check. #612

Merged
merged 1 commit into from
Aug 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@
GCE_METADATA_TIMEOUT = 3

_SERVER_SOFTWARE = 'SERVER_SOFTWARE'
_GCE_METADATA_HOST = '169.254.169.254'
_METADATA_FLAVOR_HEADER = 'Metadata-Flavor'
_GCE_METADATA_URI = 'http://169.254.169.254'
_METADATA_FLAVOR_HEADER = 'metadata-flavor' # lowercase header

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

_DESIRED_METADATA_FLAVOR = 'Google'

This comment was marked as spam.

This comment was marked as spam.

_GCE_HEADERS = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR}

# Expose utcnow() at module level to allow for
# easier testing (by replacing with a stub).
Expand Down Expand Up @@ -997,21 +998,16 @@ def _detect_gce_environment():
# could lead to false negatives in the event that we are on GCE, but
# the metadata resolution was particularly slow. The latter case is
# "unlikely".
connection = six.moves.http_client.HTTPConnection(
_GCE_METADATA_HOST, timeout=GCE_METADATA_TIMEOUT)

http = transport.get_http_object(timeout=GCE_METADATA_TIMEOUT)

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

try:
headers = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR}
connection.request('GET', '/', headers=headers)
response = connection.getresponse()
if response.status == http_client.OK:
return (response.getheader(_METADATA_FLAVOR_HEADER) ==
_DESIRED_METADATA_FLAVOR)
response, _ = transport.request(
http, _GCE_METADATA_URI, headers=_GCE_HEADERS)
return (
response.status == http_client.OK and
response.get(_METADATA_FLAVOR_HEADER) == _DESIRED_METADATA_FLAVOR)
except socket.error: # socket.timeout or socket.error(64, 'Host is down')
logger.info('Timeout attempting to reach GCE metadata service.')
return False
finally:
connection.close()


def _in_gae_environment():
Expand Down
79 changes: 32 additions & 47 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')


# TODO(craigcitro): This is duplicated from
# googleapiclient.test_discovery; consolidate these definitions.
def assertUrisEqual(testcase, expected, actual):
"""Test that URIs are the same, up to reordering of query parameters."""
expected = urllib.parse.urlparse(expected)
Expand Down Expand Up @@ -357,67 +355,41 @@ def test_environment_caching(self):
# is cached.
self.assertTrue(client._in_gae_environment())

def _environment_check_gce_helper(self, status_ok=True, socket_error=False,
def _environment_check_gce_helper(self, status_ok=True,
server_software=''):
response = mock.Mock()
if status_ok:
response.status = http_client.OK
response.getheader = mock.Mock(
name='getheader',
return_value=client._DESIRED_METADATA_FLAVOR)
headers = {'status': http_client.OK}
headers.update(client._GCE_HEADERS)
else:
response.status = http_client.NOT_FOUND

connection = mock.Mock()
connection.getresponse = mock.Mock(name='getresponse',
return_value=response)
if socket_error:
connection.getresponse.side_effect = socket.error()
headers = {'status': http_client.NOT_FOUND}

http = http_mock.HttpMock(headers=headers)
with mock.patch('oauth2client.client.os') as os_module:
os_module.environ = {client._SERVER_SOFTWARE: server_software}
with mock.patch('oauth2client.client.six') as six_module:
http_client_module = six_module.moves.http_client
http_client_module.HTTPConnection = mock.Mock(
name='HTTPConnection', return_value=connection)

with mock.patch('oauth2client.transport.get_http_object',
return_value=http) as new_http:
if server_software == '':
self.assertFalse(client._in_gae_environment())
else:
self.assertTrue(client._in_gae_environment())

if status_ok and not socket_error and server_software == '':
if status_ok and server_software == '':
self.assertTrue(client._in_gce_environment())
else:
self.assertFalse(client._in_gce_environment())

# Verify mocks.
if server_software == '':
http_client_module.HTTPConnection.assert_called_once_with(
client._GCE_METADATA_HOST,
new_http.assert_called_once_with(
timeout=client.GCE_METADATA_TIMEOUT)
connection.getresponse.assert_called_once_with()
# Remaining calls are not "getresponse"
headers = {
client._METADATA_FLAVOR_HEADER: (
client._DESIRED_METADATA_FLAVOR),
}
self.assertEqual(connection.method_calls, [
mock.call.request('GET', '/',
headers=headers),
mock.call.close(),
])
self.assertEqual(response.method_calls, [])
if status_ok and not socket_error:
response.getheader.assert_called_once_with(
client._METADATA_FLAVOR_HEADER)
self.assertEqual(http.requests, 1)
self.assertEqual(http.uri, client._GCE_METADATA_URI)
self.assertEqual(http.method, 'GET')
self.assertIsNone(http.body)
self.assertEqual(http.headers, client._GCE_HEADERS)
else:
self.assertEqual(
http_client_module.HTTPConnection.mock_calls, [])
self.assertEqual(connection.getresponse.mock_calls, [])
# Remaining calls are not "getresponse"
self.assertEqual(connection.method_calls, [])
self.assertEqual(response.method_calls, [])
self.assertEqual(response.getheader.mock_calls, [])
new_http.assert_not_called()
self.assertEqual(http.requests, 0)

def test_environment_check_gce_production(self):
self._environment_check_gce_helper(status_ok=True)
Expand All @@ -426,8 +398,21 @@ def test_environment_check_gce_prod_with_working_gae_imports(self):
with mock_module_import('google.appengine'):
self._environment_check_gce_helper(status_ok=True)

def test_environment_check_gce_timeout(self):
self._environment_check_gce_helper(socket_error=True)
@mock.patch('oauth2client.client.os.environ',
new={client._SERVER_SOFTWARE: ''})
@mock.patch('oauth2client.transport.get_http_object',
return_value=object())
@mock.patch('oauth2client.transport.request',
side_effect=socket.timeout())
def test_environment_check_gce_timeout(self, mock_request, new_http):
self.assertFalse(client._in_gae_environment())
self.assertFalse(client._in_gce_environment())

# Verify mocks.
new_http.assert_called_once_with(timeout=client.GCE_METADATA_TIMEOUT)
mock_request.assert_called_once_with(
new_http.return_value, client._GCE_METADATA_URI,
headers=client._GCE_HEADERS)

def test_environ_check_gae_module_unknown(self):
with mock_module_import('google.appengine'):
Expand Down