Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Refactor ospd connection handling #114

Merged
merged 4 commits into from
Jun 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 61 additions & 145 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from __future__ import absolute_import

import logging
import select
import socket
import ssl
import multiprocessing
Expand All @@ -41,9 +40,10 @@
import defusedxml.ElementTree as secET

from ospd import __version__
from ospd.errors import OspdCommandError
from ospd.errors import OspdCommandError, OspdError
from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid
from ospd.network import resolve_hostname, target_str_to_list
from ospd.server import Server, UnixSocketServer, TlsServer
from ospd.vtfilter import VtsFilter
from ospd.xml import simple_response_str, get_result_xml

Expand Down Expand Up @@ -125,64 +125,13 @@
}


def bind_socket(address, port):
""" Returns a socket bound on (address:port). """

assert address
assert port
bindsocket = socket.socket()
try:
bindsocket.bind((address, port))
except socket.error:
logger.error("Couldn't bind socket on %s:%s", address, port)
return None

logger.info('Listening on %s:%s', address, port)
bindsocket.listen(0)
return bindsocket


def bind_unix_socket(path):
""" Returns a unix file socket bound on (path). """

assert path
bindsocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.unlink(path)
except OSError:
if os.path.exists(path):
raise
try:
bindsocket.bind(path)
except socket.error:
logger.error("Couldn't bind socket on %s", path)
return None

logger.info('Listening on %s', path)
bindsocket.listen(0)
return bindsocket


def close_client_stream(client_stream, unix_path):
""" Closes provided client stream """
try:
client_stream.shutdown(socket.SHUT_RDWR)
if unix_path:
logger.debug('%s: Connection closed', unix_path)
else:
peer = client_stream.getpeername()
logger.debug('%s:%s: Connection closed', peer[0], peer[1])
except (socket.error, OSError) as exception:
logger.debug('Connection closing error: %s', exception)
client_stream.close()


class OSPDaemon(object):
class OSPDaemon:

""" Daemon class for OSP traffic handling.

Every scanner wrapper should subclass it and make necessary additions and
changes.

* Add any needed parameters in __init__.
* Implement check() method which verifies scanner availability and other
environment related conditions.
Expand Down Expand Up @@ -282,60 +231,59 @@ def add_vt(
severities=None,
):
""" Add a vulnerability test information.

Returns: The new number of stored VTs.
-1 in case the VT ID was already present and thus the
new VT was not considered.
-2 in case the vt_id was invalid.
"""

if not vt_id:
return -2 # no valid vt_id
raise OspdError('Invalid vt_id {}'.format(vt_id))

if self.vt_id_pattern.fullmatch(vt_id) is None:
return -2 # no valid vt_id
raise OspdError('Invalid vt_id {}'.format(vt_id))

if vt_id in self.vts:
return -1 # The VT was already in the list.
raise OspdError('vt_id {} already exists'.format(vt_id))

if name is None:
name = ''

self.vts[vt_id] = {'name': name}
vt = {'name': name}
if custom is not None:
self.vts[vt_id]["custom"] = custom
vt["custom"] = custom
if vt_params is not None:
self.vts[vt_id]["vt_params"] = vt_params
vt["vt_params"] = vt_params
if vt_refs is not None:
self.vts[vt_id]["vt_refs"] = vt_refs
vt["vt_refs"] = vt_refs
if vt_dependencies is not None:
self.vts[vt_id]["vt_dependencies"] = vt_dependencies
vt["vt_dependencies"] = vt_dependencies
if vt_creation_time is not None:
self.vts[vt_id]["creation_time"] = vt_creation_time
vt["creation_time"] = vt_creation_time
if vt_modification_time is not None:
self.vts[vt_id]["modification_time"] = vt_modification_time
vt["modification_time"] = vt_modification_time
if summary is not None:
self.vts[vt_id]["summary"] = summary
vt["summary"] = summary
if impact is not None:
self.vts[vt_id]["impact"] = impact
vt["impact"] = impact
if affected is not None:
self.vts[vt_id]["affected"] = affected
vt["affected"] = affected
if insight is not None:
self.vts[vt_id]["insight"] = insight
vt["insight"] = insight

if solution is not None:
self.vts[vt_id]["solution"] = solution
vt["solution"] = solution
if solution_t is not None:
self.vts[vt_id]["solution_type"] = solution_t
vt["solution_type"] = solution_t

if detection is not None:
self.vts[vt_id]["detection"] = detection
vt["detection"] = detection

if qod_t is not None:
self.vts[vt_id]["qod_type"] = qod_t
vt["qod_type"] = qod_t
elif qod_v is not None:
self.vts[vt_id]["qod"] = qod_v
vt["qod"] = qod_v

if severities is not None:
self.vts[vt_id]["severities"] = severities
vt["severities"] = severities

return len(self.vts)
self.vts[vt_id] = vt

def set_vts_version(self, vts_version):
""" Add into the vts dictionary an entry to identify the
Expand Down Expand Up @@ -780,53 +728,31 @@ def new_client_stream(self, sock):
return None
return ssl_socket

