Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Create a pre fork()'ed data manager to store scan data information #274

Merged
merged 15 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions ospd/command/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,10 @@ def handle_xml(self, xml: Element) -> bytes:
vt_selection = OspRequest.process_vts_params(scanner_vts)

# Dry run case.
if 'dry_run' in params and int(params['dry_run']):
scan_func = self._daemon.dry_run_scan
dry_run = 'dry_run' in params and int(params['dry_run'])
if dry_run:
scan_params = None
else:
scan_func = self._daemon.start_scan
scan_params = self._daemon.process_scan_params(params)

scan_id_aux = scan_id
Expand All @@ -591,13 +590,13 @@ def handle_xml(self, xml: Element) -> bytes:
id_.text = scan_id_aux
return simple_response_str('start_scan', 100, 'Continue', id_)

scan_process = create_process(
func=scan_func, args=(scan_id, scan_target)
)

self._daemon.scan_processes[scan_id] = scan_process

scan_process.start()
if dry_run:
bjoernricks marked this conversation as resolved.
Show resolved Hide resolved
scan_func = self._daemon.dry_run_scan
scan_process = create_process(
func=scan_func, args=(scan_id, scan_target)
)
self._daemon.scan_processes[scan_id] = scan_process
scan_process.start()

id_ = Element('id')
id_.text = scan_id
Expand Down
23 changes: 21 additions & 2 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ospd import __version__
from ospd.command import get_commands
from ospd.errors import OspdCommandError
from ospd.misc import ResultType
from ospd.misc import ResultType, create_process
from ospd.network import resolve_hostname, target_str_to_list
from ospd.protocol import OspRequest, OspResponse, RequestParser
from ospd.scan import ScanCollection, ScanStatus
Expand Down Expand Up @@ -168,6 +168,7 @@ def init(self, server: BaseServer) -> None:

Will be called after check.
"""
self.scan_collection.init_data_manager()
server.start(self.handle_client_stream)
self.initialized = True

Expand Down Expand Up @@ -1179,10 +1180,25 @@ def run(self) -> None:
time.sleep(SCHEDULER_CHECK_PERIOD)
self.scheduler()
self.clean_forgotten_scans()
self.start_pending_scans()
self.wait_for_children()
except KeyboardInterrupt:
logger.info("Received Ctrl-C shutting-down ...")

def start_pending_scans(self):
for scan_id in self.scan_collection.ids_iterator():
if self.get_scan_status(scan_id) == ScanStatus.PENDING:
scan_target = self.scan_collection.scans_table[scan_id].get(
'target'
)
bjoernricks marked this conversation as resolved.
Show resolved Hide resolved
scan_func = self.start_scan
scan_process = create_process(
func=scan_func, args=(scan_id, scan_target)
)
self.scan_processes[scan_id] = scan_process
scan_process.start()
self.set_scan_status(scan_id, ScanStatus.INIT)

def scheduler(self):
""" Should be implemented by subclass in case of need
to run tasks periodically. """
Expand Down Expand Up @@ -1255,9 +1271,12 @@ def clean_forgotten_scans(self) -> None:

def check_scan_process(self, scan_id: str) -> None:
""" Check the scan's process, and terminate the scan if not alive. """
scan_process = self.scan_processes[scan_id]
scan_process = self.scan_processes.get(scan_id)
progress = self.get_scan_progress(scan_id)

if self.get_scan_status(scan_id) == ScanStatus.PENDING:
return

if progress < PROGRESS_FINISHED and not scan_process.is_alive():
if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
self.set_scan_status(scan_id, ScanStatus.STOPPED)
Expand Down
42 changes: 20 additions & 22 deletions ospd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
class ScanStatus(Enum):
"""Scan status. """

INIT = 0
RUNNING = 1
STOPPED = 2
FINISHED = 3
PENDING = 0
INIT = 1
RUNNING = 2
STOPPED = 3
FINISHED = 4


class ScanCollection:
Expand Down Expand Up @@ -63,6 +64,9 @@ def __init__(self) -> None:
) # type: Optional[multiprocessing.managers.SyncManager]
self.scans_table = dict() # type: Dict

