diff --git a/securedrop/tests/conftest.py b/securedrop/tests/conftest.py index 56a4d5d585..a8ef804e68 100644 --- a/securedrop/tests/conftest.py +++ b/securedrop/tests/conftest.py @@ -1,16 +1,21 @@ # -*- coding: utf-8 -*- -import os -import shutil -import signal -import subprocess -import logging import gnupg +import logging +import os import psutil import pytest +import shutil +import signal +import subprocess os.environ['SECUREDROP_ENV'] = 'test' # noqa -from sdconfig import config +from sdconfig import SDConfig, config as original_config + +from os import path + +from db import db +from source_app import create_app as create_source_app # TODO: the PID file for the redis worker is hard-coded below. # Ideally this constant would be provided by a test harness. @@ -44,11 +49,46 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope='session') -def setUptearDown(): - _start_test_rqworker(config) +def setUpTearDown(): + _start_test_rqworker(original_config) yield _stop_test_rqworker() - _cleanup_test_securedrop_dataroot(config) + _cleanup_test_securedrop_dataroot(original_config) + + +@pytest.fixture(scope='function') +def config(tmpdir): + '''Clone the module so we can modify it per test.''' + + cnf = SDConfig() + + data = tmpdir.mkdir('data') + keys = data.mkdir('keys') + store = data.mkdir('store') + tmp = data.mkdir('tmp') + sqlite = data.join('db.sqlite') + + gpg = gnupg.GPG(homedir=str(keys)) + with open(path.join(path.dirname(__file__), + 'files', + 'test_journalist_key.pub')) as f: + gpg.import_keys(f.read()) + + cnf.SECUREDROP_DATA_ROOT = str(data) + cnf.GPG_KEY_DIR = str(keys) + cnf.STORE_DIR = str(store) + cnf.TEMP_DIR = str(tmp) + cnf.DATABASE_FILE = str(sqlite) + + return cnf + + +@pytest.fixture(scope='function') +def source_app(config): + app = create_source_app(config) + with app.app_context(): + db.create_all() + return app def _start_test_rqworker(config): diff --git a/securedrop/tests/pytest.ini b/securedrop/tests/pytest.ini index bfb504ab97..affc0b16ff 100644 --- a/securedrop/tests/pytest.ini +++ b/securedrop/tests/pytest.ini @@ -1,4 +1,4 @@ [pytest] testpaths = . functional -usefixtures = setUptearDown +usefixtures = setUpTearDown addopts = --cov=../securedrop/ diff --git a/securedrop/tests/test_source.py b/securedrop/tests/test_source.py index e70cca6333..0aa4637b3e 100644 --- a/securedrop/tests/test_source.py +++ b/securedrop/tests/test_source.py @@ -2,534 +2,593 @@ import gzip import json import re +import subprocess from cStringIO import StringIO -from flask import session, escape, url_for, current_app -from flask_testing import TestCase +from flask import session, escape, current_app from mock import patch, ANY -from sdconfig import config +import crypto_util import source import utils import version from db import db from models import Source +from source_app import main as source_app_main from utils.db_helper import new_codename +from utils.instrument import InstrumentedApp overly_long_codename = 'a' * (Source.MAX_CODENAME_LEN + 1) -class TestSourceApp(TestCase): +def test_page_not_found(source_app): + """Verify the page not found condition returns the intended template""" + with InstrumentedApp(source_app) as ins: + with source_app.test_client() as app: + resp = app.get('UNKNOWN') + assert resp.status_code == 404 + ins.assert_template_used('notfound.html') - def create_app(self): - return source.app - def setUp(self): - utils.env.setup() +def test_index(source_app): + """Test that the landing page loads and looks how we expect""" + with source_app.test_client() as app: + resp = app.get('/') + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert 'Submit documents for the first time' in text + assert 'Already submitted something?' in text - def tearDown(self): - utils.env.teardown() - def test_page_not_found(self): - """Verify the page not found condition returns the intended template""" - response = self.client.get('/UNKNOWN') - self.assert404(response) - self.assertTemplateUsed('notfound.html') - - def test_index(self): - """Test that the landing page loads and looks how we expect""" - response = self.client.get('/') - self.assertEqual(response.status_code, 200) - self.assertIn("Submit documents for the first time", response.data) - self.assertIn("Already submitted something?", response.data) - - def test_all_words_in_wordlist_validate(self): - """Verify that all words in the wordlist are allowed by the form - validation. Otherwise a source will have a codename and be unable to - return.""" +def test_all_words_in_wordlist_validate(source_app): + """Verify that all words in the wordlist are allowed by the form + validation. Otherwise a source will have a codename and be unable to + return.""" + with source_app.app_context(): wordlist_en = current_app.crypto_util.get_wordlist('en') - # chunk the words to cut down on the number of requets we make - # otherwise this test is *slow* - chunks = [wordlist_en[i:i + 7] for i in range(0, len(wordlist_en), 7)] + # chunk the words to cut down on the number of requets we make + # otherwise this test is *slow* + chunks = [wordlist_en[i:i + 7] for i in range(0, len(wordlist_en), 7)] + with source_app.test_client() as app: for words in chunks: - with self.client as c: - resp = c.post('/login', data=dict(codename=' '.join(words)), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - # If the word does not validate, then it will show - # 'Invalid input'. If it does validate, it should show that - # it isn't a recognized codename. - self.assertIn('Sorry, that is not a recognized codename.', - resp.data) - self.assertNotIn('logged_in', session) - - def _find_codename(self, html): - """Find a source codename (diceware passphrase) in HTML""" - # Codenames may contain HTML escape characters, and the wordlist - # contains various symbols. - codename_re = (r'

]*id="codename"[^>]*>' - r'(?P[a-z0-9 &#;?:=@_.*+()\'"$%!-]+)

') - codename_match = re.search(codename_re, html) - self.assertIsNotNone(codename_match) - return codename_match.group('codename') - - def test_generate(self): - with self.client as c: - resp = c.get('/generate') - self.assertEqual(resp.status_code, 200) - session_codename = session['codename'] - self.assertIn("This codename is what you will use in future visits", - resp.data) - codename = self._find_codename(resp.data) - self.assertEqual(len(codename.split()), Source.NUM_WORDS) - # codename is also stored in the session - make sure it matches the - # codename displayed to the source - self.assertEqual(codename, escape(session_codename)) - - def test_generate_already_logged_in(self): - with self.client as client: - new_codename(client, session) - # Make sure it redirects to /lookup when logged in - resp = client.get('/generate') - self.assertEqual(resp.status_code, 302) - # Make sure it flashes the message on the lookup page - resp = client.get('/generate', follow_redirects=True) - # Should redirect to /lookup - self.assertEqual(resp.status_code, 200) - self.assertIn("because you are already logged in.", resp.data) - - def test_create_new_source(self): - with self.client as c: - resp = c.get('/generate') - resp = c.post('/create', follow_redirects=True) - self.assertTrue(session['logged_in']) - # should be redirected to /lookup - self.assertIn("Submit Materials", resp.data) - - @patch('source.app.logger.warning') - @patch('crypto_util.CryptoUtil.genrandomid', - side_effect=[overly_long_codename, 'short codename']) - def test_generate_too_long_codename(self, genrandomid, logger): - """Generate a codename that exceeds the maximum codename length""" - - with self.client as c: - resp = c.post('/generate') - self.assertEqual(resp.status_code, 200) - - logger.assert_called_with( - "Generated a source codename that was too long, " - "skipping it. This should not happen. " - "(Codename='{}')".format(overly_long_codename) - ) - - @patch('source.app.logger.error') - def test_create_duplicate_codename(self, logger): - with self.client as c: - c.get('/generate') + resp = app.post('/login', data=dict(codename=' '.join(words)), + follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + # If the word does not validate, then it will show + # 'Invalid input'. If it does validate, it should show that + # it isn't a recognized codename. + assert 'Sorry, that is not a recognized codename.' in text + assert 'logged_in' not in session + + +def _find_codename(html): + """Find a source codename (diceware passphrase) in HTML""" + # Codenames may contain HTML escape characters, and the wordlist + # contains various symbols. + codename_re = (r'

]*id="codename"[^>]*>' + r'(?P[a-z0-9 &#;?:=@_.*+()\'"$%!-]+)

') + codename_match = re.search(codename_re, html) + assert codename_match is not None + return codename_match.group('codename') + + +def test_generate(source_app): + with source_app.test_client() as app: + resp = app.get('/generate') + assert resp.status_code == 200 + session_codename = session['codename'] + + text = resp.data.decode('utf-8') + assert "This codename is what you will use in future visits" in text + + codename = _find_codename(resp.data) + assert len(codename.split()) == Source.NUM_WORDS + # codename is also stored in the session - make sure it matches the + # codename displayed to the source + assert codename == escape(session_codename) + + +def test_generate_already_logged_in(source_app): + with source_app.test_client() as app: + new_codename(app, session) + # Make sure it redirects to /lookup when logged in + resp = app.get('/generate') + assert resp.status_code == 302 + # Make sure it flashes the message on the lookup page + resp = app.get('/generate', follow_redirects=True) + # Should redirect to /lookup + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "because you are already logged in." in text + + +def test_create_new_source(source_app): + with source_app.test_client() as app: + resp = app.get('/generate') + assert resp.status_code == 200 + resp = app.post('/create', follow_redirects=True) + assert session['logged_in'] is True + # should be redirected to /lookup + text = resp.data.decode('utf-8') + assert "Submit Materials" in text + + +def test_generate_too_long_codename(source_app): + """Generate a codename that exceeds the maximum codename length""" + + with patch.object(source_app.logger, 'warning') as logger: + with patch.object(crypto_util.CryptoUtil, 'genrandomid', + side_effect=[overly_long_codename, + 'short codename']): + with source_app.test_client() as app: + resp = app.post('/generate') + assert resp.status_code == 200 + + logger.assert_called_with( + "Generated a source codename that was too long, " + "skipping it. This should not happen. " + "(Codename='{}')".format(overly_long_codename) + ) + + +def test_create_duplicate_codename(source_app): + with patch.object(source.app.logger, 'error') as logger: + with source_app.test_client() as app: + resp = app.get('/generate') + assert resp.status_code == 200 # Create a source the first time - c.post('/create', follow_redirects=True) + resp = app.post('/create', follow_redirects=True) + assert resp.status_code == 200 # Attempt to add the same source - c.post('/create', follow_redirects=True) + app.post('/create', follow_redirects=True) logger.assert_called_once() - self.assertIn("Attempt to create a source with duplicate codename", - logger.call_args[0][0]) + assert ("Attempt to create a source with duplicate codename" + in logger.call_args[0][0]) assert 'codename' not in session - def test_lookup(self): - """Test various elements on the /lookup page.""" - with self.client as client: - codename = new_codename(client, session) - resp = client.post('login', data=dict(codename=codename), - follow_redirects=True) - # redirects to /lookup - self.assertIn("public key", resp.data) - # download the public key - resp = client.get('journalist-key') - self.assertIn("BEGIN PGP PUBLIC KEY BLOCK", resp.data) - - def test_login_and_logout(self): - resp = self.client.get('/login') - self.assertEqual(resp.status_code, 200) - self.assertIn("Enter Codename", resp.data) - - with self.client as client: - codename = new_codename(client, session) - resp = client.post('/login', data=dict(codename=codename), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Submit Materials", resp.data) - self.assertTrue(session['logged_in']) - resp = client.get('/logout', follow_redirects=True) - - with self.client as c: - resp = c.post('/login', data=dict(codename='invalid'), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn('Sorry, that is not a recognized codename.', - resp.data) - self.assertNotIn('logged_in', session) - - with self.client as c: - resp = c.post('/login', data=dict(codename=codename), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertTrue(session['logged_in']) - resp = c.get('/logout', follow_redirects=True) - - # sessions always have 'expires', so pop it for the next check - session.pop('expires', None) - - self.assertNotIn('logged_in', session) - self.assertNotIn('codename', session) - - self.assertIn('Thank you for exiting your session!', resp.data) - - def test_user_must_log_in_for_protected_views(self): - with self.client as c: - resp = c.get('/lookup', follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Enter Codename", resp.data) - - def test_login_with_whitespace(self): - """ - Test that codenames with leading or trailing whitespace still work""" - - with self.client as client: - def login_test(codename): - resp = client.get('/login') - self.assertEqual(resp.status_code, 200) - self.assertIn("Enter Codename", resp.data) - - resp = client.post('/login', data=dict(codename=codename), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Submit Materials", resp.data) - self.assertTrue(session['logged_in']) - resp = client.get('/logout', follow_redirects=True) - - codename = new_codename(client, session) - login_test(codename + ' ') - login_test(' ' + codename + ' ') - login_test(' ' + codename) - - def _dummy_submission(self, client): - """ - Helper to make a submission (content unimportant), mostly useful in - testing notification behavior for a source's first vs. their - subsequent submissions - """ - return client.post('/submit', data=dict( - msg="Pay no attention to the man behind the curtain.", + +def test_lookup(source_app): + """Test various elements on the /lookup page.""" + with source_app.test_client() as app: + codename = new_codename(app, session) + resp = app.post('/login', data=dict(codename=codename), + follow_redirects=True) + # redirects to /lookup + text = resp.data.decode('utf-8') + assert "public key" in text + # download the public key + resp = app.get('/journalist-key') + text = resp.data.decode('utf-8') + assert "BEGIN PGP PUBLIC KEY BLOCK" in text + + +def test_login_and_logout(source_app): + with source_app.test_client() as app: + resp = app.get('/login') + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Enter Codename" in text + + codename = new_codename(app, session) + resp = app.post('/login', data=dict(codename=codename), + follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Submit Materials" in text + assert session['logged_in'] is True + + with source_app.test_client() as app: + resp = app.post('/login', data=dict(codename='invalid'), + follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert 'Sorry, that is not a recognized codename.' in text + assert 'logged_in' not in session + + with source_app.test_client() as app: + resp = app.post('/login', data=dict(codename=codename), + follow_redirects=True) + assert resp.status_code == 200 + assert session['logged_in'] is True + + resp = app.post('/login', data=dict(codename=codename), + follow_redirects=True) + assert resp.status_code == 200 + assert session['logged_in'] is True + + resp = app.get('/logout', follow_redirects=True) + assert 'logged_in' not in session + assert 'codename' not in session + text = resp.data.decode('utf-8') + assert 'Thank you for exiting your session!' in text + + +def test_user_must_log_in_for_protected_views(source_app): + with source_app.test_client() as app: + resp = app.get('/lookup', follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Enter Codename" in text + + +def test_login_with_whitespace(source_app): + """ + Test that codenames with leading or trailing whitespace still work""" + + def login_test(app, codename): + resp = app.get('/login') + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Enter Codename" in text + + resp = app.post('/login', data=dict(codename=codename), + follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Submit Materials" in text + assert session['logged_in'] is True + + with source_app.test_client() as app: + codename = new_codename(app, session) + + codenames = [ + codename + ' ', + ' ' + codename + ' ', + ' ' + codename, + ] + + for codename_ in codenames: + with source_app.test_client() as app: + login_test(app, codename_) + + +def _dummy_submission(app): + """ + Helper to make a submission (content unimportant), mostly useful in + testing notification behavior for a source's first vs. their + subsequent submissions + """ + return app.post('/submit', data=dict( + msg="Pay no attention to the man behind the curtain.", + fh=(StringIO(''), ''), + ), follow_redirects=True) + + +def test_initial_submission_notification(source_app): + """ + Regardless of the type of submission (message, file, or both), the + first submission is always greeted with a notification + reminding sources to check back later for replies. + """ + with source_app.test_client() as app: + new_codename(app, session) + resp = _dummy_submission(app) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Thank you for sending this information to us." in text + + +def test_submit_message(source_app): + with source_app.test_client() as app: + new_codename(app, session) + _dummy_submission(app) + resp = app.post('/submit', data=dict( + msg="This is a test.", fh=(StringIO(''), ''), ), follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Thanks! We received your message" in text - def test_initial_submission_notification(self): - """ - Regardless of the type of submission (message, file, or both), the - first submission is always greeted with a notification - reminding sources to check back later for replies. - """ - with self.client as client: - new_codename(client, session) - resp = self._dummy_submission(client) - self.assertEqual(resp.status_code, 200) - self.assertIn( - "Thank you for sending this information to us.", - resp.data) - - def test_submit_message(self): - with self.client as client: - new_codename(client, session) - self._dummy_submission(client) - resp = client.post('/submit', data=dict( - msg="This is a test.", - fh=(StringIO(''), ''), - ), follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Thanks! We received your message", resp.data) - def test_submit_empty_message(self): - with self.client as client: - new_codename(client, session) - resp = client.post('/submit', data=dict( - msg="", - fh=(StringIO(''), ''), - ), follow_redirects=True) - self.assertIn("You must enter a message or choose a file to " - "submit.", - resp.data) - - def test_submit_big_message(self): - ''' - When the message is larger than 512KB it's written to disk instead of - just residing in memory. Make sure the different return type of - SecureTemporaryFile is handled as well as BytesIO. - ''' - with self.client as client: - new_codename(client, session) - self._dummy_submission(client) - resp = client.post('/submit', data=dict( - msg="AA" * (1024 * 512), - fh=(StringIO(''), ''), - ), follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Thanks! We received your message", resp.data) - - def test_submit_file(self): - with self.client as client: - new_codename(client, session) - self._dummy_submission(client) - resp = client.post('/submit', data=dict( - msg="", - fh=(StringIO('This is a test'), 'test.txt'), - ), follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn('Thanks! We received your document', resp.data) - - def test_submit_both(self): - with self.client as client: - new_codename(client, session) - self._dummy_submission(client) - resp = client.post('/submit', data=dict( - msg="This is a test", - fh=(StringIO('This is a test'), 'test.txt'), - ), follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Thanks! We received your message and document", - resp.data) - - @patch('source_app.main.async_genkey') - @patch('source_app.main.get_entropy_estimate') - def test_submit_message_with_low_entropy(self, get_entropy_estimate, - async_genkey): - get_entropy_estimate.return_value = 300 - - with self.client as client: - new_codename(client, session) - self._dummy_submission(client) - resp = client.post('/submit', data=dict( - msg="This is a test.", - fh=(StringIO(''), ''), - ), follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertFalse(async_genkey.called) - - @patch('source_app.main.async_genkey') - @patch('source_app.main.get_entropy_estimate') - def test_submit_message_with_enough_entropy(self, get_entropy_estimate, - async_genkey): - get_entropy_estimate.return_value = 2400 - - with self.client as client: - new_codename(client, session) - self._dummy_submission(client) - resp = client.post('/submit', data=dict( - msg="This is a test.", - fh=(StringIO(''), ''), - ), follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertTrue(async_genkey.called) +def test_submit_empty_message(source_app): + with source_app.test_client() as app: + new_codename(app, session) + resp = app.post('/submit', data=dict( + msg="", + fh=(StringIO(''), ''), + ), follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "You must enter a message or choose a file to submit." \ + in text + + +def test_submit_big_message(source_app): + ''' + When the message is larger than 512KB it's written to disk instead of + just residing in memory. Make sure the different return type of + SecureTemporaryFile is handled as well as BytesIO. + ''' + with source_app.test_client() as app: + new_codename(app, session) + _dummy_submission(app) + resp = app.post('/submit', data=dict( + msg="AA" * (1024 * 512), + fh=(StringIO(''), ''), + ), follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Thanks! We received your message" in text + + +def test_submit_file(source_app): + with source_app.test_client() as app: + new_codename(app, session) + _dummy_submission(app) + resp = app.post('/submit', data=dict( + msg="", + fh=(StringIO('This is a test'), 'test.txt'), + ), follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert 'Thanks! We received your document' in text + + +def test_submit_both(source_app): + with source_app.test_client() as app: + new_codename(app, session) + _dummy_submission(app) + resp = app.post('/submit', data=dict( + msg="This is a test", + fh=(StringIO('This is a test'), 'test.txt'), + ), follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Thanks! We received your message and document" in text + + +def test_submit_message_with_low_entropy(source_app): + with patch.object(source_app_main, 'async_genkey') as async_genkey: + with patch.object(source_app_main, 'get_entropy_estimate') \ + as get_entropy_estimate: + get_entropy_estimate.return_value = 300 + + with source_app.test_client() as app: + new_codename(app, session) + _dummy_submission(app) + resp = app.post('/submit', data=dict( + msg="This is a test.", + fh=(StringIO(''), ''), + ), follow_redirects=True) + assert resp.status_code == 200 + assert not async_genkey.called + + +def test_submit_message_with_enough_entropy(source_app): + with patch.object(source_app_main, 'async_genkey') as async_genkey: + with patch.object(source_app_main, 'get_entropy_estimate') \ + as get_entropy_estimate: + get_entropy_estimate.return_value = 2400 + + with source_app.test_client() as app: + new_codename(app, session) + _dummy_submission(app) + resp = app.post('/submit', data=dict( + msg="This is a test.", + fh=(StringIO(''), ''), + ), follow_redirects=True) + assert resp.status_code == 200 + assert async_genkey.called - def test_delete_all_successfully_deletes_replies(self): + +def test_delete_all_successfully_deletes_replies(source_app): + with source_app.app_context(): journalist, _ = utils.db_helper.init_journalist() source, codename = utils.db_helper.init_source() utils.db_helper.reply(journalist, source, 1) - with self.client as c: - resp = c.post('/login', data=dict(codename=codename), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - resp = c.post('/delete-all', follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("All replies have been deleted", resp.data) - - @patch('source.app.logger.error') - def test_delete_all_replies_already_deleted(self, logger): + + with source_app.test_client() as app: + resp = app.post('/login', data=dict(codename=codename), + follow_redirects=True) + assert resp.status_code == 200 + resp = app.post('/delete-all', follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "All replies have been deleted" in text + + +def test_delete_all_replies_already_deleted(source_app): + with source_app.app_context(): journalist, _ = utils.db_helper.init_journalist() source, codename = utils.db_helper.init_source() # Note that we are creating the source and no replies - with self.client as c: - resp = c.post('/login', data=dict(codename=codename), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - resp = c.post('/delete-all', follow_redirects=True) - self.assertEqual(resp.status_code, 200) + with source_app.test_client() as app: + with patch.object(source_app.logger, 'error') as logger: + resp = app.post('/login', data=dict(codename=codename), + follow_redirects=True) + assert resp.status_code == 200 + resp = app.post('/delete-all', follow_redirects=True) + assert resp.status_code == 200 logger.assert_called_once_with( "Found no replies when at least one was expected" ) - @patch('gzip.GzipFile', wraps=gzip.GzipFile) - def test_submit_sanitizes_filename(self, gzipfile): - """Test that upload file name is sanitized""" - insecure_filename = '../../bin/gpg' - sanitized_filename = 'bin_gpg' - with self.client as client: - new_codename(client, session) - client.post('/submit', data=dict( +def test_submit_sanitizes_filename(source_app): + """Test that upload file name is sanitized""" + insecure_filename = '../../bin/gpg' + sanitized_filename = 'bin_gpg' + + with patch.object(gzip, 'GzipFile', wraps=gzip.GzipFile) as gzipfile: + with source_app.test_client() as app: + new_codename(app, session) + resp = app.post('/submit', data=dict( msg="", fh=(StringIO('This is a test'), insecure_filename), ), follow_redirects=True) + assert resp.status_code == 200 gzipfile.assert_called_with(filename=sanitized_filename, mode=ANY, fileobj=ANY) - def test_tor2web_warning_headers(self): - resp = self.client.get('/', headers=[('X-tor2web', 'encrypted')]) - self.assertEqual(resp.status_code, 200) - self.assertIn("You appear to be using Tor2Web.", resp.data) - - def test_tor2web_warning(self): - resp = self.client.get('/tor2web-warning') - self.assertEqual(resp.status_code, 200) - self.assertIn("Why is there a warning about Tor2Web?", resp.data) - - def test_why_use_tor_browser(self): - resp = self.client.get('/use-tor') - self.assertEqual(resp.status_code, 200) - self.assertIn("You Should Use Tor Browser", resp.data) - - def test_why_journalist_key(self): - resp = self.client.get('/why-journalist-key') - self.assertEqual(resp.status_code, 200) - self.assertIn("Why download the journalist's public key?", resp.data) - - def test_metadata_route(self): - resp = self.client.get('/metadata') - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.headers.get('Content-Type'), 'application/json') - self.assertEqual(json.loads(resp.data.decode('utf-8')).get( - 'sd_version'), version.__version__) - - @patch('crypto_util.CryptoUtil.hash_codename') - def test_login_with_overly_long_codename(self, mock_hash_codename): - """Attempting to login with an overly long codename should result in - an error, and scrypt should not be called to avoid DoS.""" - with self.client as c: - resp = c.post('/login', data=dict(codename=overly_long_codename), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Field must be between 1 and {} " - "characters long.".format(Source.MAX_CODENAME_LEN), - resp.data) - self.assertFalse(mock_hash_codename.called, - "Called hash_codename for codename w/ invalid " - "length") - - @patch('source.app.logger.warning') - @patch('subprocess.call', return_value=1) - def test_failed_normalize_timestamps_logs_warning(self, call, logger): - """If a normalize timestamps event fails, the subprocess that calls - touch will fail and exit 1. When this happens, the submission should - still occur, but a warning should be logged (this will trigger an - OSSEC alert).""" - - with self.client as client: - new_codename(client, session) - self._dummy_submission(client) - resp = client.post('/submit', data=dict( - msg="This is a test.", - fh=(StringIO(''), ''), - ), follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Thanks! We received your message", resp.data) - logger.assert_called_once_with( - "Couldn't normalize submission " - "timestamps (touch exited with 1)" - ) +def test_tor2web_warning_headers(source_app): + with source_app.test_client() as app: + resp = app.get('/', headers=[('X-tor2web', 'encrypted')]) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "You appear to be using Tor2Web." in text + + +def test_tor2web_warning(source_app): + with source_app.test_client() as app: + resp = app.get('/tor2web-warning') + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Why is there a warning about Tor2Web?" in text + + +def test_why_use_tor_browser(source_app): + with source_app.test_client() as app: + resp = app.get('/use-tor') + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "You Should Use Tor Browser" in text + + +def test_why_journalist_key(source_app): + with source_app.test_client() as app: + resp = app.get('/why-journalist-key') + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Why download the journalist's public key?" in text + + +def test_metadata_route(source_app): + with source_app.test_client() as app: + resp = app.get('/metadata') + assert resp.status_code == 200 + assert resp.headers.get('Content-Type') == 'application/json' + assert json.loads(resp.data.decode('utf-8')).get('sd_version') \ + == version.__version__ + + +def test_login_with_overly_long_codename(source_app): + """Attempting to login with an overly long codename should result in + an error, and scrypt should not be called to avoid DoS.""" + with patch.object(crypto_util.CryptoUtil, 'hash_codename') \ + as mock_hash_codename: + with source_app.test_client() as app: + resp = app.post('/login', + data=dict(codename=overly_long_codename), + follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert ("Field must be between 1 and {} characters long." + .format(Source.MAX_CODENAME_LEN)) in text + assert not mock_hash_codename.called, \ + "Called hash_codename for codename w/ invalid length" + + +def test_failed_normalize_timestamps_logs_warning(source_app): + """If a normalize timestamps event fails, the subprocess that calls + touch will fail and exit 1. When this happens, the submission should + still occur, but a warning should be logged (this will trigger an + OSSEC alert).""" + + with patch.object(source_app.logger, 'warning') as logger: + with patch.object(subprocess, 'call', return_value=1): + with source_app.test_client() as app: + new_codename(app, session) + _dummy_submission(app) + resp = app.post('/submit', data=dict( + msg="This is a test.", + fh=(StringIO(''), ''), + ), follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Thanks! We received your message" in text + + logger.assert_called_once_with( + "Couldn't normalize submission " + "timestamps (touch exited with 1)" + ) + - @patch('source.app.logger.error') - def test_source_is_deleted_while_logged_in(self, logger): - """If a source is deleted by a journalist when they are logged in, - a NoResultFound will occur. The source should be redirected to the - index when this happens, and a warning logged.""" +def test_source_is_deleted_while_logged_in(source_app): + """If a source is deleted by a journalist when they are logged in, + a NoResultFound will occur. The source should be redirected to the + index when this happens, and a warning logged.""" - with self.client as client: - codename = new_codename(client, session) - resp = client.post('login', data=dict(codename=codename), - follow_redirects=True) + with patch.object(source_app.logger, 'error') as logger: + with source_app.test_client() as app: + codename = new_codename(app, session) + resp = app.post('login', data=dict(codename=codename), + follow_redirects=True) # Now the journalist deletes the source - filesystem_id = current_app.crypto_util.hash_codename(codename) - current_app.crypto_util.delete_reply_keypair(filesystem_id) - source = Source.query.filter_by(filesystem_id=filesystem_id).one() + filesystem_id = source_app.crypto_util.hash_codename(codename) + source_app.crypto_util.delete_reply_keypair(filesystem_id) + source = Source.query.filter_by( + filesystem_id=filesystem_id).one() db.session.delete(source) db.session.commit() # Source attempts to continue to navigate - resp = client.post('/lookup', follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn('Submit documents for the first time', resp.data) - self.assertNotIn('logged_in', session.keys()) - self.assertNotIn('codename', session.keys()) + resp = app.post('/lookup', follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert 'Submit documents for the first time' in text + assert 'logged_in' not in session + assert 'codename' not in session logger.assert_called_once_with( "Found no Sources when one was expected: " "No row was found for one()") - def test_login_with_invalid_codename(self): - """Logging in with a codename with invalid characters should return - an informative message to the user.""" - invalid_codename = '[]' +def test_login_with_invalid_codename(source_app): + """Logging in with a codename with invalid characters should return + an informative message to the user.""" - with self.client as c: - resp = c.post('/login', data=dict(codename=invalid_codename), - follow_redirects=True) - self.assertEqual(resp.status_code, 200) - self.assertIn("Invalid input.", resp.data) + invalid_codename = '[]' - def _test_source_session_expiration(self): - try: - old_expiration = config.SESSION_EXPIRATION_MINUTES - has_session_expiration = True - except AttributeError: - has_session_expiration = False + with source_app.test_client() as app: + resp = app.post('/login', data=dict(codename=invalid_codename), + follow_redirects=True) + assert resp.status_code == 200 + text = resp.data.decode('utf-8') + assert "Invalid input." in text - try: - with self.client as client: - codename = new_codename(client, session) - # set the expiration to ensure we trigger an expiration - config.SESSION_EXPIRATION_MINUTES = -1 +def test_source_session_expiration(config, source_app): + with source_app.test_client() as app: + codename = new_codename(app, session) - resp = client.post('/login', - data=dict(codename=codename), - follow_redirects=True) - assert resp.status_code == 200 - resp = client.get('/lookup', follow_redirects=True) - - # check that the session was cleared (apart from 'expires' - # which is always present and 'csrf_token' which leaks no info) - session.pop('expires', None) - session.pop('csrf_token', None) - assert not session, session - assert ('You have been logged out due to inactivity' in - resp.data.decode('utf-8')) - finally: - if has_session_expiration: - config.SESSION_EXPIRATION_MINUTES = old_expiration - else: - del config.SESSION_EXPIRATION_MINUTES - - def test_csrf_error_page(self): - old_enabled = self.app.config['WTF_CSRF_ENABLED'] - self.app.config['WTF_CSRF_ENABLED'] = True - - try: - with self.app.test_client() as app: - resp = app.post(url_for('main.create')) - self.assertRedirects(resp, url_for('main.index')) - - resp = app.post(url_for('main.create'), follow_redirects=True) - self.assertIn('Your session timed out due to inactivity', - resp.data) - finally: - self.app.config['WTF_CSRF_ENABLED'] = old_enabled + # set the expiration to ensure we trigger an expiration + config.SESSION_EXPIRATION_MINUTES = -1 + + resp = app.post('/login', + data=dict(codename=codename), + follow_redirects=True) + assert resp.status_code == 200 + resp = app.get('/lookup', follow_redirects=True) + + # check that the session was cleared (apart from 'expires' + # which is always present and 'csrf_token' which leaks no info) + session.pop('expires', None) + session.pop('csrf_token', None) + assert not session, session + text = resp.data.decode('utf-8') + assert 'Your session timed out due to inactivity' in text + + +def test_csrf_error_page(config, source_app): + source_app.config['WTF_CSRF_ENABLED'] = True + with source_app.test_client() as app: + with InstrumentedApp(source_app) as ins: + resp = app.post('/create') + ins.assert_redirects(resp, '/') + + resp = app.post('/create', follow_redirects=True) + text = resp.data.decode('utf-8') + assert 'Your session timed out due to inactivity' in text diff --git a/securedrop/tests/utils/db_helper.py b/securedrop/tests/utils/db_helper.py index 38101298cc..70ef0b0682 100644 --- a/securedrop/tests/utils/db_helper.py +++ b/securedrop/tests/utils/db_helper.py @@ -163,9 +163,6 @@ def submit(source, num_submissions): def new_codename(client, session): """Helper function to go through the "generate codename" flow. """ - # clear the session because our tests have implicit reliance on each other - session.clear() - client.get('/generate') codename = session['codename'] client.post('/create') diff --git a/securedrop/tests/utils/instrument.py b/securedrop/tests/utils/instrument.py new file mode 100644 index 0000000000..aab250085f --- /dev/null +++ b/securedrop/tests/utils/instrument.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +""" +Taken from: flask_testing.utils + +Flask unittest integration. + +:copyright: (c) 2010 by Dan Jacob. +:license: BSD, see LICENSE for more details. +""" +from __future__ import absolute_import, with_statement + +try: + from urllib.parse import urlparse, urljoin +except ImportError: + # Python 2 urlparse fallback + from urlparse import urlparse, urljoin + +import pytest + +from flask import template_rendered, message_flashed + + +__all__ = ['InstrumentedApp'] + + +class ContextVariableDoesNotExist(Exception): + pass + + +class InstrumentedApp: + + def __init__(self, app): + self.app = app + + def __enter__(self): + self.templates = [] + self.flashed_messages = [] + template_rendered.connect(self._add_template) + message_flashed.connect(self._add_flash_message) + return self + + def __exit__(self, *nargs): + if getattr(self, 'app', None) is not None: + del self.app + + del self.templates[:] + del self.flashed_messages[:] + + template_rendered.disconnect(self._add_template) + message_flashed.disconnect(self._add_flash_message) + + def _add_flash_message(self, app, message, category): + self.flashed_messages.append((message, category)) + + def _add_template(self, app, template, context): + if len(self.templates) > 0: + self.templates = [] + self.templates.append((template, context)) + + def assert_message_flashed(self, message, category='message'): + """ + Checks if a given message was flashed. + + :param message: expected message + :param category: expected message category + """ + for _message, _category in self.flashed_messages: + if _message == message and _category == category: + return True + + raise AssertionError("Message '{}' in category '{}' wasn't flashed" + .format(message, category)) + + def assert_template_used(self, name, tmpl_name_attribute='name'): + """ + Checks if a given template is used in the request. If the template + engine used is not Jinja2, provide ``tmpl_name_attribute`` with a + value of its `Template` class attribute name which contains the + provided ``name`` value. + + :param name: template name + :param tmpl_name_attribute: template engine specific attribute name + """ + used_templates = [] + + for template, context in self.templates: + if getattr(template, tmpl_name_attribute) == name: + return True + + used_templates.append(template) + + raise AssertionError("Template {} not used. Templates were used: {}" + .format(name, ' '.join(repr(used_templates)))) + + def get_context_variable(self, name): + """ + Returns a variable from the context passed to the template. + + Raises a ContextVariableDoesNotExist exception if does not exist in + context. + + :param name: name of variable + """ + for template, context in self.templates: + if name in context: + return context[name] + raise ContextVariableDoesNotExist + + def assert_context(self, name, value, message=None): + """ + Checks if given name exists in the template context + and equals the given value. + + :versionadded: 0.2 + :param name: name of context variable + :param value: value to check against + """ + + try: + assert self.get_context_variable(name) == value, message + except ContextVariableDoesNotExist: + pytest.fail(message or + "Context variable does not exist: {}".format(name)) + + def assert_redirects(self, response, location, message=None): + """ + Checks if response is an HTTP redirect to the + given location. + + :param response: Flask response + :param location: relative URL path to SERVER_NAME or an absolute URL + """ + parts = urlparse(location) + + if parts.netloc: + expected_location = location + else: + server_name = self.app.config.get('SERVER_NAME') or 'localhost' + expected_location = urljoin("http://%s" % server_name, location) + + valid_status_codes = (301, 302, 303, 305, 307) + valid_status_code_str = ', '.join([str(code) + for code in valid_status_codes]) + not_redirect = "HTTP Status {} expected but got {}" \ + .format(valid_status_code_str, response.status_code) + assert (response.status_code in (valid_status_codes, message) + or not_redirect) + assert response.location == expected_location, message