From 20d029ebe38ba4821194872adc853a10826c4a4b Mon Sep 17 00:00:00 2001 From: Kushal Das Date: Tue, 7 Jan 2020 16:22:01 +0530 Subject: [PATCH] Restructures the code base with more object methods proxy.py Now, Proxy class gets a must conf_path argument, and it creates the inital `conf` attribute from that. Proxy also has default on_save and on_done and err_on_done methods. def handle_response(self) method has a `assert self.res` to mark that res is populated before this call. This is for mypy main.py Mostly has changes from black entrypoint.py No need for fancy dynamic err_call_back, the proxy object will call self.err_call_back if any issue in reading configuration. The test cases now have thier own configuration files to create the proxy object. Also, to do proper dynamic attachment of any method of Proxy class we are using https://docs.python.org/3/library/types.html#types.MethodType so that our own on_save or on_done or err_on_done will be called during tests. --- securedrop_proxy/callbacks.py | 44 --- securedrop_proxy/config.py | 58 ---- securedrop_proxy/entrypoint.py | 52 ++-- securedrop_proxy/main.py | 33 +-- securedrop_proxy/proxy.py | 165 ++++++++++-- tests/files/badgateway-config.yaml | 5 + tests/files/invalid-config.yaml | 5 + tests/files/local-config.yaml | 5 + tests/test_callbacks.py | 129 --------- tests/test_config.py | 96 ------- tests/test_main.py | 96 ++++--- tests/test_proxy.py | 412 +++++++++++++++++++++-------- 12 files changed, 546 insertions(+), 554 deletions(-) delete mode 100644 securedrop_proxy/callbacks.py delete mode 100644 securedrop_proxy/config.py create mode 100644 tests/files/badgateway-config.yaml create mode 100644 tests/files/invalid-config.yaml create mode 100644 tests/files/local-config.yaml delete mode 100644 tests/test_callbacks.py delete mode 100644 tests/test_config.py diff --git a/securedrop_proxy/callbacks.py b/securedrop_proxy/callbacks.py deleted file mode 100644 index 0e890d2..0000000 --- a/securedrop_proxy/callbacks.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -import subprocess -import sys -import json -import tempfile -import uuid - - -def err_on_done(res): - print(json.dumps(res.__dict__)) - sys.exit(1) - - -# callback for handling non-JSON content. in production-like -# environments, we want to call `qvm-move-to-vm` (and expressly not -# `qvm-move`, since we want to include the destination VM name) to -# move the content to the target VM. for development and testing, we -# keep the file on the local VM. -# -# In any case, this callback mutates the given result object (in -# `res`) to include the name of the new file, or to indicate errors. -def on_save(fh, res, conf): - fn = str(uuid.uuid4()) - - try: - with tempfile.TemporaryDirectory() as tmpdir: - tmpfile = os.path.join(os.path.abspath(tmpdir), fn) - subprocess.run(["cp", fh.name, tmpfile]) - if conf.dev is not True: - subprocess.run(["qvm-move-to-vm", conf.target_vm, tmpfile]) - except Exception: - res.status = 500 - res.headers['Content-Type'] = 'application/json' - res.headers['X-Origin-Content-Type'] = res.headers['Content-Type'] - res.body = json.dumps({"error": "Unhandled error while handling non-JSON content, sorry"}) - return - - res.headers['Content-Type'] = 'application/json' - res.headers['X-Origin-Content-Type'] = res.headers['Content-Type'] - res.body = json.dumps({'filename': fn}) - - -def on_done(res): - print(json.dumps(res.__dict__)) diff --git a/securedrop_proxy/config.py b/securedrop_proxy/config.py deleted file mode 100644 index d82cdb4..0000000 --- a/securedrop_proxy/config.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import yaml - - -class Conf: - scheme = '' - host = '' - port = 0 - dev = False - - -def read_conf(conf_path, p): - - if not os.path.isfile(conf_path): - p.simple_error(500, 'Configuration file does not exist at {}'.format(conf_path)) - p.on_done(p.res) - - try: - fh = open(conf_path, 'r') - conf_in = yaml.safe_load(fh) - except yaml.YAMLError: - p.simple_error( - 500, "YAML syntax error while reading configuration file {}".format(conf_path) - ) - p.on_done(p.res) - except Exception: - p.simple_error( - 500, "Error while opening or reading configuration file {}".format(conf_path) - ) - p.on_done(p.res) - - req_conf_keys = set(('host', 'scheme', 'port')) - missing_keys = req_conf_keys - set(conf_in.keys()) - if len(missing_keys) > 0: - p.simple_error(500, 'Configuration file missing required keys: {}'.format(missing_keys)) - p.on_done(p.res) - - c = Conf() - c.host = conf_in['host'] - c.scheme = conf_in['scheme'] - c.port = conf_in['port'] - - if 'dev' in conf_in and conf_in['dev'] is True: - c.dev = True - else: - if "target_vm" not in conf_in: - p.simple_error( - 500, - ( - "Configuration file missing `target_vm` key, which is required " - "when not in development mode" - ), - ) - p.on_done(p.res) - - c.target_vm = conf_in['target_vm'] - - return c diff --git a/securedrop_proxy/entrypoint.py b/securedrop_proxy/entrypoint.py index fb0a8c1..cee130f 100755 --- a/securedrop_proxy/entrypoint.py +++ b/securedrop_proxy/entrypoint.py @@ -14,76 +14,68 @@ from logging.handlers import TimedRotatingFileHandler -from securedrop_proxy import callbacks -from securedrop_proxy import config from securedrop_proxy import main from securedrop_proxy import proxy from securedrop_proxy.version import version DEFAULT_HOME = os.path.join(os.path.expanduser("~"), ".securedrop_proxy") -LOGLEVEL = os.environ.get('LOGLEVEL', 'info').upper() +LOGLEVEL = os.environ.get("LOGLEVEL", "info").upper() -def start(): - ''' +def start() -> None: + """ Set up a new proxy object with an error handler, configuration that we read from argv[1], and the original user request from STDIN. - ''' + """ try: configure_logging() - logging.debug('Starting SecureDrop Proxy {}'.format(version)) - - # a fresh, new proxy object - p = proxy.Proxy() - - # set up an error handler early, so we can use it during - # configuration, etc - original_on_done = p.on_done - p.on_done = callbacks.err_on_done + logging.debug("Starting SecureDrop Proxy {}".format(version)) # path to config file must be at argv[1] if len(sys.argv) != 2: - raise ValueError("sd-proxy script not called with path to configuration file") + raise ValueError( + "sd-proxy script not called with path to configuration file" + ) - # read config. `read_conf` will call `p.on_done` if there is a config + # read config. `read_conf` will call `p.err_on_done` if there is a config # problem, and will return a Conf object on success. conf_path = sys.argv[1] - p.conf = config.read_conf(conf_path, p) + # a fresh, new proxy object + p = proxy.Proxy(conf_path=conf_path) # read user request from STDIN - incoming = [] + incoming_lines = [] for line in sys.stdin: - incoming.append(line) - incoming = "\n".join(incoming) + incoming_lines.append(line) + incoming = "\n".join(incoming_lines) - p.on_done = original_on_done main.__main__(incoming, p) except Exception as e: response = { "status": http.HTTPStatus.INTERNAL_SERVER_ERROR, - "body": json.dumps({ - "error": str(e), - }) + "body": json.dumps({"error": str(e)}), } print(json.dumps(response)) sys.exit(1) def configure_logging() -> None: - ''' + """ All logging related settings are set up by this function. - ''' + """ home = os.getenv("SECUREDROP_HOME", DEFAULT_HOME) - log_folder = os.path.join(home, 'logs') + log_folder = os.path.join(home, "logs") if not os.path.exists(log_folder): os.makedirs(log_folder) - log_file = os.path.join(home, 'logs', 'proxy.log') + log_file = os.path.join(home, "logs", "proxy.log") # set logging format - log_fmt = ('%(asctime)s - %(name)s:%(lineno)d(%(funcName)s) %(levelname)s: %(message)s') + log_fmt = ( + "%(asctime)s - %(name)s:%(lineno)d(%(funcName)s) %(levelname)s: %(message)s" + ) formatter = logging.Formatter(log_fmt) # define log handlers such as for rotating log files diff --git a/securedrop_proxy/main.py b/securedrop_proxy/main.py index e67f158..69abf48 100644 --- a/securedrop_proxy/main.py +++ b/securedrop_proxy/main.py @@ -1,44 +1,45 @@ import json import logging +from typing import Dict, Any -from securedrop_proxy import callbacks from securedrop_proxy import proxy +from securedrop_proxy.proxy import Proxy + logger = logging.getLogger(__name__) -def __main__(incoming, p): - ''' +def __main__(incoming: str, p: Proxy) -> None: + """ Deserialize incoming request in order to build and send a proxy request. - ''' - logging.debug('Creating request to be sent by proxy') + """ + logging.debug("Creating request to be sent by proxy") - client_req = None + client_req: Dict[str, Any] = {} try: client_req = json.loads(incoming) except json.decoder.JSONDecodeError as e: logging.error(e) - p.simple_error(400, 'Invalid JSON in request') - p.on_done(p.res) + p.simple_error(400, "Invalid JSON in request") + p.on_done() return req = proxy.Req() try: - req.method = client_req['method'] - req.path_query = client_req['path_query'] + req.method = client_req["method"] + req.path_query = client_req["path_query"] except KeyError as e: logging.error(e) - p.simple_error(400, 'Missing keys in request') - p.on_done(p.res) + p.simple_error(400, "Missing keys in request") + p.on_done() if "headers" in client_req: - req.headers = client_req['headers'] + req.headers = client_req["headers"] if "body" in client_req: - req.body = client_req['body'] + req.body = client_req["body"] p.req = req - if not p.on_save: - p.on_save = callbacks.on_save + p.proxy() diff --git a/securedrop_proxy/proxy.py b/securedrop_proxy/proxy.py index 733d2ed..ab29554 100644 --- a/securedrop_proxy/proxy.py +++ b/securedrop_proxy/proxy.py @@ -6,53 +6,158 @@ import tempfile import werkzeug +import os +import subprocess +import sys +import uuid +import yaml +from typing import Dict, Optional + import securedrop_proxy.version as version -from securedrop_proxy import callbacks +from tempfile import _TemporaryFileWrapper # type: ignore logger = logging.getLogger(__name__) +class Conf: + scheme = "" + host = "" + port = 0 + dev = False + target_vm = "" + + class Req: - def __init__(self): + def __init__(self) -> None: self.method = "" self.path_query = "" - self.body = None - self.headers = {} + self.body = "" + self.headers: Dict[str, str] = {} class Response: - def __init__(self, status): + def __init__(self, status: int) -> None: self.status = status - self.body = None - self.headers = {} + self.body = "" + self.headers: Dict[str, str] = {} self.version = version.version class Proxy: - def __init__(self, conf=None, req=Req(), on_save=None, on_done=None, timeout: float = None): - self.conf = conf - self.req = req - self.res = None - self.on_save = on_save - if on_done: - self.on_done = on_done + def __init__( + self, conf_path: str, req: Req = Req(), timeout: float = None, + ) -> None: + # The configuration path for Proxy is a must. + self.read_conf(conf_path) + self.req = req + self.res: Optional[Response] = None self.timeout = float(timeout) if timeout else 10 self._prepared_request = None - def on_done(self, res): # type: ignore - callbacks.on_done(res) + def on_done(self) -> None: + print(json.dumps(self.res.__dict__)) @staticmethod - def valid_path(path): + def valid_path(path: str) -> bool: u = furl.furl(path) if u.host is not None: return False return True + def err_on_done(self): + print(json.dumps(self.res.__dict__)) + sys.exit(1) + + def read_conf(self, conf_path: str) -> None: + + if not os.path.isfile(conf_path): + self.simple_error( + 500, "Configuration file does not exist at {}".format(conf_path) + ) + self.err_on_done() + + try: + with open(conf_path) as fh: + conf_in = yaml.safe_load(fh) + except yaml.YAMLError: + self.simple_error( + 500, + "YAML syntax error while reading configuration file {}".format( + conf_path + ), + ) + self.err_on_done() + except Exception: + self.simple_error( + 500, + "Error while opening or reading configuration file {}".format( + conf_path + ), + ) + self.err_on_done() + + req_conf_keys = set(("host", "scheme", "port")) + missing_keys = req_conf_keys - set(conf_in.keys()) + if len(missing_keys) > 0: + self.simple_error( + 500, "Configuration file missing required keys: {}".format(missing_keys) + ) + self.err_on_done() + + self.conf = Conf() + self.conf.host = conf_in["host"] + self.conf.scheme = conf_in["scheme"] + self.conf.port = conf_in["port"] + + if "dev" in conf_in and conf_in["dev"]: + self.conf.dev = True + else: + if "target_vm" not in conf_in: + self.simple_error( + 500, + ( + "Configuration file missing `target_vm` key, which is required " + "when not in development mode" + ), + ) + self.err_on_done() + + self.conf.target_vm = conf_in["target_vm"] + + # callback for handling non-JSON content. in production-like + # environments, we want to call `qvm-move-to-vm` (and expressly not + # `qvm-move`, since we want to include the destination VM name) to + # move the content to the target VM. for development and testing, we + # keep the file on the local VM. + # + # In any case, this callback mutates the given result object (in + # `res`) to include the name of the new file, or to indicate errors. + def on_save(self, fh: _TemporaryFileWrapper, res: Response) -> None: + fn = str(uuid.uuid4()) + + try: + with tempfile.TemporaryDirectory() as tmpdir: + tmpfile = os.path.join(os.path.abspath(tmpdir), fn) + subprocess.run(["cp", fh.name, tmpfile]) + if self.conf.dev is not True: + subprocess.run(["qvm-move-to-vm", self.conf.target_vm, tmpfile]) + except Exception: + res.status = 500 + res.headers["Content-Type"] = "application/json" + res.headers["X-Origin-Content-Type"] = res.headers["Content-Type"] + res.body = json.dumps( + {"error": "Unhandled error while handling non-JSON content, sorry"} + ) + return + + res.headers["Content-Type"] = "application/json" + res.headers["X-Origin-Content-Type"] = res.headers["Content-Type"] + res.body = json.dumps({"filename": fn}) + def simple_error(self, status, err): res = Response(status) res.body = json.dumps({"error": err}) @@ -60,7 +165,7 @@ def simple_error(self, status, err): self.res = res - def prep_request(self): + def prep_request(self) -> None: scheme = self.conf.scheme host = self.conf.host @@ -83,14 +188,13 @@ def prep_request(self): url.path.normalize() preq = requests.Request(method, url.url) - preq.stream = True preq.headers = self.req.headers preq.data = self.req.body prep = preq.prepare() self._prepared_request = prep - def handle_json_response(self): + def handle_json_response(self) -> None: res = Response(self._presp.status_code) @@ -114,11 +218,11 @@ def handle_non_json_response(self): res.headers = self._presp.headers - self.on_save(fh, res, self.conf) + self.on_save(fh, res) self.res = res - def handle_response(self): + def handle_response(self) -> None: logger.debug("Handling response") ctype = werkzeug.http.parse_options_header(self._presp.headers["content-type"]) @@ -128,11 +232,14 @@ def handle_response(self): else: self.handle_non_json_response() + # https://mypy.readthedocs.io/en/latest/kinds_of_types.html#union-types + # To make sure that mypy knows the type of self.res is not None. + assert self.res # headers is a Requests class which doesn't JSON serialize. # coerce it into a normal dict so it will self.res.headers = dict(self.res.headers) - def proxy(self): + def proxy(self) -> None: try: if not self.on_save: @@ -162,13 +269,15 @@ def proxy(self): requests.exceptions.TooManyRedirects, ) as e: logger.error(e) - self.simple_error(http.HTTPStatus.BAD_GATEWAY, "could not connect to server") + self.simple_error( + http.HTTPStatus.BAD_GATEWAY, "could not connect to server" + ) except requests.exceptions.HTTPError as e: logger.error(e) try: self.simple_error( e.response.status_code, - http.HTTPStatus(e.response.status_code).phrase.lower() + http.HTTPStatus(e.response.status_code).phrase.lower(), ) except ValueError: # Return a generic error message when the response @@ -176,5 +285,7 @@ def proxy(self): self.simple_error(e.response.status_code, "unspecified server error") except Exception as e: logger.error(e) - self.simple_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, "internal proxy error") - self.on_done(self.res) + self.simple_error( + http.HTTPStatus.INTERNAL_SERVER_ERROR, "internal proxy error" + ) + self.on_done() diff --git a/tests/files/badgateway-config.yaml b/tests/files/badgateway-config.yaml new file mode 100644 index 0000000..8939644 --- /dev/null +++ b/tests/files/badgateway-config.yaml @@ -0,0 +1,5 @@ +host: sdproxytest.local +scheme: https +port: 8000 +target_vm: compost +dev: False diff --git a/tests/files/invalid-config.yaml b/tests/files/invalid-config.yaml new file mode 100644 index 0000000..1338eef --- /dev/null +++ b/tests/files/invalid-config.yaml @@ -0,0 +1,5 @@ +host: jsonplaceholder.typicode.com +scheme: https://http +port: 443 +target_vm: compost +dev: False diff --git a/tests/files/local-config.yaml b/tests/files/local-config.yaml new file mode 100644 index 0000000..7bd20fc --- /dev/null +++ b/tests/files/local-config.yaml @@ -0,0 +1,5 @@ +host: localhost +scheme: http +port: 8000 +target_vm: compost +dev: False diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py deleted file mode 100644 index 83c861b..0000000 --- a/tests/test_callbacks.py +++ /dev/null @@ -1,129 +0,0 @@ -from io import StringIO -import json -import sys -import tempfile -import unittest -from unittest.mock import patch - -import vcr - -from securedrop_proxy import callbacks -from securedrop_proxy import config -from securedrop_proxy import proxy - - -class TestCallbacks(unittest.TestCase): - def setUp(self): - self.res = proxy.Response(status=200) - self.res.body = "babbys request" - - self.conf = config.Conf() - self.conf.host = 'jsonplaceholder.typicode.com' - self.conf.scheme = 'https' - self.conf.port = 443 - self.conf.dev = True - - def test_err_on_done(self): - saved_stdout = sys.stdout - try: - out = StringIO() - sys.stdout = out - with self.assertRaises(SystemExit): - callbacks.err_on_done(self.res) - output = out.getvalue().strip() - finally: - sys.stdout = saved_stdout - - response = json.loads(output) - self.assertEqual(response['status'], 200) - self.assertEqual(response['body'], 'babbys request') - - def test_on_done(self): - saved_stdout = sys.stdout - try: - out = StringIO() - sys.stdout = out - callbacks.on_done(self.res) - output = out.getvalue().strip() - finally: - sys.stdout = saved_stdout - - response = json.loads(output) - self.assertEqual(response['status'], 200) - self.assertEqual(response['body'], 'babbys request') - - def test_on_save_500_unhandled_error(self): - fh = tempfile.NamedTemporaryFile() - - # Let's generate an error and ensure that an appropriate response - # is sent back to the user - with patch("subprocess.run", side_effect=IOError): - callbacks.on_save(fh, self.res, self.conf) - - self.assertEqual(self.res.status, 500) - self.assertEqual(self.res.headers['Content-Type'], - 'application/json') - self.assertEqual(self.res.headers['X-Origin-Content-Type'], - 'application/json') - self.assertIn('Unhandled error', self.res.body) - - def test_on_save_200_success(self): - fh = tempfile.NamedTemporaryFile() - - callbacks.on_save(fh, self.res, self.conf) - - self.assertEqual(self.res.headers['Content-Type'], - 'application/json') - self.assertEqual(self.res.headers['X-Origin-Content-Type'], - 'application/json') - self.assertEqual(self.res.status, 200) - self.assertIn('filename', self.res.body) - - @vcr.use_cassette("fixtures/proxy_callbacks.yaml") - def test_custom_callbacks(self): - """ - Test the handlers in a real proxy request. - """ - conf = config.Conf() - conf.host = 'jsonplaceholder.typicode.com' - conf.scheme = 'https' - conf.port = 443 - - req = proxy.Req() - req.method = "GET" - - on_save_addition = "added by the on_save callback\n" - on_done_addition = "added by the on_done callback\n" - - def on_save(fh, res, conf): - res.headers['Content-Type'] = 'text/plain' - res.body = on_save_addition - - def on_done(res): - res.headers['Content-Type'] = 'text/plain' - res.body += on_done_addition - - p = proxy.Proxy(self.conf, req, on_save=on_save, on_done=on_done) - p.proxy() - - self.assertEqual( - p.res.body, - "{}{}".format(on_save_addition, on_done_addition) - ) - - @vcr.use_cassette("fixtures/proxy_callbacks.yaml") - def test_production_on_save(self): - """ - Test on_save's production file handling. - """ - conf = config.Conf() - conf.host = 'jsonplaceholder.typicode.com' - conf.scheme = 'https' - conf.port = 443 - conf.dev = False - conf.target_vm = "sd-svs-dispvm" - - with patch("subprocess.run") as patched_run: - fh = tempfile.NamedTemporaryFile() - callbacks.on_save(fh, self.res, conf) - self.assertEqual(patched_run.call_args[0][0][0], "qvm-move-to-vm") diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 0ae93ea..0000000 --- a/tests/test_config.py +++ /dev/null @@ -1,96 +0,0 @@ -import sys -import unittest -from unittest.mock import patch - -from securedrop_proxy import proxy -from securedrop_proxy import config - - -class TestConfig(unittest.TestCase): - def setUp(self): - self.p = proxy.Proxy() - - def test_config_file_does_not_exist(self): - def err_on_done(res): - res = res.__dict__ - self.assertEqual(res['status'], 500) - self.assertIn("Configuration file does not exist", - res['body']) - self.assertEqual(res['headers']['Content-Type'], - 'application/json') - sys.exit(1) - - self.p.on_done = err_on_done - with self.assertRaises(SystemExit): - config.read_conf('not/a/real/path', self.p) - - def test_config_file_when_yaml_is_invalid(self): - def err_on_done(res): - res = res.__dict__ - self.assertEqual(res['status'], 500) - self.assertIn("YAML syntax error", res['body']) - self.assertEqual(res['headers']['Content-Type'], - 'application/json') - sys.exit(1) - - self.p.on_done = err_on_done - with self.assertRaises(SystemExit): - config.read_conf('tests/files/invalid_yaml.yaml', self.p) - - def test_config_file_open_generic_exception(self): - def err_on_done(res): - res = res.__dict__ - self.assertEqual(res['status'], 500) - self.assertEqual(res['headers']['Content-Type'], - 'application/json') - sys.exit(1) - - self.p.on_done = err_on_done - - with self.assertRaises(SystemExit): - # Patching open so that we can simulate a non-YAML error - # (e.g. permissions) - with patch("builtins.open", side_effect=IOError): - config.read_conf('tests/files/valid-config.yaml', self.p) - - def test_config_has_valid_keys(self): - c = config.read_conf('tests/files/valid-config.yaml', self.p) - - # Verify we have a valid Conf object - self.assertEqual(c.host, 'jsonplaceholder.typicode.com') - self.assertEqual(c.port, 443) - self.assertFalse(c.dev) - self.assertEqual(c.scheme, 'https') - self.assertEqual(c.target_vm, 'compost') - - def test_config_500_when_missing_a_required_key(self): - def err_on_done(res): - res = res.__dict__ - self.assertEqual(res['status'], 500) - self.assertIn("missing required keys", res['body']) - self.assertEqual(res['headers']['Content-Type'], - 'application/json') - sys.exit(1) - - self.p.on_done = err_on_done - - with self.assertRaises(SystemExit): - config.read_conf('tests/files/missing-key.yaml', self.p) - - def test_config_500_when_missing_target_vm(self): - def err_on_done(res): - res = res.__dict__ - self.assertEqual(res['status'], 500) - self.assertIn("missing `target_vm` key", res['body']) - self.assertEqual(res['headers']['Content-Type'], - 'application/json') - sys.exit(1) - - self.p.on_done = err_on_done - - with self.assertRaises(SystemExit): - config.read_conf('tests/files/missing-target-vm.yaml', self.p) - - def test_dev_config(self): - c = config.read_conf('tests/files/dev-config.yaml', self.p) - self.assertTrue(c.dev) diff --git a/tests/test_main.py b/tests/test_main.py index c4e10e2..19bc4e5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,21 +5,17 @@ import sys import unittest import uuid +import types import vcr -from securedrop_proxy import config from securedrop_proxy import main from securedrop_proxy import proxy class TestMain(unittest.TestCase): def setUp(self): - self.conf = config.Conf() - self.conf.host = "jsonplaceholder.typicode.com" - self.conf.scheme = "https" - self.conf.port = 443 - self.conf.dev = True + self.conf_path = "tests/files/valid-config.yaml" @vcr.use_cassette("fixtures/main_json_response.yaml") def test_json_response(self): @@ -32,14 +28,19 @@ def test_json_response(self): req.headers = {"Accept": "application/json"} # Use custom callbacks - def on_save(res, fh, conf): + def on_save(self, fh, res): pass - def on_done(res): - self.assertEqual(res.status, http.HTTPStatus.OK) - print(json.dumps(res.__dict__)) + def on_done(self): + assert self.res.status == http.HTTPStatus.OK + print(json.dumps(self.res.__dict__)) - self.p = proxy.Proxy(self.conf, req, on_save, on_done) + self.p = proxy.Proxy(self.conf_path, req) + + # Patching on_save and on_done + + self.p.on_done = types.MethodType(on_done, self.p) + self.p.on_save = types.MethodType(on_save, self.p) saved_stdout = sys.stdout try: @@ -59,8 +60,7 @@ def test_non_json_response(self): test_input_json = """{ "method": "GET", "path_query": "" }""" - def on_save(fh, res, conf): - self.fn = str(uuid.uuid4()) + def on_save(self, fh, res): subprocess.run(["cp", fh.name, "/tmp/{}".format(self.fn)]) @@ -68,7 +68,11 @@ def on_save(fh, res, conf): res.headers["Content-Type"] = "application/json" res.body = json.dumps({"filename": self.fn}) - self.p = proxy.Proxy(self.conf, proxy.Req(), on_save) + self.p = proxy.Proxy(self.conf_path, proxy.Req()) + + # Patching on_save to tests + self.p.on_save = types.MethodType(on_save, self.p) + self.p.fn = str(uuid.uuid4()) saved_stdout = sys.stdout try: @@ -86,7 +90,7 @@ def on_save(fh, res, conf): self.assertIn("filename", response["body"]) # The file should not be empty - with open("/tmp/{}".format(self.fn)) as f: + with open("/tmp/{}".format(self.p.fn)) as f: saved_file = f.read() # We expect HTML content in the file from the test data @@ -95,15 +99,20 @@ def on_save(fh, res, conf): def test_input_invalid_json(self): test_input_json = """"foo": "bar", "baz": "bliff" }""" - def on_save(fh, res, conf): + def on_save(self, fh, res): pass - def on_done(res): - res = res.__dict__ - self.assertEqual(res["status"], 400) + def on_done(self): + res = self.res.__dict__ + assert res["status"] == 400 sys.exit(1) - p = proxy.Proxy(self.conf, proxy.Req(), on_save, on_done) + p = proxy.Proxy(self.conf_path, proxy.Req()) + + # patching on_save and on_done for tests + + p.on_done = types.MethodType(on_done, p) + p.on_save = types.MethodType(on_save, p) with self.assertRaises(SystemExit): main.__main__(test_input_json, p) @@ -111,16 +120,22 @@ def on_done(res): def test_input_missing_keys(self): test_input_json = """{ "foo": "bar", "baz": "bliff" }""" - def on_save(fh, res, conf): + def on_save(self, fh, res): pass - def on_done(res): - res = res.__dict__ - self.assertEqual(res["status"], 400) - self.assertEqual(res["body"], '{"error": "Missing keys in request"}') + def on_done(self): + res = self.res.__dict__ + assert res["status"] == 400 + assert res["body"] == '{"error": "Missing keys in request"}', res sys.exit(1) - p = proxy.Proxy(self.conf, proxy.Req(), on_save, on_done) + p = proxy.Proxy(self.conf_path, proxy.Req()) + + # patching on_save and on_done for tests + + p.on_done = types.MethodType(on_done, p) + p.on_save = types.MethodType(on_save, p) + with self.assertRaises(SystemExit): main.__main__(test_input_json, p) @@ -132,10 +147,10 @@ def test_input_headers(self): "headers": {"X-Test-Header": "th"}, } - def on_save(fh, res, conf): + def on_save(self, fh, res): pass - p = proxy.Proxy(self.conf, proxy.Req(), on_save) + p = proxy.Proxy(self.conf_path, proxy.Req()) main.__main__(json.dumps(test_input), p) self.assertEqual(p.req.headers, test_input["headers"]) @@ -147,10 +162,15 @@ def test_input_body(self): "body": {"id": 42, "title": "test"}, } - def on_save(fh, res, conf): + def on_save(self, fh, res): pass - p = proxy.Proxy(self.conf, proxy.Req(), on_save) + p = proxy.Proxy(self.conf_path, proxy.Req()) + + # Patching on_save for tests + + p.on_save = types.MethodType(on_save, p) + main.__main__(json.dumps(test_input), p) self.assertEqual(p.req.body, test_input["body"]) @@ -161,12 +181,10 @@ def test_default_callbacks(self): "path_query": "", } - p = proxy.Proxy(self.conf, proxy.Req()) - with unittest.mock.patch( - "securedrop_proxy.callbacks.on_done" - ) as on_done, unittest.mock.patch( - "securedrop_proxy.callbacks.on_save" - ) as on_save: - main.__main__(json.dumps(test_input), p) - self.assertEqual(on_save.call_count, 1) - self.assertEqual(on_done.call_count, 1) + p = proxy.Proxy(self.conf_path, proxy.Req()) + p.on_done = unittest.mock.MagicMock() + p.on_save = unittest.mock.MagicMock() + + main.__main__(json.dumps(test_input), p) + self.assertEqual(p.on_save.call_count, 1) + self.assertEqual(p.on_done.call_count, 1) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index bc73b05..410a073 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -1,184 +1,158 @@ +import sys import http import json import unittest import uuid +import types +from io import StringIO +import tempfile +from unittest.mock import patch import requests import vcr -from securedrop_proxy import callbacks from securedrop_proxy import proxy -from securedrop_proxy import config from securedrop_proxy import version class TestProxyValidConfig(unittest.TestCase): def setUp(self): - self.conf = config.Conf() - self.conf.host = 'jsonplaceholder.typicode.com' - self.conf.scheme = 'https' - self.conf.port = 443 + self.conf_path = "tests/files/valid-config.yaml" - def on_save(self, fh, res, conf): - self.fn = str(uuid.uuid4()) - res.headers['X-Origin-Content-Type'] = res.headers['Content-Type'] - res.headers['Content-Type'] = 'application/json' - res.body = json.dumps({'filename': self.fn}) + def on_save(self, fh, res): + res.headers["X-Origin-Content-Type"] = res.headers["Content-Type"] + res.headers["Content-Type"] = "application/json" + res.body = json.dumps({"filename": self.fn}) def on_done(self, res): - res.headers['X-Origin-Content-Type'] = res.headers['Content-Type'] - res.headers['Content-Type'] = 'application/json' + res.headers["X-Origin-Content-Type"] = res.headers["Content-Type"] + res.headers["Content-Type"] = "application/json" def test_version(self): req = proxy.Req() - req.method = 'GET' - req.path_query = '' - req.headers = {'Accept': 'application/json'} + req.method = "GET" + req.path_query = "" + req.headers = {"Accept": "application/json"} - p = proxy.Proxy() + p = proxy.Proxy(self.conf_path) p.proxy() self.assertEqual(p.res.version, version.version) - def test_400_if_callback_not_set(self): - req = proxy.Req() - req.method = 'GET' - req.path_query = '' - req.headers = {'Accept': 'application/json'} - - p = proxy.Proxy() - p.proxy() - - self.assertEqual(p.res.status, 400) - - @vcr.use_cassette('fixtures/basic_proxy_functionality.yaml') + @vcr.use_cassette("fixtures/basic_proxy_functionality.yaml") def test_proxy_basic_functionality(self): req = proxy.Req() - req.method = 'GET' - req.path_query = '' - req.headers = {'Accept': 'application/json'} - - p = proxy.Proxy(self.conf, req, self.on_save) + req.method = "GET" + req.path_query = "" + req.headers = {"Accept": "application/json"} + + def on_save(self, fh, res): + res.headers["X-Origin-Content-Type"] = res.headers["Content-Type"] + res.headers["Content-Type"] = "application/json" + res.body = json.dumps({"filename": self.fn}) + + p = proxy.Proxy(self.conf_path, req) + # Patching on_save for test + p.on_save = types.MethodType(on_save, p) + p.fn = str(uuid.uuid4()) p.proxy() self.assertEqual(p.res.status, 200) - self.assertEqual(p.res.body, json.dumps({'filename': self.fn})) - self.assertEqual(p.res.headers['Content-Type'], 'application/json') + self.assertEqual(p.res.body, json.dumps({"filename": p.fn})) + self.assertEqual(p.res.headers["Content-Type"], "application/json") - @vcr.use_cassette('fixtures/proxy_404.yaml') + @vcr.use_cassette("fixtures/proxy_404.yaml") def test_proxy_produces_404(self): req = proxy.Req() - req.method = 'GET' - req.path_query = '/notfound' - req.headers = {'Accept': 'application/json'} + req.method = "GET" + req.path_query = "/notfound" + req.headers = {"Accept": "application/json"} + + p = proxy.Proxy(self.conf_path, req) - p = proxy.Proxy(self.conf, req) - p.on_save = self.on_save - p.on_done = self.on_done p.proxy() self.assertEqual(p.res.status, 404) - self.assertEqual(p.res.headers['Content-Type'], 'application/json') + self.assertEqual(p.res.headers["Content-Type"], "application/json") - @vcr.use_cassette('fixtures/proxy_parameters.yaml') + @vcr.use_cassette("fixtures/proxy_parameters.yaml") def test_proxy_handles_query_params_gracefully(self): req = proxy.Req() - req.method = 'GET' - req.path_query = '/posts?userId=1' - req.headers = {'Accept': 'application/json'} + req.method = "GET" + req.path_query = "/posts?userId=1" + req.headers = {"Accept": "application/json"} + + p = proxy.Proxy(self.conf_path, req) - p = proxy.Proxy(self.conf, req, self.on_save) p.proxy() self.assertEqual(p.res.status, 200) - self.assertIn('application/json', p.res.headers['Content-Type']) + self.assertIn("application/json", p.res.headers["Content-Type"]) body = json.loads(p.res.body) for item in body: - self.assertEqual(item['userId'], 1) + self.assertEqual(item["userId"], 1) # No cassette needed as no network request should be sent def test_proxy_400_bad_path(self): req = proxy.Req() - req.method = 'GET' - req.path_query = 'http://badpath.lol/path' - req.headers = {'Accept': 'application/json'} + req.method = "GET" + req.path_query = "http://badpath.lol/path" + req.headers = {"Accept": "application/json"} + + p = proxy.Proxy(self.conf_path, req) - p = proxy.Proxy(self.conf, req) - p.on_save = self.on_save - p.on_done = self.on_done p.proxy() self.assertEqual(p.res.status, 400) - self.assertEqual(p.res.headers['Content-Type'], 'application/json') - self.assertIn('Path provided in request did not look valid', - p.res.body) + self.assertEqual(p.res.headers["Content-Type"], "application/json") + self.assertIn("Path provided in request did not look valid", p.res.body) - @vcr.use_cassette('fixtures/proxy_200_valid_path.yaml') + @vcr.use_cassette("fixtures/proxy_200_valid_path.yaml") def test_proxy_200_valid_path(self): req = proxy.Req() - req.method = 'GET' - req.path_query = '/posts/1' - req.headers = {'Accept': 'application/json'} + req.method = "GET" + req.path_query = "/posts/1" + req.headers = {"Accept": "application/json"} + + p = proxy.Proxy(self.conf_path, req) - p = proxy.Proxy(self.conf, req, self.on_save) p.proxy() self.assertEqual(p.res.status, 200) - self.assertIn('application/json', p.res.headers['Content-Type']) + self.assertIn("application/json", p.res.headers["Content-Type"]) body = json.loads(p.res.body) - self.assertEqual(body['userId'], 1) - - # No cassette needed as no network request should be sent - def test_proxy_400_no_handler(self): - req = proxy.Req() - req.method = 'GET' - req.path_query = 'http://badpath.lol/path' - req.headers = {'Accept': 'application/json'} - - p = proxy.Proxy(self.conf, req) - p.proxy() - - self.assertEqual(p.res.status, 400) - self.assertEqual(p.res.headers['Content-Type'], 'application/json') - self.assertIn('Request on_save callback is not set', - p.res.body) + self.assertEqual(body["userId"], 1) class TestProxyInvalidConfig(unittest.TestCase): def setUp(self): - self.conf = config.Conf() - self.conf.host = 'jsonplaceholder.typicode.com' - self.conf.scheme = 'https://http' # bad - self.conf.port = 443 + self.conf_path = "tests/files/invalid-config.yaml" - def on_save(self, fh, res, conf): + def on_save(self, fh, res): self.fn = str(uuid.uuid4()) - res.headers['X-Origin-Content-Type'] = res.headers['Content-Type'] - res.headers['Content-Type'] = 'application/json' - res.body = json.dumps({'filename': self.fn}) + res.headers["X-Origin-Content-Type"] = res.headers["Content-Type"] + res.headers["Content-Type"] = "application/json" + res.body = json.dumps({"filename": self.fn}) # No cassette needed as no network request should be sent def test_proxy_500_misconfiguration(self): req = proxy.Req() - req.method = 'GET' - req.path_query = '/posts/1' - req.headers = {'Accept': 'application/json'} + req.method = "GET" + req.path_query = "/posts/1" + req.headers = {"Accept": "application/json"} + + p = proxy.Proxy(self.conf_path, req) - p = proxy.Proxy(self.conf, req, self.on_save) p.proxy() self.assertEqual(p.res.status, 500) - self.assertEqual(p.res.headers['Content-Type'], 'application/json') - self.assertIn('Proxy error while generating URL to request', - p.res.body) + self.assertEqual(p.res.headers["Content-Type"], "application/json") + self.assertIn("Proxy error while generating URL to request", p.res.body) class TestServerErrorHandling(unittest.TestCase): def setUp(self): - self.conf = config.Conf() - self.conf.host = "localhost" - self.conf.scheme = "http" - self.conf.port = 8000 + self.conf_path = "tests/files/local-config.yaml" def make_request(self, method="GET", path_query="/", headers=None): req = proxy.Req() @@ -193,12 +167,7 @@ def test_cannot_connect(self): """ req = self.make_request() - conf = config.Conf() - conf.host = "sdproxytest.local" - conf.scheme = "https" - conf.port = 8000 - - p = proxy.Proxy(conf, req, on_save=callbacks.on_save) + p = proxy.Proxy("tests/files/badgateway-config.yaml", req) p.proxy() self.assertEqual(p.res.status, http.HTTPStatus.BAD_GATEWAY) @@ -210,6 +179,7 @@ def test_server_timeout(self): """ Test for "504 Gateway Timeout" when the server times out. """ + class TimeoutProxy(proxy.Proxy): """ Mocks a slow upstream server. @@ -218,11 +188,12 @@ class TimeoutProxy(proxy.Proxy): long. This Proxy subclass raises the exception that would cause. """ + def prep_request(self): - raise requests.exceptions.Timeout('test timeout') + raise requests.exceptions.Timeout("test timeout") req = self.make_request(path_query="/tarpit") - p = TimeoutProxy(self.conf, req, on_save=callbacks.on_save, timeout=0.00001) + p = TimeoutProxy(self.conf_path, req, timeout=0.00001) p.proxy() self.assertEqual(p.res.status, http.HTTPStatus.GATEWAY_TIMEOUT) @@ -236,7 +207,7 @@ def test_bad_request(self): Test handling of "400 Bad Request" from the server. """ req = self.make_request(path_query="/bad") - p = proxy.Proxy(self.conf, req, on_save=callbacks.on_save) + p = proxy.Proxy(self.conf_path, req) p.proxy() self.assertEqual(p.res.status, http.HTTPStatus.BAD_REQUEST) @@ -254,7 +225,7 @@ def test_unofficial_status(self): proper JSON error response with a generic error message. """ req = self.make_request(path_query="/teapot") - p = proxy.Proxy(self.conf, req, on_save=callbacks.on_save) + p = proxy.Proxy(self.conf_path, req) p.proxy() self.assertEqual(p.res.status, 418) @@ -268,15 +239,14 @@ def test_internal_server_error(self): Test handling of "500 Internal Server Error" from the server. """ req = self.make_request(path_query="/crash") - p = proxy.Proxy(self.conf, req, on_save=callbacks.on_save) + p = proxy.Proxy(self.conf_path, req) p.proxy() self.assertEqual(p.res.status, http.HTTPStatus.INTERNAL_SERVER_ERROR) self.assertIn("application/json", p.res.headers["Content-Type"]) body = json.loads(p.res.body) self.assertEqual( - body["error"], - http.HTTPStatus.INTERNAL_SERVER_ERROR.phrase.lower() + body["error"], http.HTTPStatus.INTERNAL_SERVER_ERROR.phrase.lower() ) @vcr.use_cassette("fixtures/proxy_internal_error.yaml") @@ -284,14 +254,226 @@ def test_internal_error(self): """ Ensure that the proxy returns JSON despite internal errors. """ + def bad_on_save(self, fh, res, conf): raise Exception("test internal proxy error") req = self.make_request() - p = proxy.Proxy(self.conf, req, on_save=bad_on_save) + p = proxy.Proxy(self.conf_path, req) + + # Patching on_save for tests + p.on_save = types.MethodType(bad_on_save, p) p.proxy() self.assertEqual(p.res.status, http.HTTPStatus.INTERNAL_SERVER_ERROR) self.assertIn("application/json", p.res.headers["Content-Type"]) body = json.loads(p.res.body) self.assertEqual(body["error"], "internal proxy error") + + +class TestProxyMethods(unittest.TestCase): + def setUp(self): + self.res = proxy.Response(status=200) + self.res.body = "babbys request" + + self.conf_path = "tests/files/dev-config.yaml" + + def test_err_on_done(self): + saved_stdout = sys.stdout + try: + out = StringIO() + sys.stdout = out + with self.assertRaises(SystemExit): + p = proxy.Proxy(self.conf_path) + p.res = self.res + p.err_on_done() + output = out.getvalue().strip() + finally: + sys.stdout = saved_stdout + + response = json.loads(output) + self.assertEqual(response["status"], 200) + self.assertEqual(response["body"], "babbys request") + + def test_on_done(self): + saved_stdout = sys.stdout + try: + out = StringIO() + sys.stdout = out + p = proxy.Proxy(self.conf_path) + p.res = self.res + p.on_done() + output = out.getvalue().strip() + finally: + sys.stdout = saved_stdout + + response = json.loads(output) + self.assertEqual(response["status"], 200) + self.assertEqual(response["body"], "babbys request") + + def test_on_save_500_unhandled_error(self): + fh = tempfile.NamedTemporaryFile() + + # Let's generate an error and ensure that an appropriate response + # is sent back to the user + with patch("subprocess.run", side_effect=IOError): + p = proxy.Proxy(self.conf_path) + p.on_save(fh, self.res) + + self.assertEqual(self.res.status, 500) + self.assertEqual(self.res.headers["Content-Type"], "application/json") + self.assertEqual(self.res.headers["X-Origin-Content-Type"], "application/json") + self.assertIn("Unhandled error", self.res.body) + + def test_on_save_200_success(self): + fh = tempfile.NamedTemporaryFile() + + p = proxy.Proxy(self.conf_path) + p.on_save(fh, self.res) + + self.assertEqual(self.res.headers["Content-Type"], "application/json") + self.assertEqual(self.res.headers["X-Origin-Content-Type"], "application/json") + self.assertEqual(self.res.status, 200) + self.assertIn("filename", self.res.body) + + @vcr.use_cassette("fixtures/proxy_callbacks.yaml") + def test_custom_callbacks(self): + """ + Test the handlers in a real proxy request. + """ + conf = proxy.Conf() + conf.host = "jsonplaceholder.typicode.com" + conf.scheme = "https" + conf.port = 443 + + req = proxy.Req() + req.method = "GET" + + on_save_addition = "added by the on_save callback\n" + on_done_addition = "added by the on_done callback\n" + + def on_save(self, fh, res): + res.headers["Content-Type"] = "text/plain" + res.body = on_save_addition + + def on_done(self): + self.res.headers["Content-Type"] = "text/plain" + self.res.body += on_done_addition + + p = proxy.Proxy(self.conf_path, req) + # Patching for tests + p.conf = conf + p.on_done = types.MethodType(on_done, p) + p.on_save = types.MethodType(on_save, p) + p.proxy() + + self.assertEqual(p.res.body, "{}{}".format(on_save_addition, on_done_addition)) + + @vcr.use_cassette("fixtures/proxy_callbacks.yaml") + def test_production_on_save(self): + """ + Test on_save's production file handling. + """ + conf = proxy.Conf() + conf.host = "jsonplaceholder.typicode.com" + conf.scheme = "https" + conf.port = 443 + conf.dev = False + conf.target_vm = "sd-svs-dispvm" + + with patch("subprocess.run") as patched_run: + fh = tempfile.NamedTemporaryFile() + p = proxy.Proxy(self.conf_path) + # Patching for tests + p.conf = conf + p.on_save(fh, self.res) + self.assertEqual(patched_run.call_args[0][0][0], "qvm-move-to-vm") + + +class TestConfig(unittest.TestCase): + def setUp(self): + self.conf_path = "tests/files/dev-config.yaml" + + def test_config_file_does_not_exist(self): + def err_on_done(self): + res = self.res.__dict__ + assert res["status"] == 500 + assert "Configuration file does not exist" in res["body"] + assert res["headers"]["Content-Type"] == "application/json" + sys.exit(1) + + p = proxy.Proxy(self.conf_path) + p.err_on_done = types.MethodType(err_on_done, p) + with self.assertRaises(SystemExit): + p.read_conf("not/a/real/path") + + def test_config_file_when_yaml_is_invalid(self): + def err_on_done(self): + res = self.res.__dict__ + assert res["status"] == 500 + assert "YAML syntax error" in res["body"] + assert res["headers"]["Content-Type"] == "application/json" + sys.exit(1) + + p = proxy.Proxy(self.conf_path) + p.err_on_done = types.MethodType(err_on_done, p) + with self.assertRaises(SystemExit): + p.read_conf("tests/files/invalid_yaml.yaml") + + def test_config_file_open_generic_exception(self): + def err_on_done(self): + res = self.res.__dict__ + assert res["status"] == 500 + assert res["headers"]["Content-Type"] == "application/json" + sys.exit(1) + + p = proxy.Proxy(self.conf_path) + p.err_on_done = types.MethodType(err_on_done, p) + + with self.assertRaises(SystemExit): + # Patching open so that we can simulate a non-YAML error + # (e.g. permissions) + with patch("builtins.open", side_effect=IOError): + p.read_conf("tests/files/valid-config.yaml") + + def test_config_has_valid_keys(self): + p = proxy.Proxy("tests/files/valid-config.yaml") + + # Verify we have a valid Conf object + self.assertEqual(p.conf.host, "jsonplaceholder.typicode.com") + self.assertEqual(p.conf.port, 443) + self.assertFalse(p.conf.dev) + self.assertEqual(p.conf.scheme, "https") + self.assertEqual(p.conf.target_vm, "compost") + + def test_config_500_when_missing_a_required_key(self): + def err_on_done(self): + res = self.res.__dict__ + assert res["status"] == 500 + assert "missing required keys" in res["body"] + assert res["headers"]["Content-Type"] == "application/json" + sys.exit(1) + + p = proxy.Proxy(self.conf_path) + p.err_on_done = types.MethodType(err_on_done, p) + + with self.assertRaises(SystemExit): + p.read_conf("tests/files/missing-key.yaml") + + def test_config_500_when_missing_target_vm(self): + def err_on_done(self): + res = self.res.__dict__ + assert res["status"] == 500 + assert "missing `target_vm` key" in res["body"] + assert res["headers"]["Content-Type"] == "application/json" + sys.exit(1) + + p = proxy.Proxy(self.conf_path) + p.err_on_done = types.MethodType(err_on_done, p) + + with self.assertRaises(SystemExit): + p.read_conf("tests/files/missing-target-vm.yaml") + + def test_dev_config(self): + p = proxy.Proxy("tests/files/dev-config.yaml") + assert p.conf.dev