diff --git a/CHANGELOG.md b/CHANGELOG.md index 79c38f6f..bfd03149 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Add pid file creation to avoid having two daemons. [#126](https://github.com/greenbone/ospd/pull/126) [#128](https://github.com/greenbone/ospd/pull/128) - Add OSP command. [#131](https://github.com/greenbone/ospd/pull/131) [#137](https://github.com/greenbone/ospd/pull/137) - Add method to check if a target finished cleanly or crashed. [#133](https://github.com/greenbone/ospd/pull/133) +- Add the --stream-timeout option to configure the socket timeout. [#136](https://github.com/greenbone/ospd/pull/136) +- Add support to handle multiple requests simultaneously. [#136](https://github.com/greenbone/ospd/pull/136) ### Changed - Improve documentation. diff --git a/ospd/main.py b/ospd/main.py index 80378579..afe300c4 100644 --- a/ospd/main.py +++ b/ospd/main.py @@ -123,10 +123,19 @@ def main( ) if args.port == 0: - server = UnixSocketServer(args.unix_socket, args.socket_mode) + server = UnixSocketServer( + args.unix_socket, + args.socket_mode, + args.stream_timeout, + ) else: server = TlsServer( - args.address, args.port, args.cert_file, args.key_file, args.ca_file + args.address, + args.port, + args.cert_file, + args.key_file, + args.ca_file, + args.stream_timeout, ) daemon = daemon_class(**vars(args)) diff --git a/ospd/ospd.py b/ospd/ospd.py index eb369ed6..d759bed6 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -44,7 +44,7 @@ from ospd.errors import OspdCommandError, OspdError from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid from ospd.network import resolve_hostname, target_str_to_list -from ospd.server import Server +from ospd.server import BaseServer from ospd.vtfilter import VtsFilter from ospd.xml import simple_response_str, get_result_xml @@ -1645,21 +1645,20 @@ def check(self): """ Asserts to False. Should be implemented by subclass. """ raise NotImplementedError - def run(self, server: Server): + def run(self, server: BaseServer): """ Starts the Daemon, handling commands until interrupted. """ - server.bind() + server.start(self.handle_client_stream) try: while True: - server.select( - self.handle_client_stream, timeout=SCHEDULER_CHECK_PERIOD - ) + time.sleep(10) self.scheduler() except KeyboardInterrupt: logger.info("Received Ctrl-C shutting-down ...") finally: + logger.info("Shutting-down server ...") server.close() def scheduler(self): diff --git a/ospd/parser.py b/ospd/parser.py index 05350411..dbb21bf7 100644 --- a/ospd/parser.py +++ b/ospd/parser.py @@ -34,6 +34,7 @@ DEFAULT_CONFIG_PATH = "~/.config/ospd.conf" DEFAULT_UNIX_SOCKET_PATH = "/tmp/ospd.sock" DEFAULT_PID_PATH = "/run/ospd/ospd.pid" +DEFAULT_STREAM_TIMEOUT = 10 # ten seconds ParserType = argparse.ArgumentParser Arguments = argparse.Namespace @@ -122,6 +123,13 @@ def __init__(self, description): action='store_true', help='Run in foreground and logs all messages to console.', ) + parser.add_argument( + '-t', + '--stream-timeout', + default=DEFAULT_STREAM_TIMEOUT, + type=int, + help='Stream timeout. Default: %(default)s', + ) parser.add_argument( '-l', '--log-file', help='Path to the logging file.' ) diff --git a/ospd/server.py b/ospd/server.py index e9d2878a..1339dbc2 100644 --- a/ospd/server.py +++ b/ospd/server.py @@ -20,11 +20,12 @@ """ import logging -import select import socket import ssl import time import os +import threading +import socketserver from abc import ABC, abstractmethod from pathlib import Path @@ -34,14 +35,12 @@ logger = logging.getLogger(__name__) - -DEFAULT_STREAM_TIMEOUT = 2 # two seconds DEFAULT_BUFSIZE = 1024 class Stream: - def __init__(self, sock: socket.socket): + def __init__(self, sock: socket.socket, stream_timeout: int): self.socket = sock - self.socket.settimeout(DEFAULT_STREAM_TIMEOUT) + self.socket.settimeout(stream_timeout) def close(self): """ Close the stream @@ -83,65 +82,131 @@ def write(self, data: bytes): StreamCallbackType = Callable[[Stream], None] +def validate_cacert_file(cacert: str): + """ Check if provided file is a valid CA Certificate """ + try: + context = ssl.create_default_context(cafile=cacert) + except AttributeError: + # Python version < 2.7.9 + return + except IOError: + raise OspdError('CA Certificate not found') -class Server(ABC): - @abstractmethod - def bind(self): - """ Start listening for incoming connections - """ + try: + not_after = context.get_ca_certs()[0]['notAfter'] + not_after = ssl.cert_time_to_seconds(not_after) + not_before = context.get_ca_certs()[0]['notBefore'] + not_before = ssl.cert_time_to_seconds(not_before) + except (KeyError, IndexError): + raise OspdError('CA Certificate is erroneous') - @abstractmethod - def select( - self, - stream_callback: StreamCallbackType, - timeout: Optional[float] = None, - ): - """ Wait for incoming connections or until timeout is reached + now = int(time.time()) + if not_after < now: + raise OspdError('CA Certificate expired') - If a new client connects the stream_callback is called with a Stream + if not_before > now: + raise OspdError('CA Certificate not active yet') - Arguments: + +def start_server(stream_callback, stream_timeout, newsocket, tls_ctx=None): + """ Starts listening and creates a new thread for each new client + connection. + Arguments: stream_callback (function): Callback function to be called when a stream is ready - timeout (float): Timeout in seconds to wait for new streams - """ + newsocket (path to socket or socket tuple): The tuple with + address and port or the path to the socket for unix domain + sockets. + Returns the created server object. + """ + class ThreadedRequestHandler(socketserver.BaseRequestHandler): + """ Class to handle the request.""" + def handle(self): + if tls_ctx: + logger.debug( + "New connection from" " %s:%s", newsocket[0], newsocket[1] + ) + req_socket = tls_ctx.wrap_socket(self.request, server_side=True) + else: + req_socket = self.request + logger.debug("New connection from %s", newsocket) -class BaseServer(Server): - def __init__(self): - self.socket = None + stream = Stream(req_socket, stream_timeout) + stream_callback(stream) - @abstractmethod - def _accept(self) -> Stream: + class ThreadedUnixSockServer( + socketserver.ThreadingMixIn, + socketserver.UnixStreamServer, + ): pass - def select( - self, - stream_callback: StreamCallbackType, - timeout: Optional[float] = None, + class ThreadedTlsSockServer( + socketserver.ThreadingMixIn, + socketserver.TCPServer, ): - inputs = [self.socket] + pass - readable, _, _ = select.select(inputs, [], inputs, timeout) + if tls_ctx: + try: + server = ThreadedTlsSockServer(newsocket, ThreadedRequestHandler) + except OSError as e: + logger.error( + "Couldn't bind socket on %s:%s", newsocket[0], newsocket[1] + ) + raise OspdError( + "Couldn't bind socket on {}:{}. {}".format( + newsocket[0], str(newsocket[1]), e, + )) + else: + try: + server = ThreadedUnixSockServer( + str(newsocket), ThreadedRequestHandler + ) + except OSError as e: + logger.error("Couldn't bind socket on %s", str(newsocket)) + raise OspdError( + "Couldn't bind socket on {}. {}".format(str(newsocket), e) + ) - # timeout has fired if readable is empty otherwise a new connection is - # available - if readable: - stream = self._accept() - stream_callback(stream) + + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + return server + + +class BaseServer(ABC): + def __init__(self, stream_timeout): + self.server = None + self.stream_timeout = stream_timeout + + @abstractmethod + def start( + self, + stream_callback: StreamCallbackType, + ): + """ Starts a server with capabilities to handle multiple client + connections simultaneously. + If a new client connects the stream_callback is called with a Stream + Arguments: + stream_callback (function): Callback function to be called when + a stream is ready + """ def close(self): - if self.socket: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() + """ Shutdown the server""" + self.server.shutdown() + self.server.server_close() class UnixSocketServer(BaseServer): """ Server for accepting connections via a Unix domain socket """ - def __init__(self, socket_path: str, socket_mode: str): - super().__init__() + def __init__(self, socket_path, socket_mode, stream_timeout: int): + super().__init__(stream_timeout) self.socket_path = Path(socket_path) self.socket_mode = int(socket_mode, 8) @@ -154,69 +219,26 @@ def _create_parent_dirs(self): parent = self.socket_path.parent parent.mkdir(parents=True, exist_ok=True) - def bind(self): + def start( + self, + stream_callback: StreamCallbackType, + ): self._cleanup_socket() self._create_parent_dirs() - bindsocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - - try: - bindsocket.bind(str(self.socket_path)) - except socket.error: - raise OspdError( - "Couldn't bind socket on {}".format(self.socket_path) - ) - - os.chmod(str(self.socket_path), self.socket_mode) - - logger.info( - 'Unix domain socket server listening on %s', self.socket_path + self.server = start_server( + stream_callback, + self.stream_timeout, + self.socket_path ) - bindsocket.listen(0) - bindsocket.setblocking(False) - - self.socket = bindsocket - - def _accept(self) -> Stream: - new_socket, _addr = self.socket.accept() - - logger.debug("New connection from %s", self.socket_path) - - return Stream(new_socket) + if self.socket_path.exists(): + os.chmod(str(self.socket_path), self.socket_mode) def close(self): super().close() - self._cleanup_socket() - -def validate_cacert_file(cacert: str): - """ Check if provided file is a valid CA Certificate """ - try: - context = ssl.create_default_context(cafile=cacert) - except AttributeError: - # Python version < 2.7.9 - return - except IOError: - raise OspdError('CA Certificate not found') - - try: - not_after = context.get_ca_certs()[0]['notAfter'] - not_after = ssl.cert_time_to_seconds(not_after) - not_before = context.get_ca_certs()[0]['notBefore'] - not_before = ssl.cert_time_to_seconds(not_before) - except (KeyError, IndexError): - raise OspdError('CA Certificate is erroneous') - - now = int(time.time()) - if not_after < now: - raise OspdError('CA Certificate expired') - - if not_before > now: - raise OspdError('CA Certificate not active yet') - - class TlsServer(BaseServer): """ Server for accepting TLS encrypted connections via a TCP socket """ @@ -228,10 +250,11 @@ def __init__( cert_file: str, key_file: str, ca_file: str, + stream_timeout: int, ): - super().__init__() - self.address = address - self.port = port + super().__init__(stream_timeout) + self.socket = (address, port) + if not Path(cert_file).exists(): raise OspdError('cert file {} not found'.format(cert_file)) @@ -262,28 +285,13 @@ def __init__( self.tls_context.load_cert_chain(cert_file, keyfile=key_file) self.tls_context.load_verify_locations(ca_file) - def _accept(self) -> Stream: - new_socket, addr = self.socket.accept() - - logger.debug("New connection from" " %s:%s", addr[0], addr[1]) - - ssl_socket = self.tls_context.wrap_socket(new_socket, server_side=True) - - return Stream(ssl_socket) - - def bind(self): - bindsocket = socket.socket() - try: - bindsocket.bind((self.address, self.port)) - except socket.error: - logger.error( - "Couldn't bind socket on %s:%s", self.address, self.port - ) - return None - - logger.info('TLS server listening on %s:%s', self.address, self.port) - - bindsocket.listen(0) - bindsocket.setblocking(False) - - self.socket = bindsocket + def start( + self, + stream_callback: StreamCallbackType, + ): + self.server = start_server( + stream_callback, + self.stream_timeout, + self.socket, + tls_ctx=self.tls_context + )