diff --git a/securedrop/tests/conftest.py b/securedrop/tests/conftest.py
index 56a4d5d585..a8ef804e68 100644
--- a/securedrop/tests/conftest.py
+++ b/securedrop/tests/conftest.py
@@ -1,16 +1,21 @@
# -*- coding: utf-8 -*-
-import os
-import shutil
-import signal
-import subprocess
-import logging
import gnupg
+import logging
+import os
import psutil
import pytest
+import shutil
+import signal
+import subprocess
os.environ['SECUREDROP_ENV'] = 'test' # noqa
-from sdconfig import config
+from sdconfig import SDConfig, config as original_config
+
+from os import path
+
+from db import db
+from source_app import create_app as create_source_app
# TODO: the PID file for the redis worker is hard-coded below.
# Ideally this constant would be provided by a test harness.
@@ -44,11 +49,46 @@ def pytest_collection_modifyitems(config, items):
@pytest.fixture(scope='session')
-def setUptearDown():
- _start_test_rqworker(config)
+def setUpTearDown():
+ _start_test_rqworker(original_config)
yield
_stop_test_rqworker()
- _cleanup_test_securedrop_dataroot(config)
+ _cleanup_test_securedrop_dataroot(original_config)
+
+
+@pytest.fixture(scope='function')
+def config(tmpdir):
+ '''Clone the module so we can modify it per test.'''
+
+ cnf = SDConfig()
+
+ data = tmpdir.mkdir('data')
+ keys = data.mkdir('keys')
+ store = data.mkdir('store')
+ tmp = data.mkdir('tmp')
+ sqlite = data.join('db.sqlite')
+
+ gpg = gnupg.GPG(homedir=str(keys))
+ with open(path.join(path.dirname(__file__),
+ 'files',
+ 'test_journalist_key.pub')) as f:
+ gpg.import_keys(f.read())
+
+ cnf.SECUREDROP_DATA_ROOT = str(data)
+ cnf.GPG_KEY_DIR = str(keys)
+ cnf.STORE_DIR = str(store)
+ cnf.TEMP_DIR = str(tmp)
+ cnf.DATABASE_FILE = str(sqlite)
+
+ return cnf
+
+
+@pytest.fixture(scope='function')
+def source_app(config):
+ app = create_source_app(config)
+ with app.app_context():
+ db.create_all()
+ return app
def _start_test_rqworker(config):
diff --git a/securedrop/tests/pytest.ini b/securedrop/tests/pytest.ini
index bfb504ab97..affc0b16ff 100644
--- a/securedrop/tests/pytest.ini
+++ b/securedrop/tests/pytest.ini
@@ -1,4 +1,4 @@
[pytest]
testpaths = . functional
-usefixtures = setUptearDown
+usefixtures = setUpTearDown
addopts = --cov=../securedrop/
diff --git a/securedrop/tests/test_source.py b/securedrop/tests/test_source.py
index e70cca6333..0aa4637b3e 100644
--- a/securedrop/tests/test_source.py
+++ b/securedrop/tests/test_source.py
@@ -2,534 +2,593 @@
import gzip
import json
import re
+import subprocess
from cStringIO import StringIO
-from flask import session, escape, url_for, current_app
-from flask_testing import TestCase
+from flask import session, escape, current_app
from mock import patch, ANY
-from sdconfig import config
+import crypto_util
import source
import utils
import version
from db import db
from models import Source
+from source_app import main as source_app_main
from utils.db_helper import new_codename
+from utils.instrument import InstrumentedApp
overly_long_codename = 'a' * (Source.MAX_CODENAME_LEN + 1)
-class TestSourceApp(TestCase):
+def test_page_not_found(source_app):
+ """Verify the page not found condition returns the intended template"""
+ with InstrumentedApp(source_app) as ins:
+ with source_app.test_client() as app:
+ resp = app.get('UNKNOWN')
+ assert resp.status_code == 404
+ ins.assert_template_used('notfound.html')
- def create_app(self):
- return source.app
- def setUp(self):
- utils.env.setup()
+def test_index(source_app):
+ """Test that the landing page loads and looks how we expect"""
+ with source_app.test_client() as app:
+ resp = app.get('/')
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert 'Submit documents for the first time' in text
+ assert 'Already submitted something?' in text
- def tearDown(self):
- utils.env.teardown()
- def test_page_not_found(self):
- """Verify the page not found condition returns the intended template"""
- response = self.client.get('/UNKNOWN')
- self.assert404(response)
- self.assertTemplateUsed('notfound.html')
-
- def test_index(self):
- """Test that the landing page loads and looks how we expect"""
- response = self.client.get('/')
- self.assertEqual(response.status_code, 200)
- self.assertIn("Submit documents for the first time", response.data)
- self.assertIn("Already submitted something?", response.data)
-
- def test_all_words_in_wordlist_validate(self):
- """Verify that all words in the wordlist are allowed by the form
- validation. Otherwise a source will have a codename and be unable to
- return."""
+def test_all_words_in_wordlist_validate(source_app):
+ """Verify that all words in the wordlist are allowed by the form
+ validation. Otherwise a source will have a codename and be unable to
+ return."""
+ with source_app.app_context():
wordlist_en = current_app.crypto_util.get_wordlist('en')
- # chunk the words to cut down on the number of requets we make
- # otherwise this test is *slow*
- chunks = [wordlist_en[i:i + 7] for i in range(0, len(wordlist_en), 7)]
+ # chunk the words to cut down on the number of requets we make
+ # otherwise this test is *slow*
+ chunks = [wordlist_en[i:i + 7] for i in range(0, len(wordlist_en), 7)]
+ with source_app.test_client() as app:
for words in chunks:
- with self.client as c:
- resp = c.post('/login', data=dict(codename=' '.join(words)),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- # If the word does not validate, then it will show
- # 'Invalid input'. If it does validate, it should show that
- # it isn't a recognized codename.
- self.assertIn('Sorry, that is not a recognized codename.',
- resp.data)
- self.assertNotIn('logged_in', session)
-
- def _find_codename(self, html):
- """Find a source codename (diceware passphrase) in HTML"""
- # Codenames may contain HTML escape characters, and the wordlist
- # contains various symbols.
- codename_re = (r'
]*id="codename"[^>]*>'
- r'(?P[a-z0-9 ?:=@_.*+()\'"$%!-]+)
')
- codename_match = re.search(codename_re, html)
- self.assertIsNotNone(codename_match)
- return codename_match.group('codename')
-
- def test_generate(self):
- with self.client as c:
- resp = c.get('/generate')
- self.assertEqual(resp.status_code, 200)
- session_codename = session['codename']
- self.assertIn("This codename is what you will use in future visits",
- resp.data)
- codename = self._find_codename(resp.data)
- self.assertEqual(len(codename.split()), Source.NUM_WORDS)
- # codename is also stored in the session - make sure it matches the
- # codename displayed to the source
- self.assertEqual(codename, escape(session_codename))
-
- def test_generate_already_logged_in(self):
- with self.client as client:
- new_codename(client, session)
- # Make sure it redirects to /lookup when logged in
- resp = client.get('/generate')
- self.assertEqual(resp.status_code, 302)
- # Make sure it flashes the message on the lookup page
- resp = client.get('/generate', follow_redirects=True)
- # Should redirect to /lookup
- self.assertEqual(resp.status_code, 200)
- self.assertIn("because you are already logged in.", resp.data)
-
- def test_create_new_source(self):
- with self.client as c:
- resp = c.get('/generate')
- resp = c.post('/create', follow_redirects=True)
- self.assertTrue(session['logged_in'])
- # should be redirected to /lookup
- self.assertIn("Submit Materials", resp.data)
-
- @patch('source.app.logger.warning')
- @patch('crypto_util.CryptoUtil.genrandomid',
- side_effect=[overly_long_codename, 'short codename'])
- def test_generate_too_long_codename(self, genrandomid, logger):
- """Generate a codename that exceeds the maximum codename length"""
-
- with self.client as c:
- resp = c.post('/generate')
- self.assertEqual(resp.status_code, 200)
-
- logger.assert_called_with(
- "Generated a source codename that was too long, "
- "skipping it. This should not happen. "
- "(Codename='{}')".format(overly_long_codename)
- )
-
- @patch('source.app.logger.error')
- def test_create_duplicate_codename(self, logger):
- with self.client as c:
- c.get('/generate')
+ resp = app.post('/login', data=dict(codename=' '.join(words)),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ # If the word does not validate, then it will show
+ # 'Invalid input'. If it does validate, it should show that
+ # it isn't a recognized codename.
+ assert 'Sorry, that is not a recognized codename.' in text
+ assert 'logged_in' not in session
+
+
+def _find_codename(html):
+ """Find a source codename (diceware passphrase) in HTML"""
+ # Codenames may contain HTML escape characters, and the wordlist
+ # contains various symbols.
+ codename_re = (r']*id="codename"[^>]*>'
+ r'(?P[a-z0-9 ?:=@_.*+()\'"$%!-]+)
')
+ codename_match = re.search(codename_re, html)
+ assert codename_match is not None
+ return codename_match.group('codename')
+
+
+def test_generate(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/generate')
+ assert resp.status_code == 200
+ session_codename = session['codename']
+
+ text = resp.data.decode('utf-8')
+ assert "This codename is what you will use in future visits" in text
+
+ codename = _find_codename(resp.data)
+ assert len(codename.split()) == Source.NUM_WORDS
+ # codename is also stored in the session - make sure it matches the
+ # codename displayed to the source
+ assert codename == escape(session_codename)
+
+
+def test_generate_already_logged_in(source_app):
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ # Make sure it redirects to /lookup when logged in
+ resp = app.get('/generate')
+ assert resp.status_code == 302
+ # Make sure it flashes the message on the lookup page
+ resp = app.get('/generate', follow_redirects=True)
+ # Should redirect to /lookup
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "because you are already logged in." in text
+
+
+def test_create_new_source(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/generate')
+ assert resp.status_code == 200
+ resp = app.post('/create', follow_redirects=True)
+ assert session['logged_in'] is True
+ # should be redirected to /lookup
+ text = resp.data.decode('utf-8')
+ assert "Submit Materials" in text
+
+
+def test_generate_too_long_codename(source_app):
+ """Generate a codename that exceeds the maximum codename length"""
+
+ with patch.object(source_app.logger, 'warning') as logger:
+ with patch.object(crypto_util.CryptoUtil, 'genrandomid',
+ side_effect=[overly_long_codename,
+ 'short codename']):
+ with source_app.test_client() as app:
+ resp = app.post('/generate')
+ assert resp.status_code == 200
+
+ logger.assert_called_with(
+ "Generated a source codename that was too long, "
+ "skipping it. This should not happen. "
+ "(Codename='{}')".format(overly_long_codename)
+ )
+
+
+def test_create_duplicate_codename(source_app):
+ with patch.object(source.app.logger, 'error') as logger:
+ with source_app.test_client() as app:
+ resp = app.get('/generate')
+ assert resp.status_code == 200
# Create a source the first time
- c.post('/create', follow_redirects=True)
+ resp = app.post('/create', follow_redirects=True)
+ assert resp.status_code == 200
# Attempt to add the same source
- c.post('/create', follow_redirects=True)
+ app.post('/create', follow_redirects=True)
logger.assert_called_once()
- self.assertIn("Attempt to create a source with duplicate codename",
- logger.call_args[0][0])
+ assert ("Attempt to create a source with duplicate codename"
+ in logger.call_args[0][0])
assert 'codename' not in session
- def test_lookup(self):
- """Test various elements on the /lookup page."""
- with self.client as client:
- codename = new_codename(client, session)
- resp = client.post('login', data=dict(codename=codename),
- follow_redirects=True)
- # redirects to /lookup
- self.assertIn("public key", resp.data)
- # download the public key
- resp = client.get('journalist-key')
- self.assertIn("BEGIN PGP PUBLIC KEY BLOCK", resp.data)
-
- def test_login_and_logout(self):
- resp = self.client.get('/login')
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Enter Codename", resp.data)
-
- with self.client as client:
- codename = new_codename(client, session)
- resp = client.post('/login', data=dict(codename=codename),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Submit Materials", resp.data)
- self.assertTrue(session['logged_in'])
- resp = client.get('/logout', follow_redirects=True)
-
- with self.client as c:
- resp = c.post('/login', data=dict(codename='invalid'),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn('Sorry, that is not a recognized codename.',
- resp.data)
- self.assertNotIn('logged_in', session)
-
- with self.client as c:
- resp = c.post('/login', data=dict(codename=codename),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertTrue(session['logged_in'])
- resp = c.get('/logout', follow_redirects=True)
-
- # sessions always have 'expires', so pop it for the next check
- session.pop('expires', None)
-
- self.assertNotIn('logged_in', session)
- self.assertNotIn('codename', session)
-
- self.assertIn('Thank you for exiting your session!', resp.data)
-
- def test_user_must_log_in_for_protected_views(self):
- with self.client as c:
- resp = c.get('/lookup', follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Enter Codename", resp.data)
-
- def test_login_with_whitespace(self):
- """
- Test that codenames with leading or trailing whitespace still work"""
-
- with self.client as client:
- def login_test(codename):
- resp = client.get('/login')
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Enter Codename", resp.data)
-
- resp = client.post('/login', data=dict(codename=codename),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Submit Materials", resp.data)
- self.assertTrue(session['logged_in'])
- resp = client.get('/logout', follow_redirects=True)
-
- codename = new_codename(client, session)
- login_test(codename + ' ')
- login_test(' ' + codename + ' ')
- login_test(' ' + codename)
-
- def _dummy_submission(self, client):
- """
- Helper to make a submission (content unimportant), mostly useful in
- testing notification behavior for a source's first vs. their
- subsequent submissions
- """
- return client.post('/submit', data=dict(
- msg="Pay no attention to the man behind the curtain.",
+
+def test_lookup(source_app):
+ """Test various elements on the /lookup page."""
+ with source_app.test_client() as app:
+ codename = new_codename(app, session)
+ resp = app.post('/login', data=dict(codename=codename),
+ follow_redirects=True)
+ # redirects to /lookup
+ text = resp.data.decode('utf-8')
+ assert "public key" in text
+ # download the public key
+ resp = app.get('/journalist-key')
+ text = resp.data.decode('utf-8')
+ assert "BEGIN PGP PUBLIC KEY BLOCK" in text
+
+
+def test_login_and_logout(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/login')
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Enter Codename" in text
+
+ codename = new_codename(app, session)
+ resp = app.post('/login', data=dict(codename=codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Submit Materials" in text
+ assert session['logged_in'] is True
+
+ with source_app.test_client() as app:
+ resp = app.post('/login', data=dict(codename='invalid'),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert 'Sorry, that is not a recognized codename.' in text
+ assert 'logged_in' not in session
+
+ with source_app.test_client() as app:
+ resp = app.post('/login', data=dict(codename=codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ assert session['logged_in'] is True
+
+ resp = app.post('/login', data=dict(codename=codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ assert session['logged_in'] is True
+
+ resp = app.get('/logout', follow_redirects=True)
+ assert 'logged_in' not in session
+ assert 'codename' not in session
+ text = resp.data.decode('utf-8')
+ assert 'Thank you for exiting your session!' in text
+
+
+def test_user_must_log_in_for_protected_views(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/lookup', follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Enter Codename" in text
+
+
+def test_login_with_whitespace(source_app):
+ """
+ Test that codenames with leading or trailing whitespace still work"""
+
+ def login_test(app, codename):
+ resp = app.get('/login')
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Enter Codename" in text
+
+ resp = app.post('/login', data=dict(codename=codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Submit Materials" in text
+ assert session['logged_in'] is True
+
+ with source_app.test_client() as app:
+ codename = new_codename(app, session)
+
+ codenames = [
+ codename + ' ',
+ ' ' + codename + ' ',
+ ' ' + codename,
+ ]
+
+ for codename_ in codenames:
+ with source_app.test_client() as app:
+ login_test(app, codename_)
+
+
+def _dummy_submission(app):
+ """
+ Helper to make a submission (content unimportant), mostly useful in
+ testing notification behavior for a source's first vs. their
+ subsequent submissions
+ """
+ return app.post('/submit', data=dict(
+ msg="Pay no attention to the man behind the curtain.",
+ fh=(StringIO(''), ''),
+ ), follow_redirects=True)
+
+
+def test_initial_submission_notification(source_app):
+ """
+ Regardless of the type of submission (message, file, or both), the
+ first submission is always greeted with a notification
+ reminding sources to check back later for replies.
+ """
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ resp = _dummy_submission(app)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Thank you for sending this information to us." in text
+
+
+def test_submit_message(source_app):
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ _dummy_submission(app)
+ resp = app.post('/submit', data=dict(
+ msg="This is a test.",
fh=(StringIO(''), ''),
), follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Thanks! We received your message" in text
- def test_initial_submission_notification(self):
- """
- Regardless of the type of submission (message, file, or both), the
- first submission is always greeted with a notification
- reminding sources to check back later for replies.
- """
- with self.client as client:
- new_codename(client, session)
- resp = self._dummy_submission(client)
- self.assertEqual(resp.status_code, 200)
- self.assertIn(
- "Thank you for sending this information to us.",
- resp.data)
-
- def test_submit_message(self):
- with self.client as client:
- new_codename(client, session)
- self._dummy_submission(client)
- resp = client.post('/submit', data=dict(
- msg="This is a test.",
- fh=(StringIO(''), ''),
- ), follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Thanks! We received your message", resp.data)
- def test_submit_empty_message(self):
- with self.client as client:
- new_codename(client, session)
- resp = client.post('/submit', data=dict(
- msg="",
- fh=(StringIO(''), ''),
- ), follow_redirects=True)
- self.assertIn("You must enter a message or choose a file to "
- "submit.",
- resp.data)
-
- def test_submit_big_message(self):
- '''
- When the message is larger than 512KB it's written to disk instead of
- just residing in memory. Make sure the different return type of
- SecureTemporaryFile is handled as well as BytesIO.
- '''
- with self.client as client:
- new_codename(client, session)
- self._dummy_submission(client)
- resp = client.post('/submit', data=dict(
- msg="AA" * (1024 * 512),
- fh=(StringIO(''), ''),
- ), follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Thanks! We received your message", resp.data)
-
- def test_submit_file(self):
- with self.client as client:
- new_codename(client, session)
- self._dummy_submission(client)
- resp = client.post('/submit', data=dict(
- msg="",
- fh=(StringIO('This is a test'), 'test.txt'),
- ), follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn('Thanks! We received your document', resp.data)
-
- def test_submit_both(self):
- with self.client as client:
- new_codename(client, session)
- self._dummy_submission(client)
- resp = client.post('/submit', data=dict(
- msg="This is a test",
- fh=(StringIO('This is a test'), 'test.txt'),
- ), follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Thanks! We received your message and document",
- resp.data)
-
- @patch('source_app.main.async_genkey')
- @patch('source_app.main.get_entropy_estimate')
- def test_submit_message_with_low_entropy(self, get_entropy_estimate,
- async_genkey):
- get_entropy_estimate.return_value = 300
-
- with self.client as client:
- new_codename(client, session)
- self._dummy_submission(client)
- resp = client.post('/submit', data=dict(
- msg="This is a test.",
- fh=(StringIO(''), ''),
- ), follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertFalse(async_genkey.called)
-
- @patch('source_app.main.async_genkey')
- @patch('source_app.main.get_entropy_estimate')
- def test_submit_message_with_enough_entropy(self, get_entropy_estimate,
- async_genkey):
- get_entropy_estimate.return_value = 2400
-
- with self.client as client:
- new_codename(client, session)
- self._dummy_submission(client)
- resp = client.post('/submit', data=dict(
- msg="This is a test.",
- fh=(StringIO(''), ''),
- ), follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertTrue(async_genkey.called)
+def test_submit_empty_message(source_app):
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ resp = app.post('/submit', data=dict(
+ msg="",
+ fh=(StringIO(''), ''),
+ ), follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "You must enter a message or choose a file to submit." \
+ in text
+
+
+def test_submit_big_message(source_app):
+ '''
+ When the message is larger than 512KB it's written to disk instead of
+ just residing in memory. Make sure the different return type of
+ SecureTemporaryFile is handled as well as BytesIO.
+ '''
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ _dummy_submission(app)
+ resp = app.post('/submit', data=dict(
+ msg="AA" * (1024 * 512),
+ fh=(StringIO(''), ''),
+ ), follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Thanks! We received your message" in text
+
+
+def test_submit_file(source_app):
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ _dummy_submission(app)
+ resp = app.post('/submit', data=dict(
+ msg="",
+ fh=(StringIO('This is a test'), 'test.txt'),
+ ), follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert 'Thanks! We received your document' in text
+
+
+def test_submit_both(source_app):
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ _dummy_submission(app)
+ resp = app.post('/submit', data=dict(
+ msg="This is a test",
+ fh=(StringIO('This is a test'), 'test.txt'),
+ ), follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Thanks! We received your message and document" in text
+
+
+def test_submit_message_with_low_entropy(source_app):
+ with patch.object(source_app_main, 'async_genkey') as async_genkey:
+ with patch.object(source_app_main, 'get_entropy_estimate') \
+ as get_entropy_estimate:
+ get_entropy_estimate.return_value = 300
+
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ _dummy_submission(app)
+ resp = app.post('/submit', data=dict(
+ msg="This is a test.",
+ fh=(StringIO(''), ''),
+ ), follow_redirects=True)
+ assert resp.status_code == 200
+ assert not async_genkey.called
+
+
+def test_submit_message_with_enough_entropy(source_app):
+ with patch.object(source_app_main, 'async_genkey') as async_genkey:
+ with patch.object(source_app_main, 'get_entropy_estimate') \
+ as get_entropy_estimate:
+ get_entropy_estimate.return_value = 2400
+
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ _dummy_submission(app)
+ resp = app.post('/submit', data=dict(
+ msg="This is a test.",
+ fh=(StringIO(''), ''),
+ ), follow_redirects=True)
+ assert resp.status_code == 200
+ assert async_genkey.called
- def test_delete_all_successfully_deletes_replies(self):
+
+def test_delete_all_successfully_deletes_replies(source_app):
+ with source_app.app_context():
journalist, _ = utils.db_helper.init_journalist()
source, codename = utils.db_helper.init_source()
utils.db_helper.reply(journalist, source, 1)
- with self.client as c:
- resp = c.post('/login', data=dict(codename=codename),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- resp = c.post('/delete-all', follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("All replies have been deleted", resp.data)
-
- @patch('source.app.logger.error')
- def test_delete_all_replies_already_deleted(self, logger):
+
+ with source_app.test_client() as app:
+ resp = app.post('/login', data=dict(codename=codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ resp = app.post('/delete-all', follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "All replies have been deleted" in text
+
+
+def test_delete_all_replies_already_deleted(source_app):
+ with source_app.app_context():
journalist, _ = utils.db_helper.init_journalist()
source, codename = utils.db_helper.init_source()
# Note that we are creating the source and no replies
- with self.client as c:
- resp = c.post('/login', data=dict(codename=codename),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- resp = c.post('/delete-all', follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
+ with source_app.test_client() as app:
+ with patch.object(source_app.logger, 'error') as logger:
+ resp = app.post('/login', data=dict(codename=codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ resp = app.post('/delete-all', follow_redirects=True)
+ assert resp.status_code == 200
logger.assert_called_once_with(
"Found no replies when at least one was expected"
)
- @patch('gzip.GzipFile', wraps=gzip.GzipFile)
- def test_submit_sanitizes_filename(self, gzipfile):
- """Test that upload file name is sanitized"""
- insecure_filename = '../../bin/gpg'
- sanitized_filename = 'bin_gpg'
- with self.client as client:
- new_codename(client, session)
- client.post('/submit', data=dict(
+def test_submit_sanitizes_filename(source_app):
+ """Test that upload file name is sanitized"""
+ insecure_filename = '../../bin/gpg'
+ sanitized_filename = 'bin_gpg'
+
+ with patch.object(gzip, 'GzipFile', wraps=gzip.GzipFile) as gzipfile:
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ resp = app.post('/submit', data=dict(
msg="",
fh=(StringIO('This is a test'), insecure_filename),
), follow_redirects=True)
+ assert resp.status_code == 200
gzipfile.assert_called_with(filename=sanitized_filename,
mode=ANY,
fileobj=ANY)
- def test_tor2web_warning_headers(self):
- resp = self.client.get('/', headers=[('X-tor2web', 'encrypted')])
- self.assertEqual(resp.status_code, 200)
- self.assertIn("You appear to be using Tor2Web.", resp.data)
-
- def test_tor2web_warning(self):
- resp = self.client.get('/tor2web-warning')
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Why is there a warning about Tor2Web?", resp.data)
-
- def test_why_use_tor_browser(self):
- resp = self.client.get('/use-tor')
- self.assertEqual(resp.status_code, 200)
- self.assertIn("You Should Use Tor Browser", resp.data)
-
- def test_why_journalist_key(self):
- resp = self.client.get('/why-journalist-key')
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Why download the journalist's public key?", resp.data)
-
- def test_metadata_route(self):
- resp = self.client.get('/metadata')
- self.assertEqual(resp.status_code, 200)
- self.assertEqual(resp.headers.get('Content-Type'), 'application/json')
- self.assertEqual(json.loads(resp.data.decode('utf-8')).get(
- 'sd_version'), version.__version__)
-
- @patch('crypto_util.CryptoUtil.hash_codename')
- def test_login_with_overly_long_codename(self, mock_hash_codename):
- """Attempting to login with an overly long codename should result in
- an error, and scrypt should not be called to avoid DoS."""
- with self.client as c:
- resp = c.post('/login', data=dict(codename=overly_long_codename),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Field must be between 1 and {} "
- "characters long.".format(Source.MAX_CODENAME_LEN),
- resp.data)
- self.assertFalse(mock_hash_codename.called,
- "Called hash_codename for codename w/ invalid "
- "length")
-
- @patch('source.app.logger.warning')
- @patch('subprocess.call', return_value=1)
- def test_failed_normalize_timestamps_logs_warning(self, call, logger):
- """If a normalize timestamps event fails, the subprocess that calls
- touch will fail and exit 1. When this happens, the submission should
- still occur, but a warning should be logged (this will trigger an
- OSSEC alert)."""
-
- with self.client as client:
- new_codename(client, session)
- self._dummy_submission(client)
- resp = client.post('/submit', data=dict(
- msg="This is a test.",
- fh=(StringIO(''), ''),
- ), follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Thanks! We received your message", resp.data)
- logger.assert_called_once_with(
- "Couldn't normalize submission "
- "timestamps (touch exited with 1)"
- )
+def test_tor2web_warning_headers(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/', headers=[('X-tor2web', 'encrypted')])
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "You appear to be using Tor2Web." in text
+
+
+def test_tor2web_warning(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/tor2web-warning')
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Why is there a warning about Tor2Web?" in text
+
+
+def test_why_use_tor_browser(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/use-tor')
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "You Should Use Tor Browser" in text
+
+
+def test_why_journalist_key(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/why-journalist-key')
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Why download the journalist's public key?" in text
+
+
+def test_metadata_route(source_app):
+ with source_app.test_client() as app:
+ resp = app.get('/metadata')
+ assert resp.status_code == 200
+ assert resp.headers.get('Content-Type') == 'application/json'
+ assert json.loads(resp.data.decode('utf-8')).get('sd_version') \
+ == version.__version__
+
+
+def test_login_with_overly_long_codename(source_app):
+ """Attempting to login with an overly long codename should result in
+ an error, and scrypt should not be called to avoid DoS."""
+ with patch.object(crypto_util.CryptoUtil, 'hash_codename') \
+ as mock_hash_codename:
+ with source_app.test_client() as app:
+ resp = app.post('/login',
+ data=dict(codename=overly_long_codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert ("Field must be between 1 and {} characters long."
+ .format(Source.MAX_CODENAME_LEN)) in text
+ assert not mock_hash_codename.called, \
+ "Called hash_codename for codename w/ invalid length"
+
+
+def test_failed_normalize_timestamps_logs_warning(source_app):
+ """If a normalize timestamps event fails, the subprocess that calls
+ touch will fail and exit 1. When this happens, the submission should
+ still occur, but a warning should be logged (this will trigger an
+ OSSEC alert)."""
+
+ with patch.object(source_app.logger, 'warning') as logger:
+ with patch.object(subprocess, 'call', return_value=1):
+ with source_app.test_client() as app:
+ new_codename(app, session)
+ _dummy_submission(app)
+ resp = app.post('/submit', data=dict(
+ msg="This is a test.",
+ fh=(StringIO(''), ''),
+ ), follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Thanks! We received your message" in text
+
+ logger.assert_called_once_with(
+ "Couldn't normalize submission "
+ "timestamps (touch exited with 1)"
+ )
+
- @patch('source.app.logger.error')
- def test_source_is_deleted_while_logged_in(self, logger):
- """If a source is deleted by a journalist when they are logged in,
- a NoResultFound will occur. The source should be redirected to the
- index when this happens, and a warning logged."""
+def test_source_is_deleted_while_logged_in(source_app):
+ """If a source is deleted by a journalist when they are logged in,
+ a NoResultFound will occur. The source should be redirected to the
+ index when this happens, and a warning logged."""
- with self.client as client:
- codename = new_codename(client, session)
- resp = client.post('login', data=dict(codename=codename),
- follow_redirects=True)
+ with patch.object(source_app.logger, 'error') as logger:
+ with source_app.test_client() as app:
+ codename = new_codename(app, session)
+ resp = app.post('login', data=dict(codename=codename),
+ follow_redirects=True)
# Now the journalist deletes the source
- filesystem_id = current_app.crypto_util.hash_codename(codename)
- current_app.crypto_util.delete_reply_keypair(filesystem_id)
- source = Source.query.filter_by(filesystem_id=filesystem_id).one()
+ filesystem_id = source_app.crypto_util.hash_codename(codename)
+ source_app.crypto_util.delete_reply_keypair(filesystem_id)
+ source = Source.query.filter_by(
+ filesystem_id=filesystem_id).one()
db.session.delete(source)
db.session.commit()
# Source attempts to continue to navigate
- resp = client.post('/lookup', follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn('Submit documents for the first time', resp.data)
- self.assertNotIn('logged_in', session.keys())
- self.assertNotIn('codename', session.keys())
+ resp = app.post('/lookup', follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert 'Submit documents for the first time' in text
+ assert 'logged_in' not in session
+ assert 'codename' not in session
logger.assert_called_once_with(
"Found no Sources when one was expected: "
"No row was found for one()")
- def test_login_with_invalid_codename(self):
- """Logging in with a codename with invalid characters should return
- an informative message to the user."""
- invalid_codename = '[]'
+def test_login_with_invalid_codename(source_app):
+ """Logging in with a codename with invalid characters should return
+ an informative message to the user."""
- with self.client as c:
- resp = c.post('/login', data=dict(codename=invalid_codename),
- follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertIn("Invalid input.", resp.data)
+ invalid_codename = '[]'
- def _test_source_session_expiration(self):
- try:
- old_expiration = config.SESSION_EXPIRATION_MINUTES
- has_session_expiration = True
- except AttributeError:
- has_session_expiration = False
+ with source_app.test_client() as app:
+ resp = app.post('/login', data=dict(codename=invalid_codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ text = resp.data.decode('utf-8')
+ assert "Invalid input." in text
- try:
- with self.client as client:
- codename = new_codename(client, session)
- # set the expiration to ensure we trigger an expiration
- config.SESSION_EXPIRATION_MINUTES = -1
+def test_source_session_expiration(config, source_app):
+ with source_app.test_client() as app:
+ codename = new_codename(app, session)
- resp = client.post('/login',
- data=dict(codename=codename),
- follow_redirects=True)
- assert resp.status_code == 200
- resp = client.get('/lookup', follow_redirects=True)
-
- # check that the session was cleared (apart from 'expires'
- # which is always present and 'csrf_token' which leaks no info)
- session.pop('expires', None)
- session.pop('csrf_token', None)
- assert not session, session
- assert ('You have been logged out due to inactivity' in
- resp.data.decode('utf-8'))
- finally:
- if has_session_expiration:
- config.SESSION_EXPIRATION_MINUTES = old_expiration
- else:
- del config.SESSION_EXPIRATION_MINUTES
-
- def test_csrf_error_page(self):
- old_enabled = self.app.config['WTF_CSRF_ENABLED']
- self.app.config['WTF_CSRF_ENABLED'] = True
-
- try:
- with self.app.test_client() as app:
- resp = app.post(url_for('main.create'))
- self.assertRedirects(resp, url_for('main.index'))
-
- resp = app.post(url_for('main.create'), follow_redirects=True)
- self.assertIn('Your session timed out due to inactivity',
- resp.data)
- finally:
- self.app.config['WTF_CSRF_ENABLED'] = old_enabled
+ # set the expiration to ensure we trigger an expiration
+ config.SESSION_EXPIRATION_MINUTES = -1
+
+ resp = app.post('/login',
+ data=dict(codename=codename),
+ follow_redirects=True)
+ assert resp.status_code == 200
+ resp = app.get('/lookup', follow_redirects=True)
+
+ # check that the session was cleared (apart from 'expires'
+ # which is always present and 'csrf_token' which leaks no info)
+ session.pop('expires', None)
+ session.pop('csrf_token', None)
+ assert not session, session
+ text = resp.data.decode('utf-8')
+ assert 'Your session timed out due to inactivity' in text
+
+
+def test_csrf_error_page(config, source_app):
+ source_app.config['WTF_CSRF_ENABLED'] = True
+ with source_app.test_client() as app:
+ with InstrumentedApp(source_app) as ins:
+ resp = app.post('/create')
+ ins.assert_redirects(resp, '/')
+
+ resp = app.post('/create', follow_redirects=True)
+ text = resp.data.decode('utf-8')
+ assert 'Your session timed out due to inactivity' in text
diff --git a/securedrop/tests/utils/db_helper.py b/securedrop/tests/utils/db_helper.py
index 38101298cc..70ef0b0682 100644
--- a/securedrop/tests/utils/db_helper.py
+++ b/securedrop/tests/utils/db_helper.py
@@ -163,9 +163,6 @@ def submit(source, num_submissions):
def new_codename(client, session):
"""Helper function to go through the "generate codename" flow.
"""
- # clear the session because our tests have implicit reliance on each other
- session.clear()
-
client.get('/generate')
codename = session['codename']
client.post('/create')
diff --git a/securedrop/tests/utils/instrument.py b/securedrop/tests/utils/instrument.py
new file mode 100644
index 0000000000..aab250085f
--- /dev/null
+++ b/securedrop/tests/utils/instrument.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+"""
+Taken from: flask_testing.utils
+
+Flask unittest integration.
+
+:copyright: (c) 2010 by Dan Jacob.
+:license: BSD, see LICENSE for more details.
+"""
+from __future__ import absolute_import, with_statement
+
+try:
+ from urllib.parse import urlparse, urljoin
+except ImportError:
+ # Python 2 urlparse fallback
+ from urlparse import urlparse, urljoin
+
+import pytest
+
+from flask import template_rendered, message_flashed
+
+
+__all__ = ['InstrumentedApp']
+
+
+class ContextVariableDoesNotExist(Exception):
+ pass
+
+
+class InstrumentedApp:
+
+ def __init__(self, app):
+ self.app = app
+
+ def __enter__(self):
+ self.templates = []
+ self.flashed_messages = []
+ template_rendered.connect(self._add_template)
+ message_flashed.connect(self._add_flash_message)
+ return self
+
+ def __exit__(self, *nargs):
+ if getattr(self, 'app', None) is not None:
+ del self.app
+
+ del self.templates[:]
+ del self.flashed_messages[:]
+
+ template_rendered.disconnect(self._add_template)
+ message_flashed.disconnect(self._add_flash_message)
+
+ def _add_flash_message(self, app, message, category):
+ self.flashed_messages.append((message, category))
+
+ def _add_template(self, app, template, context):
+ if len(self.templates) > 0:
+ self.templates = []
+ self.templates.append((template, context))
+
+ def assert_message_flashed(self, message, category='message'):
+ """
+ Checks if a given message was flashed.
+
+ :param message: expected message
+ :param category: expected message category
+ """
+ for _message, _category in self.flashed_messages:
+ if _message == message and _category == category:
+ return True
+
+ raise AssertionError("Message '{}' in category '{}' wasn't flashed"
+ .format(message, category))
+
+ def assert_template_used(self, name, tmpl_name_attribute='name'):
+ """
+ Checks if a given template is used in the request. If the template
+ engine used is not Jinja2, provide ``tmpl_name_attribute`` with a
+ value of its `Template` class attribute name which contains the
+ provided ``name`` value.
+
+ :param name: template name
+ :param tmpl_name_attribute: template engine specific attribute name
+ """
+ used_templates = []
+
+ for template, context in self.templates:
+ if getattr(template, tmpl_name_attribute) == name:
+ return True
+
+ used_templates.append(template)
+
+ raise AssertionError("Template {} not used. Templates were used: {}"
+ .format(name, ' '.join(repr(used_templates))))
+
+ def get_context_variable(self, name):
+ """
+ Returns a variable from the context passed to the template.
+
+ Raises a ContextVariableDoesNotExist exception if does not exist in
+ context.
+
+ :param name: name of variable
+ """
+ for template, context in self.templates:
+ if name in context:
+ return context[name]
+ raise ContextVariableDoesNotExist
+
+ def assert_context(self, name, value, message=None):
+ """
+ Checks if given name exists in the template context
+ and equals the given value.
+
+ :versionadded: 0.2
+ :param name: name of context variable
+ :param value: value to check against
+ """
+
+ try:
+ assert self.get_context_variable(name) == value, message
+ except ContextVariableDoesNotExist:
+ pytest.fail(message or
+ "Context variable does not exist: {}".format(name))
+
+ def assert_redirects(self, response, location, message=None):
+ """
+ Checks if response is an HTTP redirect to the
+ given location.
+
+ :param response: Flask response
+ :param location: relative URL path to SERVER_NAME or an absolute URL
+ """
+ parts = urlparse(location)
+
+ if parts.netloc:
+ expected_location = location
+ else:
+ server_name = self.app.config.get('SERVER_NAME') or 'localhost'
+ expected_location = urljoin("http://%s" % server_name, location)
+
+ valid_status_codes = (301, 302, 303, 305, 307)
+ valid_status_code_str = ', '.join([str(code)
+ for code in valid_status_codes])
+ not_redirect = "HTTP Status {} expected but got {}" \
+ .format(valid_status_code_str, response.status_code)
+ assert (response.status_code in (valid_status_codes, message)
+ or not_redirect)
+ assert response.location == expected_location, message