Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Commit

Permalink
Merge pull request #274 from jjnicola/prefork
Browse files Browse the repository at this point in the history
Create a pre fork()'ed data manager to store scan data information
  • Loading branch information
jjnicola authored May 19, 2020
2 parents 7458307 + 59fdfc4 commit 54fe4da
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 229 deletions.
19 changes: 9 additions & 10 deletions ospd/command/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,10 @@ def handle_xml(self, xml: Element) -> bytes:
vt_selection = OspRequest.process_vts_params(scanner_vts)

# Dry run case.
if 'dry_run' in params and int(params['dry_run']):
scan_func = self._daemon.dry_run_scan
dry_run = 'dry_run' in params and int(params['dry_run'])
if dry_run:
scan_params = None
else:
scan_func = self._daemon.start_scan
scan_params = self._daemon.process_scan_params(params)

scan_id_aux = scan_id
Expand All @@ -591,13 +590,13 @@ def handle_xml(self, xml: Element) -> bytes:
id_.text = scan_id_aux
return simple_response_str('start_scan', 100, 'Continue', id_)

scan_process = create_process(
func=scan_func, args=(scan_id, scan_target)
)

self._daemon.scan_processes[scan_id] = scan_process

scan_process.start()
if dry_run:
scan_func = self._daemon.dry_run_scan
scan_process = create_process(
func=scan_func, args=(scan_id, scan_target)
)
self._daemon.scan_processes[scan_id] = scan_process
scan_process.start()

id_ = Element('id')
id_.text = scan_id
Expand Down
25 changes: 22 additions & 3 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ospd import __version__
from ospd.command import get_commands
from ospd.errors import OspdCommandError
from ospd.misc import ResultType
from ospd.misc import ResultType, create_process
from ospd.network import resolve_hostname, target_str_to_list
from ospd.protocol import OspRequest, OspResponse, RequestParser
from ospd.scan import ScanCollection, ScanStatus
Expand Down Expand Up @@ -168,6 +168,7 @@ def init(self, server: BaseServer) -> None:
Will be called after check.
"""
self.scan_collection.init()
server.start(self.handle_client_stream)
self.initialized = True

Expand Down Expand Up @@ -546,7 +547,7 @@ def dry_run_scan(self, scan_id: str, target: Dict) -> None:

os.setsid()

host = resolve_hostname(target[0])
host = resolve_hostname(target.get('hosts'))
if host is None:
logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id))

Expand Down Expand Up @@ -1179,10 +1180,25 @@ def run(self) -> None:
time.sleep(SCHEDULER_CHECK_PERIOD)
self.scheduler()
self.clean_forgotten_scans()
self.start_pending_scans()
self.wait_for_children()
except KeyboardInterrupt:
logger.info("Received Ctrl-C shutting-down ...")

def start_pending_scans(self):
for scan_id in self.scan_collection.ids_iterator():
if self.get_scan_status(scan_id) == ScanStatus.PENDING:
scan_target = self.scan_collection.scans_table[scan_id].get(
'target'
)
scan_func = self.start_scan
scan_process = create_process(
func=scan_func, args=(scan_id, scan_target)
)
self.scan_processes[scan_id] = scan_process
scan_process.start()
self.set_scan_status(scan_id, ScanStatus.INIT)

def scheduler(self):
""" Should be implemented by subclass in case of need
to run tasks periodically. """
Expand Down Expand Up @@ -1255,9 +1271,12 @@ def clean_forgotten_scans(self) -> None:

def check_scan_process(self, scan_id: str) -> None:
""" Check the scan's process, and terminate the scan if not alive. """
scan_process = self.scan_processes[scan_id]
scan_process = self.scan_processes.get(scan_id)
progress = self.get_scan_progress(scan_id)

if self.get_scan_status(scan_id) == ScanStatus.PENDING:
return

if progress < PROGRESS_FINISHED and not scan_process.is_alive():
if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
self.set_scan_status(scan_id, ScanStatus.STOPPED)
Expand Down
42 changes: 20 additions & 22 deletions ospd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
class ScanStatus(Enum):
"""Scan status. """

INIT = 0
RUNNING = 1
STOPPED = 2
FINISHED = 3
PENDING = 0
INIT = 1
RUNNING = 2
STOPPED = 3
FINISHED = 4


