Skip to content
This repository has been archived by the owner on Jan 23, 2024. It is now read-only.

feat: add active debuggee support #64

Merged
merged 3 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions firebase-sample/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import googleclouddebugger
googleclouddebugger.enable(use_firebase= True)

googleclouddebugger.enable(use_firebase=True)

from flask import Flask

app = Flask(__name__)


@app.route("/")
def hello_world():
return "<p>Hello World!</p>"

66 changes: 60 additions & 6 deletions src/googleclouddebugger/firebase_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __init__(self):
self._transmission_thread = None
self._transmission_thread_startup_lock = threading.Lock()
self._transmission_queue = deque(maxlen=100)
self._mark_active_timer = None
self._mark_active_interval_sec = 60 * 60 # 1 hour in seconds
self._new_updates = threading.Event()
self._breakpoint_subscription = None

Expand Down Expand Up @@ -206,7 +208,8 @@ def SetupAuth(self,
try:
r = requests.get(
f'{_METADATA_SERVER_URL}/project/project-id',
headers={'Metadata-Flavor': 'Google'})
headers={'Metadata-Flavor': 'Google'},
timeout=1)
project_id = r.text
except requests.exceptions.RequestException:
native.LogInfo('Metadata server not available')
Expand Down Expand Up @@ -246,6 +249,10 @@ def Stop(self):
self._transmission_thread.join()
self._transmission_thread = None

if self._mark_active_timer is not None:
self._mark_active_timer.cancel()
self._mark_active_timer = None

if self._breakpoint_subscription is not None:
self._breakpoint_subscription.close()
self._breakpoint_subscription = None
Expand Down Expand Up @@ -302,6 +309,8 @@ def _MainThreadProc(self):
subscription_required, delay = self._SubscribeToBreakpoints()
self.subscription_complete.set()

self._StartMarkActiveTimer()

def _TransmissionThreadProc(self):
"""Entry point for the transmission worker thread."""

Expand All @@ -312,6 +321,22 @@ def _TransmissionThreadProc(self):

self._new_updates.wait(delay)

def _MarkActiveTimerFunc(self):
"""Entry point for the mark active timer."""

try:
self._MarkDebuggeeActive()
except:
native.LogInfo(
f'Failed to mark debuggee as active: {traceback.format_exc()}')
finally:
self._StartMarkActiveTimer()
mctavish marked this conversation as resolved.
Show resolved Hide resolved

def _StartMarkActiveTimer(self):
self._mark_active_timer = threading.Timer(self._mark_active_interval_sec,
self._MarkActiveTimerFunc)
self._mark_active_timer.start()

def _RegisterDebuggee(self):
"""Single attempt to register the debuggee.

Expand All @@ -334,12 +359,21 @@ def _RegisterDebuggee(self):
return (True, self.register_backoff.Failed())

try:
debuggee_path = f'cdbg/debuggees/{self._debuggee_id}'
native.LogInfo(
f'registering at {self._database_url}, path: {debuggee_path}')
firebase_admin.db.reference(debuggee_path).set(debuggee)
present = self._CheckDebuggeePresence()
if present:
self._MarkDebuggeeActive()
else:
debuggee_path = f'cdbg/debuggees/{self._debuggee_id}'
native.LogInfo(
f'registering at {self._database_url}, path: {debuggee_path}')
debuggee_data = copy.deepcopy(debuggee)
debuggee_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'}
debuggee_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'}
firebase_admin.db.reference(debuggee_path).set(debuggee_data)

native.LogInfo(
f'Debuggee registered successfully, ID: {self._debuggee_id}')

self.register_backoff.Succeeded()
return (False, 0) # Proceed immediately to subscribing to breakpoints.
except BaseException:
Expand All @@ -348,6 +382,26 @@ def _RegisterDebuggee(self):
native.LogInfo(f'Failed to register debuggee: {traceback.format_exc()}')
return (True, self.register_backoff.Failed())

def _CheckDebuggeePresence(self):
path = f'cdbg/debuggees/{self._debuggee_id}/registrationTimeUnixMsec'
try:
snapshot = firebase_admin.db.reference(path).get()
# The value doesn't matter; just return true if there's any value.
return snapshot is not None
except BaseException:
native.LogInfo(
f'Failed to check debuggee presence: {traceback.format_exc()}')
return False

def _MarkDebuggeeActive(self):
active_path = f'cdbg/debuggees/{self._debuggee_id}/lastUpdateTimeUnixMsec'
try:
server_time = {'.sv': 'timestamp'}
firebase_admin.db.reference(active_path).set(server_time)
except BaseException:
native.LogInfo(
f'Failed to mark debuggee active: {traceback.format_exc()}')

def _SubscribeToBreakpoints(self):
# Kill any previous subscriptions first.
if self._breakpoint_subscription is not None:
Expand All @@ -374,7 +428,7 @@ def _ActiveBreakpointCallback(self, event):
if event.path != '/':
breakpoint_id = event.path[1:]
# Breakpoint may have already been deleted, so pop for possible no-op.
self._breakpoints.pop(breakpoint_id, None)
self._breakpoints.pop(breakpoint_id, None)
else:
if event.path == '/':
# New set of breakpoints.
Expand Down
93 changes: 74 additions & 19 deletions tests/firebase_client_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit tests for firebase_client module."""

