From 9ddbb7e7b9ee038c50c2a108653f44035beab08a Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 11:23:50 +0200 Subject: [PATCH 01/15] Pre-fork() the data manager. It is started at the beginning and avoid to inherit unnecessary data. --- ospd/ospd.py | 2 ++ ospd/scan.py | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 0838505d..733abb13 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1174,6 +1174,8 @@ def run(self) -> None: """ Starts the Daemon, handling commands until interrupted. """ + self.scan_collection.data_manager = multiprocessing.Manager() + try: while True: time.sleep(SCHEDULER_CHECK_PERIOD) diff --git a/ospd/scan.py b/ospd/scan.py index 516973e1..4d7bf663 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -209,9 +209,6 @@ def create_scan( if not target: target = {} - if self.data_manager is None: - self.data_manager = multiprocessing.Manager() - if not options: options = dict() From 8830657ddf4be1cf6f9c190d36c6493cfc0ba23f Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 11:26:15 +0200 Subject: [PATCH 02/15] Start a scan directly only if it is a dry scan. Otherwise, if it is a scan against a real target, the scan information is stored in the already forked data manager. This scan will be launched later by the scheduler. This avoid to fork the new scan process with all the data objects inherited from the parent process (stream, data, etree object, etc). At this moment, this reduces the memory usage of the new process from 110MB to ~30MB, for a full and fast single host target scan. --- ospd/command/command.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/ospd/command/command.py b/ospd/command/command.py index 6fa7f68c..295df59f 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -574,11 +574,10 @@ def handle_xml(self, xml: Element) -> bytes: vt_selection = OspRequest.process_vts_params(scanner_vts) # Dry run case. - if 'dry_run' in params and int(params['dry_run']): - scan_func = self._daemon.dry_run_scan + dry_run = 'dry_run' in params and int(params['dry_run']) + if dry_run: scan_params = None else: - scan_func = self._daemon.start_scan scan_params = self._daemon.process_scan_params(params) scan_id_aux = scan_id @@ -591,13 +590,13 @@ def handle_xml(self, xml: Element) -> bytes: id_.text = scan_id_aux return simple_response_str('start_scan', 100, 'Continue', id_) - scan_process = create_process( - func=scan_func, args=(scan_id, scan_target) - ) - - self._daemon.scan_processes[scan_id] = scan_process - - scan_process.start() + if dry_run: + scan_func = self._daemon.dry_run_scan + scan_process = create_process( + func=scan_func, args=(scan_id, scan_target) + ) + self._daemon.scan_processes[scan_id] = scan_process + scan_process.start() id_ = Element('id') id_.text = scan_id From 7586a9d08904373721864a9330cbb0f44d8fd44d Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 11:49:52 +0200 Subject: [PATCH 03/15] Add new scan status PENDING --- ospd/scan.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ospd/scan.py b/ospd/scan.py index 4d7bf663..dc7d6312 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -32,10 +32,11 @@ class ScanStatus(Enum): """Scan status. """ - INIT = 0 - RUNNING = 1 - STOPPED = 2 - FINISHED = 3 + PENDING = 0 + INIT = 1 + RUNNING = 2 + STOPPED = 3 + FINISHED = 4 class ScanCollection: From f1feae1cd6a77878805578b30c7698a231df96d1 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 11:50:49 +0200 Subject: [PATCH 04/15] Set the initial status as PENDING. The scheduler check for scans with this new status to launch the scan. --- ospd/scan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ospd/scan.py b/ospd/scan.py index dc7d6312..e2f471f9 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -224,7 +224,7 @@ def create_scan( scan_info['options'] = options scan_info['start_time'] = int(time.time()) scan_info['end_time'] = 0 - scan_info['status'] = ScanStatus.INIT + scan_info['status'] = ScanStatus.PENDING if scan_id is None or scan_id == '': scan_id = str(uuid.uuid4()) From c30f4bbe2c3cdd0f734f35ad50a7232eaaa507fe Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 11:59:17 +0200 Subject: [PATCH 05/15] Check for pending scans. The main loop in the parent process check for pending scans in the scan table. It will fork a new scan process for the task. The scan data was already stored in the data manager in a previous step (during handling of start_scan cmd). Therefore the fork()'ed child only inherit the base memory of the parent process. --- ospd/ospd.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 733abb13..83e6a011 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -46,7 +46,7 @@ from ospd import __version__ from ospd.command import get_commands from ospd.errors import OspdCommandError -from ospd.misc import ResultType +from ospd.misc import ResultType, create_process from ospd.network import resolve_hostname, target_str_to_list from ospd.protocol import OspRequest, OspResponse, RequestParser from ospd.scan import ScanCollection, ScanStatus @@ -1181,10 +1181,25 @@ def run(self) -> None: time.sleep(SCHEDULER_CHECK_PERIOD) self.scheduler() self.clean_forgotten_scans() + self.check_pending_scans() self.wait_for_children() except KeyboardInterrupt: logger.info("Received Ctrl-C shutting-down ...") + def check_pending_scans(self): + for scan_id in list(self.scan_collection.ids_iterator()): + if self.get_scan_status(scan_id) == ScanStatus.PENDING: + scan_target = self.scan_collection.scans_table[scan_id].get( + 'target' + ) + scan_func = self.start_scan + scan_process = create_process( + func=scan_func, args=(scan_id, scan_target) + ) + self.scan_processes[scan_id] = scan_process + scan_process.start() + self.set_scan_status(scan_id, ScanStatus.INIT) + def scheduler(self): """ Should be implemented by subclass in case of need to run tasks periodically. """ From 766644b9625536c4c2a736bb39ac15e68404d3e9 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 12:03:46 +0200 Subject: [PATCH 06/15] Only check scan progress of already started scans Pending scans are still not in the scan process table, still not started. Therefore, they are skipped. --- ospd/ospd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 83e6a011..9f2544df 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1272,9 +1272,12 @@ def clean_forgotten_scans(self) -> None: def check_scan_process(self, scan_id: str) -> None: """ Check the scan's process, and terminate the scan if not alive. """ - scan_process = self.scan_processes[scan_id] + scan_process = self.scan_processes.get(scan_id) progress = self.get_scan_progress(scan_id) + if self.get_scan_status(scan_id) == ScanStatus.PENDING: + return + if progress < PROGRESS_FINISHED and not scan_process.is_alive(): if not self.get_scan_status(scan_id) == ScanStatus.STOPPED: self.set_scan_status(scan_id, ScanStatus.STOPPED) From c2c54a19b416f52c9c4f859258b026b1d3a238d0 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 12:08:25 +0200 Subject: [PATCH 07/15] Pop the port list from the scan table once the scan was started. The port list is not used anymore once the scan was started. So, it is cleaned up from the data manager and it reduce the footprint during a scan. --- ospd/scan.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ospd/scan.py b/ospd/scan.py index e2f471f9..fc3f6742 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -347,7 +347,10 @@ def get_host_count(self, scan_id: str) -> int: def get_ports(self, scan_id: str): """ Get a scan's ports list. """ - return self.scans_table[scan_id]['target'].get('ports') + target = self.scans_table[scan_id].get('target') + ports = target.pop('ports') + self.scans_table[scan_id]['target'] = target + return ports def get_exclude_hosts(self, scan_id: str): """ Get an exclude host list for a given target. From e01529e40790a99aeca5f4593c94c244d19e24d9 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 12:29:18 +0200 Subject: [PATCH 08/15] Pop the vts list from the scan table once the scan was started. The vts list is not used anymore once the scan was started. So, it is cleaned up from the data manager and it reduce the memory usage during a scan. --- ospd/scan.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/ospd/scan.py b/ospd/scan.py index fc3f6742..fa071748 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -377,15 +377,11 @@ def get_target_options(self, scan_id: str): def get_vts(self, scan_id: str) -> Dict: """ Get a scan's vts. """ + scan_info = self.scans_table[scan_id] + vts = scan_info.pop('vts') + self.scans_table[scan_id] = scan_info - return self.scans_table[scan_id]['vts'] - - def release_vts_list(self, scan_id: str) -> None: - """ Release the memory used for the vts list. """ - - scan_data = self.scans_table.get(scan_id) - if scan_data and 'vts' in scan_data: - del scan_data['vts'] + return vts def id_exists(self, scan_id: str) -> bool: """ Check whether a scan exists in the table. """ From 33c0a2c4fac61fe6fda84a548ef2381cae78cb41 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 12:32:09 +0200 Subject: [PATCH 09/15] Don't delete the pre-fork()'ed data manager. --- ospd/scan.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ospd/scan.py b/ospd/scan.py index fa071748..c7fb625a 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -394,10 +394,8 @@ def delete_scan(self, scan_id: str) -> bool: if self.get_status(scan_id) == ScanStatus.RUNNING: return False - self.scans_table.pop(scan_id) - - if len(self.scans_table) == 0: - del self.data_manager - self.data_manager = None + scans_table = self.scans_table + del scans_table[scan_id] + self.scans_table = scans_table return True From 077e3ab06b96cdca7cec117422103010450fba51 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 18 May 2020 17:43:33 +0200 Subject: [PATCH 10/15] Update tests --- tests/command/test_commands.py | 93 +++++++-- tests/helper.py | 10 + tests/test_scan_and_result.py | 336 ++++++++++++++++----------------- tests/test_ssh_daemon.py | 2 + 4 files changed, 247 insertions(+), 194 deletions(-) diff --git a/tests/command/test_commands.py b/tests/command/test_commands.py index 13106eb4..c5b579a8 100644 --- a/tests/command/test_commands.py +++ b/tests/command/test_commands.py @@ -31,7 +31,7 @@ from ospd.errors import OspdCommandError, OspdError from ospd.misc import create_process -from ..helper import DummyWrapper, assert_called, FakeStream +from ..helper import DummyWrapper, assert_called, FakeStream, FakeDataManager class GetPerformanceTestCase(TestCase): @@ -97,7 +97,7 @@ def test_scan_with_vts_empty_vt_list(self): with self.assertRaises(OspdCommandError): cmd.handle_xml(request) - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") def test_scan_with_vts(self, mock_create_process): daemon = DummyWrapper([]) cmd = StartScan(daemon) @@ -121,13 +121,67 @@ def test_scan_with_vts(self, mock_create_process): response = et.fromstring(cmd.handle_xml(request)) scan_id = response.findtext('id') - self.assertEqual( - daemon.get_scan_vts(scan_id), {'1.2.3.4': {}, 'vt_groups': []} - ) - self.assertNotEqual(daemon.get_scan_vts(scan_id), {'1.2.3.6': {}}) + vts_collection = daemon.get_scan_vts(scan_id) + self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []}) + self.assertNotEqual(vts_collection, {'1.2.3.6': {}}) + daemon.check_pending_scans() assert_called(mock_create_process) + def test_scan_pop_vts(self): + daemon = DummyWrapper([]) + cmd = StartScan(daemon) + + request = et.fromstring( + '' + '' + '' + 'localhost' + '80, 443' + '' + '' + '' + '' + '' + '' + '' + ) + + # With one vt, without params + response = et.fromstring(cmd.handle_xml(request)) + scan_id = response.findtext('id') + + vts_collection = daemon.get_scan_vts(scan_id) + self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []}) + self.assertRaises(KeyError, daemon.get_scan_vts, scan_id) + + def test_scan_pop_ports(self): + daemon = DummyWrapper([]) + cmd = StartScan(daemon) + + request = et.fromstring( + '' + '' + '' + 'localhost' + '80, 443' + '' + '' + '' + '' + '' + '' + '' + ) + + # With one vt, without params + response = et.fromstring(cmd.handle_xml(request)) + scan_id = response.findtext('id') + + ports = daemon.scan_collection.get_ports(scan_id) + self.assertEqual(ports, '80, 443') + self.assertRaises(KeyError, daemon.scan_collection.get_ports, scan_id) + def test_is_new_scan_allowed_false(self): daemon = DummyWrapper([]) cmd = StartScan(daemon) @@ -152,7 +206,7 @@ def test_is_new_scan_allowed_true(self): self.assertTrue(cmd.is_new_scan_allowed()) - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") def test_scan_without_vts(self, mock_create_process): daemon = DummyWrapper([]) cmd = StartScan(daemon) @@ -172,9 +226,9 @@ def test_scan_without_vts(self, mock_create_process): response = et.fromstring(cmd.handle_xml(request)) scan_id = response.findtext('id') - self.assertEqual(daemon.get_scan_vts(scan_id), {}) + daemon.check_pending_scans() assert_called(mock_create_process) def test_scan_with_vts_and_param_missing_vt_param_id(self): @@ -200,7 +254,7 @@ def test_scan_with_vts_and_param_missing_vt_param_id(self): with self.assertRaises(OspdError): cmd.handle_xml(request) - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") def test_scan_with_vts_and_param(self, mock_create_process): daemon = DummyWrapper([]) cmd = StartScan(daemon) @@ -229,7 +283,7 @@ def test_scan_with_vts_and_param(self, mock_create_process): daemon.get_scan_vts(scan_id), {'1234': {'ABC': '200'}, 'vt_groups': []}, ) - + daemon.check_pending_scans() assert_called(mock_create_process) def test_scan_with_vts_and_param_missing_vt_group_filter(self): @@ -253,7 +307,7 @@ def test_scan_with_vts_and_param_missing_vt_group_filter(self): with self.assertRaises(OspdError): cmd.handle_xml(request) - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") def test_scan_with_vts_and_param_with_vt_group_filter( self, mock_create_process ): @@ -280,9 +334,10 @@ def test_scan_with_vts_and_param_with_vt_group_filter( self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']}) + daemon.check_pending_scans() assert_called(mock_create_process) - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") @patch("ospd.command.command.logger") def test_scan_ignore_multi_target(self, mock_logger, mock_create_process): daemon = DummyWrapper([]) @@ -300,16 +355,18 @@ def test_scan_ignore_multi_target(self, mock_logger, mock_create_process): ) cmd.handle_xml(request) - + daemon.check_pending_scans() assert_called(mock_logger.warning) assert_called(mock_create_process) - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") @patch("ospd.command.command.logger") def test_scan_use_legacy_target_and_port( self, mock_logger, mock_create_process ): daemon = DummyWrapper([]) + daemon.scan_collection.datamanager = FakeDataManager() + cmd = StartScan(daemon) request = et.fromstring( '' @@ -325,20 +382,22 @@ def test_scan_use_legacy_target_and_port( self.assertEqual(daemon.get_scan_host(scan_id), 'localhost') self.assertEqual(daemon.get_scan_ports(scan_id), '22') + daemon.check_pending_scans() + assert_called(mock_logger.warning) assert_called(mock_create_process) class StopCommandTestCase(TestCase): @patch("ospd.ospd.os") - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") def test_stop_scan(self, mock_create_process, mock_os): mock_process = mock_create_process.return_value mock_process.is_alive.return_value = True mock_process.pid = "foo" - fs = FakeStream() daemon = DummyWrapper([]) + daemon.scan_collection.datamanager = FakeDataManager() request = ( '' '' @@ -353,6 +412,8 @@ def test_stop_scan(self, mock_create_process, mock_os): daemon.handle_command(request, fs) response = fs.get_response() + daemon.check_pending_scans() + assert_called(mock_create_process) assert_called(mock_process.start) diff --git a/tests/helper.py b/tests/helper.py index b7abec7f..8aec21a0 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -16,6 +16,7 @@ # along with this program. If not, see . import time +import multiprocessing from unittest.mock import Mock @@ -48,12 +49,21 @@ def get_response(self): return et.fromstring(self.response) +class FakeDataManager: + def __init__(self): + pass + + def dict(self): + return dict() + + class DummyWrapper(OSPDaemon): def __init__(self, results, checkresult=True): super().__init__() self.checkresult = checkresult self.results = results self.initialized = True + self.scan_collection.data_manager = FakeDataManager() def check(self): return self.checkresult diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index e8015c67..b153ac28 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -32,7 +32,7 @@ from ospd.resultlist import ResultList from ospd.errors import OspdCommandError -from .helper import DummyWrapper, assert_called, FakeStream +from .helper import DummyWrapper, assert_called, FakeStream, FakeDataManager class FakeStartProcess: @@ -76,11 +76,14 @@ def __init__(self, type_, **kwargs): class ScanTestCase(unittest.TestCase): + def setUp(self): + self.daemon = DummyWrapper([]) + self.daemon.scan_collection.datamanager = FakeDataManager() + def test_get_default_scanner_params(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() # The status of the response must be success (i.e. 200) @@ -91,50 +94,45 @@ def test_get_default_scanner_params(self): self.assertIsNotNone(response.find('scanner_params')) def test_get_default_help(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertEqual(response.tag, 'help_response') def test_get_default_scanner_version(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertIsNotNone(response.find('protocol')) def test_get_vts_no_vt(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertIsNotNone(response.find('vts')) def test_get_vt_xml_no_dict(self): - daemon = DummyWrapper([]) single_vt = ('1234', None) - vt = daemon.get_vt_xml(single_vt) + vt = self.daemon.get_vt_xml(single_vt) self.assertFalse(vt.get('id')) def test_get_vts_single_vt(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.add_vt('1.2.3.4', 'A vulnerability test') - daemon.handle_command('', fs) + self.daemon.add_vt('1.2.3.4', 'A vulnerability test') + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -146,26 +144,23 @@ def test_get_vts_single_vt(self): self.assertEqual(vt.get('id'), '1.2.3.4') def test_get_vts_still_not_init(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.initialized = False - daemon.handle_command('', fs) + self.daemon.initialized = False + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '400') def test_get_help_still_not_init(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.initialized = False - daemon.handle_command('', fs) + self.daemon.initialized = False + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') def test_get_vts_filter_positive(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -173,7 +168,7 @@ def test_get_vts_filter_positive(self): ) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '', fs ) response = fs.get_response() @@ -192,15 +187,14 @@ def test_get_vts_filter_positive(self): ) def test_get_vts_filter_negative(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", vt_modification_time='19000202', ) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '', fs, ) response = fs.get_response() @@ -220,22 +214,20 @@ def test_get_vts_filter_negative(self): ) def test_get_vts_bad_filter(self): - daemon = DummyWrapper([]) fs = FakeStream() cmd = '' - self.assertRaises(OspdCommandError, daemon.handle_command, cmd, fs) - self.assertTrue(daemon.vts.is_cache_available) + self.assertRaises(OspdCommandError, self.daemon.handle_command, cmd, fs) + self.assertTrue(self.daemon.vts.is_cache_available) def test_get_vtss_multiple_vts(self): - daemon = DummyWrapper([]) - daemon.add_vt('1.2.3.4', 'A vulnerability test') - daemon.add_vt('1.2.3.5', 'Another vulnerability test') - daemon.add_vt('123456789', 'Yet another vulnerability test') + self.daemon.add_vt('1.2.3.4', 'A vulnerability test') + self.daemon.add_vt('1.2.3.5', 'Another vulnerability test') + self.daemon.add_vt('123456789', 'Yet another vulnerability test') fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -243,15 +235,16 @@ def test_get_vtss_multiple_vts(self): self.assertIsNotNone(vts.find('vt')) def test_get_vts_multiple_vts_with_custom(self): - daemon = DummyWrapper([]) - daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b') - daemon.add_vt( + self.daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b') + self.daemon.add_vt( '4.3.2.1', 'Another vulnerability test with custom info', custom='b' ) - daemon.add_vt('123456789', 'Yet another vulnerability test', custom='b') + self.daemon.add_vt( + '123456789', 'Yet another vulnerability test', custom='b' + ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() custom = response.findall('vts/vt/custom') @@ -259,13 +252,12 @@ def test_get_vts_multiple_vts_with_custom(self): self.assertEqual(3, len(custom)) def test_get_vts_vts_with_params(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b" ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() # The status of the response must be success (i.e. 200) @@ -286,8 +278,7 @@ def test_get_vts_vts_with_params(self): self.assertEqual(2, len(params)) def test_get_vts_vts_with_refs(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -296,7 +287,7 @@ def test_get_vts_vts_with_refs(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() # The status of the response must be success (i.e. 200) @@ -318,8 +309,7 @@ def test_get_vts_vts_with_refs(self): self.assertEqual(2, len(refs)) def test_get_vts_vts_with_dependencies(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -328,7 +318,7 @@ def test_get_vts_vts_with_dependencies(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() @@ -336,8 +326,7 @@ def test_get_vts_vts_with_dependencies(self): self.assertEqual(2, len(deps)) def test_get_vts_vts_with_severities(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -346,15 +335,14 @@ def test_get_vts_vts_with_severities(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() severity = response.findall('vts/vt/severities/severity') self.assertEqual(1, len(severity)) def test_get_vts_vts_with_detection_qodt(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -364,15 +352,14 @@ def test_get_vts_vts_with_detection_qodt(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() detection = response.findall('vts/vt/detection') self.assertEqual(1, len(detection)) def test_get_vts_vts_with_detection_qodv(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -382,15 +369,14 @@ def test_get_vts_vts_with_detection_qodv(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() detection = response.findall('vts/vt/detection') self.assertEqual(1, len(detection)) def test_get_vts_vts_with_summary(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -399,15 +385,14 @@ def test_get_vts_vts_with_summary(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() summary = response.findall('vts/vt/summary') self.assertEqual(1, len(summary)) def test_get_vts_vts_with_impact(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -416,15 +401,14 @@ def test_get_vts_vts_with_impact(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() impact = response.findall('vts/vt/impact') self.assertEqual(1, len(impact)) def test_get_vts_vts_with_affected(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -433,15 +417,14 @@ def test_get_vts_vts_with_affected(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() affect = response.findall('vts/vt/affected') self.assertEqual(1, len(affect)) def test_get_vts_vts_with_insight(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -450,15 +433,14 @@ def test_get_vts_vts_with_insight(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() insight = response.findall('vts/vt/insight') self.assertEqual(1, len(insight)) def test_get_vts_vts_with_solution(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -469,15 +451,14 @@ def test_get_vts_vts_with_solution(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() solution = response.findall('vts/vt/solution') self.assertEqual(1, len(solution)) def test_get_vts_vts_with_ctime(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -485,7 +466,7 @@ def test_get_vts_vts_with_ctime(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() creation_time = response.findall('vts/vt/creation_time') @@ -495,8 +476,7 @@ def test_get_vts_vts_with_ctime(self): ) def test_get_vts_vts_with_mtime(self): - daemon = DummyWrapper([]) - daemon.add_vt( + self.daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", @@ -504,7 +484,7 @@ def test_get_vts_vts_with_mtime(self): ) fs = FakeStream() - daemon.handle_command('', fs) + self.daemon.handle_command('', fs) response = fs.get_response() modification_time = response.findall('vts/vt/modification_time') @@ -514,10 +494,9 @@ def test_get_vts_vts_with_mtime(self): ) def test_clean_forgotten_scans(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '', fs, @@ -528,9 +507,10 @@ def test_clean_forgotten_scans(self): finished = False + self.daemon.check_pending_scans() while not finished: fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() @@ -546,41 +526,52 @@ def test_clean_forgotten_scans(self): time.sleep(0.01) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() - self.assertEqual(len(list(daemon.scan_collection.ids_iterator())), 1) + self.assertEqual( + len(list(self.daemon.scan_collection.ids_iterator())), 1 + ) # Set an old end_time - daemon.scan_collection.scans_table[scan_id]['end_time'] = 123456 + self.daemon.scan_collection.scans_table[scan_id]['end_time'] = 123456 # Run the check - daemon.clean_forgotten_scans() + self.daemon.clean_forgotten_scans() # Not removed - self.assertEqual(len(list(daemon.scan_collection.ids_iterator())), 1) + self.assertEqual( + len(list(self.daemon.scan_collection.ids_iterator())), 1 + ) # Set the max time and run again - daemon.scaninfo_store_time = 1 - daemon.clean_forgotten_scans() + self.daemon.scaninfo_store_time = 1 + self.daemon.clean_forgotten_scans() # Now is removed - self.assertEqual(len(list(daemon.scan_collection.ids_iterator())), 0) + self.assertEqual( + len(list(self.daemon.scan_collection.ids_iterator())), 0 + ) def test_scan_with_error(self): - daemon = DummyWrapper([Result('error', value='something went wrong')]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '', fs, ) + response = fs.get_response() scan_id = response.findtext('id') finished = False + self.daemon.check_pending_scans() + self.daemon.add_scan_error( + scan_id, host='a', value='something went wrong' + ) + while not finished: fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() @@ -599,7 +590,7 @@ def test_scan_with_error(self): fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() @@ -608,16 +599,15 @@ def test_scan_with_error(self): response.findtext('scan/results/result'), 'something went wrong' ) fs = FakeStream() - daemon.handle_command('' % scan_id, fs) + self.daemon.handle_command('' % scan_id, fs) response = fs.get_response() self.assertEqual(response.get('status'), '200') def test_get_scan_pop(self): - daemon = DummyWrapper([Result('host-detail', value='Some Host Detail')]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '', fs, @@ -625,17 +615,21 @@ def test_get_scan_pop(self): response = fs.get_response() scan_id = response.findtext('id') + self.daemon.add_scan_host_detail( + scan_id, host='a', value='Some Host Detail' + ) + time.sleep(1) fs = FakeStream() - daemon.handle_command('' % scan_id, fs) + self.daemon.handle_command('' % scan_id, fs) response = fs.get_response() self.assertEqual( response.findtext('scan/results/result'), 'Some Host Detail' ) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() @@ -645,7 +639,7 @@ def test_get_scan_pop(self): ) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs, ) @@ -654,45 +648,40 @@ def test_get_scan_pop(self): self.assertEqual(response.findtext('scan/results/result'), None) def test_get_scan_pop_max_res(self): - daemon = DummyWrapper( - [ - Result('host-detail', value='Some Host Detail'), - Result('host-detail', value='Some Host Detail1'), - Result('host-detail', value='Some Host Detail2'), - ] - ) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '', fs, ) response = fs.get_response() - scan_id = response.findtext('id') - time.sleep(1) + + self.daemon.add_scan_log(scan_id, host='a', name='a') + self.daemon.add_scan_log(scan_id, host='c', name='c') + self.daemon.add_scan_log(scan_id, host='b', name='b') fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs, ) + response = fs.get_response() self.assertEqual(len(response.findall('scan/results/result')), 1) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() - self.assertEqual(len(response.findall('scan/results/result')), 2) def test_billon_laughs(self): # pylint: disable=line-too-long - daemon = DummyWrapper([]) + lol = ( '' '' ) fs = FakeStream() - self.assertRaises(EntitiesForbidden, daemon.handle_command, lol, fs) + self.assertRaises( + EntitiesForbidden, self.daemon.handle_command, lol, fs + ) def test_target_with_credentials(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -747,13 +737,12 @@ def test_target_with_credentials(self): 'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'}, } scan_id = response.findtext('id') - response = daemon.get_scan_credentials(scan_id) + response = self.daemon.get_scan_credentials(scan_id) self.assertEqual(response, cred_dict) def test_scan_get_target(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -768,16 +757,15 @@ def test_scan_get_target(self): scan_id = response.findtext('id') fs = FakeStream() - daemon.handle_command('' % scan_id, fs) + self.daemon.handle_command('' % scan_id, fs) response = fs.get_response() scan_res = response.find('scan') self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24') def test_scan_get_target_options(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -792,14 +780,13 @@ def test_scan_get_target_options(self): scan_id = response.findtext('id') time.sleep(1) - target_options = daemon.get_scan_target_options(scan_id) + target_options = self.daemon.get_scan_target_options(scan_id) self.assertEqual(target_options, {'alive_test': '0'}) def test_progress(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -812,18 +799,17 @@ def test_progress(self): response = fs.get_response() scan_id = response.findtext('id') - daemon.set_scan_host_progress(scan_id, 'localhost1', 75) - daemon.set_scan_host_progress(scan_id, 'localhost2', 25) + self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75) + self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25) self.assertEqual( - daemon.scan_collection.calculate_target_progress(scan_id), 50 + self.daemon.scan_collection.calculate_target_progress(scan_id), 50 ) def test_sort_host_finished(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -836,23 +822,22 @@ def test_sort_host_finished(self): response = fs.get_response() scan_id = response.findtext('id') - daemon.set_scan_host_progress(scan_id, 'localhost3', -1) - daemon.set_scan_host_progress(scan_id, 'localhost1', 75) - daemon.set_scan_host_progress(scan_id, 'localhost4', 100) - daemon.set_scan_host_progress(scan_id, 'localhost2', 25) + self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1) + self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75) + self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100) + self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25) - daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) + self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) - rounded_progress = daemon.scan_collection.calculate_target_progress( + rounded_progress = self.daemon.scan_collection.calculate_target_progress( # pylint: disable=line-too-long) scan_id ) self.assertEqual(rounded_progress, 66) def test_calculate_progress_without_current_hosts(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -865,26 +850,25 @@ def test_calculate_progress_without_current_hosts(self): response = fs.get_response() scan_id = response.findtext('id') - daemon.set_scan_host_progress(scan_id) - daemon.set_scan_host_progress(scan_id, 'localhost3', -1) - daemon.set_scan_host_progress(scan_id, 'localhost4', 100) + self.daemon.set_scan_host_progress(scan_id) + self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1) + self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100) - daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) + self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) - float_progress = daemon.scan_collection.calculate_target_progress( + float_progress = self.daemon.scan_collection.calculate_target_progress( scan_id ) self.assertEqual(int(float_progress), 33) - daemon.scan_collection.set_progress(scan_id, float_progress) - progress = daemon.get_scan_progress(scan_id) + self.daemon.scan_collection.set_progress(scan_id, float_progress) + progress = self.daemon.get_scan_progress(scan_id) self.assertEqual(progress, 33) def test_get_scan_without_scanid(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -898,16 +882,15 @@ def test_get_scan_without_scanid(self): fs = FakeStream() self.assertRaises( OspdCommandError, - daemon.handle_command, + self.daemon.handle_command, '', fs, ) def test_get_scan_progress_xml(self): - daemon = DummyWrapper([]) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -920,15 +903,15 @@ def test_get_scan_progress_xml(self): response = fs.get_response() scan_id = response.findtext('id') - daemon.set_scan_host_progress(scan_id, 'localhost3', -1) - daemon.set_scan_host_progress(scan_id, 'localhost4', 100) - daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) + self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1) + self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100) + self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) - daemon.set_scan_host_progress(scan_id, 'localhost1', 75) - daemon.set_scan_host_progress(scan_id, 'localhost2', 25) + self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75) + self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25) fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs, ) response = fs.get_response() @@ -951,21 +934,17 @@ def test_get_scan_progress_xml(self): self.assertEqual(count_excluded, '0') def test_set_get_vts_version(self): - daemon = DummyWrapper([]) - daemon.set_vts_version('1234') + self.daemon.set_vts_version('1234') - version = daemon.get_vts_version() + version = self.daemon.get_vts_version() self.assertEqual('1234', version) def test_set_get_vts_version_error(self): - daemon = DummyWrapper([]) - self.assertRaises(TypeError, daemon.set_vts_version) + self.assertRaises(TypeError, self.daemon.set_vts_version) @patch("ospd.ospd.os") - @patch("ospd.command.command.create_process") + @patch("ospd.ospd.create_process") def test_scan_exists(self, mock_create_process, _mock_os): - daemon = DummyWrapper([]) - fp = FakeStartProcess() mock_create_process.side_effect = fp mock_process = fp.call_mock @@ -974,7 +953,7 @@ def test_scan_exists(self, mock_create_process, _mock_os): mock_process.pid = "main-scan-process" fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -991,10 +970,12 @@ def test_scan_exists(self, mock_create_process, _mock_os): status = response.get('status_text') self.assertEqual(status, 'OK') + self.daemon.check_pending_scans() + assert_called(mock_create_process) assert_called(mock_process.start) - daemon.handle_command('' % scan_id, fs) + self.daemon.handle_command('' % scan_id, fs) fs = FakeStream() cmd = ( @@ -1007,7 +988,7 @@ def test_scan_exists(self, mock_create_process, _mock_os): '' ) - daemon.handle_command( + self.daemon.handle_command( cmd, fs, ) response = fs.get_response() @@ -1015,9 +996,9 @@ def test_scan_exists(self, mock_create_process, _mock_os): self.assertEqual(status, 'Continue') def test_result_order(self): - daemon = DummyWrapper([]) + fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -1032,13 +1013,13 @@ def test_result_order(self): scan_id = response.findtext('id') - daemon.add_scan_log(scan_id, host='a', name='a') - daemon.add_scan_log(scan_id, host='c', name='c') - daemon.add_scan_log(scan_id, host='b', name='b') + self.daemon.add_scan_log(scan_id, host='a', name='a') + self.daemon.add_scan_log(scan_id, host='c', name='c') + self.daemon.add_scan_log(scan_id, host='b', name='b') hosts = ['a', 'c', 'b'] fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() @@ -1050,10 +1031,9 @@ def test_result_order(self): self.assertEqual(hosts[idx], att_dict['name']) def test_batch_result(self): - daemon = DummyWrapper([]) reslist = ResultList() fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' '' '' @@ -1070,12 +1050,12 @@ def test_batch_result(self): reslist.add_scan_log_to_list(host='a', name='a') reslist.add_scan_log_to_list(host='c', name='c') reslist.add_scan_log_to_list(host='b', name='b') - daemon.scan_collection.add_result_list(scan_id, reslist) + self.daemon.scan_collection.add_result_list(scan_id, reslist) hosts = ['a', 'c', 'b'] fs = FakeStream() - daemon.handle_command( + self.daemon.handle_command( '' % scan_id, fs ) response = fs.get_response() diff --git a/tests/test_ssh_daemon.py b/tests/test_ssh_daemon.py index 35f4da43..363bd449 100644 --- a/tests/test_ssh_daemon.py +++ b/tests/test_ssh_daemon.py @@ -22,6 +22,7 @@ from ospd import ospd_ssh from ospd.ospd_ssh import OSPDaemonSimpleSSH +from .helper import FakeDataManager class FakeFile(object): @@ -73,6 +74,7 @@ def AutoAddPolicy(): # pylint: disable=invalid-name class DummyWrapper(OSPDaemonSimpleSSH): def __init__(self, niceness=10): super().__init__(niceness=niceness) + self.scan_collection.data_manager = FakeDataManager() def check(self): return True From 9622136cd4ec83820b273a3feb2f249b56064a81 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 19 May 2020 09:32:54 +0200 Subject: [PATCH 11/15] Don't cast the iterator to list --- ospd/ospd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 9f2544df..045bc16c 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1187,7 +1187,7 @@ def run(self) -> None: logger.info("Received Ctrl-C shutting-down ...") def check_pending_scans(self): - for scan_id in list(self.scan_collection.ids_iterator()): + for scan_id in self.scan_collection.ids_iterator(): if self.get_scan_status(scan_id) == ScanStatus.PENDING: scan_target = self.scan_collection.scans_table[scan_id].get( 'target' From d84e40c91a6fa37e92ab4bdf7a7d170298fbd36e Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 19 May 2020 09:39:11 +0200 Subject: [PATCH 12/15] Rename check_pending_scans() to start_pending_scans() Adjust tests as well --- ospd/ospd.py | 4 ++-- tests/command/test_commands.py | 14 +++++++------- tests/test_scan_and_result.py | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 045bc16c..349a1607 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1181,12 +1181,12 @@ def run(self) -> None: time.sleep(SCHEDULER_CHECK_PERIOD) self.scheduler() self.clean_forgotten_scans() - self.check_pending_scans() + self.start_pending_scans() self.wait_for_children() except KeyboardInterrupt: logger.info("Received Ctrl-C shutting-down ...") - def check_pending_scans(self): + def start_pending_scans(self): for scan_id in self.scan_collection.ids_iterator(): if self.get_scan_status(scan_id) == ScanStatus.PENDING: scan_target = self.scan_collection.scans_table[scan_id].get( diff --git a/tests/command/test_commands.py b/tests/command/test_commands.py index c5b579a8..2c658f46 100644 --- a/tests/command/test_commands.py +++ b/tests/command/test_commands.py @@ -125,7 +125,7 @@ def test_scan_with_vts(self, mock_create_process): self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []}) self.assertNotEqual(vts_collection, {'1.2.3.6': {}}) - daemon.check_pending_scans() + daemon.start_pending_scans() assert_called(mock_create_process) def test_scan_pop_vts(self): @@ -228,7 +228,7 @@ def test_scan_without_vts(self, mock_create_process): scan_id = response.findtext('id') self.assertEqual(daemon.get_scan_vts(scan_id), {}) - daemon.check_pending_scans() + daemon.start_pending_scans() assert_called(mock_create_process) def test_scan_with_vts_and_param_missing_vt_param_id(self): @@ -283,7 +283,7 @@ def test_scan_with_vts_and_param(self, mock_create_process): daemon.get_scan_vts(scan_id), {'1234': {'ABC': '200'}, 'vt_groups': []}, ) - daemon.check_pending_scans() + daemon.start_pending_scans() assert_called(mock_create_process) def test_scan_with_vts_and_param_missing_vt_group_filter(self): @@ -334,7 +334,7 @@ def test_scan_with_vts_and_param_with_vt_group_filter( self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']}) - daemon.check_pending_scans() + daemon.start_pending_scans() assert_called(mock_create_process) @patch("ospd.ospd.create_process") @@ -355,7 +355,7 @@ def test_scan_ignore_multi_target(self, mock_logger, mock_create_process): ) cmd.handle_xml(request) - daemon.check_pending_scans() + daemon.start_pending_scans() assert_called(mock_logger.warning) assert_called(mock_create_process) @@ -382,7 +382,7 @@ def test_scan_use_legacy_target_and_port( self.assertEqual(daemon.get_scan_host(scan_id), 'localhost') self.assertEqual(daemon.get_scan_ports(scan_id), '22') - daemon.check_pending_scans() + daemon.start_pending_scans() assert_called(mock_logger.warning) assert_called(mock_create_process) @@ -412,7 +412,7 @@ def test_stop_scan(self, mock_create_process, mock_os): daemon.handle_command(request, fs) response = fs.get_response() - daemon.check_pending_scans() + daemon.start_pending_scans() assert_called(mock_create_process) assert_called(mock_process.start) diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index b153ac28..12e4f61b 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -507,7 +507,7 @@ def test_clean_forgotten_scans(self): finished = False - self.daemon.check_pending_scans() + self.daemon.start_pending_scans() while not finished: fs = FakeStream() self.daemon.handle_command( @@ -564,7 +564,7 @@ def test_scan_with_error(self): response = fs.get_response() scan_id = response.findtext('id') finished = False - self.daemon.check_pending_scans() + self.daemon.start_pending_scans() self.daemon.add_scan_error( scan_id, host='a', value='something went wrong' ) @@ -970,7 +970,7 @@ def test_scan_exists(self, mock_create_process, _mock_os): status = response.get('status_text') self.assertEqual(status, 'OK') - self.daemon.check_pending_scans() + self.daemon.start_pending_scans() assert_called(mock_create_process) assert_called(mock_process.start) From 156dca35834440468aa567120e60a4af3129dfa5 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 19 May 2020 09:53:38 +0200 Subject: [PATCH 13/15] Add method to initialize the data manager. Also initialize it before starting the server --- ospd/ospd.py | 3 +-- ospd/scan.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 349a1607..d3e2d67b 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -168,6 +168,7 @@ def init(self, server: BaseServer) -> None: Will be called after check. """ + self.scan_collection.init_data_manager() server.start(self.handle_client_stream) self.initialized = True @@ -1174,8 +1175,6 @@ def run(self) -> None: """ Starts the Daemon, handling commands until interrupted. """ - self.scan_collection.data_manager = multiprocessing.Manager() - try: while True: time.sleep(SCHEDULER_CHECK_PERIOD) diff --git a/ospd/scan.py b/ospd/scan.py index c7fb625a..c8796816 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -64,6 +64,9 @@ def __init__(self) -> None: ) # type: Optional[multiprocessing.managers.SyncManager] self.scans_table = dict() # type: Dict + def init_data_manager(self): + self.data_manager = multiprocessing.Manager() + def add_result( self, scan_id: str, From 01b1496553952a1affd9eee480d4fe658eae8061 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 19 May 2020 10:43:47 +0200 Subject: [PATCH 14/15] Fix dry_run_scan() --- ospd/ospd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index d3e2d67b..d41f9015 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -547,7 +547,7 @@ def dry_run_scan(self, scan_id: str, target: Dict) -> None: os.setsid() - host = resolve_hostname(target[0]) + host = resolve_hostname(target.get('hosts')) if host is None: logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id)) From 59fdfc4a68f960e860f60e6247d3743650d8e758 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 19 May 2020 12:35:41 +0200 Subject: [PATCH 15/15] Rename scan collection init method --- ospd/ospd.py | 2 +- ospd/scan.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index d41f9015..6549ca46 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -168,7 +168,7 @@ def init(self, server: BaseServer) -> None: Will be called after check. """ - self.scan_collection.init_data_manager() + self.scan_collection.init() server.start(self.handle_client_stream) self.initialized = True diff --git a/ospd/scan.py b/ospd/scan.py index c8796816..3df8bb4e 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -64,7 +64,7 @@ def __init__(self) -> None: ) # type: Optional[multiprocessing.managers.SyncManager] self.scans_table = dict() # type: Dict - def init_data_manager(self): + def init(self): self.data_manager = multiprocessing.Manager() def add_result(