Skip to content

Commit

Permalink
Change: Convert GvmConnection into a protocol
Browse files Browse the repository at this point in the history
Improve GvmConnection by being just a protocol and not a specific
implementation. This allows for more flexibility. For example the
DebugConnection is now also a GvmConnection.
  • Loading branch information
bjoernricks authored and greenbonebot committed Jan 29, 2024
1 parent dd53b48 commit 60712a8
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 39 deletions.
38 changes: 26 additions & 12 deletions gvm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import ssl
import sys
import time
from abc import ABC, abstractmethod
from os import PathLike
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Protocol, Union, runtime_checkable

import paramiko
import paramiko.ssh_exception
Expand All @@ -41,6 +42,19 @@
Data = Union[str, bytes]


@runtime_checkable
class GvmConnection(Protocol):
def connect(self) -> None: ...

def disconnect(self) -> None: ...

def send(self, data: Data) -> None: ...

def read(self) -> str: ...

def finish_send(self): ...


class XmlReader:
"""
Read a XML command until its closing element
Expand Down Expand Up @@ -77,7 +91,7 @@ def feed_xml(self, data: Data) -> None:
) from None


class GvmConnection:
class AbstractGvmConnection(ABC):
"""
Base class for establishing a connection to a remote server daemon.
Expand All @@ -97,6 +111,7 @@ def _read(self) -> bytes:

return self._socket.recv(BUF_SIZE)

@abstractmethod
def connect(self) -> None:
"""Establish a connection to a remote server"""
raise NotImplementedError
Expand Down Expand Up @@ -164,7 +179,7 @@ def finish_send(self):
self._socket.shutdown(socketlib.SHUT_WR)


class SSHConnection(GvmConnection):
class SSHConnection(AbstractGvmConnection):
"""
SSH Class to connect, read and write from GVM via SSH
Expand All @@ -174,7 +189,7 @@ class SSHConnection(GvmConnection):
127.0.0.1.
port: Port of the remote SSH server. Default is port 22.
username: Username to use for SSH login. Default is "gmp".
password: Passwort to use for SSH login. Default is "".
password: Password to use for SSH login. Default is "".
"""

def __init__(
Expand All @@ -188,8 +203,7 @@ def __init__(
known_hosts_file: Optional[Union[str, PathLike]] = None,
auto_accept_host: Optional[bool] = None,
) -> None:
super().__init__(timeout=timeout)

super().__init__(timeout)
self.hostname = hostname if hostname is not None else DEFAULT_HOSTNAME
self.port = int(port) if port is not None else DEFAULT_SSH_PORT
self.username = (
Expand Down Expand Up @@ -414,11 +428,11 @@ def connect(self) -> None:
def _read(self) -> bytes:
return self._stdout.channel.recv(BUF_SIZE)

def send(self, data: Union[bytes, str]) -> int:
def send(self, data: Data) -> None:
if isinstance(data, str):
return self._send_all(data.encode())

return self._send_all(data)
self._send_all(data.encode())
else:
self._send_all(data)

def finish_send(self) -> None:
# shutdown socket for sending. only allow reading data afterwards
Expand All @@ -439,7 +453,7 @@ def disconnect(self) -> None:
del self._socket, self._stdin, self._stdout, self._stderr


class TLSConnection(GvmConnection):
class TLSConnection(AbstractGvmConnection):
"""
TLS class to connect, read and write from a remote GVM daemon via TLS
secured socket.
Expand Down Expand Up @@ -524,7 +538,7 @@ def disconnect(self):
return super().disconnect()


class UnixSocketConnection(GvmConnection):
class UnixSocketConnection(AbstractGvmConnection):
"""
UNIX-Socket class to connect, read, write from a daemon via direct
communicating UNIX-Socket
Expand Down
40 changes: 28 additions & 12 deletions tests/connections/test_gvm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import unittest
from unittest.mock import patch

from gvm.connections import DEFAULT_TIMEOUT, GvmConnection, XmlReader
from gvm.connections import (
DEFAULT_TIMEOUT,
AbstractGvmConnection,
DebugConnection,
GvmConnection,
XmlReader,
)
from gvm.errors import GvmError


Expand All @@ -19,39 +25,49 @@ def test_is_end_xml_false(self):
self.assertFalse(false)


class TestConnection(AbstractGvmConnection):
def connect(self) -> None:
pass


class GvmConnectionTestCase(unittest.TestCase):
# pylint: disable=protected-access
def test_init_no_args(self):
connection = GvmConnection()
connection = TestConnection()
self.check_for_default_values(connection)

def test_init_with_none(self):
connection = GvmConnection(timeout=None)
connection = TestConnection(timeout=None)
self.check_for_default_values(connection)

def check_for_default_values(self, gvm_connection: GvmConnection):
self.assertIsNone(gvm_connection._socket)
self.assertEqual(gvm_connection._timeout, DEFAULT_TIMEOUT)

def test_connect_not_implemented(self):
connection = GvmConnection()
with self.assertRaises(NotImplementedError):
connection.connect()

@patch("gvm.connections.GvmConnection._read")
@patch("gvm.connections.AbstractGvmConnection._read")
def test_read_no_data(self, _read_mock):
_read_mock.return_value = None
connection = GvmConnection()
connection = TestConnection()
with self.assertRaises(GvmError, msg="Remote closed the connection"):
connection.read()

@patch("gvm.connections.GvmConnection._read")
@patch("gvm.connections.AbstractGvmConnection._read")
def test_read_trigger_timeout(self, _read_mock):
# mocking the response into two parts, so we run into the timeout
# check in the loop
_read_mock.side_effect = [b"<foo>xyz<bar></bar>", b"</foo>"]
connection = GvmConnection(timeout=0)
connection = TestConnection(timeout=0)
with self.assertRaises(
GvmError, msg="Timeout while reading the response"
):
connection.read()

def test_is_gvm_connection(self):
connection = TestConnection()
self.assertTrue(isinstance(connection, GvmConnection))


class DebugConnectionTestCase(unittest.TestCase):
def test_is_gvm_connection(self):
connection = DebugConnection(TestConnection())
self.assertTrue(isinstance(connection, GvmConnection))
19 changes: 11 additions & 8 deletions tests/connections/test_ssh_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DEFAULT_SSH_PASSWORD,
DEFAULT_SSH_PORT,
DEFAULT_SSH_USERNAME,
GvmConnection,
SSHConnection,
)
from gvm.errors import GvmError
Expand Down Expand Up @@ -177,7 +178,7 @@ def test_connect_adding_and_save_hostkey(self, input_mock, _print_mock):
)