import copy
import os
import sys
import tempfile
Expand Down Expand Up @@ -77,10 +78,14 @@ def setUp(self):
self.addCleanup(patcher.stop)

# Set up the mocks for the database refs.
self._mock_presence_ref = MagicMock()
self._mock_presence_ref.get.return_value = None
self._mock_active_ref = MagicMock()
self._mock_register_ref = MagicMock()
self._fake_subscribe_ref = FakeReference()
self._mock_db_ref.side_effect = [
self._mock_register_ref, self._fake_subscribe_ref
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref
]

def tearDown(self):
Expand Down Expand Up @@ -139,18 +144,46 @@ def testStart(self):
self._mock_initialize_app.assert_called_with(
None, {'databaseURL': f'https://{TEST_PROJECT_ID}-cdbg.firebaseio.com'})
self.assertEqual([
call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec'),
call(f'cdbg/debuggees/{debuggee_id}'),
call(f'cdbg/breakpoints/{debuggee_id}/active')
], self._mock_db_ref.call_args_list)

# Verify that the register call has been made.
self._mock_register_ref.set.assert_called_once_with(
self._client._GetDebuggee())
expected_data = copy.deepcopy(self._client._GetDebuggee())
expected_data['registrationTimeUnixMsec'] = {'.sv': 'timestamp'}
expected_data['lastUpdateTimeUnixMsec'] = {'.sv': 'timestamp'}
self._mock_register_ref.set.assert_called_once_with(expected_data)

def testStartAlreadyPresent(self):
# Create a mock for just this test that claims the debuggee is registered.
mock_presence_ref = MagicMock()
mock_presence_ref.get.return_value = 'present!'

