diff --git a/common/experiment_utils.py b/common/experiment_utils.py index 3911751ac..91e1639b9 100644 --- a/common/experiment_utils.py +++ b/common/experiment_utils.py @@ -104,7 +104,8 @@ def get_oss_fuzz_corpora_unarchived_path(): def get_random_corpora_filestore_path(): - """Returns path containing seed corpora for the target fuzzing experiment.""" # pylint: disable=line-too-long + """Returns path containing seed corpora for the target fuzzing + experiment.""" return posixpath.join(get_experiment_filestore_path(), 'random_corpora') diff --git a/database/models.py b/database/models.py index 02bfff7a9..49c745b86 100644 --- a/database/models.py +++ b/database/models.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """SQLAlchemy Database Models.""" +import json import sqlalchemy from sqlalchemy.ext import declarative from sqlalchemy import Boolean @@ -79,6 +80,20 @@ class Snapshot(Base): primaryjoin= 'and_(Snapshot.time==Crash.time, Snapshot.trial_id==Crash.trial_id)') + def as_dict(self): + """Transform the object into a dictionary. This util method is necessary + because __dict__ returns a _sa_instance_state internal SQLAlchemy + attribute that is not serializable""" + return { + column.name: getattr(self, column.name) + for column in self.__table__.columns + } + + def to_bytes(self): + """Transform the object into bytes format. This method is helpful to be + able to publish a snapshot object in a pubsub queue.""" + return json.dumps(self.as_dict()).encode('utf-8') + class Crash(Base): """Represents crashes found in experiments.""" diff --git a/database/test_models.py b/database/test_models.py new file mode 100644 index 000000000..a27106976 --- /dev/null +++ b/database/test_models.py @@ -0,0 +1,51 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for methods under models.py.""" +import pytest + +from database import models + + +@pytest.fixture() +def snapshot(): + """Simple pytest fixture to return a model snapshot.""" + return models.Snapshot(trial_id=1) + + +def assert_dicts_equal_ignoring_order(dict1, dict2): + """Helping function to check if two dictionaries have the same keys, and + same values for each key, ignoring the keys order.""" + assert set(dict1.keys()) == set(dict2.keys()) + for key in dict1: + assert dict1[key] == dict2[key] + + +def test_snapshot_to_bytes(snapshot): # pylint: disable=redefined-outer-name + """Tests if a snapshot model is being successfully converted to bytes + format.""" + snapshot_as_bytes = snapshot.to_bytes() + assert isinstance(snapshot_as_bytes, bytes) + + +def test_snapshot_as_dict(snapshot): # pylint: disable=redefined-outer-name + """Tests if a snapshot model is being successfully converted to a + dictionary.""" + snapshot_as_dict = snapshot.as_dict() + expected_dict = { + 'edges_covered': None, + 'fuzzer_stats': None, + 'time': None, + 'trial_id': 1 + } + assert_dicts_equal_ignoring_order(snapshot_as_dict, expected_dict) diff --git a/experiment/measurer/datatypes.py b/experiment/measurer/datatypes.py index 21415b336..28deed0cc 100644 --- a/experiment/measurer/datatypes.py +++ b/experiment/measurer/datatypes.py @@ -11,11 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Module for common data types shared under the measurer module.""" +"""Module for common data types and helping functions shared under the measurer +module.""" import collections +import json SnapshotMeasureRequest = collections.namedtuple( 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle']) RetryRequest = collections.namedtuple( 'RetryRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle']) + + +def from_dict_to_snapshot_retry_request(values: dict): + """Converts a dict into a RetryRequest named tuple.""" + return RetryRequest(values['fuzzer'], values['benchmark'], + values['trial_id'], values['cycle']) + + +def from_dict_to_snapshot_measure_request(values: dict): + """Converts a dict into a SnapshotMeasureRequest named tuple.""" + return SnapshotMeasureRequest(values['fuzzer'], values['benchmark'], + values['trial_id'], values['cycle']) + + +def from_snapshot_measure_request_to_bytes( + snapshot_measure_request: SnapshotMeasureRequest) -> bytes: + """Takes a snapshot measure request and transform it into bytes, so + it can be published in a pub sub queue.""" + return json.dumps(snapshot_measure_request._asdict()).encode('utf-8') + + +def from_retry_request_to_bytes(retry_request: RetryRequest) -> bytes: + """Takes a snapshot retry request and transform it into bytes, so + it can be published in a pub sub queue.""" + return json.dumps(retry_request._asdict()).encode('utf-8') diff --git a/experiment/measurer/measure_manager.py b/experiment/measurer/measure_manager.py index 288148401..9742de181 100644 --- a/experiment/measurer/measure_manager.py +++ b/experiment/measurer/measure_manager.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-lines """Module for measuring snapshots from trial runners.""" import collections import gc import glob import multiprocessing +import multiprocessing.queues import json import os import pathlib @@ -25,13 +27,15 @@ import tempfile import tarfile import time -from typing import List +from typing import List, Optional, Tuple import queue import psutil +import google.api_core.exceptions + from sqlalchemy import func from sqlalchemy import orm - +from google.cloud import pubsub_v1 from common import benchmark_utils from common import experiment_utils from common import experiment_path as exp_path @@ -75,9 +79,14 @@ def measure_main(experiment_config): experiment = experiment_config['experiment'] max_total_time = experiment_config['max_total_time'] measurers_cpus = experiment_config['measurers_cpus'] - region_coverage = experiment_config['region_coverage'] - measure_manager_loop(experiment, max_total_time, measurers_cpus, - region_coverage) + region_coverage = experiment_config.get('region_coverage', False) + cloud_project = experiment_config.get('cloud_project', '') + local_experiment = experiment_config.get('local_experiment', False) + + measure_manager = get_measure_manager(local_experiment, experiment, + cloud_project, measurers_cpus, + region_coverage) + measure_manager.measure_manager_loop(max_total_time) # Clean up resources. gc.collect() @@ -672,72 +681,6 @@ def initialize_logs(): }) -def consume_snapshots_from_response_queue( - response_queue, queued_snapshots) -> List[models.Snapshot]: - """Consume response_queue, allows retry objects to retried, and - return all measured snapshots in a list.""" - measured_snapshots = [] - while True: - try: - response_object = response_queue.get_nowait() - if isinstance(response_object, measurer_datatypes.RetryRequest): - # Need to retry measurement task, will remove identifier from - # the set so task can be retried in next loop iteration. - snapshot_identifier = (response_object.trial_id, - response_object.cycle) - queued_snapshots.remove(snapshot_identifier) - logger.info('Reescheduling task for trial %s and cycle %s', - response_object.trial_id, response_object.cycle) - elif isinstance(response_object, models.Snapshot): - measured_snapshots.append(response_object) - else: - logger.error('Type of response object not mapped! %s', - type(response_object)) - except queue.Empty: - break - return measured_snapshots - - -def measure_manager_inner_loop(experiment: str, max_cycle: int, request_queue, - response_queue, queued_snapshots): - """Reads from database to determine which snapshots needs measuring. Write - measurements tasks to request queue, get results from response queue, and - write measured snapshots to database. Returns False if there's no more - snapshots left to be measured""" - initialize_logs() - # Read database to determine which snapshots needs measuring. - unmeasured_snapshots = get_unmeasured_snapshots(experiment, max_cycle) - logger.info('Retrieved %d unmeasured snapshots from measure manager', - len(unmeasured_snapshots)) - # When there are no more snapshots left to be measured, should break loop. - if not unmeasured_snapshots: - return False - - # Write measurements requests to request queue - for unmeasured_snapshot in unmeasured_snapshots: - # No need to insert fuzzer and benchmark info here as it's redundant - # (Can be retrieved through trial_id). - unmeasured_snapshot_identifier = (unmeasured_snapshot.trial_id, - unmeasured_snapshot.cycle) - # Checking if snapshot already was queued so workers will not repeat - # measurement for same snapshot - if unmeasured_snapshot_identifier not in queued_snapshots: - request_queue.put(unmeasured_snapshot) - queued_snapshots.add(unmeasured_snapshot_identifier) - - # Read results from response queue. - measured_snapshots = consume_snapshots_from_response_queue( - response_queue, queued_snapshots) - logger.info('Retrieved %d measured snapshots from response queue', - len(measured_snapshots)) - - # Save measured snapshots to database. - if measured_snapshots: - db_utils.add_all(measured_snapshots) - - return True - - def get_pool_args(measurers_cpus, runners_cpus): """Return pool args based on measurer cpus and runner cpus arguments.""" if measurers_cpus is None or runners_cpus is None: @@ -755,49 +698,338 @@ def get_pool_args(measurers_cpus, runners_cpus): return (measurers_cpus, _process_init, (cores_queue,)) -def measure_manager_loop(experiment: str, - max_total_time: int, - measurers_cpus=None, - region_coverage=False): # pylint: disable=too-many-locals - """Measure manager loop. Creates request and response queues, request - measurements tasks from workers, retrieve measurement results from response - queue and writes measured snapshots in database.""" - logger.info('Starting measure manager loop.') - if not measurers_cpus: - measurers_cpus = multiprocessing.cpu_count() - logger.info('Number of measurer CPUs not passed as argument. using %d', - measurers_cpus) - with multiprocessing.Pool() as pool, multiprocessing.Manager() as manager: - logger.info('Setting up coverage binaries') - set_up_coverage_binaries(pool, experiment) - request_queue = manager.Queue() - response_queue = manager.Queue() +class BaseMeasureManager: + """Base class for measure manager. Encapsulates core methods that will be + implemented for Local and Google Cloud measure managers.""" + + def __init__(self, experiment: str, region_coverage: bool): + self.region_coverage = region_coverage + self.experiment = experiment + self.measurers_cpus = None + + def initialize_queues(self, manager): + """Initialize and returns request and response queues, respectively.""" + raise NotImplementedError + + def start_workers(self, request_queue, response_queue, pool): + """Initialize measure workers.""" + raise NotImplementedError + + def put_task_in_request_queue(self, task, request_queue): + """Put task in request queue. The request queue can be a pub sub queue + or a multiprocessing in-memory queue, depending on the + implementation.""" + raise NotImplementedError + + def get_result_from_response_queue(self, response_queue): + """Get result from request queue. Can be a pub sub queue or a + multiprocessing in-memory queue, depending on the implementation.""" + raise NotImplementedError + + def consume_snapshots_from_response_queue( + self, response_queue, queued_snapshots) -> List[models.Snapshot]: + """Consume response_queue, allows retry objects to retried, and + return all measured snapshots in a list.""" + measured_snapshots = [] + while True: + try: + response_object = self.get_result_from_response_queue( + response_queue) + if isinstance(response_object, measurer_datatypes.RetryRequest): + # Need to retry measurement task, will remove identifier + # from the set so task can be retried in next loop + # iteration. + snapshot_identifier = (response_object.trial_id, + response_object.cycle) + queued_snapshots.remove(snapshot_identifier) + logger.info('Reescheduling task for trial %s and cycle %s', + response_object.trial_id, response_object.cycle) + elif isinstance(response_object, models.Snapshot): + measured_snapshots.append(response_object) + elif response_object is None: + logger.error('Result is None. Response queue is empty.') + raise queue.Empty() + else: + logger.error('Type of response object not mapped! %s', + type(response_object)) + except queue.Empty: + break + return measured_snapshots + + def measure_manager_inner_loop(self, max_cycle: int, request_queue, + response_queue, queued_snapshots): + """Reads from database to determine which snapshots needs measuring. + Write measurements tasks to request queue, get results from response + queue, and write measured snapshots to database. Returns False if + there's no more snapshots left to be measured.""" + initialize_logs() + # Read database to determine which snapshots needs measuring. + unmeasured_snapshots = get_unmeasured_snapshots(self.experiment, + max_cycle) + logger.info('Retrieved %d unmeasured snapshots from measure manager.', + len(unmeasured_snapshots)) + # When there are no more snapshots left to be measured, should break + # loop. + if not unmeasured_snapshots: + return False + + # Write measurements requests to request queue + for unmeasured_snapshot in unmeasured_snapshots: + # No need to insert fuzzer and benchmark info here as it's redundant + # (Can be retrieved through trial_id). + unmeasured_snapshot_identifier = (unmeasured_snapshot.trial_id, + unmeasured_snapshot.cycle) + # Checking if snapshot already was queued so workers will not repeat + # measurement for same snapshot + if unmeasured_snapshot_identifier not in queued_snapshots: + self.put_task_in_request_queue(unmeasured_snapshot, + request_queue) + queued_snapshots.add(unmeasured_snapshot_identifier) + + # Read results from response queue. + measured_snapshots = self.consume_snapshots_from_response_queue( + response_queue, queued_snapshots) + logger.info('Retrieved %d measured snapshots from response queue.', + len(measured_snapshots)) + + # Save measured snapshots to database. + if measured_snapshots: + db_utils.add_all(measured_snapshots) + + return True + + def measure_manager_loop(self, max_total_time: int): + """Measure manager loop. Creates request and response queues, request + measurements tasks from workers, retrieve measurement results from + response queue and writes measured snapshots in database.""" + logger.info('Starting measure manager loop.') + if not self.measurers_cpus: + self.measurers_cpus = multiprocessing.cpu_count() + logger.info( + 'Number of measurer CPUs not passed as argument. using %d', + self.measurers_cpus) + with multiprocessing.Pool(processes=self.measurers_cpus) as pool, multiprocessing.Manager() as manager: # pylint: disable=line-too-long + set_up_coverage_binaries(pool, self.experiment) + (request_queue, response_queue) = self.initialize_queues(manager) + self.start_workers(request_queue, response_queue, pool) + max_cycle = _time_to_cycle(max_total_time) + queued_snapshots = set() + while not scheduler.all_trials_ended(self.experiment): + continue_inner_loop = self.measure_manager_inner_loop( + max_cycle, request_queue, response_queue, queued_snapshots) + if not continue_inner_loop: + break + time.sleep(MEASUREMENT_LOOP_WAIT) + logger.info('All trials ended. Ending measure manager loop') + +class LocalMeasureManager(BaseMeasureManager): + """Class that holds implementations of core methods for running a measure + worker locally.""" + + def __init__(self, + experiment: str, + region_coverage: bool, + measurers_cpus: Optional[int] = None): + super().__init__(experiment, region_coverage) + self.measurers_cpus = measurers_cpus + + def initialize_queues( + self, manager + ) -> Tuple[multiprocessing.queues.Queue, multiprocessing.queues.Queue]: + return (manager.Queue(), manager.Queue()) + + def start_workers(self, request_queue: multiprocessing.queues.Queue, + response_queue: multiprocessing.queues.Queue, pool): config = { 'request_queue': request_queue, 'response_queue': response_queue, - 'region_coverage': region_coverage, + 'region_coverage': self.region_coverage, } local_measure_worker = measure_worker.LocalMeasureWorker(config) # Since each worker is going to be in an infinite loop, we dont need - # result return. Workers' life scope will end automatically when there - # are no more snapshots left to measure. - logger.info('Starting measure worker loop for %d workers', - measurers_cpus) - for _ in range(measurers_cpus): - _result = pool.apply_async(local_measure_worker.measure_worker_loop) - - max_cycle = _time_to_cycle(max_total_time) - queued_snapshots = set() - while not scheduler.all_trials_ended(experiment): - continue_inner_loop = measure_manager_inner_loop( - experiment, max_cycle, request_queue, response_queue, - queued_snapshots) - if not continue_inner_loop: - break - time.sleep(MEASUREMENT_LOOP_WAIT) - logger.info('All trials ended. Ending measure manager loop') + # result return. Workers' life scope will end automatically when + # there are no more snapshots left to measure. + log_message = ('Starting measure worker loop for ' + f'{self.measurers_cpus}' + ' workers in local measure manager') + logger.info(log_message) + for _ in range(self.measurers_cpus): + pool.apply_async(local_measure_worker.measure_worker_loop) + + def get_result_from_response_queue( + self, response_queue: multiprocessing.queues.Queue): + return response_queue.get_nowait() + + def put_task_in_request_queue( + self, task: measurer_datatypes.SnapshotMeasureRequest, + request_queue: multiprocessing.queues.Queue): + request_queue.put_nowait(task) + + +class GoogleCloudMeasureManager(BaseMeasureManager): # pylint: disable=too-many-instance-attributes + """Measurer manager implementation that subscribe and publishes from a + Google Cloud Pub/Sub Queue, instead of multiprocessing queue.""" + + def __init__(self, + experiment: str, + cloud_project: str, + region_coverage: bool, + measurers_cpus: Optional[int] = None): + super().__init__(experiment, region_coverage) + self.subscriber_client = pubsub_v1.SubscriberClient() + self.publisher_client = pubsub_v1.PublisherClient( + publisher_options=pubsub_v1.types.PublisherOptions( + enable_message_ordering=True)) + self.project_id = cloud_project + self.request_queue_topic_id = f'request-queue-topic-{self.experiment}' + self.request_queue_topic_path = self.publisher_client.topic_path( + self.project_id, self.request_queue_topic_id) + self.response_queue_topic_id = f'response-queue-topic-{self.experiment}' + self.response_queue_topic_path = self.subscriber_client.topic_path( + self.project_id, self.response_queue_topic_id) + self.response_queue_subscription_id = ('response-queue-subscription-' + f'{self.experiment}') + self.subscription_path = self.subscriber_client.subscription_path( + self.project_id, self.response_queue_subscription_id) + self.measurers_cpus = measurers_cpus + + def initialize_queues(self, manager) -> Tuple[Optional[str], Optional[str]]: + try: + request_queue_topic = self.publisher_client.create_topic( + request={'name': self.request_queue_topic_path}) + logger.info('Request queue topic created successfully: %s', + request_queue_topic.name) + + response_queue_topic = self.publisher_client.create_topic( + request={'name': self.response_queue_topic_path}) + logger.info('Response queue topic created successfully: %s', + response_queue_topic.name) + + return request_queue_topic.name, response_queue_topic.name + except google.api_core.exceptions.GoogleAPICallError as error: + logger.error('Error while initializing queues: %s', error) + return None, None + + def _create_response_queue_subscription(self): + """Creates a new Pub/Sub subscription for the response queue.""" + try: + subscription = pubsub_v1.SubscriberClient().create_subscription( + request={ + 'name': self.subscription_path, + 'topic': self.response_queue_topic_path + }) + logger.info('Subscription %s created successfully.', + subscription.name) + return subscription.name + except google.api_core.exceptions.GoogleAPICallError as error: + logger.error('Error creating subscription %s', error) + return None + + def start_workers(self, request_queue, response_queue, pool): + self._create_response_queue_subscription() + + # Since each worker is going to be in an infinite loop, we dont need + # result return. Workers' life scope will end automatically when + # there are no more snapshots left to measure. + # pylint: disable=unreachable + log_message = ('Starting measure worker loop for ' + f'{self.measurers_cpus}' + ' workers in google cloud measure manager') + logger.info(log_message) + + config = { + 'request_queue_topic_id': self.request_queue_topic_id, + 'response_queue_topic_id': self.response_queue_topic_id, + 'region_coverage': self.region_coverage, + 'project_id': self.project_id, + 'experiment': self.experiment, + } + + # Create the worker request queue subscription once, before starting all + # workers + worker_request_queue_subscription = ('request-queue-subscription-' + f'{self.experiment}') + worker_subscription_path = self.subscriber_client.subscription_path( + self.project_id, worker_request_queue_subscription) + worker_request_queue_topic_path = self.subscriber_client.topic_path( + self.project_id, self.request_queue_topic_id) + measure_worker.GoogleCloudMeasureWorker.create_request_queue_subscription( # pylint: disable=line-too-long + worker_subscription_path, worker_request_queue_topic_path) + + for _ in range(self.measurers_cpus): + + def start_measure_workers_and_start_loop(): + google_cloud_worker = measure_worker.GoogleCloudMeasureWorker( + config) + google_cloud_worker.measure_worker_loop() + + pool.apply_async(start_measure_workers_and_start_loop) + + def get_result_from_response_queue(self, response_queue: str): + + try: + response = self.subscriber_client.pull(request={ + 'subscription': self.subscription_path, + 'max_messages': 1 + }) + except google.api_core.exceptions.GoogleAPICallError as error: + logger.error('Error when calling pubsub API: %s', error) + return None + + if not response.received_messages: + return None + + message = response.received_messages[0] + ack_ids = [message.ack_id] + + # Acknowledge the received message to remove it from the queue. + self.subscriber_client.acknowledge(request={ + 'subscription': self.subscription_path, + 'ack_ids': ack_ids + }) + + unserialized_result = message.message.data + serialized_result = json.loads(unserialized_result) + if message.message.attributes.get('retry'): + return measurer_datatypes.from_dict_to_snapshot_retry_request( + serialized_result) + return models.Snapshot(**serialized_result) + + def put_task_in_request_queue( + self, task: measurer_datatypes.SnapshotMeasureRequest, + request_queue: str): + try: + # Convert message data to bytes + message_as_bytes = measurer_datatypes.from_snapshot_measure_request_to_bytes( # pylint: disable=line-too-long + task) + # Build the Pub/Sub message object + future = self.publisher_client.publish(topic=request_queue, + data=message_as_bytes, + ordering_key=str(task.cycle)) + message_id = future.result() # Get the published message ID + logger.info( + 'Manager successfully published task with message ID %s to %s.', + message_id, request_queue) + except Exception as error: # pylint: disable=broad-except + logger.error( + 'An error occurred when publishing task to request queue: %s.', + error) + + +def get_measure_manager(local_experiment: bool, experiment: str, + cloud_project: str, measurers_cpus: Optional[int], + region_coverage: bool): + """Return a measure manager object created from the right class, depending + on whether the experiment is local or not(i.e. measure manager factory).""" + if local_experiment: + logger.info('local_experiment is True, using local measure manager.') + return LocalMeasureManager(experiment, region_coverage, measurers_cpus) + logger.info( + 'local_experiment is False, using google cloud measure manager.') + return GoogleCloudMeasureManager(experiment, cloud_project, region_coverage, + measurers_cpus) def main(): diff --git a/experiment/measurer/measure_worker.py b/experiment/measurer/measure_worker.py index cfa033d06..3bdc96a68 100644 --- a/experiment/measurer/measure_worker.py +++ b/experiment/measurer/measure_worker.py @@ -13,7 +13,12 @@ # limitations under the License. """Module for measurer workers logic.""" import time -from typing import Dict, Optional +import json +import os +from typing import Dict, Union, Optional +import google.api_core.exceptions +from google.cloud import pubsub_v1 +import google.api from common import logs from database.models import Snapshot import experiment.measurer.datatypes as measurer_datatypes @@ -28,39 +33,66 @@ class BaseMeasureWorker: implemented for Local and Google Cloud measure workers.""" def __init__(self, config: Dict): - self.request_queue = config['request_queue'] - self.response_queue = config['response_queue'] self.region_coverage = config['region_coverage'] def get_task_from_request_queue(self): """"Get task from request queue""" raise NotImplementedError - def put_result_in_response_queue(self, measured_snapshot, request): + def process_measured_snapshot_result( + self, measured_snapshot: Optional[Snapshot], + request: measurer_datatypes.SnapshotMeasureRequest): + """Process a measured snapshot result, and return either a serialized + measured snapshot, or a serialized retry request, depending on whether a + corpus was found for that cycle or not""" + raise NotImplementedError + + def put_result_in_response_queue( + self, result: Union[measurer_datatypes.RetryRequest, Snapshot, + bytes], retry: bool): """Save measurement result in response queue, for the measure manager to retrieve""" raise NotImplementedError + def _write_pid_to_fs(self): + """Debugging method""" + pid = os.getpid() + with open('worker-pid.txt', 'w+', encoding='utf-8') as pid_file: + pid_file.write(str(pid)) + def measure_worker_loop(self): """Periodically retrieves request from request queue, measure it, and put result in response queue""" - logs.initialize(default_extras={ - 'component': 'measurer', - 'subcomponent': 'worker', - }) + # Write pid to file system to check if worker process is being started + # correctly. Only for debug purposes, will be removed later + self._write_pid_to_fs() + + try: + logs.initialize(default_extras={ + 'component': 'measurer', + 'subcomponent': 'worker', + }) + except Exception as error: # pylint: disable=broad-except + logger.error('Error while initializing logs: %s', error) + logger.info('Starting one measure worker loop') while True: # 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', # 'cycle'] request = self.get_task_from_request_queue() - logger.info( - 'Measurer worker: Got request %s %s %d %d from request queue', - request.fuzzer, request.benchmark, request.trial_id, - request.cycle) - measured_snapshot = measure_manager.measure_snapshot_coverage( - request.fuzzer, request.benchmark, request.trial_id, - request.cycle, self.region_coverage) - self.put_result_in_response_queue(measured_snapshot, request) + if request: + logger.info( + 'Measurer worker: Got request %s %s %d %d from request queue', # pylint: disable=line-too-long + request.fuzzer, + request.benchmark, + request.trial_id, + request.cycle) + measured_snapshot = measure_manager.measure_snapshot_coverage( + request.fuzzer, request.benchmark, request.trial_id, + request.cycle, self.region_coverage) + result, retry = self.process_measured_snapshot_result( + measured_snapshot, request) + self.put_result_in_response_queue(result, retry) time.sleep(MEASUREMENT_TIMEOUT) @@ -68,6 +100,11 @@ class LocalMeasureWorker(BaseMeasureWorker): """Class that holds implementations of core methods for running a measure worker locally.""" + def __init__(self, config: Dict): + self.request_queue = config['request_queue'] + self.response_queue = config['response_queue'] + super().__init__(config) + def get_task_from_request_queue( self) -> measurer_datatypes.SnapshotMeasureRequest: """Get item from request multiprocessing queue, block if necessary until @@ -75,14 +112,109 @@ def get_task_from_request_queue( request = self.request_queue.get(block=True) return request - def put_result_in_response_queue( - self, measured_snapshot: Optional[Snapshot], - request: measurer_datatypes.SnapshotMeasureRequest): + def process_measured_snapshot_result(self, measured_snapshot, request): if measured_snapshot: - logger.info('Put measured snapshot in response_queue') - self.response_queue.put(measured_snapshot) - else: - retry_request = measurer_datatypes.RetryRequest( - request.fuzzer, request.benchmark, request.trial_id, - request.cycle) - self.response_queue.put(retry_request) + return measured_snapshot, False + retry_request = measurer_datatypes.RetryRequest(request.fuzzer, + request.benchmark, + request.trial_id, + request.cycle) + return retry_request, True + + def put_result_in_response_queue(self, result, retry): + self.response_queue.put(result) + + +class GoogleCloudMeasureWorker(BaseMeasureWorker): # pylint: disable=too-many-instance-attributes + """Worker that consumes from a Google Cloud Pub/Sub Queue, instead of a + multiprocessing queue""" + + def __init__(self, config: Dict): + super().__init__(config) + self.publisher_client = pubsub_v1.PublisherClient() + self.subscriber_client = pubsub_v1.SubscriberClient() + self.project_id = config['project_id'] + self.request_queue_topic_id = config['request_queue_topic_id'] + self.request_queue_topic_path = self.subscriber_client.topic_path( + self.project_id, self.request_queue_topic_id) + self.response_queue_topic_id = config['response_queue_topic_id'] + self.response_queue_topic_path = self.publisher_client.topic_path( + self.project_id, self.response_queue_topic_id) + self.experiment = config['experiment'] + self.request_queue_subscription = ('request-queue-subscription-' + f'{self.experiment}') + self.subscription_path = self.subscriber_client.subscription_path( + self.project_id, self.request_queue_subscription) + + @staticmethod + def create_request_queue_subscription(subscription_path, + request_queue_topic_path): + """Creates a new Pub/Sub subscription for the request queue.""" + try: + subscription = pubsub_v1.SubscriberClient().create_subscription( + request={ + 'name': subscription_path, + 'topic': request_queue_topic_path, + 'enable_message_ordering': True, + }) + logger.info('Subscription %s created successfully.', + subscription.name) + return subscription.name + except google.api_core.exceptions.GoogleAPICallError as error: + logger.error('Error while creating request queue subscription: %s.', + error) + return None + + def get_task_from_request_queue( + self) -> Optional[measurer_datatypes.SnapshotMeasureRequest]: + try: + response = self.subscriber_client.pull(request={ + 'subscription': self.subscription_path, + 'max_messages': 1 + }) + except google.api_core.exceptions.GoogleAPICallError as error: + logger.error('Error when calling pubsub API: %s', error) + return None + + if not response.received_messages: + return None + + message = response.received_messages[0] + ack_ids = [message.ack_id] + + # Acknowledge the received message to remove it from the + # queue. + self.subscriber_client.acknowledge(request={ + 'subscription': self.subscription_path, + 'ack_ids': ack_ids + }) + + # Needs to deserialize data from bytes to + # SnapshotMeasureRequest + serialized_data = json.loads(message.message.data) + return measurer_datatypes.from_dict_to_snapshot_measure_request( # pylint: disable=line-too-long + serialized_data) + + def process_measured_snapshot_result(self, measured_snapshot, request): + if measured_snapshot: + measured_snapshot_serialized = json.dumps( + measured_snapshot.as_dict()).encode('utf-8') + return measured_snapshot_serialized, False + + retry_request = measurer_datatypes.RetryRequest(request.fuzzer, + request.benchmark, + request.trial_id, + request.cycle) + retry_request_encoded = json.dumps( + retry_request._asdict()).encode('utf-8') + return retry_request_encoded, True + + def put_result_in_response_queue(self, result, retry): + try: + self.publisher_client.publish(topic=self.response_queue_topic_path, + data=result, + attrs={'retry': retry}) + logger.info('Result published successfully in response queue.') + except google.api_core.exceptions.GoogleAPICallError as error: + logger.error('Error when publishing result in response queue %s.', + error) diff --git a/experiment/measurer/test_datatypes.py b/experiment/measurer/test_datatypes.py new file mode 100644 index 000000000..279e9ca68 --- /dev/null +++ b/experiment/measurer/test_datatypes.py @@ -0,0 +1,66 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for datatypes.py.""" +import experiment.measurer.datatypes as measurer_datatypes + + +def test_from_dict_to_snapshot_retry_request(): + """Tests if a dictionary is being properly converted to a RetryRequest named + tuple object.""" + dictionary = { + 'fuzzer': 'test-fuzzer', + 'benchmark': 'test-benchmark', + 'trial_id': 1, + 'cycle': 0 + } + result = measurer_datatypes.from_dict_to_snapshot_retry_request(dictionary) + expected_retry_request = measurer_datatypes.RetryRequest( + 'test-fuzzer', 'test-benchmark', 1, 0) + assert result == expected_retry_request + + +def test_from_dict_to_snapshot_measure_request(): + """Tests if a dictionary is being properly converted to a + SnapshotMeasureRequest named tuple object.""" + dictionary = { + 'fuzzer': 'test-fuzzer', + 'benchmark': 'test-benchmark', + 'trial_id': 1, + 'cycle': 0 + } + result = measurer_datatypes.from_dict_to_snapshot_measure_request( + dictionary) + expected_tuple = measurer_datatypes.SnapshotMeasureRequest( + 'test-fuzzer', 'test-benchmark', 1, 0) + assert result == expected_tuple + + +def test_from_snapshot_measure_request_to_bytes(): + """Tests if a SnapshotMeasureRequest named tuple object is being + successfully converted to bytes format.""" + snapshot_measure_request = measurer_datatypes.SnapshotMeasureRequest( + 'test-fuzzer', 'test-benchmark', 1, 0) + req_as_bytes = measurer_datatypes.from_snapshot_measure_request_to_bytes( + snapshot_measure_request) + assert isinstance(req_as_bytes, bytes) + + +def test_from_snapshot_retry_request_to_bytes(): + """Tests if a RetryRequest named tuple object is being successfully + converted to bytes format.""" + snapshot_retry_request = measurer_datatypes.RetryRequest( + 'test-fuzzer', 'test-benchmark', 1, 0) + snapshot_as_bytes = measurer_datatypes.from_retry_request_to_bytes( + snapshot_retry_request) + assert isinstance(snapshot_as_bytes, bytes) diff --git a/experiment/measurer/test_measure_manager.py b/experiment/measurer/test_measure_manager.py index 7b6521869..6dd6c0902 100644 --- a/experiment/measurer/test_measure_manager.py +++ b/experiment/measurer/test_measure_manager.py @@ -15,8 +15,8 @@ import os import shutil from unittest import mock +import multiprocessing import queue - import pytest from common import experiment_utils @@ -411,19 +411,56 @@ def test_path_exists_in_experiment_filestore(mocked_execute, environ): expect_zero=False) -def test_consume_unmapped_type_from_response_queue(): +@pytest.fixture +def local_measure_manager(): + """Fixture for instantiating a local measure manager object.""" + local_measure_manager = measure_manager.LocalMeasureManager( + 'experiment', False, None) + return local_measure_manager + + +@pytest.fixture +@mock.patch('google.cloud.pubsub_v1.PublisherClient') +@mock.patch('google.cloud.pubsub_v1.SubscriberClient') +def gcloud_measure_manager(mock_subscriber, mock_publisher): + """Fixture for instantiating a google cloud measure manager object.""" + gcloud_measure_manager = measure_manager.GoogleCloudMeasureManager( + 'experiment', 'fuzzbench-test', False, None) + return gcloud_measure_manager + + +def test_consume_unmapped_type_from_response_queue(local_measure_manager): """Tests the scenario where an unmapped type is retrieved from the response queue. This scenario is not expected to happen, so in this case no snapshots are returned.""" # Use normal queue here as multiprocessing queue gives flaky tests. response_queue = queue.Queue() response_queue.put('unexpected string') - snapshots = measure_manager.consume_snapshots_from_response_queue( + snapshots = local_measure_manager.consume_snapshots_from_response_queue( response_queue, set()) assert not snapshots -def test_consume_retry_type_from_response_queue(): +def test_consume_none_from_response_queue(local_measure_manager): + """Tests the scenario where None is retrieved from the response queue. + Should expect to raise queue.Empty exception, break loop early, and return + no snapshots.""" + response_queue = queue.Queue() + response_queue.put(None) + # Mock get method keeping its functionality, to assert it was only called + # once later. + local_measure_manager.get_result_from_response_queue = mock.MagicMock( + wraps=local_measure_manager.get_result_from_response_queue) + snapshots = local_measure_manager.consume_snapshots_from_response_queue( + response_queue, set()) + # Get result should only be called once since its gonna raise queue.Empty + # exception in the first call + local_measure_manager.get_result_from_response_queue.assert_called_once() + # Should return an empty list + assert not snapshots + + +def test_consume_retry_type_from_response_queue(local_measure_manager): """Tests the scenario where a retry object is retrieved from the response queue. In this scenario, we want to remove the snapshot identifier from the queued_snapshots set, as this allows the measurement task to be @@ -435,13 +472,13 @@ def test_consume_retry_type_from_response_queue(): snapshot_identifier = (TRIAL_NUM, CYCLE) response_queue.put(retry_request_object) queued_snapshots_set = set([snapshot_identifier]) - snapshots = measure_manager.consume_snapshots_from_response_queue( + snapshots = local_measure_manager.consume_snapshots_from_response_queue( response_queue, queued_snapshots_set) assert not snapshots assert len(queued_snapshots_set) == 0 -def test_consume_snapshot_type_from_response_queue(): +def test_consume_snapshot_type_from_response_queue(local_measure_manager): """Tests the scenario where a measured snapshot is retrieved from the response queue. In this scenario, we want to return the snapshot in the function.""" @@ -452,31 +489,32 @@ def test_consume_snapshot_type_from_response_queue(): measured_snapshot = models.Snapshot(trial_id=TRIAL_NUM) response_queue.put(measured_snapshot) assert response_queue.qsize() == 1 - snapshots = measure_manager.consume_snapshots_from_response_queue( + snapshots = local_measure_manager.consume_snapshots_from_response_queue( response_queue, queued_snapshots_set) assert len(snapshots) == 1 @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') def test_measure_manager_inner_loop_break_condition( - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop returns False when there's no more snapshots left to be measured.""" # Empty list means no more snapshots left to be measured. mocked_get_unmeasured_snapshots.return_value = [] request_queue = queue.Queue() response_queue = queue.Queue() - continue_inner_loop = measure_manager.measure_manager_inner_loop( - 'experiment', 1, request_queue, response_queue, set()) + continue_inner_loop = local_measure_manager.measure_manager_inner_loop( + 1, request_queue, response_queue, set()) assert not continue_inner_loop @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') @mock.patch( - 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') + 'experiment.measurer.measure_manager.BaseMeasureManager.consume_snapshots_from_response_queue' # pylint: disable=line-too-long +) def test_measure_manager_inner_loop_writes_to_request_queue( mocked_consume_snapshots_from_response_queue, - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop is writing measurement tasks to request queue.""" mocked_get_unmeasured_snapshots.return_value = [ @@ -485,18 +523,19 @@ def test_measure_manager_inner_loop_writes_to_request_queue( mocked_consume_snapshots_from_response_queue.return_value = [] request_queue = queue.Queue() response_queue = queue.Queue() - measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, - response_queue, set()) + local_measure_manager.measure_manager_inner_loop(1, request_queue, + response_queue, set()) assert request_queue.qsize() == 1 @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') @mock.patch( - 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') + 'experiment.measurer.measure_manager.BaseMeasureManager.consume_snapshots_from_response_queue' # pylint: disable=line-too-long +) @mock.patch('database.utils.add_all') def test_measure_manager_inner_loop_dont_write_to_db( mocked_add_all, mocked_consume_snapshots_from_response_queue, - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop does not call add_all to write to the database, when there are no measured snapshots to be written.""" mocked_get_unmeasured_snapshots.return_value = [ @@ -505,18 +544,19 @@ def test_measure_manager_inner_loop_dont_write_to_db( request_queue = queue.Queue() response_queue = queue.Queue() mocked_consume_snapshots_from_response_queue.return_value = [] - measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, - response_queue, set()) + local_measure_manager.measure_manager_inner_loop(1, request_queue, + response_queue, set()) mocked_add_all.not_called() @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') @mock.patch( - 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') + 'experiment.measurer.measure_manager.BaseMeasureManager.consume_snapshots_from_response_queue' # pylint: disable=line-too-long +) @mock.patch('database.utils.add_all') def test_measure_manager_inner_loop_writes_to_db( mocked_add_all, mocked_consume_snapshots_from_response_queue, - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop calls add_all to write to the database, when there are measured snapshots to be written.""" mocked_get_unmeasured_snapshots.return_value = [ @@ -526,6 +566,109 @@ def test_measure_manager_inner_loop_writes_to_db( response_queue = queue.Queue() snapshot_model = models.Snapshot(trial_id=1) mocked_consume_snapshots_from_response_queue.return_value = [snapshot_model] - measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, - response_queue, set()) + local_measure_manager.measure_manager_inner_loop(1, request_queue, + response_queue, set()) mocked_add_all.assert_called_with([snapshot_model]) + + +def test_gcloud_measure_manager_get_result_from_response_queue_not_acknowledge( + gcloud_measure_manager): + """Tests that the subscriber client does not acknowledge any received + message when response received messages is an empty list after calling + subscriber client pull method""" + pull_return_value = mock.MagicMock() + pull_return_value.received_messages = [] + gcloud_measure_manager.subscriber_client.pull.return_value = pull_return_value # pylint: disable=line-too-long + gcloud_measure_manager.get_result_from_response_queue('') + assert not gcloud_measure_manager.subscriber_client.acknowledge.called + + +def test_gcloud_measure_manager_get_result_from_response_queue_acknowledge( + gcloud_measure_manager): + """Tests that the subscriber client acknowledge received messages when + response received messages is not an empty list after calling subscriber + client pull method""" + pull_return = mock.MagicMock() + received_message = mock.MagicMock() + received_message.ack_id = 0 + + # Since we are simulating getting an object from the queue, this object will + # come serialized in bytes format + unserialized_data = measurer_datatypes.RetryRequest('fuzzer', 'benchmark', + 1748392, 0) + serialized_data = measurer_datatypes.from_retry_request_to_bytes( + unserialized_data) + received_message.message.data = serialized_data + + # Mocking the pull method to return the serialized message + pull_return.received_messages = [received_message] + gcloud_measure_manager.subscriber_client.pull.return_value = pull_return + + gcloud_measure_manager.get_result_from_response_queue('') + gcloud_measure_manager.subscriber_client.acknowledge.assert_called_once() + + +def test_gcloud_measure_manager_get_retry_request_from_response_queue( + gcloud_measure_manager): + """Tests that the subscriber client returns a retry request from the + response queue, when the pull method gets a serialized retry request from + the queue, with the retry attribute set to True.""" + pull_return = mock.MagicMock() + received_message = mock.MagicMock() + received_message.ack_id = 0 + + # Since we are simulating getting an object from the queue, this object will + # come serialized in bytes format + unserialized_data = measurer_datatypes.RetryRequest('fuzzer', 'benchmark', + 1748392, 0) + serialized_data = measurer_datatypes.from_retry_request_to_bytes( + unserialized_data) + received_message.message.data = serialized_data + received_message.message.attributes = {'retry': True} + + # Mocking the pull method to return the serialized message + pull_return.received_messages = [received_message] + gcloud_measure_manager.subscriber_client.pull.return_value = pull_return + + result = gcloud_measure_manager.get_result_from_response_queue('') + assert isinstance(result, measurer_datatypes.RetryRequest) + + +def test_gcloud_measure_manager_get_snapshot_from_response_queue( + gcloud_measure_manager): + """Tests that the subscriber client returns a measured snapshot from the + response queue, when the pull method gets a serialized measured snapshot + from the queue, with the retry attribute set to False.""" + pull_return = mock.MagicMock() + received_message = mock.MagicMock() + received_message.ack_id = 0 + + # Since we are simulating getting an object from the queue, this object will + # come serialized in bytes format + unserialized_data = models.Snapshot(trial_id=1) + serialized_data = unserialized_data.to_bytes() + received_message.message.data = serialized_data + received_message.message.attributes = {'retry': False} + + # Mocking the pull method to return the serialized message + pull_return.received_messages = [received_message] + gcloud_measure_manager.subscriber_client.pull.return_value = pull_return + + result = gcloud_measure_manager.get_result_from_response_queue('') + assert isinstance(result, models.Snapshot) + + +@mock.patch('experiment.measurer.measure_worker.GoogleCloudMeasureWorker') +@mock.patch('google.cloud.pubsub_v1.PublisherClient') +@mock.patch('google.cloud.pubsub_v1.SubscriberClient') +def test_gcloud_measure_manager_start_workers(_mock_subscriber, _mock_publisher, + mock_gcloud_measure_worker, + gcloud_measure_manager): + """Tests that the start workers method is calling the measure worker loop + method, a number of times equal to the number of measurers CPUs.""" + cpus_available = multiprocessing.cpu_count() + gcloud_measure_manager.measurers_cpus = cpus_available + with mock.patch('multiprocessing.pool.Pool.apply_async') as pool: + gcloud_measure_manager.start_workers('request-queue-topic', + 'response-queue-topic', pool) + assert pool.apply_async.call_count == cpus_available diff --git a/experiment/measurer/test_measure_worker.py b/experiment/measurer/test_measure_worker.py index 4e5bd7b05..ca78c6d69 100644 --- a/experiment/measurer/test_measure_worker.py +++ b/experiment/measurer/test_measure_worker.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for measure_worker.py.""" +from unittest import mock import multiprocessing import pytest @@ -22,7 +23,7 @@ @pytest.fixture def local_measure_worker(): - """Fixture for instantiating a local measure worker object""" + """Fixture for instantiating a local measure worker object.""" request_queue = multiprocessing.Queue() response_queue = multiprocessing.Queue() region_coverage = False @@ -34,24 +35,112 @@ def local_measure_worker(): return measure_worker.LocalMeasureWorker(config) -def test_put_snapshot_in_response_queue(local_measure_worker): # pylint: disable=redefined-outer-name - """Tests the scenario where measure_snapshot is not None, so snapshot is put - in response_queue""" +@pytest.fixture +@mock.patch('google.cloud.pubsub_v1.PublisherClient') +@mock.patch('google.cloud.pubsub_v1.SubscriberClient') +def gcloud_measure_worker(_mock_subscriber_client, _mock_publisher_client): + """Fixture for instantiating a google cloud measure worker object, with + mocked subscriber and publisher clients, and mocked subscription creation""" + config = { + 'region_coverage': False, + 'project_id': 'fuzzbench-test', + 'request_queue_topic_id': 'request_queue_topic_id', + 'response_queue_topic_id': 'response_queue_topic_id', + 'experiment': 'test', + } + return measure_worker.GoogleCloudMeasureWorker(config) + + +def test_process_measured_snapshot_as_serialized_snapshot( + gcloud_measure_worker): # pylint: disable=redefined-outer-name + """Tests if process_measured_snapshot_result is serializing snapshot when + called by a google cloud measure worker.""" request = measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', 1, 0) snapshot = Snapshot(trial_id=1) - local_measure_worker.put_result_in_response_queue(snapshot, request) - response_queue = local_measure_worker.response_queue - assert response_queue.qsize() == 1 - assert isinstance(response_queue.get(), Snapshot) + result, _retry = gcloud_measure_worker.process_measured_snapshot_result( # pylint: disable=line-too-long + snapshot, request) + assert isinstance(result, bytes) -def test_put_retry_in_response_queue(local_measure_worker): # pylint: disable=redefined-outer-name - """Tests the scenario where measure_snapshot is None, so task needs to be - retried""" - request = measurer_datatypes.RetryRequest('fuzzer', 'benchmark', 1, 0) +def test_process_measured_snapshot_as_retry_request(local_measure_worker): # pylint: disable=redefined-outer-name + """"Tests the scenario where measure_snapshot is None, so task needs to be + retried.""" + request = measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', + 1, 0) snapshot = None - local_measure_worker.put_result_in_response_queue(snapshot, request) - response_queue = local_measure_worker.response_queue - assert response_queue.qsize() == 1 - assert isinstance(response_queue.get(), measurer_datatypes.RetryRequest) + result, _retry = local_measure_worker.process_measured_snapshot_result( + snapshot, request) + assert isinstance(result, measurer_datatypes.RetryRequest) + + +def test_process_measured_snapshot_as_snapshot(local_measure_worker): # pylint: disable=redefined-outer-name + """"Tests the scenario where measure_snapshot is not None, so snapshot is + returned.""" + request = measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', + 1, 0) + snapshot = Snapshot(trial_id=1) + result, _retry = local_measure_worker.process_measured_snapshot_result( + snapshot, request) + assert isinstance(result, Snapshot) + + +def test_put_snapshot_in_response_queue(local_measure_worker): # pylint: disable=redefined-outer-name + """Tests if result is being put in response queue as expected.""" + request = measurer_datatypes.SnapshotMeasureRequest('fuzzer', 'benchmark', + 1, 0) + snapshot = Snapshot(trial_id=1) + result, retry = local_measure_worker.process_measured_snapshot_result( + snapshot, request) + local_measure_worker.put_result_in_response_queue(result, retry) + assert local_measure_worker.response_queue.qsize() == 1 + + +def test_get_task_from_request_queue_gcloud_worker_calls_acknowledge( + gcloud_measure_worker): # pylint: disable=redefined-outer-name + """Tests that the method get_task_from_request_queue from + GoogleCloudMeasureWorker worker is calling acknowledge after pulling a + message from the queue.""" + pull_return = mock.MagicMock() + received_message = mock.MagicMock() + received_message.ack_id = 0 + + # Since we are simulating getting an object from the queue, this object will + # come serialized in bytes format + unserialized_data = measurer_datatypes.SnapshotMeasureRequest( + 'fuzzer', 'benchmark', 1748392, 0) + serialized_data = measurer_datatypes.from_snapshot_measure_request_to_bytes( + unserialized_data) + received_message.message.data = serialized_data + + # Mocking the pull method to return the serialized message + pull_return.received_messages = [received_message] + gcloud_measure_worker.subscriber_client.pull.return_value = pull_return + + _result = gcloud_measure_worker.get_task_from_request_queue() + gcloud_measure_worker.subscriber_client.acknowledge.assert_called_once() + + +def test_get_task_from_request_queue_gcloud_worker(gcloud_measure_worker): # pylint: disable=redefined-outer-name + """Tests that the method get_task_from_request_queue from + GoogleCloudMeasureWorker worker is properly returning a snapshot measure + request, meaning it is successfully unserializing the message from the + queue.""" + pull_return = mock.MagicMock() + received_message = mock.MagicMock() + received_message.ack_id = 0 + + # Since we are simulating getting an object from the queue, this object will + # come serialized in bytes format + unserialized_data = measurer_datatypes.SnapshotMeasureRequest( + 'fuzzer', 'benchmark', 1748392, 0) + serialized_data = measurer_datatypes.from_snapshot_measure_request_to_bytes( + unserialized_data) + received_message.message.data = serialized_data + + # Mocking the pull method to return the serialized message + pull_return.received_messages = [received_message] + gcloud_measure_worker.subscriber_client.pull.return_value = pull_return + + result = gcloud_measure_worker.get_task_from_request_queue() + assert isinstance(result, measurer_datatypes.SnapshotMeasureRequest) diff --git a/experiment/resources/dispatcher-startup-script-template.sh b/experiment/resources/dispatcher-startup-script-template.sh index 3a6c7b465..7e6911f42 100644 --- a/experiment/resources/dispatcher-startup-script-template.sh +++ b/experiment/resources/dispatcher-startup-script-template.sh @@ -32,4 +32,4 @@ docker run --rm \ -e PRIVATE={{private}} \ --cap-add=SYS_PTRACE --cap-add=SYS_NICE \ -v /var/run/docker.sock:/var/run/docker.sock --name=dispatcher-container \ - {{docker_registry}}/dispatcher-image /work/startup-dispatcher.sh &> /tmp/dispatcher.log + {{docker_registry}}/dispatcher-image/progressive-pubsub-measurer /work/startup-dispatcher.sh &> /tmp/dispatcher.log diff --git a/requirements.txt b/requirements.txt index 56b835357..a6280ea19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ google-auth==2.12.0 google-cloud-error-reporting==1.6.3 google-cloud-logging==3.1.2 google-cloud-secret-manager==2.12.6 +google-cloud-pubsub==2.19.2 clusterfuzz==2.6.0 Jinja2==3.1.2 numpy==1.23.4 diff --git a/service/experiment-config.yaml b/service/experiment-config.yaml index b9acb09f8..877c853e7 100644 --- a/service/experiment-config.yaml +++ b/service/experiment-config.yaml @@ -2,8 +2,8 @@ # Unless you are a fuzzbench maintainer running this service, this # will not work with your setup. -trials: 20 -max_total_time: 82800 # 23 hours, the default time for preemptible experiments. +trials: 3 +max_total_time: 3660 cloud_project: fuzzbench docker_registry: gcr.io/fuzzbench cloud_compute_zone: us-central1-c diff --git a/service/gcbrun_experiment.py b/service/gcbrun_experiment.py index f19ab493d..b30f12f28 100644 --- a/service/gcbrun_experiment.py +++ b/service/gcbrun_experiment.py @@ -16,6 +16,7 @@ """Entrypoint for gcbrun into run_experiment. This script will get the command from the last PR comment containing "/gcbrun" and pass it to run_experiment.py which will run an experiment.""" +# Dummy comment to trigger run experiment action import logging import os