@staticmethod
def write_to_stream(stream, response, block_len=1024):
"""
Send the response in blocks of the given len using the
passed method dependending on the socket type.
"""
try:
i_start = 0
i_end = block_len
while True:
if i_end > len(response):
stream(response[i_start:])
break
stream(response[i_start:i_end])
i_start = i_end
i_end += block_len
except (socket.timeout, socket.error) as exception:
logger.debug('Error sending response to the client: %s', exception)

def handle_client_stream(self, stream, is_unix=False):
def handle_client_stream(self, stream):
""" Handles stream of data received from client. """

assert stream
data = []
stream.settimeout(2)
data = b''

while True:
try:
if is_unix:
buf = stream.recv(1024)
else:
buf = stream.read(1024)
buf = stream.read()
if not buf:
break
data.append(buf)

data += buf
except (AttributeError, ValueError) as message:
logger.error(message)
return
except (ssl.SSLError) as exception:
logger.debug('Error: %s', exception[0])
logger.debug('Error: %s', exception)
break
except (socket.timeout) as exception:
logger.debug('Error: %s', exception)
break
data = b''.join(data)

if len(data) <= 0:
logger.debug("Empty client stream")
return

try:
response = self.handle_command(data)
except OspdCommandError as exception:
Expand All @@ -836,11 +762,9 @@ def handle_client_stream(self, stream, is_unix=False):
logger.exception('While handling client command:')
exception = OspdCommandError('Fatal error', 'error')
response = exception.as_xml()
if is_unix:
send_method = stream.send
else:
send_method = stream.write
self.write_to_stream(send_method, response)

stream.write(response)
stream.close()

def parallel_scan(self, scan_id, target):
""" Starts the scan with scan_id. """
Expand Down Expand Up @@ -1648,44 +1572,36 @@ def check(self):
""" Asserts to False. Should be implemented by subclass. """
raise NotImplementedError

def run(self, address, port, unix_path):
def run_server(self, server: Server):
""" Starts the Daemon, handling commands until interrupted.

@return False if error. Runs indefinitely otherwise.
"""
assert address or unix_path
if unix_path:
sock = bind_unix_socket(unix_path)
else:
sock = bind_socket(address, port)
if sock is None:
return False

sock.setblocking(False)
inputs = [sock]
outputs = []
server.bind()

try:
while True:
readable, _, _ = select.select(
inputs, outputs, inputs, SCHEDULER_CHECK_PERIOD
server.select(
self.handle_client_stream, timeout=SCHEDULER_CHECK_PERIOD
)
for r_socket in readable:
if unix_path and r_socket is sock:
client_stream, _ = sock.accept()
logger.debug("New connection from %s", unix_path)
self.handle_client_stream(client_stream, True)
else:
client_stream = self.new_client_stream(sock)
if client_stream is None:
continue
self.handle_client_stream(client_stream, False)
close_client_stream(client_stream, unix_path)
self.scheduler()
except KeyboardInterrupt:
logger.info("Received Ctrl-C shutting-down ...")
finally:
sock.shutdown(socket.SHUT_RDWR)
sock.close()
server.close()

def run(self, address, port, unix_path):
if unix_path:
server = UnixSocketServer(unix_path)
else:
server = TlsServer(
address,
port,
self.certs['cert_file'],
self.certs['key_file'],
self.certs['ca_file'],
)

self.run_server(server)

def scheduler(self):
""" Should be implemented by subclass in case of need
Expand Down
Loading