diff --git a/.gitignore b/.gitignore index 7d04173..c6c3b67 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ docs/_build/ .coverage .coverage.* htmlcov/ +.hypothesis/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e29622..7be015d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/asottile/reorder_python_imports - rev: v1.2.0 + rev: v1.3.1 hooks: - id: reorder-python-imports args: ["--application-directories", "src"] @@ -9,7 +9,10 @@ repos: hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v1.4.0-1 + rev: v2.0.0 hooks: + - id: check-byte-order-marker + - id: trailing-whitespace + - id: end-of-file-fixer - id: flake8 - additional_dependencies: [flake8-bugbear] \ No newline at end of file + additional_dependencies: [flake8-bugbear] diff --git a/.travis.yml b/.travis.yml index 92acace..e87eaa0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,7 +25,9 @@ script: - tox cache: - - pip + directories: + - $HOME/.cache/pip + - $HOME/.cache/pre-commit branches: only: diff --git a/LICENSE.rst b/LICENSE.rst index fbfceb2..e506dca 100644 --- a/LICENSE.rst +++ b/LICENSE.rst @@ -44,4 +44,4 @@ The initial implementation of It's Dangerous was inspired by Django's signing module. Copyright © Django Software Foundation and individual contributors. -All rights reserved. \ No newline at end of file +All rights reserved. diff --git a/setup.cfg b/setup.cfg index bfbcafd..6614658 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,4 +33,4 @@ ignore = E203, E501, W503 # up to 88 allowed by bugbear B950 max-line-length = 80 # init is used to export public API, ignore import warnings -exclude = src/itsdangerous/__init__.py \ No newline at end of file +exclude = src/itsdangerous/__init__.py diff --git a/src/itsdangerous/_compat.py b/src/itsdangerous/_compat.py index de70f9c..2291bce 100644 --- a/src/itsdangerous/_compat.py +++ b/src/itsdangerous/_compat.py @@ -16,7 +16,7 @@ number_types = (numbers.Real, decimal.Decimal) -def constant_time_compare(val1, val2): +def _constant_time_compare(val1, val2): """Return ``True`` if the two strings are equal, ``False`` otherwise. @@ -43,4 +43,4 @@ def constant_time_compare(val1, val2): # Starting with 2.7/3.3 the standard library has a c-implementation for # constant time string compares. -constant_time_compare = getattr(hmac, "compare_digest", constant_time_compare) +constant_time_compare = getattr(hmac, "compare_digest", _constant_time_compare) diff --git a/src/itsdangerous/jws.py b/src/itsdangerous/jws.py index b1f31de..92e9ec8 100644 --- a/src/itsdangerous/jws.py +++ b/src/itsdangerous/jws.py @@ -190,13 +190,13 @@ def loads(self, s, salt=None, return_header=False): if "exp" not in header: raise BadSignature("Missing expiry date", payload=payload) + int_date_error = BadHeader("Expiry date is not an IntDate", payload=payload) try: header["exp"] = int(header["exp"]) except ValueError: - raise BadHeader("Expiry date is not valid timestamp", payload=payload) - - if not (isinstance(header["exp"], number_types) and header["exp"] > 0): - raise BadSignature("expiry date is not an IntDate", payload=payload) + raise int_date_error + if header["exp"] < 0: + raise int_date_error if header["exp"] < self.now(): raise SignatureExpired( diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 0000000..2043fad --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,11 @@ +import pytest + +from itsdangerous._compat import _constant_time_compare + + +@pytest.mark.parametrize( + ("a", "b", "expect"), + ((b"a", b"a", True), (b"a", b"b", False), (b"a", b"aa", False)), +) +def test_python_constant_time_compare(a, b, expect): + assert _constant_time_compare(a, b) == expect diff --git a/tests/test_encoding.py b/tests/test_encoding.py new file mode 100644 index 0000000..d60ec17 --- /dev/null +++ b/tests/test_encoding.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +import pytest + +from itsdangerous.encoding import base64_decode +from itsdangerous.encoding import base64_encode +from itsdangerous.encoding import bytes_to_int +from itsdangerous.encoding import int_to_bytes +from itsdangerous.encoding import want_bytes +from itsdangerous.exc import BadData + + +@pytest.mark.parametrize("value", (u"mañana", b"tomorrow")) +def test_want_bytes(value): + out = want_bytes(value) + assert isinstance(out, bytes) + + +@pytest.mark.parametrize("value", (u"無限", b"infinite")) +def test_base64(value): + enc = base64_encode(value) + assert isinstance(enc, bytes) + dec = base64_decode(enc) + assert dec == want_bytes(value) + + +def test_base64_bad(): + with pytest.raises(BadData): + base64_decode("12345") + + +@pytest.mark.parametrize( + ("value", "expect"), ((0, b""), (192, b"\xc0"), (18446744073709551615, b"\xff" * 8)) +) +def test_int_bytes(value, expect): + enc = int_to_bytes(value) + assert enc == expect + dec = bytes_to_int(enc) + assert dec == value diff --git a/tests/test_itsdangerous.py b/tests/test_itsdangerous.py deleted file mode 100755 index 1a9f1ec..0000000 --- a/tests/test_itsdangerous.py +++ /dev/null @@ -1,359 +0,0 @@ -#!/usr/bin/env python -import hashlib -import pickle -import time -import unittest -from datetime import datetime - -import pytest - -import itsdangerous -from itsdangerous._compat import PY2 -from itsdangerous._compat import text_type -from itsdangerous.encoding import want_bytes - -# Helper function for some unsafe string manipulation on encoded -# data. This is required for Python 3 but would break on Python 2 -if PY2: - - def _coerce_string(reference_string, value): - return value - - -else: - - def _coerce_string(reference_string, value): - assert isinstance(value, text_type), "rhs needs to be a string" - if type(reference_string) != type(value): - value = value.encode("utf-8") - return value - - -class UtilityTestCase(unittest.TestCase): - def test_want_bytes(self): - self.assertEqual(want_bytes(b"foobar"), b"foobar") - self.assertEqual(want_bytes(u"foobar"), b"foobar") - - -class SignerTestCase(unittest.TestCase): - signer_class = itsdangerous.Signer - - def make_signer(self, *args, **kwargs): - return self.signer_class(*args, **kwargs) - - def test_sign(self): - s = self.make_signer("secret-key") - assert isinstance(s.sign("my string"), bytes) - - def test_sign_invalid_separator(self): - with pytest.raises(ValueError) as excinfo: - self.make_signer("secret-key", sep="-") - assert "separator cannot be used" in str(excinfo.value) - - -class SerializerTestCase(unittest.TestCase): - serializer_class = itsdangerous.Serializer - - def make_serializer(self, *args, **kwargs): - return self.serializer_class(*args, **kwargs) - - def test_dumps_loads(self): - objects = ( - ["a", "list"], - "a string", - u"a unicode string \u2019", - {"a": "dictionary"}, - 42, - 42.5, - ) - s = self.make_serializer("Test") - for o in objects: - value = s.dumps(o) - self.assertNotEqual(o, value) - self.assertEqual(o, s.loads(value)) - - def test_decode_detects_tampering(self): - s = self.make_serializer("Test") - - transforms = ( - lambda s: s.upper(), - lambda s: s + _coerce_string(s, "a"), - lambda s: _coerce_string(s, "a") + s[1:], - lambda s: s.replace(_coerce_string(s, "."), _coerce_string(s, "")), - ) - value = {"foo": "bar", "baz": 1} - encoded = s.dumps(value) - self.assertEqual(value, s.loads(encoded)) - for transform in transforms: - self.assertRaises(itsdangerous.BadSignature, s.loads, transform(encoded)) - - def test_accepts_unicode(self): - objects = ( - ["a", "list"], - "a string", - u"a unicode string \u2019", - {"a": "dictionary"}, - 42, - 42.5, - ) - s = self.make_serializer("Test") - for o in objects: - value = s.dumps(o) - self.assertNotEqual(o, value) - self.assertEqual(o, s.loads(value)) - - def test_exception_attributes(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key) - ts = s.dumps(value) - - try: - s.loads(ts + _coerce_string(ts, "x")) - except itsdangerous.BadSignature as e: - self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b".", 1)[0]) - self.assertEqual(s.load_payload(e.payload), value) - else: - self.fail("Did not get bad signature") - - def test_unsafe_load(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key) - ts = s.dumps(value) - self.assertEqual(s.loads_unsafe(ts), (True, u"hello")) - self.assertEqual(s.loads_unsafe(ts, salt="modified"), (False, u"hello")) - - def test_load_unsafe_with_unicode_strings(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key) - ts = s.dumps(value) - self.assertEqual(s.loads_unsafe(ts), (True, u"hello")) - self.assertEqual(s.loads_unsafe(ts, salt="modified"), (False, u"hello")) - - try: - s.loads(ts, salt="modified") - except itsdangerous.BadSignature as e: - self.assertEqual(s.load_payload(e.payload), u"hello") - - def test_signer_kwargs(self): - secret_key = "predictable-key" - value = "hello" - s = self.make_serializer( - secret_key, - signer_kwargs=dict(digest_method=hashlib.md5, key_derivation="hmac"), - ) - ts = s.dumps(value) - self.assertEqual(s.loads(ts), u"hello") - - def test_serializer_kwargs(self): - s = self.make_serializer( - "predictable-key", serializer_kwargs={"sort_keys": True} - ) - - # pickle tests pop serializer kwargs, so skip this test for those - if not s.serializer_kwargs: - return - - ts1 = s.dumps({"c": 3, "a": 1, "b": 2}) - ts2 = s.dumps(dict(a=1, b=2, c=3)) - - self.assertEqual(ts1, ts2) - - -class TimedSerializerTestCase(SerializerTestCase): - serializer_class = itsdangerous.TimedSerializer - - def setUp(self): - self._time = time.time - time.time = lambda: 0 - - def tearDown(self): - time.time = self._time - - def test_decode_with_timeout(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key) - ts = s.dumps(value) - self.assertNotEqual(ts, itsdangerous.Serializer(secret_key).dumps(value)) - - self.assertEqual(s.loads(ts), value) - time.time = lambda: 10 - self.assertEqual(s.loads(ts, max_age=11), value) - self.assertEqual(s.loads(ts, max_age=10), value) - self.assertRaises(itsdangerous.SignatureExpired, s.loads, ts, max_age=9) - - def test_decode_return_timestamp(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key) - ts = s.dumps(value) - loaded, timestamp = s.loads(ts, return_timestamp=True) - self.assertEqual(loaded, value) - self.assertEqual(timestamp, datetime.utcfromtimestamp(time.time())) - - def test_exception_attributes(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key) - ts = s.dumps(value) - try: - s.loads(ts, max_age=-1) - except itsdangerous.SignatureExpired as e: - self.assertEqual(e.date_signed, datetime.utcfromtimestamp(time.time())) - self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b".", 2)[0]) - self.assertEqual(s.load_payload(e.payload), value) - else: - self.fail("Did not get expiration") - - -class JSONWebSignatureSerializerTestCase(SerializerTestCase): - serializer_class = itsdangerous.JSONWebSignatureSerializer - - def test_decode_return_header(self): - secret_key = "predictable-key" - value = u"hello" - header = {"typ": "dummy"} - - s = self.make_serializer(secret_key) - full_header = header.copy() - full_header["alg"] = s.algorithm_name - - ts = s.dumps(value, header_fields=header) - loaded, loaded_header = s.loads(ts, return_header=True) - self.assertEqual(loaded, value) - self.assertEqual(loaded_header, full_header) - - def test_hmac_algorithms(self): - secret_key = "predictable-key" - value = u"hello" - - algorithms = ("HS256", "HS384", "HS512") - for algorithm in algorithms: - s = self.make_serializer(secret_key, algorithm_name=algorithm) - ts = s.dumps(value) - self.assertEqual(s.loads(ts), value) - - def test_none_algorithm(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key) - ts = s.dumps(value) - self.assertEqual(s.loads(ts), value) - - def test_algorithm_mismatch(self): - secret_key = "predictable-key" - value = u"hello" - - s = self.make_serializer(secret_key, algorithm_name="HS256") - ts = s.dumps(value) - - s = self.make_serializer(secret_key, algorithm_name="HS384") - try: - s.loads(ts) - except itsdangerous.BadSignature as e: - self.assertEqual(s.load_payload(e.payload), value) - else: - self.fail("Did not get algorithm mismatch") - - -class TimedJSONWebSignatureSerializerTest(unittest.TestCase): - serializer_class = itsdangerous.TimedJSONWebSignatureSerializer - - def test_token_contains_issue_date_and_expiry_time(self): - s = self.serializer_class("secret") - result = s.dumps({"es": "geht"}) - self.assertTrue("exp" in s.loads(result, return_header=True)[1]) - self.assertTrue("iat" in s.loads(result, return_header=True)[1]) - - def test_token_expires_at_given_expiry_time(self): - s = self.serializer_class("secret") - an_hour_ago = int(time.time()) - 3601 - s.now = lambda: an_hour_ago - result = s.dumps({"foo": "bar"}) - s = self.serializer_class("secret") - self.assertRaises(itsdangerous.SignatureExpired, s.loads, result) - - def test_token_is_invalid_if_expiry_time_is_missing(self): - bad_s = itsdangerous.JSONWebSignatureSerializer("secret") - invalid_token_empty = bad_s.dumps({}) - s = self.serializer_class("secret") - self.assertRaises(itsdangerous.BadSignature, s.loads, invalid_token_empty) - - def test_token_is_invalid_if_expiry_time_is_negative(self): - s = self.serializer_class("secret", expires_in=-123) - result = s.dumps({"foo": "bar"}) - self.assertRaises(itsdangerous.BadSignature, s.loads, result) - - def test_creating_a_token_adds_the_expiry_date(self): - expires_in_two_hours = 7200 - s = self.serializer_class("secret", expires_in=expires_in_two_hours) - result, header = s.loads(s.dumps({"foo": "bar"}), return_header=True) - self.assertEqual(header["exp"] - header["iat"], expires_in_two_hours) - - -class URLSafeSerializerMixin(object): - def test_is_base62(self): - allowed = frozenset( - b"0123456789abcdefghijklmnopqrstuvwxyz" + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ_-." - ) - objects = ( - ["a", "list"], - "a string", - u"a unicode string \u2019", - {"a": "dictionary"}, - 42, - 42.5, - ) - s = self.make_serializer("Test") - for o in objects: - value = want_bytes(s.dumps(o)) - self.assertTrue(set(value).issubset(set(allowed))) - self.assertNotEqual(o, value) - self.assertEqual(o, s.loads(value)) - - def test_invalid_base64_does_not_fail_load_payload(self): - s = itsdangerous.URLSafeSerializer("aha!") - self.assertRaises(itsdangerous.BadPayload, s.load_payload, b"kZ4m3du844lIN") - - -class PickleSerializerMixin(object): - def make_serializer(self, *args, **kwargs): - kwargs.pop("serializer_kwargs", "") - kwargs.setdefault("serializer", pickle) - return super(PickleSerializerMixin, self).make_serializer(*args, **kwargs) - - -class URLSafeSerializerTestCase(URLSafeSerializerMixin, SerializerTestCase): - serializer_class = itsdangerous.URLSafeSerializer - - -class URLSafeTimedSerializerTestCase(URLSafeSerializerMixin, TimedSerializerTestCase): - serializer_class = itsdangerous.URLSafeTimedSerializer - - -class PickleSerializerTestCase(PickleSerializerMixin, SerializerTestCase): - pass - - -class PickleTimedSerializerTestCase(PickleSerializerMixin, TimedSerializerTestCase): - pass - - -class PickleURLSafeSerializerTestCase(PickleSerializerMixin, URLSafeSerializerTestCase): - pass - - -class PickleURLSafeTimedSerializerTestCase( - PickleSerializerMixin, URLSafeTimedSerializerTestCase -): - pass diff --git a/tests/test_jws.py b/tests/test_jws.py new file mode 100644 index 0000000..9938311 --- /dev/null +++ b/tests/test_jws.py @@ -0,0 +1,122 @@ +from functools import partial + +import pytest +from tests.test_serializer import TestSerializer +from tests.test_timed import TestTimedSerializer + +from itsdangerous.exc import BadData +from itsdangerous.exc import BadHeader +from itsdangerous.exc import BadPayload +from itsdangerous.exc import BadSignature +from itsdangerous.exc import SignatureExpired +from itsdangerous.jws import JSONWebSignatureSerializer +from itsdangerous.jws import TimedJSONWebSignatureSerializer + + +class TestJWSSerializer(TestSerializer): + @pytest.fixture() + def serializer_factory(self): + return partial(JSONWebSignatureSerializer, secret_key="secret-key") + + test_signer_cls = None + test_signer_kwargs = None + + @pytest.mark.parametrize("algorithm_name", ("HS256", "HS384", "HS512", "none")) + def test_algorithm(self, serializer_factory, algorithm_name): + serializer = serializer_factory(algorithm_name=algorithm_name) + assert serializer.loads(serializer.dumps("value")) == "value" + + def test_invalid_algorithm(self, serializer_factory): + with pytest.raises(NotImplementedError) as exc_info: + serializer_factory(algorithm_name="invalid") + + assert "not supported" in str(exc_info.value) + + def test_algorithm_mismatch(self, serializer_factory, serializer): + other = serializer_factory(algorithm_name="HS256") + other.algorithm = serializer.algorithm + signed = other.dumps("value") + + with pytest.raises(BadHeader) as exc_info: + serializer.loads(signed) + + assert "mismatch" in str(exc_info.value) + + @pytest.mark.parametrize( + ("value", "exc_cls", "match"), + ( + ("ab", BadPayload, '"."'), + ("a.b", BadHeader, "base64 decode"), + ("ew.b", BadPayload, "base64 decode"), + ("ew.ab", BadData, "malformed"), + ("W10.ab", BadHeader, "JSON object"), + ), + ) + def test_load_payload_exceptions(self, serializer, value, exc_cls, match): + signer = serializer.make_signer() + signed = signer.sign(value) + + with pytest.raises(exc_cls) as exc_info: + serializer.loads(signed) + + assert match in str(exc_info.value) + + +class TestTimedJWSSerializer(TestJWSSerializer, TestTimedSerializer): + @pytest.fixture() + def serializer_factory(self): + return partial( + TimedJSONWebSignatureSerializer, secret_key="secret-key", expires_in=10 + ) + + def test_default_expires_in(self, serializer_factory): + serializer = serializer_factory(expires_in=None) + assert serializer.expires_in == serializer.DEFAULT_EXPIRES_IN + + test_max_age = None + + def test_exp(self, serializer, value, ts, freeze): + signed = serializer.dumps(value) + freeze.tick() + assert serializer.loads(signed) == value + freeze.tick(10) + + with pytest.raises(SignatureExpired) as exc_info: + serializer.loads(signed) + + assert exc_info.value.date_signed == ts + assert exc_info.value.payload == value + + test_return_payload = None + + def test_return_header(self, serializer, value, ts): + signed = serializer.dumps(value) + payload, header = serializer.loads(signed, return_header=True) + date_signed = serializer.get_issue_date(header) + assert (payload, date_signed) == (value, ts) + + def test_missing_exp(self, serializer): + header = serializer.make_header(None) + del header["exp"] + signer = serializer.make_signer() + signed = signer.sign(serializer.dump_payload(header, "value")) + + with pytest.raises(BadSignature): + serializer.loads(signed) + + @pytest.mark.parametrize("exp", ("invalid", -1)) + def test_invalid_exp(self, serializer, exp): + header = serializer.make_header(None) + header["exp"] = exp + signer = serializer.make_signer() + signed = signer.sign(serializer.dump_payload(header, "value")) + + with pytest.raises(BadHeader) as exc_info: + serializer.loads(signed) + + assert "IntDate" in str(exc_info.value) + + def test_invalid_iat(self, serializer): + header = serializer.make_header(None) + header["iat"] = "invalid" + assert serializer.get_issue_date(header) is None diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 0000000..465d507 --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,133 @@ +import pickle +from functools import partial +from io import BytesIO +from io import StringIO + +import pytest + +from itsdangerous.exc import BadPayload +from itsdangerous.exc import BadSignature +from itsdangerous.serializer import Serializer + + +def coerce_str(ref, s): + if not isinstance(s, type(ref)): + return s.encode("utf8") + + return s + + +class TestSerializer(object): + @pytest.fixture(params=(Serializer, partial(Serializer, serializer=pickle))) + def serializer_factory(self, request): + return partial(request.param, secret_key="secret_key") + + @pytest.fixture() + def serializer(self, serializer_factory): + return serializer_factory() + + @pytest.fixture() + def value(self): + return {"id": 42} + + @pytest.mark.parametrize( + "value", (None, True, "str", u"text", [1, 2, 3], {"id": 42}) + ) + def test_serializer(self, serializer, value): + assert serializer.loads(serializer.dumps(value)) == value + + @pytest.mark.parametrize( + "transform", + ( + lambda s: s.upper(), + lambda s: s + coerce_str(s, "a"), + lambda s: coerce_str(s, "a") + s[1:], + lambda s: s.replace(coerce_str(s, "."), coerce_str(s, "")), + ), + ) + def test_changed_value(self, serializer, value, transform): + signed = serializer.dumps(value) + assert serializer.loads(signed) == value + changed = transform(signed) + + with pytest.raises(BadSignature): + serializer.loads(changed) + + def test_bad_signature_exception(self, serializer, value): + bad_signed = serializer.dumps(value)[:-1] + + with pytest.raises(BadSignature) as exc_info: + serializer.loads(bad_signed) + + assert serializer.load_payload(exc_info.value.payload) == value + + def test_bad_payload_exception(self, serializer, value): + original = serializer.dumps(value) + payload = original.rsplit(coerce_str(original, "."), 1)[0] + bad = serializer.make_signer().sign(payload[:-1]) + + with pytest.raises(BadPayload) as exc_info: + serializer.loads(bad) + + assert exc_info.value.original_error is not None + + def test_loads_unsafe(self, serializer, value): + signed = serializer.dumps(value) + assert serializer.loads_unsafe(signed) == (True, value) + + bad_signed = signed[:-1] + assert serializer.loads_unsafe(bad_signed) == (False, value) + + payload = signed.rsplit(coerce_str(signed, "."), 1)[0] + bad_payload = serializer.make_signer().sign(payload[:-1])[:-1] + assert serializer.loads_unsafe(bad_payload) == (False, None) + + class BadUnsign(serializer.signer): + def unsign(self, signed_value, *args, **kwargs): + try: + return super(BadUnsign, self).unsign(signed_value, *args, **kwargs) + except BadSignature as e: + e.payload = None + raise + + serializer.signer = BadUnsign + assert serializer.loads_unsafe(bad_signed) == (False, None) + + def test_file(self, serializer, value): + f = BytesIO() if isinstance(serializer.dumps(value), bytes) else StringIO() + serializer.dump(value, f) + f.seek(0) + assert serializer.load(f) == value + f.seek(0) + assert serializer.load_unsafe(f) == (True, value) + + def test_alt_salt(self, serializer, value): + signed = serializer.dumps(value, salt="other") + + with pytest.raises(BadSignature): + serializer.loads(signed) + + assert serializer.loads(signed, salt="other") == value + + def test_signer_cls(self, serializer_factory, serializer, value): + class Other(serializer.signer): + default_key_derivation = "hmac" + + other = serializer_factory(signer=Other) + assert other.loads(other.dumps(value)) == value + assert other.dumps(value) != serializer.dumps(value) + + def test_signer_kwargs(self, serializer_factory, serializer, value): + other = serializer_factory(signer_kwargs={"key_derivation": "hmac"}) + assert other.loads(other.dumps(value)) == value + assert other.dumps("value") != serializer.dumps("value") + + def test_serializer_kwargs(self, serializer_factory): + serializer = serializer_factory(serializer_kwargs={"skipkeys": True}) + + try: + serializer.serializer.dumps(None, skipkeys=True) + except TypeError: + return + + assert serializer.loads(serializer.dumps({(): 1})) == {} diff --git a/tests/test_signer.py b/tests/test_signer.py new file mode 100644 index 0000000..5f7fe8e --- /dev/null +++ b/tests/test_signer.py @@ -0,0 +1,99 @@ +import hashlib +from functools import partial + +import pytest + +from itsdangerous.exc import BadSignature +from itsdangerous.signer import HMACAlgorithm +from itsdangerous.signer import NoneAlgorithm +from itsdangerous.signer import Signer +from itsdangerous.signer import SigningAlgorithm + + +class _ReverseAlgorithm(SigningAlgorithm): + def get_signature(self, key, value): + return (key + value)[::-1] + + +class TestSigner(object): + @pytest.fixture() + def signer_factory(self): + return partial(Signer, secret_key="secret-key") + + @pytest.fixture() + def signer(self, signer_factory): + return signer_factory() + + def test_signer(self, signer): + signed = signer.sign("my string") + assert isinstance(signed, bytes) + assert signer.validate(signed) + out = signer.unsign(signed) + assert out == b"my string" + + def test_no_separator(self, signer): + signed = signer.sign("my string") + signed = signed.replace(signer.sep, b"*", 1) + assert not signer.validate(signed) + + with pytest.raises(BadSignature): + signer.unsign(signed) + + def test_broken_signature(self, signer): + signed = signer.sign("b") + bad_signed = signed[:-1] + bad_sig = bad_signed.rsplit(b".", 1)[1] + assert not signer.verify_signature(b"b", bad_sig) + + with pytest.raises(BadSignature) as exc_info: + signer.unsign(bad_signed) + + assert exc_info.value.payload == b"b" + + def test_changed_value(self, signer): + signed = signer.sign("my string") + signed = signed.replace(b"my", b"other", 1) + assert not signer.validate(signed) + + with pytest.raises(BadSignature): + signer.unsign(signed) + + def test_invalid_separator(self, signer_factory): + with pytest.raises(ValueError) as exc_info: + signer_factory(sep="-") + + assert "separator cannot be used" in str(exc_info.value) + + @pytest.mark.parametrize( + "key_derivation", ("concat", "django-concat", "hmac", "none") + ) + def test_key_derivation(self, signer_factory, key_derivation): + signer = signer_factory(key_derivation=key_derivation) + assert signer.unsign(signer.sign("value")) == b"value" + + def test_invalid_key_derivation(self, signer_factory): + signer = signer_factory(key_derivation="invalid") + + with pytest.raises(TypeError): + signer.derive_key() + + def test_digest_method(self, signer_factory): + signer = signer_factory(digest_method=hashlib.md5) + assert signer.unsign(signer.sign("value")) == b"value" + + @pytest.mark.parametrize( + "algorithm", (None, NoneAlgorithm(), HMACAlgorithm(), _ReverseAlgorithm()) + ) + def test_algorithm(self, signer_factory, algorithm): + signer = signer_factory(algorithm=algorithm) + assert signer.unsign(signer.sign("value")) == b"value" + + if algorithm is None: + assert signer.algorithm.digest_method == signer.digest_method + + +def test_abstract_algorithm(): + alg = SigningAlgorithm() + + with pytest.raises(NotImplementedError): + alg.get_signature("a", "b") diff --git a/tests/test_timed.py b/tests/test_timed.py new file mode 100644 index 0000000..36c1a86 --- /dev/null +++ b/tests/test_timed.py @@ -0,0 +1,85 @@ +from datetime import datetime +from functools import partial + +import pytest +from freezegun import freeze_time +from tests.test_serializer import TestSerializer +from tests.test_signer import TestSigner + +from itsdangerous import Signer +from itsdangerous.exc import BadTimeSignature +from itsdangerous.exc import SignatureExpired +from itsdangerous.timed import TimedSerializer +from itsdangerous.timed import TimestampSigner + + +class FreezeMixin(object): + @pytest.fixture() + def ts(self): + return datetime(2011, 6, 24, 0, 9, 5) + + @pytest.fixture(autouse=True) + def freeze(self, ts): + with freeze_time(ts) as ft: + yield ft + + +class TestTimestampSigner(FreezeMixin, TestSigner): + @pytest.fixture() + def signer_factory(self): + return partial(TimestampSigner, secret_key="secret-key") + + def test_max_age(self, signer, ts, freeze): + signed = signer.sign("value") + freeze.tick() + assert signer.unsign(signed, max_age=10) == b"value" + freeze.tick(10) + + with pytest.raises(SignatureExpired) as exc_info: + signer.unsign(signed, max_age=10) + + assert exc_info.value.date_signed == ts + + def test_return_timestamp(self, signer, ts): + signed = signer.sign("value") + assert signer.unsign(signed, return_timestamp=True) == (b"value", ts) + + def test_timestamp_missing(self, signer): + other = Signer("secret-key") + signed = other.sign("value") + + with pytest.raises(BadTimeSignature) as exc_info: + signer.unsign(signed) + + assert "missing" in str(exc_info.value) + + def test_malformed_timestamp(self, signer): + other = Signer("secret-key") + signed = other.sign(b"value.____________") + + with pytest.raises(BadTimeSignature) as exc_info: + signer.unsign(signed) + + assert "Malformed" in str(exc_info.value) + + +class TestTimedSerializer(FreezeMixin, TestSerializer): + @pytest.fixture() + def serializer_factory(self): + return partial(TimedSerializer, secret_key="secret_key") + + def test_max_age(self, serializer, value, ts, freeze): + signed = serializer.dumps(value) + freeze.tick() + assert serializer.loads(signed, max_age=10) == value + freeze.tick(10) + + with pytest.raises(SignatureExpired) as exc_info: + serializer.loads(signed, max_age=10) + + assert exc_info.value.date_signed == ts + assert serializer.load_payload(exc_info.value.payload) == value + + def test_return_payload(self, serializer, value, ts): + signed = serializer.dumps(value) + assert serializer.loads(signed, return_timestamp=True) == (value, ts) diff --git a/tests/test_url_safe.py b/tests/test_url_safe.py new file mode 100644 index 0000000..5cb7f2c --- /dev/null +++ b/tests/test_url_safe.py @@ -0,0 +1,24 @@ +from functools import partial + +import pytest +from tests.test_serializer import TestSerializer +from tests.test_timed import TestTimedSerializer + +from itsdangerous import URLSafeSerializer +from itsdangerous import URLSafeTimedSerializer + + +class TestURLSafeSerializer(TestSerializer): + @pytest.fixture() + def serializer_factory(self): + return partial(URLSafeSerializer, secret_key="secret-key") + + @pytest.fixture(params=({"id": 42}, pytest.param("a" * 1000, id="zlib"))) + def value(self, request): + return request.param + + +class TestURLSafeTimedSerializer(TestURLSafeSerializer, TestTimedSerializer): + @pytest.fixture() + def serializer_factory(self): + return partial(URLSafeTimedSerializer, secret_key="secret-key") diff --git a/tox.ini b/tox.ini index a91897d..d64c2c9 100644 --- a/tox.ini +++ b/tox.ini @@ -9,7 +9,9 @@ skip_missing_interpreters = true [testenv] setenv = COVERAGE_FILE = .coverage.{envname} -deps = pytest-cov +deps = + pytest-cov + freezegun commands = pytest --cov --cov-report= {posargs} [testenv:stylecheck]