class ScanCollection:
Expand Down Expand Up @@ -63,6 +64,9 @@ def __init__(self) -> None:
) # type: Optional[multiprocessing.managers.SyncManager]
self.scans_table = dict() # type: Dict

def init(self):
self.data_manager = multiprocessing.Manager()

def add_result(
self,
scan_id: str,
Expand Down Expand Up @@ -209,9 +213,6 @@ def create_scan(
if not target:
target = {}

if self.data_manager is None:
self.data_manager = multiprocessing.Manager()

if not options:
options = dict()

Expand All @@ -226,7 +227,7 @@ def create_scan(
scan_info['options'] = options
scan_info['start_time'] = int(time.time())
scan_info['end_time'] = 0
scan_info['status'] = ScanStatus.INIT
scan_info['status'] = ScanStatus.PENDING

if scan_id is None or scan_id == '':
scan_id = str(uuid.uuid4())
Expand Down Expand Up @@ -349,7 +350,10 @@ def get_host_count(self, scan_id: str) -> int:
def get_ports(self, scan_id: str):
""" Get a scan's ports list.
"""
return self.scans_table[scan_id]['target'].get('ports')
target = self.scans_table[scan_id].get('target')
ports = target.pop('ports')
self.scans_table[scan_id]['target'] = target
return ports

def get_exclude_hosts(self, scan_id: str):
""" Get an exclude host list for a given target.
Expand All @@ -376,15 +380,11 @@ def get_target_options(self, scan_id: str):

def get_vts(self, scan_id: str) -> Dict:
""" Get a scan's vts. """
scan_info = self.scans_table[scan_id]
vts = scan_info.pop('vts')
self.scans_table[scan_id] = scan_info

return self.scans_table[scan_id]['vts']

def release_vts_list(self, scan_id: str) -> None:
""" Release the memory used for the vts list. """

scan_data = self.scans_table.get(scan_id)
if scan_data and 'vts' in scan_data:
del scan_data['vts']
return vts

def id_exists(self, scan_id: str) -> bool:
""" Check whether a scan exists in the table. """
Expand All @@ -397,10 +397,8 @@ def delete_scan(self, scan_id: str) -> bool:
if self.get_status(scan_id) == ScanStatus.RUNNING:
return False

self.scans_table.pop(scan_id)

if len(self.scans_table) == 0:
del self.data_manager
self.data_manager = None
scans_table = self.scans_table
del scans_table[scan_id]
self.scans_table = scans_table

return True
93 changes: 77 additions & 16 deletions tests/command/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ospd.errors import OspdCommandError, OspdError
from ospd.misc import create_process

from ..helper import DummyWrapper, assert_called, FakeStream
from ..helper import DummyWrapper, assert_called, FakeStream, FakeDataManager


class GetPerformanceTestCase(TestCase):
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_scan_with_vts_empty_vt_list(self):
with self.assertRaises(OspdCommandError):
cmd.handle_xml(request)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_with_vts(self, mock_create_process):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand All @@ -121,13 +121,67 @@ def test_scan_with_vts(self, mock_create_process):
response = et.fromstring(cmd.handle_xml(request))
scan_id = response.findtext('id')

self.assertEqual(
daemon.get_scan_vts(scan_id), {'1.2.3.4': {}, 'vt_groups': []}
)
self.assertNotEqual(daemon.get_scan_vts(scan_id), {'1.2.3.6': {}})
vts_collection = daemon.get_scan_vts(scan_id)
self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
self.assertNotEqual(vts_collection, {'1.2.3.6': {}})

daemon.start_pending_scans()
assert_called(mock_create_process)

def test_scan_pop_vts(self):
daemon = DummyWrapper([])
cmd = StartScan(daemon)

request = et.fromstring(
'<start_scan>'
'<targets>'
'<target>'
'<hosts>localhost</hosts>'
'<ports>80, 443</ports>'
'</target>'
'</targets>'
'<scanner_params />'
'<vt_selection>'
'<vt_single id="1.2.3.4" />'
'</vt_selection>'
'</start_scan>'
)

# With one vt, without params
response = et.fromstring(cmd.handle_xml(request))
scan_id = response.findtext('id')

vts_collection = daemon.get_scan_vts(scan_id)
self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
self.assertRaises(KeyError, daemon.get_scan_vts, scan_id)

def test_scan_pop_ports(self):
daemon = DummyWrapper([])
cmd = StartScan(daemon)

request = et.fromstring(
'<start_scan>'
'<targets>'
'<target>'
'<hosts>localhost</hosts>'
'<ports>80, 443</ports>'
'</target>'
'</targets>'
'<scanner_params />'
'<vt_selection>'
'<vt_single id="1.2.3.4" />'
'</vt_selection>'
'</start_scan>'
)

# With one vt, without params
response = et.fromstring(cmd.handle_xml(request))
scan_id = response.findtext('id')

ports = daemon.scan_collection.get_ports(scan_id)
self.assertEqual(ports, '80, 443')
self.assertRaises(KeyError, daemon.scan_collection.get_ports, scan_id)

def test_is_new_scan_allowed_false(self):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand All @@ -152,7 +206,7 @@ def test_is_new_scan_allowed_true(self):

self.assertTrue(cmd.is_new_scan_allowed())

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_without_vts(self, mock_create_process):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand All @@ -172,9 +226,9 @@ def test_scan_without_vts(self, mock_create_process):
response = et.fromstring(cmd.handle_xml(request))

scan_id = response.findtext('id')

self.assertEqual(daemon.get_scan_vts(scan_id), {})

daemon.start_pending_scans()
assert_called(mock_create_process)

def test_scan_with_vts_and_param_missing_vt_param_id(self):
Expand All @@ -200,7 +254,7 @@ def test_scan_with_vts_and_param_missing_vt_param_id(self):
with self.assertRaises(OspdError):
cmd.handle_xml(request)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_with_vts_and_param(self, mock_create_process):
daemon = DummyWrapper([])
cmd = StartScan(daemon)
Expand Down Expand Up @@ -229,7 +283,7 @@ def test_scan_with_vts_and_param(self, mock_create_process):
daemon.get_scan_vts(scan_id),
{'1234': {'ABC': '200'}, 'vt_groups': []},
)

daemon.start_pending_scans()
assert_called(mock_create_process)

def test_scan_with_vts_and_param_missing_vt_group_filter(self):
Expand All @@ -253,7 +307,7 @@ def test_scan_with_vts_and_param_missing_vt_group_filter(self):
with self.assertRaises(OspdError):
cmd.handle_xml(request)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_scan_with_vts_and_param_with_vt_group_filter(
self, mock_create_process
):
Expand All @@ -280,9 +334,10 @@ def test_scan_with_vts_and_param_with_vt_group_filter(

self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']})

daemon.start_pending_scans()
assert_called(mock_create_process)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
@patch("ospd.command.command.logger")
def test_scan_ignore_multi_target(self, mock_logger, mock_create_process):
daemon = DummyWrapper([])
Expand All @@ -300,16 +355,18 @@ def test_scan_ignore_multi_target(self, mock_logger, mock_create_process):
)

cmd.handle_xml(request)

daemon.start_pending_scans()
assert_called(mock_logger.warning)
assert_called(mock_create_process)

@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
@patch("ospd.command.command.logger")
def test_scan_use_legacy_target_and_port(
self, mock_logger, mock_create_process
):
daemon = DummyWrapper([])
daemon.scan_collection.datamanager = FakeDataManager()

cmd = StartScan(daemon)
request = et.fromstring(
'<start_scan target="localhost" ports="22">'
Expand All @@ -325,20 +382,22 @@ def test_scan_use_legacy_target_and_port(
self.assertEqual(daemon.get_scan_host(scan_id), 'localhost')
self.assertEqual(daemon.get_scan_ports(scan_id), '22')

daemon.start_pending_scans()

assert_called(mock_logger.warning)
assert_called(mock_create_process)


class StopCommandTestCase(TestCase):
@patch("ospd.ospd.os")
@patch("ospd.command.command.create_process")
@patch("ospd.ospd.create_process")
def test_stop_scan(self, mock_create_process, mock_os):
mock_process = mock_create_process.return_value
mock_process.is_alive.return_value = True
mock_process.pid = "foo"

fs = FakeStream()
daemon = DummyWrapper([])
daemon.scan_collection.datamanager = FakeDataManager()
request = (
'<start_scan>'
'<targets>'
Expand All @@ -353,6 +412,8 @@ def test_stop_scan(self, mock_create_process, mock_os):
daemon.handle_command(request, fs)
response = fs.get_response()

daemon.start_pending_scans()

assert_called(mock_create_process)
assert_called(mock_process.start)

Expand Down
Loading

0 comments on commit 54fe4da

Please sign in to comment.