diff --git a/firebase-sample/app.py b/firebase-sample/app.py index 725c7ab..0916e7c 100644 --- a/firebase-sample/app.py +++ b/firebase-sample/app.py @@ -1,11 +1,12 @@ import googleclouddebugger -googleclouddebugger.enable(use_firebase= True) + +googleclouddebugger.enable(use_firebase=True) from flask import Flask app = Flask(__name__) + @app.route("/") def hello_world(): return "

Hello World!

" - diff --git a/src/googleclouddebugger/firebase_client.py b/src/googleclouddebugger/firebase_client.py index 0fd3fd0..be59a3c 100644 --- a/src/googleclouddebugger/firebase_client.py +++ b/src/googleclouddebugger/firebase_client.py @@ -102,6 +102,8 @@ def __init__(self): self._transmission_thread = None self._transmission_thread_startup_lock = threading.Lock() self._transmission_queue = deque(maxlen=100) + self._mark_active_timer = None + self._mark_active_interval_sec = 60 * 60 # 1 hour in seconds self._new_updates = threading.Event() self._breakpoint_subscription = None @@ -206,7 +208,8 @@ def SetupAuth(self, try: r = requests.get( f'{_METADATA_SERVER_URL}/project/project-id', - headers={'Metadata-Flavor': 'Google'}) + headers={'Metadata-Flavor': 'Google'}, + timeout=1) project_id = r.text except requests.exceptions.RequestException: native.LogInfo('Metadata server not available') @@ -246,6 +249,10 @@ def Stop(self): self._transmission_thread.join() self._transmission_thread = None + if self._mark_active_timer is not None: + self._mark_active_timer.cancel() + self._mark_active_timer = None + if self._breakpoint_subscription is not None: self._breakpoint_subscription.close() self._breakpoint_subscription = None @@ -302,6 +309,8 @@ def _MainThreadProc(self): subscription_required, delay = self._SubscribeToBreakpoints() self.subscription_complete.set() + self._StartMarkActiveTimer() + def _TransmissionThreadProc(self): """Entry point for the transmission worker thread.""" @@ -312,6 +321,22 @@ def _TransmissionThreadProc(self): self._new_updates.wait(delay) + def _MarkActiveTimerFunc(self): + """Entry point for the mark active timer.""" + + try: + self._MarkDebuggeeActive() + except: + native.LogInfo( + f'Failed to mark debuggee as active: {traceback.format_exc()}') + finally: + self._StartMarkActiveTimer() + + def _StartMarkActiveTimer(self): + self._mark_active_timer = threading.Timer(self._mark_active_interval_sec, + self._MarkActiveTimerFunc) + self._mark_active_timer.start() + def _RegisterDebuggee(self): """Single attempt to register the debuggee. @@ -334,12 +359,21 @@ def _RegisterDebuggee(self): return (True, self.register_backoff.Failed()) try: - debuggee_path = f'cdbg/debuggees/{self._debuggee_id}' - native.LogInfo( - f'registering at {self._database_url}, path: {debuggee_path}') - firebase_admin.db.reference(debuggee_path).set(debuggee) + present = self._CheckDebuggeePresence() + if present: + self._MarkDebuggeeActive() + else: + debuggee_path = f'cdbg/debuggees/{self._debuggee_id}' + native.LogInfo( + f'registering at {self._database_url}, path: {debuggee_path}') + debuggee_data = copy.deepcopy(debuggee) + debuggee_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'} + debuggee_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'} + firebase_admin.db.reference(debuggee_path).set(debuggee_data) + native.LogInfo( f'Debuggee registered successfully, ID: {self._debuggee_id}') + self.register_backoff.Succeeded() return (False, 0) # Proceed immediately to subscribing to breakpoints. except BaseException: @@ -348,6 +382,26 @@ def _RegisterDebuggee(self): native.LogInfo(f'Failed to register debuggee: {traceback.format_exc()}') return (True, self.register_backoff.Failed()) + def _CheckDebuggeePresence(self): + path = f'cdbg/debuggees/{self._debuggee_id}/registrationTimeUnixMsec' + try: + snapshot = firebase_admin.db.reference(path).get() + # The value doesn't matter; just return true if there's any value. + return snapshot is not None + except BaseException: + native.LogInfo( + f'Failed to check debuggee presence: {traceback.format_exc()}') + return False + + def _MarkDebuggeeActive(self): + active_path = f'cdbg/debuggees/{self._debuggee_id}/lastUpdateTimeUnixMsec' + try: + server_time = {'.sv': 'timestamp'} + firebase_admin.db.reference(active_path).set(server_time) + except BaseException: + native.LogInfo( + f'Failed to mark debuggee active: {traceback.format_exc()}') + def _SubscribeToBreakpoints(self): # Kill any previous subscriptions first. if self._breakpoint_subscription is not None: @@ -374,7 +428,7 @@ def _ActiveBreakpointCallback(self, event): if event.path != '/': breakpoint_id = event.path[1:] # Breakpoint may have already been deleted, so pop for possible no-op. - self._breakpoints.pop(breakpoint_id, None) + self._breakpoints.pop(breakpoint_id, None) else: if event.path == '/': # New set of breakpoints. diff --git a/tests/firebase_client_test.py b/tests/firebase_client_test.py index cf60e3e..5cd8fb6 100644 --- a/tests/firebase_client_test.py +++ b/tests/firebase_client_test.py @@ -1,5 +1,6 @@ """Unit tests for firebase_client module.""" +import copy import os import sys import tempfile @@ -77,10 +78,14 @@ def setUp(self): self.addCleanup(patcher.stop) # Set up the mocks for the database refs. + self._mock_presence_ref = MagicMock() + self._mock_presence_ref.get.return_value = None + self._mock_active_ref = MagicMock() self._mock_register_ref = MagicMock() self._fake_subscribe_ref = FakeReference() self._mock_db_ref.side_effect = [ - self._mock_register_ref, self._fake_subscribe_ref + self._mock_presence_ref, self._mock_register_ref, + self._fake_subscribe_ref ] def tearDown(self): @@ -139,18 +144,46 @@ def testStart(self): self._mock_initialize_app.assert_called_with( None, {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'}) self.assertEqual([ + call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec'), call(f'cdbg/debuggees/{debuggee_id}'), call(f'cdbg/breakpoints/{debuggee_id}/active') ], self._mock_db_ref.call_args_list) # Verify that the register call has been made. - self._mock_register_ref.set.assert_called_once_with( - self._client._GetDebuggee()) + expected_data = copy.deepcopy(self._client._GetDebuggee()) + expected_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'} + expected_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'} + self._mock_register_ref.set.assert_called_once_with(expected_data) + + def testStartAlreadyPresent(self): + # Create a mock for just this test that claims the debuggee is registered. + mock_presence_ref = MagicMock() + mock_presence_ref.get.return_value = 'present!' + + self._mock_db_ref.side_effect = [ + mock_presence_ref, self._mock_active_ref, self._fake_subscribe_ref + ] + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + debuggee_id = self._client._debuggee_id + + self.assertEqual([ + call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec'), + call(f'cdbg/debuggees/{debuggee_id}/lastUpdateTimeUnixMsec'), + call(f'cdbg/breakpoints/{debuggee_id}/active') + ], self._mock_db_ref.call_args_list) + + # Verify that the register call has been made. + self._mock_active_ref.set.assert_called_once_with({'.sv': 'timestamp'}) def testStartRegisterRetry(self): - # A new db ref is fetched on each retry. + # A new set of db refs are fetched on each retry. self._mock_db_ref.side_effect = [ - self._mock_register_ref, self._mock_register_ref, + self._mock_presence_ref, self._mock_register_ref, + self._mock_presence_ref, self._mock_register_ref, self._fake_subscribe_ref ] @@ -169,6 +202,7 @@ def testStartSubscribeRetry(self): # A new db ref is fetched on each retry. self._mock_db_ref.side_effect = [ + self._mock_presence_ref, self._mock_register_ref, mock_subscribe_ref, # Fail the first time self._fake_subscribe_ref # Succeed the second time @@ -178,7 +212,28 @@ def testStartSubscribeRetry(self): self._client.Start() self._client.subscription_complete.wait() - self.assertEqual(3, self._mock_db_ref.call_count) + self.assertEqual(4, self._mock_db_ref.call_count) + + def testMarkActiveTimer(self): + # Make sure that there are enough refs queued up. + refs = list(self._mock_db_ref.side_effect) + refs.extend([self._mock_active_ref] * 10) + self._mock_db_ref.side_effect = refs + + # Speed things WAY up rather than waiting for hours. + self._client._mark_active_interval_sec = 0.1 + + self._client.SetupAuth(project_id=TEST_PROJECT_ID) + self._client.Start() + self._client.subscription_complete.wait() + + # wait long enough for the timer to trigger a few times. + time.sleep(0.5) + + print(f'Timer triggered {self._mock_active_ref.set.call_count} times') + self.assertTrue(self._mock_active_ref.set.call_count > 3) + self._mock_active_ref.set.assert_called_with({'.sv': 'timestamp'}) + def testBreakpointSubscription(self): # This class will keep track of the breakpoint updates and will check @@ -219,12 +274,10 @@ def callback(self, new_breakpoints): }, ] - expected_results = [[breakpoints[0]], - [breakpoints[0], breakpoints[1]], + expected_results = [[breakpoints[0]], [breakpoints[0], breakpoints[1]], [breakpoints[0], breakpoints[1], breakpoints[2]], [breakpoints[1], breakpoints[2]], - [breakpoints[1], breakpoints[2]] - ] + [breakpoints[1], breakpoints[2]]] result_checker = ResultChecker(expected_results, self) self._client.on_active_breakpoints_changed = result_checker.callback @@ -257,8 +310,9 @@ def testEnqueueBreakpointUpdate(self): final_ref_mock = MagicMock() self._mock_db_ref.side_effect = [ - self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, - snapshot_ref_mock, final_ref_mock + self._mock_presence_ref, self._mock_register_ref, + self._fake_subscribe_ref, active_ref_mock, snapshot_ref_mock, + final_ref_mock ] self._client.SetupAuth(project_id=TEST_PROJECT_ID) @@ -316,13 +370,13 @@ def testEnqueueBreakpointUpdate(self): db_ref_calls = self._mock_db_ref.call_args_list self.assertEqual( call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'), - db_ref_calls[2]) + db_ref_calls[3]) self.assertEqual( call(f'cdbg/breakpoints/{debuggee_id}/snapshot/{breakpoint_id}'), - db_ref_calls[3]) + db_ref_calls[4]) self.assertEqual( call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'), - db_ref_calls[4]) + db_ref_calls[5]) active_ref_mock.delete.assert_called_once() snapshot_ref_mock.set.assert_called_once_with(full_breakpoint) @@ -333,8 +387,8 @@ def testEnqueueBreakpointUpdateWithLogpoint(self): final_ref_mock = MagicMock() self._mock_db_ref.side_effect = [ - self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock, - final_ref_mock + self._mock_presence_ref, self._mock_register_ref, + self._fake_subscribe_ref, active_ref_mock, final_ref_mock ] self._client.SetupAuth(project_id=TEST_PROJECT_ID) @@ -383,10 +437,10 @@ def testEnqueueBreakpointUpdateWithLogpoint(self): db_ref_calls = self._mock_db_ref.call_args_list self.assertEqual( call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'), - db_ref_calls[2]) + db_ref_calls[3]) self.assertEqual( call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'), - db_ref_calls[3]) + db_ref_calls[4]) active_ref_mock.delete.assert_called_once() final_ref_mock.set.assert_called_once_with(output_breakpoint) @@ -414,6 +468,7 @@ def testEnqueueBreakpointUpdateRetry(self): ] self._mock_db_ref.side_effect = [ + self._mock_presence_ref, self._mock_register_ref, self._fake_subscribe_ref, # setup active_ref_mock, # attempt 1