diff --git a/ospd/error.py b/ospd/errors.py similarity index 68% rename from ospd/error.py rename to ospd/errors.py index d900c2f1..77be60e3 100644 --- a/ospd/error.py +++ b/ospd/errors.py @@ -22,13 +22,34 @@ from ospd.xml import simple_response_str -class OSPDError(Exception): +class OspdError(Exception): + """ Base error class for all Ospd related errors """ + + +class RequiredArgument(OspdError): + """Raised if a required argument/parameter is missing + + Derives from :py:class:`OspdError` + """ + + def __init__(self, function, argument): + # pylint: disable=super-init-not-called + self.function = function + self.argument = argument + + def __str__(self): + return "{}: Argument {} is required".format( + self.function, self.argument + ) + + +class OspdCommandError(OspdError): """ This is an exception that will result in an error message to the client """ def __init__(self, message, command='osp', status=400): - super().__init__() + super().__init__(message) self.message = message self.command = command self.status = status diff --git a/ospd/ospd.py b/ospd/ospd.py index ecb0b4b0..96f7ece7 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -41,7 +41,7 @@ import defusedxml.ElementTree as secET from ospd import __version__ -from ospd.error import OSPDError +from ospd.errors import OspdCommandError from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid from ospd.network import resolve_hostname, target_str_to_list from ospd.vtfilter import VtsFilter @@ -204,7 +204,7 @@ def __init__( cafile, niceness=None, # pylint: disable=unused-argument customvtfilter=None, - **kwargs # pylint: disable=unused-argument + **kwargs # pylint: disable=unused-argument ): """ Initializes the daemon's internal data. """ # @todo: Actually it makes sense to move the certificate params to @@ -345,7 +345,7 @@ def set_vts_version(self, vts_version): vts_version (str): Identifies a unique vts version. """ if not vts_version: - raise OSPDError( + raise OspdCommandError( 'A vts_version parameter is required', 'set_vts_version' ) self.vts_version = vts_version @@ -400,16 +400,22 @@ def _preprocess_scan_params(self, xml_params): try: params[key] = int(params[key]) except ValueError: - raise OSPDError('Invalid %s value' % key, 'start_scan') + raise OspdCommandError( + 'Invalid %s value' % key, 'start_scan' + ) if param_type == 'boolean': if params[key] not in [0, 1]: - raise OSPDError('Invalid %s value' % key, 'start_scan') + raise OspdCommandError( + 'Invalid %s value' % key, 'start_scan' + ) elif param_type == 'selection': selection = self.get_scanner_param_default(key).split('|') if params[key] not in selection: - raise OSPDError('Invalid %s value' % key, 'start_scan') + raise OspdCommandError( + 'Invalid %s value' % key, 'start_scan' + ) if self.get_scanner_param_mandatory(key) and params[key] == '': - raise OSPDError( + raise OspdCommandError( 'Mandatory %s value is missing' % key, 'start_scan' ) return params @@ -451,7 +457,7 @@ def process_vts_params(self, scanner_vts): vt_selection[vt_id] = {} for vt_value in vt: if not vt_value.attrib.get('id'): - raise OSPDError( + raise OspdCommandError( 'Invalid VT preference. No attribute id', 'start_scan', ) @@ -461,7 +467,7 @@ def process_vts_params(self, scanner_vts): if vt.tag == 'vt_group': vts_filter = vt.attrib.get('filter', None) if vts_filter is None: - raise OSPDError( + raise OspdCommandError( 'Invalid VT group. No filter given.', 'start_scan' ) filters.append(vts_filter) @@ -571,7 +577,7 @@ def process_targets_element(cls, scanner_target): if hosts: target_list.append([hosts, ports, credentials, exclude_hosts]) else: - raise OSPDError('No target to scan', 'start_scan') + raise OspdCommandError('No target to scan', 'start_scan') return target_list @@ -588,7 +594,7 @@ def handle_start_scan_command(self, scan_et): if target_str is None or ports_str is None: target_list = scan_et.find('targets') if target_list is None or len(target_list) == 0: - raise OSPDError('No targets or ports', 'start_scan') + raise OspdCommandError('No targets or ports', 'start_scan') else: scan_targets = self.process_targets_element(target_list) else: @@ -598,21 +604,21 @@ def handle_start_scan_command(self, scan_et): scan_id = scan_et.attrib.get('scan_id') if scan_id is not None and scan_id != '' and not valid_uuid(scan_id): - raise OSPDError('Invalid scan_id UUID', 'start_scan') + raise OspdCommandError('Invalid scan_id UUID', 'start_scan') try: parallel = int(scan_et.attrib.get('parallel', '1')) if parallel < 1 or parallel > 20: parallel = 1 except ValueError: - raise OSPDError( + raise OspdCommandError( 'Invalid value for parallel scans. ' 'It must be a number', 'start_scan', ) scanner_params = scan_et.find('scanner_params') if scanner_params is None: - raise OSPDError('No scanner_params element', 'start_scan') + raise OspdCommandError('No scanner_params element', 'start_scan') params = self._preprocess_scan_params(scanner_params) @@ -621,7 +627,7 @@ def handle_start_scan_command(self, scan_et): scanner_vts = scan_et.find('vt_selection') if scanner_vts is not None: if len(scanner_vts) == 0: - raise OSPDError('VTs list is empty', 'start_scan') + raise OspdCommandError('VTs list is empty', 'start_scan') else: vt_selection = self.process_vts_params(scanner_vts) @@ -653,7 +659,7 @@ def handle_stop_scan_command(self, scan_et): scan_id = scan_et.attrib.get('scan_id') if scan_id is None or scan_id == '': - raise OSPDError('No scan_id attribute', 'stop_scan') + raise OspdCommandError('No scan_id attribute', 'stop_scan') self.stop_scan(scan_id) return simple_response_str('stop_scan', 200, 'OK') @@ -661,9 +667,13 @@ def handle_stop_scan_command(self, scan_et): def stop_scan(self, scan_id): scan_process = self.scan_processes.get(scan_id) if not scan_process: - raise OSPDError('Scan not found {0}.'.format(scan_id), 'stop_scan') + raise OspdCommandError( + 'Scan not found {0}.'.format(scan_id), 'stop_scan' + ) if not scan_process.is_alive(): - raise OSPDError('Scan already stopped or finished.', 'stop_scan') + raise OspdCommandError( + 'Scan already stopped or finished.', 'stop_scan' + ) self.set_scan_status(scan_id, ScanStatus.STOPPED) logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident) @@ -819,12 +829,12 @@ def handle_client_stream(self, stream, is_unix=False): return try: response = self.handle_command(data) - except OSPDError as exception: + except OspdCommandError as exception: response = exception.as_xml() logger.debug('Command error: %s', exception.message) except Exception: # pylint: disable=broad-except logger.exception('While handling client command:') - exception = OSPDError('Fatal error', 'error') + exception = OspdCommandError('Fatal error', 'error') response = exception.as_xml() if is_unix: send_method = stream.send @@ -917,7 +927,7 @@ def start_scan(self, scan_id, targets, parallel=1): logger.info("%s: Scan started.", scan_id) target_list = targets if target_list is None or not target_list: - raise OSPDError('Erroneous targets list', 'start_scan') + raise OspdCommandError('Erroneous targets list', 'start_scan') self.process_exclude_hosts(scan_id, target_list) @@ -1084,7 +1094,7 @@ def handle_help_command(self, scan_et): elif help_format == "xml": text = self.get_xml_str(self.commands) return simple_response_str('help', 200, 'OK', text) - raise OSPDError('Bogus help format', 'help') + raise OspdCommandError('Bogus help format', 'help') def get_help_text(self): """ Returns the help output in plain text format.""" @@ -1143,7 +1153,7 @@ def handle_delete_scan_command(self, scan_et): self.check_scan_process(scan_id) if self.delete_scan(scan_id): return simple_response_str('delete_scan', 200, 'OK') - raise OSPDError('Scan in progress', 'delete_scan') + raise OspdCommandError('Scan in progress', 'delete_scan') def delete_scan(self, scan_id): """ Deletes scan_id scan from collection. @@ -1610,10 +1620,10 @@ def handle_command(self, command): tree = secET.fromstring(command) except secET.ParseError: logger.debug("Erroneous client input: %s", command) - raise OSPDError('Invalid data') + raise OspdCommandError('Invalid data') if not self.command_exists(tree.tag) and tree.tag != "authenticate": - raise OSPDError('Bogus command name') + raise OspdCommandError('Bogus command name') if tree.tag == "get_version": return self.handle_get_version_command() diff --git a/ospd/vtfilter.py b/ospd/vtfilter.py index ad2a905e..4a48e851 100644 --- a/ospd/vtfilter.py +++ b/ospd/vtfilter.py @@ -21,7 +21,7 @@ import re import operator -from ospd.error import OSPDError +from ospd.errors import OspdCommandError class VtsFilter(object): @@ -57,14 +57,14 @@ def parse_filters(self, vt_filter): for single_filter in filter_list: filter_aux = re.split(r'(\W)', single_filter, 1) if len(filter_aux) < 3: - raise OSPDError( + raise OspdCommandError( "Invalid number of argument in the filter", "get_vts" ) _element, _oper, _val = filter_aux if _element not in self.allowed_filter: - raise OSPDError("Invalid filter element", "get_vts") + raise OspdCommandError("Invalid filter element", "get_vts") if _oper not in self.filter_operator: - raise OSPDError("Invalid filter operator", "get_vts") + raise OspdCommandError("Invalid filter operator", "get_vts") filters.append(filter_aux) @@ -109,7 +109,7 @@ def get_filtered_vts_list(self, vts, vt_filter): Dictionary with filtered vulnerability tests. """ if not vt_filter: - raise OSPDError('vt_filter: A valid filter is required.') + raise OspdCommandError('vt_filter: A valid filter is required.') filters = self.parse_filters(vt_filter) if not filters: diff --git a/tests/test_error.py b/tests/test_errors.py similarity index 52% rename from tests/test_error.py rename to tests/test_errors.py index 4dd2e207..70946911 100644 --- a/tests/test_error.py +++ b/tests/test_errors.py @@ -16,32 +16,59 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. -""" Test module for OSPDError class +""" Test module for OspdCommandError class """ import unittest -from ospd.ospd import OSPDError +from ospd.errors import OspdError, OspdCommandError, RequiredArgument -class OSPDErrorTestCase(unittest.TestCase): +class OspdCommandErrorTestCase(unittest.TestCase): + def test_is_ospd_error(self): + e = OspdCommandError('message') + self.assertIsInstance(e, OspdError) + def test_default_params(self): - e = OSPDError('message') + e = OspdCommandError('message') self.assertEqual('message', e.message) self.assertEqual(400, e.status) self.assertEqual('osp', e.command) def test_constructor(self): - e = OSPDError('message', 'command', '304') + e = OspdCommandError('message', 'command', '304') self.assertEqual('message', e.message) self.assertEqual('command', e.command) self.assertEqual('304', e.status) + def test_string_conversion(self): + e = OspdCommandError('message foo bar', 'command', '304') + + self.assertEqual('message foo bar', str(e)) + def test_as_xml(self): - e = OSPDError('message') + e = OspdCommandError('message') self.assertEqual( b'', e.as_xml() ) + + +class RequiredArgumentTestCase(unittest.TestCase): + def test_raise_exception(self): + with self.assertRaises(RequiredArgument) as cm: + raise RequiredArgument('foo', 'bar') + + ex = cm.exception + self.assertEqual(ex.function, 'foo') + self.assertEqual(ex.argument, 'bar') + + def test_string_conversion(self): + ex = RequiredArgument('foo', 'bar') + self.assertEqual(str(ex), 'foo: Argument bar is required') + + def test_is_ospd_error(self): + e = RequiredArgument('foo', 'bar') + self.assertIsInstance(e, OspdError) diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index 926eddb8..786d1c4e 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -30,7 +30,7 @@ from defusedxml.common import EntitiesForbidden from ospd.ospd import OSPDaemon -from ospd.error import OSPDError +from ospd.errors import OspdCommandError class Result(object): @@ -140,13 +140,17 @@ def get_solution_vt_as_xml_str(vt_id, solution, solution_type=None): return response @staticmethod - def get_creation_time_vt_as_xml_str(vt_id, creation_time): # pylint: disable=arguments-differ + def get_creation_time_vt_as_xml_str( + vt_id, creation_time + ): # pylint: disable=arguments-differ response = '%s' % creation_time return response @staticmethod - def get_modification_time_vt_as_xml_str(vt_id, modification_time): # pylint: disable=arguments-differ + def get_modification_time_vt_as_xml_str( + vt_id, modification_time + ): # pylint: disable=arguments-differ response = ( '%s' % modification_time ) @@ -199,7 +203,6 @@ def exec_scan(self, scan_id, target): class ScanTestCase(unittest.TestCase): - def test_get_default_scanner_params(self): daemon = DummyWrapper([]) response = secET.fromstring( @@ -691,7 +694,7 @@ def test_get_scan_pop(self): self.assertIn( response.findtext('scan/results/result'), - ['Scan process failure.', 'Scan stopped.'] + ['Scan process failure.', 'Scan stopped.'], ) response = secET.fromstring( @@ -716,10 +719,14 @@ def test_stop_scan(self): time.sleep(3) cmd = secET.fromstring('' % scan_id) - self.assertRaises(OSPDError, daemon.handle_stop_scan_command, cmd) + self.assertRaises( + OspdCommandError, daemon.handle_stop_scan_command, cmd + ) cmd = secET.fromstring('') - self.assertRaises(OSPDError, daemon.handle_stop_scan_command, cmd) + self.assertRaises( + OspdCommandError, daemon.handle_stop_scan_command, cmd + ) def test_scan_with_vts(self): daemon = DummyWrapper([]) @@ -730,7 +737,7 @@ def test_scan_with_vts(self): '' ) - with self.assertRaises(OSPDError): + with self.assertRaises(OspdCommandError): daemon.handle_start_scan_command(cmd) # With one vt, without params @@ -776,7 +783,7 @@ def test_scan_with_vts_and_param(self): '' ) - with self.assertRaises(OSPDError): + with self.assertRaises(OspdCommandError): daemon.handle_start_scan_command(cmd) # No error @@ -804,7 +811,9 @@ def test_scan_with_vts_and_param(self): '' '' ) - self.assertRaises(OSPDError, daemon.handle_start_scan_command, cmd) + self.assertRaises( + OspdCommandError, daemon.handle_start_scan_command, cmd + ) # No error response = secET.fromstring( @@ -956,7 +965,9 @@ def test_scan_multi_target_parallel_with_error(self): '' ) time.sleep(1) - self.assertRaises(OSPDError, daemon.handle_start_scan_command, cmd) + self.assertRaises( + OspdCommandError, daemon.handle_start_scan_command, cmd + ) def test_scan_multi_target_parallel_100(self): daemon = DummyWrapper([]) @@ -1038,7 +1049,7 @@ def test_resume_task(self): time.sleep(3) cmd = secET.fromstring('' % scan_id) - with self.assertRaises(OSPDError): + with self.assertRaises(OspdCommandError): daemon.handle_stop_scan_command(cmd) response = secET.fromstring( @@ -1053,8 +1064,7 @@ def test_resume_task(self): # Resume the task cmd = ( '' - '' - % scan_id + '' % scan_id ) response = secET.fromstring(daemon.handle_command(cmd))