with self.assertLogs("gvm.connections", level="INFO") as cm:
hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
Expand Down Expand Up @@ -229,7 +230,7 @@ def test_connect_adding_and_dont_save_hostkey(
)

with self.assertLogs("gvm.connections", level="INFO") as cm:
hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
Expand Down Expand Up @@ -274,7 +275,7 @@ def test_connect_wrong_input(self, stdout_mock, input_mock):
ssh_connection._socket = paramiko.SSHClient()

with self.assertLogs("gvm.connections", level="INFO") as cm:
hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
Expand Down Expand Up @@ -323,7 +324,7 @@ def test_user_denies_auth(self, input_mock):
with self.assertRaises(
SystemExit, msg="User denied key. Host key verification failed."
):
hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
Expand Down Expand Up @@ -441,8 +442,7 @@ def test_send(self):
)

ssh_connection.connect()
req = ssh_connection.send("blah")
self.assertEqual(req, 4)
ssh_connection.send("blah")
ssh_connection.disconnect()

def test_send_error(self):
Expand Down Expand Up @@ -473,8 +473,7 @@ def test_send_and_slice(self):
)

ssh_connection.connect()
req = ssh_connection.send("blah")
self.assertEqual(req, 4)
ssh_connection.send("blah")

stdin.channel.send.assert_called()
with self.assertRaises(AssertionError):
Expand All @@ -495,3 +494,7 @@ def test_read(self):
recved = ssh_connection._read()
self.assertEqual(recved, b"foo bar baz")
ssh_connection.disconnect()

def test_is_gvm_connection(self):
ssh_connection = SSHConnection(known_hosts_file=self.known_hosts_file)
self.assertTrue(isinstance(ssh_connection, GvmConnection))
5 changes: 5 additions & 0 deletions tests/connections/test_tls_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DEFAULT_GVM_PORT,
DEFAULT_HOSTNAME,
DEFAULT_TIMEOUT,
GvmConnection,
TLSConnection,
)

Expand Down Expand Up @@ -62,3 +63,7 @@ def test_connect_auth(self):
context_mock.load_cert_chain.assert_called_once()
context_mock.wrap_socket.assert_called_once()
self.assertFalse(context_mock.check_hostname)

def test_is_gvm_connection(self):
connection = TLSConnection()
self.assertTrue(isinstance(connection, GvmConnection))
13 changes: 6 additions & 7 deletions tests/connections/test_unix_socket_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gvm.connections import (
DEFAULT_TIMEOUT,
DEFAULT_UNIX_SOCKET_PATH,
GvmConnection,
UnixSocketConnection,
)
from gvm.errors import GvmError
Expand Down Expand Up @@ -65,8 +66,7 @@ def test_unix_socket_connection_connect_send_bytes_read(self):
path=self.socketname, timeout=DEFAULT_TIMEOUT
)
connection.connect()
req = connection.send(bytes("<gmp/>", "utf-8"))
self.assertIsNone(req)
connection.send(bytes("<gmp/>", "utf-8"))
resp = connection.read()
self.assertEqual(resp, '<gmp_response status="200" status_text="OK"/>')
connection.disconnect()
Expand All @@ -76,8 +76,7 @@ def test_unix_socket_connection_connect_send_str_read(self):
path=self.socketname, timeout=DEFAULT_TIMEOUT
)
connection.connect()
req = connection.send("<gmp/>")
self.assertIsNone(req)
connection.send("<gmp/>")
resp = connection.read()
self.assertEqual(resp, '<gmp_response status="200" status_text="OK"/>')
connection.disconnect()
Expand Down Expand Up @@ -120,6 +119,6 @@ def check_default_values(self, connection: UnixSocketConnection):
self.assertEqual(connection._timeout, DEFAULT_TIMEOUT)
self.assertEqual(connection.path, DEFAULT_UNIX_SOCKET_PATH)


if __name__ == "__main__":
unittest.main()
def test_is_gvm_connection(self):
connection = UnixSocketConnection()
self.assertTrue(isinstance(connection, GvmConnection))

0 comments on commit 60712a8

Please sign in to comment.