Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Commit

Permalink
Merge pull request #139 from bjoernricks/master
Browse files Browse the repository at this point in the history
Refactor stream handling
  • Loading branch information
jjnicola authored Sep 30, 2019
2 parents 47acd97 + 50ae334 commit 0674066
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 98 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Add OSP <get_performance> command. [#131](https://github.com/greenbone/ospd/pull/131) [#137](https://github.com/greenbone/ospd/pull/137)
- Add method to check if a target finished cleanly or crashed. [#133](https://github.com/greenbone/ospd/pull/133)
- Add the --stream-timeout option to configure the socket timeout. [#136](https://github.com/greenbone/ospd/pull/136)
- Add support to handle multiple requests simultaneously. [#136](https://github.com/greenbone/ospd/pull/136)
- Add support to handle multiple requests simultaneously.
[#136](https://github.com/greenbone/ospd/pull/136), [#139](https://github.com/greenbone/ospd/pull/139)

### Changed
- Improve documentation.
Expand Down
185 changes: 88 additions & 97 deletions ospd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, Optional, Tuple, Union

from ospd.errors import OspdError

logger = logging.getLogger(__name__)

DEFAULT_BUFSIZE = 1024


class Stream:
def __init__(self, sock: socket.socket, stream_timeout: int):
self.socket = sock
Expand Down Expand Up @@ -72,16 +73,17 @@ def write(self, data: bytes):
try:
b_sent = self.socket.send(data[b_start:b_end])
except socket.error as e:
logger.error(
"Error sending data to the client. %s", e
)
logger.error("Error sending data to the client. %s", e)
return
b_start = b_end
b_end += b_sent


StreamCallbackType = Callable[[Stream], None]

InetAddress = Tuple[str, int]


def validate_cacert_file(cacert: str):
""" Check if provided file is a valid CA Certificate """
try:
Expand All @@ -108,88 +110,24 @@ def validate_cacert_file(cacert: str):
raise OspdError('CA Certificate not active yet')


def start_server(stream_callback, stream_timeout, newsocket, tls_ctx=None):
""" Starts listening and creates a new thread for each new client
connection.
Arguments:
stream_callback (function): Callback function to be called when
a stream is ready
newsocket (path to socket or socket tuple): The tuple with
address and port or the path to the socket for unix domain
sockets.
Returns the created server object.
"""
class ThreadedRequestHandler(socketserver.BaseRequestHandler):
""" Class to handle the request."""
class RequestHandler(socketserver.BaseRequestHandler):
""" Class to handle the request."""

def handle(self):
if tls_ctx:
logger.debug(
"New connection from" " %s:%s", newsocket[0], newsocket[1]
)
req_socket = tls_ctx.wrap_socket(self.request, server_side=True)
else:
req_socket = self.request
logger.debug("New connection from %s", newsocket)

stream = Stream(req_socket, stream_timeout)
stream_callback(stream)

class ThreadedUnixSockServer(
socketserver.ThreadingMixIn,
socketserver.UnixStreamServer,
):
pass

class ThreadedTlsSockServer(
socketserver.ThreadingMixIn,
socketserver.TCPServer,
):
pass

if tls_ctx:
try:
server = ThreadedTlsSockServer(newsocket, ThreadedRequestHandler)
except OSError as e:
logger.error(
"Couldn't bind socket on %s:%s", newsocket[0], newsocket[1]
)
raise OspdError(
"Couldn't bind socket on {}:{}. {}".format(
newsocket[0], str(newsocket[1]), e,
))
else:
try:
server = ThreadedUnixSockServer(
str(newsocket), ThreadedRequestHandler
)
except OSError as e:
logger.error("Couldn't bind socket on %s", str(newsocket))
raise OspdError(
"Couldn't bind socket on {}. {}".format(str(newsocket), e)
)


server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()

return server
def handle(self):
self.server.handle_request(self.request, self.client_address)


class BaseServer(ABC):
def __init__(self, stream_timeout):
def __init__(self, stream_timeout: int):
self.server = None
self.stream_timeout = stream_timeout

@abstractmethod
def start(
self,
stream_callback: StreamCallbackType,
):
def start(self, stream_callback: StreamCallbackType):
""" Starts a server with capabilities to handle multiple client
connections simultaneously.
If a new client connects the stream_callback is called with a Stream
Arguments:
stream_callback (function): Callback function to be called when
a stream is ready
Expand All @@ -200,12 +138,44 @@ def close(self):
self.server.shutdown()
self.server.server_close()

@abstractmethod
def handle_request(self, request, client_address):
""" Handle an incomming client request"""

def _start_threading_server(self):
server_thread = threading.Thread(target=self.server.serve_forever)
server_thread.daemon = True
server_thread.start()


class SocketServerMixin:
def __init__(self, server: BaseServer, address: Union[str, InetAddress]):
self.server = server
super().__init__(address, RequestHandler, bind_and_activate=True)

def handle_request(self, request, client_address):
self.server.handle_request(request, client_address)


class ThreadedUnixSocketServer(
SocketServerMixin,
socketserver.ThreadingMixIn,
socketserver.UnixStreamServer,
):
pass


class ThreadedTlsSocketServer(
SocketServerMixin, socketserver.ThreadingMixIn, socketserver.TCPServer
):
pass


class UnixSocketServer(BaseServer):
""" Server for accepting connections via a Unix domain socket
"""

def __init__(self, socket_path, socket_mode, stream_timeout: int):
def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int):
super().__init__(stream_timeout)
self.socket_path = Path(socket_path)
self.socket_mode = int(socket_mode, 8)
Expand All @@ -219,26 +189,36 @@ def _create_parent_dirs(self):
parent = self.socket_path.parent
parent.mkdir(parents=True, exist_ok=True)

def start(
self,
stream_callback: StreamCallbackType,
):
def start(self, stream_callback: StreamCallbackType):
self._cleanup_socket()
self._create_parent_dirs()

self.server = start_server(
stream_callback,
self.stream_timeout,
self.socket_path
)

if self.socket_path.exists():
os.chmod(str(self.socket_path), self.socket_mode)

try:
self.stream_callback = stream_callback
self.server = ThreadedUnixSocketServer(self, str(self.socket_path))
self._start_threading_server()
except OSError as e:
logger.error("Couldn't bind socket on %s", str(self.socket_path))
raise OspdError(
"Couldn't bind socket on {}. {}".format(
str(self.socket_path), e
)
)

def close(self):
super().close()
self._cleanup_socket()

def handle_request(self, request, client_address):
logger.debug("New connection from %s", str(self.socket_path))

stream = Stream(request, self.stream_timeout)
self.stream_callback(stream)


class TlsServer(BaseServer):
""" Server for accepting TLS encrypted connections via a TCP socket
"""
Expand All @@ -255,7 +235,6 @@ def __init__(
super().__init__(stream_timeout)
self.socket = (address, port)


if not Path(cert_file).exists():
raise OspdError('cert file {} not found'.format(cert_file))

Expand Down Expand Up @@ -285,13 +264,25 @@ def __init__(
self.tls_context.load_cert_chain(cert_file, keyfile=key_file)
self.tls_context.load_verify_locations(ca_file)

def start(
self,
stream_callback: StreamCallbackType,
):
self.server = start_server(
stream_callback,
self.stream_timeout,
self.socket,
tls_ctx=self.tls_context
)
def start(self, stream_callback: StreamCallbackType):
try:
self.stream_callback = stream_callback
self.server = ThreadedTlsSocketServer(self, self.socket)
self._start_threading_server()
except OSError as e:
logger.error(
"Couldn't bind socket on %s:%s", self.socket[0], self.socket[1]
)
raise OspdError(
"Couldn't bind socket on {}:{}. {}".format(
self.socket[0], str(self.socket[1]), e
)
)

def handle_request(self, request, client_address):
logger.debug("New connection from %s", client_address)

req_socket = self.tls_context.wrap_socket(request, server_side=True)

stream = Stream(req_socket, self.stream_timeout)
self.stream_callback(stream)

0 comments on commit 0674066

Please sign in to comment.