diff --git a/src/googleclouddebugger/firebase_client.py b/src/googleclouddebugger/firebase_client.py index 0da2a9e..4cb414a 100644 --- a/src/googleclouddebugger/firebase_client.py +++ b/src/googleclouddebugger/firebase_client.py @@ -14,14 +14,15 @@ """Communicates with Firebase RTDB backend.""" from collections import deque +import copy import hashlib import json import os import platform import requests -import socket import sys import threading +import time import traceback import firebase_admin @@ -114,6 +115,7 @@ def __init__(self): # Delay before retrying failed request. self.register_backoff = backoff.Backoff() # Register debuggee. + self.subscribe_backoff = backoff.Backoff() # Subscribe to updates. self.update_backoff = backoff.Backoff() # Update breakpoint. # Maximum number of times that the message is re-transmitted before it @@ -279,13 +281,25 @@ def _MainThreadProc(self): self._breakpoint_subscription. """ # Note: if self._credentials is None, default app credentials will be used. - # TODO: Error handling. - firebase_admin.initialize_app(self._credentials, - {'databaseURL': self._database_url}) + try: + firebase_admin.initialize_app(self._credentials, + {'databaseURL': self._database_url}) + except ValueError: + native.LogWarning( + f'Failed to initialize firebase: {traceback.format_exc()}') + native.LogError('Failed to start debugger agent. Giving up.') + return - self._RegisterDebuggee() + registration_required, delay = True, 0 + while registration_required: + time.sleep(delay) + registration_required, delay = self._RegisterDebuggee() self.registration_complete.set() - self._SubscribeToBreakpoints() + + subscription_required, delay = True, 0 + while subscription_required: + time.sleep(delay) + subscription_required, delay = self._SubscribeToBreakpoints() self.subscription_complete.set() def _TransmissionThreadProc(self): @@ -310,26 +324,29 @@ def _RegisterDebuggee(self): Returns: (registration_required, delay) tuple """ + debuggee = None try: debuggee = self._GetDebuggee() self._debuggee_id = debuggee['id'] - - try: - debuggee_path = f'cdbg/debuggees/{self._debuggee_id}' - native.LogInfo( - f'registering at {self._database_url}, path: {debuggee_path}') - firebase_admin.db.reference(debuggee_path).set(debuggee) - native.LogInfo( - f'Debuggee registered successfully, ID: {self._debuggee_id}') - self.register_backoff.Succeeded() - return (False, 0) # Proceed immediately to subscribing to breakpoints. - except BaseException: - native.LogInfo(f'Failed to register debuggee: {traceback.format_exc()}') except BaseException: - native.LogWarning('Debuggee information not available: ' + - traceback.format_exc()) + native.LogWarning( + f'Debuggee information not available: {traceback.format_exc()}') + return (True, self.register_backoff.Failed()) - return (True, self.register_backoff.Failed()) + try: + debuggee_path = f'cdbg/debuggees/{self._debuggee_id}' + native.LogInfo( + f'registering at {self._database_url}, path: {debuggee_path}') + firebase_admin.db.reference(debuggee_path).set(debuggee) + native.LogInfo( + f'Debuggee registered successfully, ID: {self._debuggee_id}') + self.register_backoff.Succeeded() + return (False, 0) # Proceed immediately to subscribing to breakpoints. + except BaseException: + # There is no significant benefit to handing different exceptions + # in different ways; we will log and retry regardless. + native.LogInfo(f'Failed to register debuggee: {traceback.format_exc()}') + return (True, self.register_backoff.Failed()) def _SubscribeToBreakpoints(self): # Kill any previous subscriptions first. @@ -340,7 +357,13 @@ def _SubscribeToBreakpoints(self): path = f'cdbg/breakpoints/{self._debuggee_id}/active' native.LogInfo(f'Subscribing to breakpoint updates at {path}') ref = firebase_admin.db.reference(path) - self._breakpoint_subscription = ref.listen(self._ActiveBreakpointCallback) + try: + self._breakpoint_subscription = ref.listen(self._ActiveBreakpointCallback) + return (False, 0) + except firebase_admin.exceptions.FirebaseError: + native.LogInfo( + f'Failed to subscribe to breakpoints: {traceback.format_exc()}') + return (True, self.subscribe_backoff.Failed()) def _ActiveBreakpointCallback(self, event): if event.event_type == 'put': @@ -410,7 +433,7 @@ def _TransmitBreakpointUpdates(self): try: # Something has changed on the breakpoint. # It should be going from active to final, but let's make sure. - if not breakpoint_data['isFinalState']: + if not breakpoint_data.get('isFinalState', False): raise BaseException( f'Unexpected breakpoint update requested: {breakpoint_data}') @@ -428,6 +451,7 @@ def _TransmitBreakpointUpdates(self): f'cdbg/breakpoints/{self._debuggee_id}/active/{bp_id}') bp_ref.delete() + summary_data = breakpoint_data # Save snapshot data for snapshots only. if is_snapshot: # Note that there may not be snapshot data. @@ -436,14 +460,15 @@ def _TransmitBreakpointUpdates(self): bp_ref.set(breakpoint_data) # Now strip potential snapshot data. - breakpoint_data.pop('evaluatedExpressions', None) - breakpoint_data.pop('stackFrames', None) - breakpoint_data.pop('variableTable', None) + summary_data = copy.deepcopy(breakpoint_data) + summary_data.pop('evaluatedExpressions', None) + summary_data.pop('stackFrames', None) + summary_data.pop('variableTable', None) # Then add it to the list of final breakpoints. bp_ref = firebase_admin.db.reference( f'cdbg/breakpoints/{self._debuggee_id}/final/{bp_id}') - bp_ref.set(breakpoint_data) + bp_ref.set(summary_data) native.LogInfo(f'Breakpoint {bp_id} update transmitted successfully') @@ -460,15 +485,7 @@ def _TransmitBreakpointUpdates(self): # This is very common if multiple instances are sending final update # simultaneously. native.LogInfo(f'{err}, breakpoint: {bp_id}') - except socket.error as err: - if retry_count < self.max_transmit_attempts - 1: - native.LogInfo(f'Socket error {err.errno} while sending breakpoint ' - f'{bp_id} update: {traceback.format_exc()}') - retry_list.append((breakpoint_data, retry_count + 1)) - else: - native.LogWarning(f'Breakpoint {bp_id} retry count exceeded maximum') - # Socket errors shouldn't persist like this; reconnect. - #reconnect = True + except BaseException: native.LogWarning(f'Fatal error sending breakpoint {bp_id} update: ' f'{traceback.format_exc()}') diff --git a/tests/firebase_client_test.py b/tests/firebase_client_test.py index c1690b2..1986a9a 100644 --- a/tests/firebase_client_test.py +++ b/tests/firebase_client_test.py @@ -1,10 +1,9 @@ """Unit tests for firebase_client module.""" -import errno import os -import socket import sys import tempfile +import time from unittest import mock from unittest.mock import MagicMock from unittest.mock import call @@ -12,7 +11,6 @@ import requests import requests_mock -from googleapiclient.errors import HttpError from googleclouddebugger import version from googleclouddebugger import firebase_client @@ -20,6 +18,7 @@ from absl.testing import parameterized import firebase_admin.credentials +from firebase_admin.exceptions import FirebaseError TEST_PROJECT_ID = 'test-project-id' METADATA_PROJECT_URL = ('http://metadata.google.internal/computeMetadata/' @@ -59,6 +58,31 @@ def setUp(self): self.breakpoints_changed_count = 0 self.breakpoints = {} + # Speed up the delays for retry loops. + for backoff in [ + self._client.register_backoff, self._client.subscribe_backoff, + self._client.update_backoff + ]: + backoff.min_interval_sec /= 100000.0 + backoff.max_interval_sec /= 100000.0 + backoff._current_interval_sec /= 100000.0 + + # Set up patchers. + patcher = patch('firebase_admin.initialize_app') + self._mock_initialize_app = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch('firebase_admin.db.reference') + self._mock_db_ref = patcher.start() + self.addCleanup(patcher.stop) + + # Set up the mocks for the database refs. + self._mock_register_ref = MagicMock() + self._fake_subscribe_ref = FakeReference() + self._mock_db_ref.side_effect = [ + self._mock_register_ref, self._fake_subscribe_ref + ] + def tearDown(self): self._client.Stop() @@ -105,33 +129,58 @@ def testSetupAuthNoProjectId(self): with self.assertRaises(firebase_client.NoProjectIdError): self._client.SetupAuth() - @patch('firebase_admin.db.reference') - @patch('firebase_admin.initialize_app') - def testStart(self, mock_initialize_app, mock_db_ref): + def testStart(self): self._client.SetupAuth(project_id=TEST_PROJECT_ID) self._client.Start() self._client.subscription_complete.wait() debuggee_id = self._client._debuggee_id - mock_initialize_app.assert_called_with( + self._mock_initialize_app.assert_called_with( None, {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}) self.assertEqual([ call(f'cdbg/debuggees/{debuggee_id}'), call(f'cdbg/breakpoints/{debuggee_id}/active') - ], mock_db_ref.call_args_list) + ], self._mock_db_ref.call_args_list) + + # Verify that the register call has been made. + self._mock_register_ref.set.assert_called_once_with( + self._client._GetDebuggee()) + + def testStartRegisterRetry(self): + # A new db ref is fetched on each retry. + self._mock_db_ref.side_effect = [ + self._mock_register_ref, self._mock_register_ref, + self._fake_subscribe_ref + ] + + # Fail once, then succeed on retry. + self._mock_register_ref.set.side_effect = [FirebaseError(1, 'foo'), None] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.registration_complete.wait() - # TODO: testStartRegisterRetry - # TODO: testStartSubscribeRetry - # - Note: failures don't require retrying registration. + self.assertEqual(2, self._mock_register_ref.set.call_count) - @patch('firebase_admin.db.reference') - @patch('firebase_admin.initialize_app') - def testBreakpointSubscription(self, mock_initialize_app, mock_db_ref): - mock_register_ref = MagicMock() - fake_subscribe_ref = FakeReference() - mock_db_ref.side_effect = [mock_register_ref, fake_subscribe_ref] + def testStartSubscribeRetry(self): + mock_subscribe_ref = MagicMock() + mock_subscribe_ref.listen.side_effect = FirebaseError(1, 'foo') + # A new db ref is fetched on each retry. + self._mock_db_ref.side_effect = [ + self._mock_register_ref, + mock_subscribe_ref, # Fail the first time + self._fake_subscribe_ref # Succeed the second time + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + self.assertEqual(3, self._mock_db_ref.call_count) + + def testBreakpointSubscription(self): # This class will keep track of the breakpoint updates and will check # them against expectations. class ResultChecker: @@ -182,15 +231,247 @@ def callback(self, new_breakpoints): self._client.subscription_complete.wait() # Send in updates to trigger the subscription callback. - fake_subscribe_ref.update('put', '/', - {breakpoints[0]['id']: breakpoints[0]}) - fake_subscribe_ref.update('patch', '/', - {breakpoints[1]['id']: breakpoints[1]}) - fake_subscribe_ref.update('put', f'/{breakpoints[2]["id"]}', breakpoints[2]) - fake_subscribe_ref.update('put', f'/{breakpoints[0]["id"]}', None) + self._fake_subscribe_ref.update('put', '/', + {breakpoints[0]['id']: breakpoints[0]}) + self._fake_subscribe_ref.update('patch', '/', + {breakpoints[1]['id']: breakpoints[1]}) + self._fake_subscribe_ref.update('put', f'/{breakpoints[2]["id"]}', + breakpoints[2]) + self._fake_subscribe_ref.update('put', f'/{breakpoints[0]["id"]}', None) self.assertEqual(len(expected_results), result_checker._change_count) + def testEnqueueBreakpointUpdate(self): + active_ref_mock = MagicMock() + snapshot_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + self._mock_db_ref.side_effect = [ + self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, + snapshot_ref_mock, final_ref_mock + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + breakpoint_id = 'breakpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + } + short_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + full_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. + while self._client._transmission_queue: + time.sleep(0.1) + + db_ref_calls = self._mock_db_ref.call_args_list + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'), + db_ref_calls[2]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/snapshots/{breakpoint_id}'), + db_ref_calls[3]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'), + db_ref_calls[4]) + + active_ref_mock.delete.assert_called_once() + snapshot_ref_mock.set.assert_called_once_with(full_breakpoint) + final_ref_mock.set.assert_called_once_with(short_breakpoint) + + def testEnqueueBreakpointUpdateWithLogpoint(self): + active_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + self._mock_db_ref.side_effect = [ + self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, + final_ref_mock + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + breakpoint_id = 'logpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'action': 'LOG', + 'isFinalState': True, + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + }, + } + output_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'LOG', + 'status': { + 'isError': True, + 'refersTo': 'BREAKPOINT_SOURCE_LOCATION', + }, + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. + while self._client._transmission_queue: + time.sleep(0.1) + + db_ref_calls = self._mock_db_ref.call_args_list + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'), + db_ref_calls[2]) + self.assertEqual( + call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'), + db_ref_calls[3]) + + active_ref_mock.delete.assert_called_once() + final_ref_mock.set.assert_called_once_with(output_breakpoint) + + # Make sure that the snapshots node was not accessed. + self.assertTrue( + call(f'cdbg/breakpoints/{debuggee_id}/snapshots/{breakpoint_id}') not in + db_ref_calls) + + def testEnqueueBreakpointUpdateRetry(self): + active_ref_mock = MagicMock() + snapshot_ref_mock = MagicMock() + final_ref_mock = MagicMock() + + # This test will have three failures, one for each of the firebase writes. + # UNAVAILABLE errors are retryable. + active_ref_mock.delete.side_effect = [ + FirebaseError('UNAVAILABLE', 'active error'), None, None, None + ] + snapshot_ref_mock.set.side_effect = [ + FirebaseError('UNAVAILABLE', 'snapshot error'), None, None + ] + final_ref_mock.set.side_effect = [ + FirebaseError('UNAVAILABLE', 'final error'), None + ] + + self._mock_db_ref.side_effect = [ + self._mock_register_ref, + self._fake_subscribe_ref, # setup + active_ref_mock, # attempt 1 + active_ref_mock, + snapshot_ref_mock, # attempt 2 + active_ref_mock, + snapshot_ref_mock, + final_ref_mock, # attempt 3 + active_ref_mock, + snapshot_ref_mock, + final_ref_mock # attempt 4 + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + breakpoint_id = 'breakpoint-0' + + input_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + } + short_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + full_breakpoint = { + 'id': breakpoint_id, + 'location': { + 'path': 'foo.py', + 'line': 18 + }, + 'isFinalState': True, + 'action': 'CAPTURE', + 'evaluatedExpressions': ['expressions go here'], + 'stackFrames': ['stuff goes here'], + 'variableTable': ['lots', 'of', 'variables'], + 'finalTimeUnixMsec': { + '.sv': 'timestamp' + } + } + + self._client.EnqueueBreakpointUpdate(input_breakpoint) + + # Wait for the breakpoint to be sent. Retries will have occured. + while self._client._transmission_queue: + time.sleep(0.1) + + active_ref_mock.delete.assert_has_calls([call()] * 4) + snapshot_ref_mock.set.assert_has_calls([call(full_breakpoint)] * 3) + final_ref_mock.set.assert_has_calls([call(short_breakpoint)] * 2) + def _TestInitializeLabels(self, module_var, version_var, minor_var): self._client.SetupAuth(project_id=TEST_PROJECT_ID)