self._mock_db_ref.side_effect = [
mock_presence_ref, self._mock_active_ref, self._fake_subscribe_ref
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
self._client.Start()
self._client.subscription_complete.wait()

debuggee_id = self._client._debuggee_id

self.assertEqual([
call(f'cdbg/debuggees/{debuggee_id}/registrationTimeUnixMsec'),
call(f'cdbg/debuggees/{debuggee_id}/lastUpdateTimeUnixMsec'),
call(f'cdbg/breakpoints/{debuggee_id}/active')
], self._mock_db_ref.call_args_list)

# Verify that the register call has been made.
self._mock_active_ref.set.assert_called_once_with({'.sv': 'timestamp'})

def testStartRegisterRetry(self):
# A new db ref is fetched on each retry.
# A new set of db refs are fetched on each retry.
self._mock_db_ref.side_effect = [
self._mock_register_ref, self._mock_register_ref,
self._mock_presence_ref, self._mock_register_ref,
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref
]

Expand All @@ -169,6 +202,7 @@ def testStartSubscribeRetry(self):

# A new db ref is fetched on each retry.
self._mock_db_ref.side_effect = [
self._mock_presence_ref,
self._mock_register_ref,
mock_subscribe_ref, # Fail the first time
self._fake_subscribe_ref # Succeed the second time
Expand All @@ -178,7 +212,28 @@ def testStartSubscribeRetry(self):
self._client.Start()
self._client.subscription_complete.wait()

self.assertEqual(3, self._mock_db_ref.call_count)
self.assertEqual(4, self._mock_db_ref.call_count)

def testMarkActiveTimer(self):
# Make sure that there are enough refs queued up.
refs = list(self._mock_db_ref.side_effect)
refs.extend([self._mock_active_ref] * 10)
self._mock_db_ref.side_effect = refs

# Speed things WAY up rather than waiting for hours.
self._client._mark_active_interval_sec = 0.1

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
self._client.Start()
self._client.subscription_complete.wait()

# wait long enough for the timer to trigger a few times.
time.sleep(0.5)

print(f'Timer triggered {self._mock_active_ref.set.call_count} times')
self.assertTrue(self._mock_active_ref.set.call_count > 3)
self._mock_active_ref.set.assert_called_with({'.sv': 'timestamp'})


def testBreakpointSubscription(self):
# This class will keep track of the breakpoint updates and will check
Expand Down Expand Up @@ -219,12 +274,10 @@ def callback(self, new_breakpoints):
},
]

expected_results = [[breakpoints[0]],
[breakpoints[0], breakpoints[1]],
expected_results = [[breakpoints[0]], [breakpoints[0], breakpoints[1]],
[breakpoints[0], breakpoints[1], breakpoints[2]],
[breakpoints[1], breakpoints[2]],
[breakpoints[1], breakpoints[2]]
]
[breakpoints[1], breakpoints[2]]]
result_checker = ResultChecker(expected_results, self)

self._client.on_active_breakpoints_changed = result_checker.callback
Expand Down Expand Up @@ -257,8 +310,9 @@ def testEnqueueBreakpointUpdate(self):
final_ref_mock = MagicMock()

self._mock_db_ref.side_effect = [
self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock,
snapshot_ref_mock, final_ref_mock
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref, active_ref_mock, snapshot_ref_mock,
final_ref_mock
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
Expand Down Expand Up @@ -316,13 +370,13 @@ def testEnqueueBreakpointUpdate(self):
db_ref_calls = self._mock_db_ref.call_args_list
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'),
db_ref_calls[2])
db_ref_calls[3])
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/snapshot/{breakpoint_id}'),
db_ref_calls[3])
db_ref_calls[4])
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'),
db_ref_calls[4])
db_ref_calls[5])

active_ref_mock.delete.assert_called_once()
snapshot_ref_mock.set.assert_called_once_with(full_breakpoint)
Expand All @@ -333,8 +387,8 @@ def testEnqueueBreakpointUpdateWithLogpoint(self):
final_ref_mock = MagicMock()

self._mock_db_ref.side_effect = [
self._mock_register_ref, self._fake_subscribe_ref, active_ref_mock,
final_ref_mock
self._mock_presence_ref, self._mock_register_ref,
self._fake_subscribe_ref, active_ref_mock, final_ref_mock
]

self._client.SetupAuth(project_id=TEST_PROJECT_ID)
Expand Down Expand Up @@ -383,10 +437,10 @@ def testEnqueueBreakpointUpdateWithLogpoint(self):
db_ref_calls = self._mock_db_ref.call_args_list
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/active/{breakpoint_id}'),
db_ref_calls[2])
db_ref_calls[3])
self.assertEqual(
call(f'cdbg/breakpoints/{debuggee_id}/final/{breakpoint_id}'),
db_ref_calls[3])
db_ref_calls[4])

active_ref_mock.delete.assert_called_once()
final_ref_mock.set.assert_called_once_with(output_breakpoint)
Expand Down Expand Up @@ -414,6 +468,7 @@ def testEnqueueBreakpointUpdateRetry(self):
]

self._mock_db_ref.side_effect = [
self._mock_presence_ref,
self._mock_register_ref,
self._fake_subscribe_ref, # setup
active_ref_mock, # attempt 1
Expand Down