Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Refactor error module #112

Merged
merged 6 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions ospd/error.py → ospd/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,34 @@
from ospd.xml import simple_response_str


class OSPDError(Exception):
class OspdError(Exception):
""" Base error class for all Ospd related errors """


class RequiredArgument(OspdError):
"""Raised if a required argument/parameter is missing

Derives from :py:class:`OspdError`
"""

def __init__(self, function, argument):
# pylint: disable=super-init-not-called
self.function = function
self.argument = argument

def __str__(self):
return "{}: Argument {} is required".format(
self.function, self.argument
)


class OspdCommandError(OspdError):

""" This is an exception that will result in an error message to the
client """

def __init__(self, message, command='osp', status=400):
super().__init__()
super().__init__(message)
self.message = message
self.command = command
self.status = status
Expand Down
60 changes: 35 additions & 25 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import defusedxml.ElementTree as secET

from ospd import __version__
from ospd.error import OSPDError
from ospd.errors import OspdCommandError
from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid
from ospd.network import resolve_hostname, target_str_to_list
from ospd.vtfilter import VtsFilter
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(
cafile,
niceness=None, # pylint: disable=unused-argument
customvtfilter=None,
**kwargs # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
""" Initializes the daemon's internal data. """
# @todo: Actually it makes sense to move the certificate params to
Expand Down Expand Up @@ -345,7 +345,7 @@ def set_vts_version(self, vts_version):
vts_version (str): Identifies a unique vts version.
"""
if not vts_version:
raise OSPDError(
raise OspdCommandError(
'A vts_version parameter is required', 'set_vts_version'
)
self.vts_version = vts_version
Expand Down Expand Up @@ -400,16 +400,22 @@ def _preprocess_scan_params(self, xml_params):
try:
params[key] = int(params[key])
except ValueError:
raise OSPDError('Invalid %s value' % key, 'start_scan')
raise OspdCommandError(
'Invalid %s value' % key, 'start_scan'
)
if param_type == 'boolean':
if params[key] not in [0, 1]:
raise OSPDError('Invalid %s value' % key, 'start_scan')
raise OspdCommandError(
'Invalid %s value' % key, 'start_scan'
)
elif param_type == 'selection':
selection = self.get_scanner_param_default(key).split('|')
if params[key] not in selection:
raise OSPDError('Invalid %s value' % key, 'start_scan')
raise OspdCommandError(
'Invalid %s value' % key, 'start_scan'
)
if self.get_scanner_param_mandatory(key) and params[key] == '':
raise OSPDError(
raise OspdCommandError(
'Mandatory %s value is missing' % key, 'start_scan'
)
return params
Expand Down Expand Up @@ -451,7 +457,7 @@ def process_vts_params(self, scanner_vts):
vt_selection[vt_id] = {}
for vt_value in vt:
if not vt_value.attrib.get('id'):
raise OSPDError(
raise OspdCommandError(
'Invalid VT preference. No attribute id',
'start_scan',
)
Expand All @@ -461,7 +467,7 @@ def process_vts_params(self, scanner_vts):
if vt.tag == 'vt_group':
vts_filter = vt.attrib.get('filter', None)
if vts_filter is None:
raise OSPDError(
raise OspdCommandError(
'Invalid VT group. No filter given.', 'start_scan'
)
filters.append(vts_filter)
Expand Down Expand Up @@ -571,7 +577,7 @@ def process_targets_element(cls, scanner_target):
if hosts:
target_list.append([hosts, ports, credentials, exclude_hosts])
else:
raise OSPDError('No target to scan', 'start_scan')
raise OspdCommandError('No target to scan', 'start_scan')

return target_list

Expand All @@ -588,7 +594,7 @@ def handle_start_scan_command(self, scan_et):
if target_str is None or ports_str is None:
target_list = scan_et.find('targets')
if target_list is None or len(target_list) == 0:
raise OSPDError('No targets or ports', 'start_scan')
raise OspdCommandError('No targets or ports', 'start_scan')
else:
scan_targets = self.process_targets_element(target_list)
else:
Expand All @@ -598,21 +604,21 @@ def handle_start_scan_command(self, scan_et):

scan_id = scan_et.attrib.get('scan_id')
if scan_id is not None and scan_id != '' and not valid_uuid(scan_id):
raise OSPDError('Invalid scan_id UUID', 'start_scan')
raise OspdCommandError('Invalid scan_id UUID', 'start_scan')

try:
parallel = int(scan_et.attrib.get('parallel', '1'))
if parallel < 1 or parallel > 20:
parallel = 1
except ValueError:
raise OSPDError(
raise OspdCommandError(
'Invalid value for parallel scans. ' 'It must be a number',
'start_scan',
)

scanner_params = scan_et.find('scanner_params')
if scanner_params is None:
raise OSPDError('No scanner_params element', 'start_scan')
raise OspdCommandError('No scanner_params element', 'start_scan')

params = self._preprocess_scan_params(scanner_params)

Expand All @@ -621,7 +627,7 @@ def handle_start_scan_command(self, scan_et):
scanner_vts = scan_et.find('vt_selection')
if scanner_vts is not None:
if len(scanner_vts) == 0:
raise OSPDError('VTs list is empty', 'start_scan')
raise OspdCommandError('VTs list is empty', 'start_scan')
else:
vt_selection = self.process_vts_params(scanner_vts)

Expand Down Expand Up @@ -653,17 +659,21 @@ def handle_stop_scan_command(self, scan_et):

scan_id = scan_et.attrib.get('scan_id')
if scan_id is None or scan_id == '':
raise OSPDError('No scan_id attribute', 'stop_scan')
raise OspdCommandError('No scan_id attribute', 'stop_scan')
self.stop_scan(scan_id)

return simple_response_str('stop_scan', 200, 'OK')

def stop_scan(self, scan_id):
scan_process = self.scan_processes.get(scan_id)
if not scan_process:
raise OSPDError('Scan not found {0}.'.format(scan_id), 'stop_scan')
raise OspdCommandError(
'Scan not found {0}.'.format(scan_id), 'stop_scan'
)
if not scan_process.is_alive():
raise OSPDError('Scan already stopped or finished.', 'stop_scan')
raise OspdCommandError(
'Scan already stopped or finished.', 'stop_scan'
)

self.set_scan_status(scan_id, ScanStatus.STOPPED)
logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
Expand Down Expand Up @@ -819,12 +829,12 @@ def handle_client_stream(self, stream, is_unix=False):
return
try:
response = self.handle_command(data)
except OSPDError as exception:
except OspdCommandError as exception:
response = exception.as_xml()
logger.debug('Command error: %s', exception.message)
except Exception: # pylint: disable=broad-except
logger.exception('While handling client command:')
exception = OSPDError('Fatal error', 'error')
exception = OspdCommandError('Fatal error', 'error')
response = exception.as_xml()
if is_unix:
send_method = stream.send
Expand Down Expand Up @@ -917,7 +927,7 @@ def start_scan(self, scan_id, targets, parallel=1):
logger.info("%s: Scan started.", scan_id)
target_list = targets
if target_list is None or not target_list:
raise OSPDError('Erroneous targets list', 'start_scan')
raise OspdCommandError('Erroneous targets list', 'start_scan')

self.process_exclude_hosts(scan_id, target_list)

Expand Down Expand Up @@ -1084,7 +1094,7 @@ def handle_help_command(self, scan_et):
elif help_format == "xml":
text = self.get_xml_str(self.commands)
return simple_response_str('help', 200, 'OK', text)
raise OSPDError('Bogus help format', 'help')
raise OspdCommandError('Bogus help format', 'help')

def get_help_text(self):
""" Returns the help output in plain text format."""
Expand Down Expand Up @@ -1143,7 +1153,7 @@ def handle_delete_scan_command(self, scan_et):
self.check_scan_process(scan_id)
if self.delete_scan(scan_id):
return simple_response_str('delete_scan', 200, 'OK')
raise OSPDError('Scan in progress', 'delete_scan')
raise OspdCommandError('Scan in progress', 'delete_scan')

def delete_scan(self, scan_id):
""" Deletes scan_id scan from collection.
Expand Down Expand Up @@ -1610,10 +1620,10 @@ def handle_command(self, command):
tree = secET.fromstring(command)
except secET.ParseError:
logger.debug("Erroneous client input: %s", command)
raise OSPDError('Invalid data')
raise OspdCommandError('Invalid data')

if not self.command_exists(tree.tag) and tree.tag != "authenticate":
raise OSPDError('Bogus command name')
raise OspdCommandError('Bogus command name')

if tree.tag == "get_version":
return self.handle_get_version_command()
Expand Down
10 changes: 5 additions & 5 deletions ospd/vtfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import re
import operator

from ospd.error import OSPDError
from ospd.errors import OspdCommandError


class VtsFilter(object):
Expand Down Expand Up @@ -57,14 +57,14 @@ def parse_filters(self, vt_filter):
for single_filter in filter_list:
filter_aux = re.split(r'(\W)', single_filter, 1)
if len(filter_aux) < 3:
raise OSPDError(
raise OspdCommandError(
"Invalid number of argument in the filter", "get_vts"
)
_element, _oper, _val = filter_aux
if _element not in self.allowed_filter:
raise OSPDError("Invalid filter element", "get_vts")
raise OspdCommandError("Invalid filter element", "get_vts")
if _oper not in self.filter_operator:
raise OSPDError("Invalid filter operator", "get_vts")
raise OspdCommandError("Invalid filter operator", "get_vts")

filters.append(filter_aux)

Expand Down Expand Up @@ -109,7 +109,7 @@ def get_filtered_vts_list(self, vts, vt_filter):
Dictionary with filtered vulnerability tests.
"""
if not vt_filter:
raise OSPDError('vt_filter: A valid filter is required.')
raise OspdCommandError('vt_filter: A valid filter is required.')

filters = self.parse_filters(vt_filter)
if not filters:
Expand Down
39 changes: 33 additions & 6 deletions tests/test_error.py → tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,59 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.

""" Test module for OSPDError class
""" Test module for OspdCommandError class
"""

import unittest

from ospd.ospd import OSPDError
from ospd.errors import OspdError, OspdCommandError, RequiredArgument


class OSPDErrorTestCase(unittest.TestCase):
class OspdCommandErrorTestCase(unittest.TestCase):
def test_is_ospd_error(self):
e = OspdCommandError('message')
self.assertIsInstance(e, OspdError)

def test_default_params(self):
e = OSPDError('message')
e = OspdCommandError('message')

self.assertEqual('message', e.message)
self.assertEqual(400, e.status)
self.assertEqual('osp', e.command)

def test_constructor(self):
e = OSPDError('message', 'command', '304')
e = OspdCommandError('message', 'command', '304')

self.assertEqual('message', e.message)
self.assertEqual('command', e.command)
self.assertEqual('304', e.status)

def test_string_conversion(self):
e = OspdCommandError('message foo bar', 'command', '304')

self.assertEqual('message foo bar', str(e))

def test_as_xml(self):
e = OSPDError('message')
e = OspdCommandError('message')

self.assertEqual(
b'<osp_response status="400" status_text="message" />', e.as_xml()
)


class RequiredArgumentTestCase(unittest.TestCase):
def test_raise_exception(self):
with self.assertRaises(RequiredArgument) as cm:
raise RequiredArgument('foo', 'bar')

ex = cm.exception
self.assertEqual(ex.function, 'foo')
self.assertEqual(ex.argument, 'bar')

def test_string_conversion(self):
ex = RequiredArgument('foo', 'bar')
self.assertEqual(str(ex), 'foo: Argument bar is required')

def test_is_ospd_error(self):
e = RequiredArgument('foo', 'bar')
self.assertIsInstance(e, OspdError)
Loading