def init_data_manager(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I wont call it init_data_manager because the data manager should be an internal detail of the class. If we are using a data manager or if we are storing the data in a dict, db, redis, memory, json, ... should not matter for the outside. For the outside using code it must only be obvious that a init function has to be called after the instance has been created.

self.data_manager = multiprocessing.Manager()

def add_result(
self,
scan_id: str,
Expand Down Expand Up @@ -209,9 +213,6 @@ def create_scan(
if not target:
target = {}

if self.data_manager is None:
self.data_manager = multiprocessing.Manager()

if not options:
options = dict()

Expand All @@ -226,7 +227,7 @@ def create_scan(
scan_info['options'] = options
scan_info['start_time'] = int(time.time())
scan_info['end_time'] = 0
scan_info['status'] = ScanStatus.INIT
scan_info['status'] = ScanStatus.PENDING

if scan_id is None or scan_id == '':
scan_id = str(uuid.uuid4())
Expand Down Expand Up @@ -349,7 +350,10 @@ def get_host_count(self, scan_id: str) -> int:
def get_ports(self, scan_id: str):
""" Get a scan's ports list.
"""
return self.scans_table[scan_id]['target'].get('ports')
target = self.scans_table[scan_id].get('target')
ports = target.pop('ports')
self.scans_table[scan_id]['target'] = target
return ports

def get_exclude_hosts(self, scan_id: str):
""" Get an exclude host list for a given target.
Expand All @@ -376,15 +380,11 @@ def get_target_options(self, scan_id: str):

def get_vts(self, scan_id: str) -> Dict:
""" Get a scan's vts. """
scan_info = self.scans_table[scan_id]
vts = scan_info.pop('vts')
self.scans_table[scan_id] = scan_info

return self.scans_table[scan_id]['vts']

def release_vts_list(self, scan_id: str) -> None:
""" Release the memory used for the vts list. """

scan_data = self.scans_table.get(scan_id)
if scan_data and 'vts' in scan_data:
del scan_data['vts']
return vts

def id_exists(self, scan_id: str) -> bool:
""" Check whether a scan exists in the table. """
Expand All @@ -397,10 +397,8 @@ def delete_scan(self, scan_id: str) -> bool:
if self.get_status(scan_id) == ScanStatus.RUNNING:
return False

self.scans_table.pop(scan_id)

if len(self.scans_table) == 0:
del self.data_manager
self.data_manager = None
scans_table = self.scans_table
del scans_table[scan_id]
self.scans_table = scans_table

return True
93 changes: 77 additions & 16 deletions tests/command/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ospd.errors import OspdCommandError, OspdError
from ospd.misc import create_process

from ..helper import DummyWrapper, assert_called, FakeStream
from ..helper import DummyWrapper, assert_called, FakeStream, FakeDataManager


class GetPerformanceTestCase(TestCase):
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_scan_with_vts_empty_vt_list(self):
with self.assertRaises(OspdCommandError):
cmd.handle_xml(request)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_with_vts(self, mock_create_process):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand All @@ -121,13 +121,67 @@ def test_scan_with_vts(self, mock_create_process):
response = et.fromstring(cmd.handle_xml(request))
scan_id = response.findtext('id')

self.assertEqual(
daemon.get_scan_vts(scan_id), {'1.2.3.4': {}, 'vt_groups': []}
)
self.assertNotEqual(daemon.get_scan_vts(scan_id), {'1.2.3.6': {}})
vts_collection = daemon.get_scan_vts(scan_id)
self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
self.assertNotEqual(vts_collection, {'1.2.3.6': {}})

daemon.start_pending_scans()
assert_called(mock_create_process)

def test_scan_pop_vts(self):
daemon = DummyWrapper([])
cmd = StartScan(daemon)

request = et.fromstring(
'<start_scan>'
'<targets>'
'<target>'
'<hosts>localhost</hosts>'
'<ports>80, 443</ports>'
'</target>'
'</targets>'
'<scanner_params />'
'<vt_selection>'
'<vt_single id="1.2.3.4" />'
'</vt_selection>'
'</start_scan>'
)

# With one vt, without params
response = et.fromstring(cmd.handle_xml(request))
scan_id = response.findtext('id')

vts_collection = daemon.get_scan_vts(scan_id)
self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
self.assertRaises(KeyError, daemon.get_scan_vts, scan_id)

def test_scan_pop_ports(self):
daemon = DummyWrapper([])
cmd = StartScan(daemon)

request = et.fromstring(
'<start_scan>'
'<targets>'
'<target>'
'<hosts>localhost</hosts>'
'<ports>80, 443</ports>'
'</target>'
'</targets>'
'<scanner_params />'
'<vt_selection>'
'<vt_single id="1.2.3.4" />'
'</vt_selection>'
'</start_scan>'
)

# With one vt, without params
response = et.fromstring(cmd.handle_xml(request))
scan_id = response.findtext('id')

ports = daemon.scan_collection.get_ports(scan_id)
self.assertEqual(ports, '80, 443')
self.assertRaises(KeyError, daemon.scan_collection.get_ports, scan_id)

def test_is_new_scan_allowed_false(self):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand All @@ -152,7 +206,7 @@ def test_is_new_scan_allowed_true(self):

self.assertTrue(cmd.is_new_scan_allowed())

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_without_vts(self, mock_create_process):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand All @@ -172,9 +226,9 @@ def test_scan_without_vts(self, mock_create_process):
response = et.fromstring(cmd.handle_xml(request))

scan_id = response.findtext('id')

self.assertEqual(daemon.get_scan_vts(scan_id), {})

daemon.start_pending_scans()
assert_called(mock_create_process)

def test_scan_with_vts_and_param_missing_vt_param_id(self):
Expand All @@ -200,7 +254,7 @@ def test_scan_with_vts_and_param_missing_vt_param_id(self):
with self.assertRaises(OspdError):
cmd.handle_xml(request)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_with_vts_and_param(self, mock_create_process):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand Down Expand Up @@ -229,7 +283,7 @@ def test_scan_with_vts_and_param(self, mock_create_process):
daemon.get_scan_vts(scan_id),
{'1234': {'ABC': '200'}, 'vt_groups': []},
)

daemon.start_pending_scans()
assert_called(mock_create_process)

def test_scan_with_vts_and_param_missing_vt_group_filter(self):
Expand All @@ -253,7 +307,7 @@ def test_scan_with_vts_and_param_missing_vt_group_filter(self):
with self.assertRaises(OspdError):
cmd.handle_xml(request)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_with_vts_and_param_with_vt_group_filter(
self, mock_create_process
):
Expand All @@ -280,9 +334,10 @@ def test_scan_with_vts_and_param_with_vt_group_filter(

self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']})

daemon.start_pending_scans()
assert_called(mock_create_process)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
@patch("ospd.command.command.logger")
def test_scan_ignore_multi_target(self, mock_logger, mock_create_process):
daemon = DummyWrapper([])
Expand All @@ -300,16 +355,18 @@ def test_scan_ignore_multi_target(self, mock_logger, mock_create_process):
)

cmd.handle_xml(request)

daemon.start_pending_scans()
assert_called(mock_logger.warning)
assert_called(mock_create_process)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
@patch("ospd.command.command.logger")
def test_scan_use_legacy_target_and_port(
self, mock_logger, mock_create_process
):
daemon = DummyWrapper([])
daemon.scan_collection.datamanager = FakeDataManager()

cmd = StartScan(daemon)
request = et.fromstring(
'<start_scan target="localhost" ports="22">'
Expand All @@ -325,20 +382,22 @@ def test_scan_use_legacy_target_and_port(
self.assertEqual(daemon.get_scan_host(scan_id), 'localhost')
self.assertEqual(daemon.get_scan_ports(scan_id), '22')

daemon.start_pending_scans()

assert_called(mock_logger.warning)
assert_called(mock_create_process)


class StopCommandTestCase(TestCase):
@patch("ospd.ospd.os")
@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_stop_scan(self, mock_create_process, mock_os):
mock_process = mock_create_process.return_value
mock_process.is_alive.return_value = True
mock_process.pid = "foo"

fs = FakeStream()
daemon = DummyWrapper([])
daemon.scan_collection.datamanager = FakeDataManager()
request = (
'<start_scan>'
'<targets>'
Expand All @@ -353,6 +412,8 @@ def test_stop_scan(self, mock_create_process, mock_os):
daemon.handle_command(request, fs)
response = fs.get_response()

daemon.start_pending_scans()

assert_called(mock_create_process)
assert_called(mock_process.start)

Expand Down
Loading