diff --git a/ospd/error.py b/ospd/errors.py
similarity index 68%
rename from ospd/error.py
rename to ospd/errors.py
index d900c2f1..77be60e3 100644
--- a/ospd/error.py
+++ b/ospd/errors.py
@@ -22,13 +22,34 @@
from ospd.xml import simple_response_str
-class OSPDError(Exception):
+class OspdError(Exception):
+ """ Base error class for all Ospd related errors """
+
+
+class RequiredArgument(OspdError):
+ """Raised if a required argument/parameter is missing
+
+ Derives from :py:class:`OspdError`
+ """
+
+ def __init__(self, function, argument):
+ # pylint: disable=super-init-not-called
+ self.function = function
+ self.argument = argument
+
+ def __str__(self):
+ return "{}: Argument {} is required".format(
+ self.function, self.argument
+ )
+
+
+class OspdCommandError(OspdError):
""" This is an exception that will result in an error message to the
client """
def __init__(self, message, command='osp', status=400):
- super().__init__()
+ super().__init__(message)
self.message = message
self.command = command
self.status = status
diff --git a/ospd/ospd.py b/ospd/ospd.py
index ecb0b4b0..96f7ece7 100644
--- a/ospd/ospd.py
+++ b/ospd/ospd.py
@@ -41,7 +41,7 @@
import defusedxml.ElementTree as secET
from ospd import __version__
-from ospd.error import OSPDError
+from ospd.errors import OspdCommandError
from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid
from ospd.network import resolve_hostname, target_str_to_list
from ospd.vtfilter import VtsFilter
@@ -204,7 +204,7 @@ def __init__(
cafile,
niceness=None, # pylint: disable=unused-argument
customvtfilter=None,
- **kwargs # pylint: disable=unused-argument
+ **kwargs # pylint: disable=unused-argument
):
""" Initializes the daemon's internal data. """
# @todo: Actually it makes sense to move the certificate params to
@@ -345,7 +345,7 @@ def set_vts_version(self, vts_version):
vts_version (str): Identifies a unique vts version.
"""
if not vts_version:
- raise OSPDError(
+ raise OspdCommandError(
'A vts_version parameter is required', 'set_vts_version'
)
self.vts_version = vts_version
@@ -400,16 +400,22 @@ def _preprocess_scan_params(self, xml_params):
try:
params[key] = int(params[key])
except ValueError:
- raise OSPDError('Invalid %s value' % key, 'start_scan')
+ raise OspdCommandError(
+ 'Invalid %s value' % key, 'start_scan'
+ )
if param_type == 'boolean':
if params[key] not in [0, 1]:
- raise OSPDError('Invalid %s value' % key, 'start_scan')
+ raise OspdCommandError(
+ 'Invalid %s value' % key, 'start_scan'
+ )
elif param_type == 'selection':
selection = self.get_scanner_param_default(key).split('|')
if params[key] not in selection:
- raise OSPDError('Invalid %s value' % key, 'start_scan')
+ raise OspdCommandError(
+ 'Invalid %s value' % key, 'start_scan'
+ )
if self.get_scanner_param_mandatory(key) and params[key] == '':
- raise OSPDError(
+ raise OspdCommandError(
'Mandatory %s value is missing' % key, 'start_scan'
)
return params
@@ -451,7 +457,7 @@ def process_vts_params(self, scanner_vts):
vt_selection[vt_id] = {}
for vt_value in vt:
if not vt_value.attrib.get('id'):
- raise OSPDError(
+ raise OspdCommandError(
'Invalid VT preference. No attribute id',
'start_scan',
)
@@ -461,7 +467,7 @@ def process_vts_params(self, scanner_vts):
if vt.tag == 'vt_group':
vts_filter = vt.attrib.get('filter', None)
if vts_filter is None:
- raise OSPDError(
+ raise OspdCommandError(
'Invalid VT group. No filter given.', 'start_scan'
)
filters.append(vts_filter)
@@ -571,7 +577,7 @@ def process_targets_element(cls, scanner_target):
if hosts:
target_list.append([hosts, ports, credentials, exclude_hosts])
else:
- raise OSPDError('No target to scan', 'start_scan')
+ raise OspdCommandError('No target to scan', 'start_scan')
return target_list
@@ -588,7 +594,7 @@ def handle_start_scan_command(self, scan_et):
if target_str is None or ports_str is None:
target_list = scan_et.find('targets')
if target_list is None or len(target_list) == 0:
- raise OSPDError('No targets or ports', 'start_scan')
+ raise OspdCommandError('No targets or ports', 'start_scan')
else:
scan_targets = self.process_targets_element(target_list)
else:
@@ -598,21 +604,21 @@ def handle_start_scan_command(self, scan_et):
scan_id = scan_et.attrib.get('scan_id')
if scan_id is not None and scan_id != '' and not valid_uuid(scan_id):
- raise OSPDError('Invalid scan_id UUID', 'start_scan')
+ raise OspdCommandError('Invalid scan_id UUID', 'start_scan')
try:
parallel = int(scan_et.attrib.get('parallel', '1'))
if parallel < 1 or parallel > 20:
parallel = 1
except ValueError:
- raise OSPDError(
+ raise OspdCommandError(
'Invalid value for parallel scans. ' 'It must be a number',
'start_scan',
)
scanner_params = scan_et.find('scanner_params')
if scanner_params is None:
- raise OSPDError('No scanner_params element', 'start_scan')
+ raise OspdCommandError('No scanner_params element', 'start_scan')
params = self._preprocess_scan_params(scanner_params)
@@ -621,7 +627,7 @@ def handle_start_scan_command(self, scan_et):
scanner_vts = scan_et.find('vt_selection')
if scanner_vts is not None:
if len(scanner_vts) == 0:
- raise OSPDError('VTs list is empty', 'start_scan')
+ raise OspdCommandError('VTs list is empty', 'start_scan')
else:
vt_selection = self.process_vts_params(scanner_vts)
@@ -653,7 +659,7 @@ def handle_stop_scan_command(self, scan_et):
scan_id = scan_et.attrib.get('scan_id')
if scan_id is None or scan_id == '':
- raise OSPDError('No scan_id attribute', 'stop_scan')
+ raise OspdCommandError('No scan_id attribute', 'stop_scan')
self.stop_scan(scan_id)
return simple_response_str('stop_scan', 200, 'OK')
@@ -661,9 +667,13 @@ def handle_stop_scan_command(self, scan_et):
def stop_scan(self, scan_id):
scan_process = self.scan_processes.get(scan_id)
if not scan_process:
- raise OSPDError('Scan not found {0}.'.format(scan_id), 'stop_scan')
+ raise OspdCommandError(
+ 'Scan not found {0}.'.format(scan_id), 'stop_scan'
+ )
if not scan_process.is_alive():
- raise OSPDError('Scan already stopped or finished.', 'stop_scan')
+ raise OspdCommandError(
+ 'Scan already stopped or finished.', 'stop_scan'
+ )
self.set_scan_status(scan_id, ScanStatus.STOPPED)
logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
@@ -819,12 +829,12 @@ def handle_client_stream(self, stream, is_unix=False):
return
try:
response = self.handle_command(data)
- except OSPDError as exception:
+ except OspdCommandError as exception:
response = exception.as_xml()
logger.debug('Command error: %s', exception.message)
except Exception: # pylint: disable=broad-except
logger.exception('While handling client command:')
- exception = OSPDError('Fatal error', 'error')
+ exception = OspdCommandError('Fatal error', 'error')
response = exception.as_xml()
if is_unix:
send_method = stream.send
@@ -917,7 +927,7 @@ def start_scan(self, scan_id, targets, parallel=1):
logger.info("%s: Scan started.", scan_id)
target_list = targets
if target_list is None or not target_list:
- raise OSPDError('Erroneous targets list', 'start_scan')
+ raise OspdCommandError('Erroneous targets list', 'start_scan')
self.process_exclude_hosts(scan_id, target_list)
@@ -1084,7 +1094,7 @@ def handle_help_command(self, scan_et):
elif help_format == "xml":
text = self.get_xml_str(self.commands)
return simple_response_str('help', 200, 'OK', text)
- raise OSPDError('Bogus help format', 'help')
+ raise OspdCommandError('Bogus help format', 'help')
def get_help_text(self):
""" Returns the help output in plain text format."""
@@ -1143,7 +1153,7 @@ def handle_delete_scan_command(self, scan_et):
self.check_scan_process(scan_id)
if self.delete_scan(scan_id):
return simple_response_str('delete_scan', 200, 'OK')
- raise OSPDError('Scan in progress', 'delete_scan')
+ raise OspdCommandError('Scan in progress', 'delete_scan')
def delete_scan(self, scan_id):
""" Deletes scan_id scan from collection.
@@ -1610,10 +1620,10 @@ def handle_command(self, command):
tree = secET.fromstring(command)
except secET.ParseError:
logger.debug("Erroneous client input: %s", command)
- raise OSPDError('Invalid data')
+ raise OspdCommandError('Invalid data')
if not self.command_exists(tree.tag) and tree.tag != "authenticate":
- raise OSPDError('Bogus command name')
+ raise OspdCommandError('Bogus command name')
if tree.tag == "get_version":
return self.handle_get_version_command()
diff --git a/ospd/vtfilter.py b/ospd/vtfilter.py
index ad2a905e..4a48e851 100644
--- a/ospd/vtfilter.py
+++ b/ospd/vtfilter.py
@@ -21,7 +21,7 @@
import re
import operator
-from ospd.error import OSPDError
+from ospd.errors import OspdCommandError
class VtsFilter(object):
@@ -57,14 +57,14 @@ def parse_filters(self, vt_filter):
for single_filter in filter_list:
filter_aux = re.split(r'(\W)', single_filter, 1)
if len(filter_aux) < 3:
- raise OSPDError(
+ raise OspdCommandError(
"Invalid number of argument in the filter", "get_vts"
)
_element, _oper, _val = filter_aux
if _element not in self.allowed_filter:
- raise OSPDError("Invalid filter element", "get_vts")
+ raise OspdCommandError("Invalid filter element", "get_vts")
if _oper not in self.filter_operator:
- raise OSPDError("Invalid filter operator", "get_vts")
+ raise OspdCommandError("Invalid filter operator", "get_vts")
filters.append(filter_aux)
@@ -109,7 +109,7 @@ def get_filtered_vts_list(self, vts, vt_filter):
Dictionary with filtered vulnerability tests.
"""
if not vt_filter:
- raise OSPDError('vt_filter: A valid filter is required.')
+ raise OspdCommandError('vt_filter: A valid filter is required.')
filters = self.parse_filters(vt_filter)
if not filters:
diff --git a/tests/test_error.py b/tests/test_errors.py
similarity index 52%
rename from tests/test_error.py
rename to tests/test_errors.py
index 4dd2e207..70946911 100644
--- a/tests/test_error.py
+++ b/tests/test_errors.py
@@ -16,32 +16,59 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
-""" Test module for OSPDError class
+""" Test module for OspdCommandError class
"""
import unittest
-from ospd.ospd import OSPDError
+from ospd.errors import OspdError, OspdCommandError, RequiredArgument
-class OSPDErrorTestCase(unittest.TestCase):
+class OspdCommandErrorTestCase(unittest.TestCase):
+ def test_is_ospd_error(self):
+ e = OspdCommandError('message')
+ self.assertIsInstance(e, OspdError)
+
def test_default_params(self):
- e = OSPDError('message')
+ e = OspdCommandError('message')
self.assertEqual('message', e.message)
self.assertEqual(400, e.status)
self.assertEqual('osp', e.command)
def test_constructor(self):
- e = OSPDError('message', 'command', '304')
+ e = OspdCommandError('message', 'command', '304')
self.assertEqual('message', e.message)
self.assertEqual('command', e.command)
self.assertEqual('304', e.status)
+ def test_string_conversion(self):
+ e = OspdCommandError('message foo bar', 'command', '304')
+
+ self.assertEqual('message foo bar', str(e))
+
def test_as_xml(self):
- e = OSPDError('message')
+ e = OspdCommandError('message')
self.assertEqual(
b'', e.as_xml()
)
+
+
+class RequiredArgumentTestCase(unittest.TestCase):
+ def test_raise_exception(self):
+ with self.assertRaises(RequiredArgument) as cm:
+ raise RequiredArgument('foo', 'bar')
+
+ ex = cm.exception
+ self.assertEqual(ex.function, 'foo')
+ self.assertEqual(ex.argument, 'bar')
+
+ def test_string_conversion(self):
+ ex = RequiredArgument('foo', 'bar')
+ self.assertEqual(str(ex), 'foo: Argument bar is required')
+
+ def test_is_ospd_error(self):
+ e = RequiredArgument('foo', 'bar')
+ self.assertIsInstance(e, OspdError)
diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py
index 926eddb8..786d1c4e 100644
--- a/tests/test_scan_and_result.py
+++ b/tests/test_scan_and_result.py
@@ -30,7 +30,7 @@
from defusedxml.common import EntitiesForbidden
from ospd.ospd import OSPDaemon
-from ospd.error import OSPDError
+from ospd.errors import OspdCommandError
class Result(object):
@@ -140,13 +140,17 @@ def get_solution_vt_as_xml_str(vt_id, solution, solution_type=None):
return response
@staticmethod
- def get_creation_time_vt_as_xml_str(vt_id, creation_time): # pylint: disable=arguments-differ
+ def get_creation_time_vt_as_xml_str(
+ vt_id, creation_time
+ ): # pylint: disable=arguments-differ
response = '%s' % creation_time
return response
@staticmethod
- def get_modification_time_vt_as_xml_str(vt_id, modification_time): # pylint: disable=arguments-differ
+ def get_modification_time_vt_as_xml_str(
+ vt_id, modification_time
+ ): # pylint: disable=arguments-differ
response = (
'%s' % modification_time
)
@@ -199,7 +203,6 @@ def exec_scan(self, scan_id, target):
class ScanTestCase(unittest.TestCase):
-
def test_get_default_scanner_params(self):
daemon = DummyWrapper([])
response = secET.fromstring(
@@ -691,7 +694,7 @@ def test_get_scan_pop(self):
self.assertIn(
response.findtext('scan/results/result'),
- ['Scan process failure.', 'Scan stopped.']
+ ['Scan process failure.', 'Scan stopped.'],
)
response = secET.fromstring(
@@ -716,10 +719,14 @@ def test_stop_scan(self):
time.sleep(3)
cmd = secET.fromstring('' % scan_id)
- self.assertRaises(OSPDError, daemon.handle_stop_scan_command, cmd)
+ self.assertRaises(
+ OspdCommandError, daemon.handle_stop_scan_command, cmd
+ )
cmd = secET.fromstring('')
- self.assertRaises(OSPDError, daemon.handle_stop_scan_command, cmd)
+ self.assertRaises(
+ OspdCommandError, daemon.handle_stop_scan_command, cmd
+ )
def test_scan_with_vts(self):
daemon = DummyWrapper([])
@@ -730,7 +737,7 @@ def test_scan_with_vts(self):
''
)
- with self.assertRaises(OSPDError):
+ with self.assertRaises(OspdCommandError):
daemon.handle_start_scan_command(cmd)
# With one vt, without params
@@ -776,7 +783,7 @@ def test_scan_with_vts_and_param(self):
''
)
- with self.assertRaises(OSPDError):
+ with self.assertRaises(OspdCommandError):
daemon.handle_start_scan_command(cmd)
# No error
@@ -804,7 +811,9 @@ def test_scan_with_vts_and_param(self):
''
''
)
- self.assertRaises(OSPDError, daemon.handle_start_scan_command, cmd)
+ self.assertRaises(
+ OspdCommandError, daemon.handle_start_scan_command, cmd
+ )
# No error
response = secET.fromstring(
@@ -956,7 +965,9 @@ def test_scan_multi_target_parallel_with_error(self):
''
)
time.sleep(1)
- self.assertRaises(OSPDError, daemon.handle_start_scan_command, cmd)
+ self.assertRaises(
+ OspdCommandError, daemon.handle_start_scan_command, cmd
+ )
def test_scan_multi_target_parallel_100(self):
daemon = DummyWrapper([])
@@ -1038,7 +1049,7 @@ def test_resume_task(self):
time.sleep(3)
cmd = secET.fromstring('' % scan_id)
- with self.assertRaises(OSPDError):
+ with self.assertRaises(OspdCommandError):
daemon.handle_stop_scan_command(cmd)
response = secET.fromstring(
@@ -1053,8 +1064,7 @@ def test_resume_task(self):
# Resume the task
cmd = (
''
- ''
- % scan_id
+ '' % scan_id
)
response = secET.fromstring(daemon.handle_command(cmd))