From 07bab02bc3a62c82103993cc1d214d6cd608ef6c Mon Sep 17 00:00:00 2001 From: Matus Valo Date: Wed, 15 Sep 2021 17:00:08 +0200 Subject: [PATCH] Remove dependency to case (#1389) * Remove dependency to case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix flake8 errors * Remove unused code Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- conftest.py | 4 + requirements/test.txt | 1 - t/mocks.py | 26 +++- t/unit/conftest.py | 236 +++++++++++++++++++++++++++++++ t/unit/test_common.py | 3 +- t/unit/test_compression.py | 9 +- t/unit/test_mixins.py | 2 +- t/unit/test_serialization.py | 9 +- t/unit/transport/test_pyamqp.py | 5 +- t/unit/transport/test_redis.py | 95 +++++++------ t/unit/utils/test_amq_manager.py | 13 +- t/unit/utils/test_compat.py | 10 +- t/unit/utils/test_functional.py | 17 ++- 13 files changed, 343 insertions(+), 87 deletions(-) diff --git a/conftest.py b/conftest.py index ab0a5d90e..3fc0a687f 100644 --- a/conftest.py +++ b/conftest.py @@ -16,6 +16,10 @@ def pytest_configure(config): "markers", "env(name): mark test to run only on named environment", ) + config.addinivalue_line("markers", "replace_module_value") + config.addinivalue_line("markers", "masked_modules") + config.addinivalue_line("markers", "ensured_modules") + config.addinivalue_line("markers", "sleepdeprived_patched_module") def pytest_runtest_setup(item): diff --git a/requirements/test.txt b/requirements/test.txt index cb4823306..7566efae8 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,4 @@ pytz>dev -case>=1.5.2 pytest~=6.2 pytest-sugar Pyro4 diff --git a/t/mocks.py b/t/mocks.py index d1bccf391..b02a34d6a 100644 --- a/t/mocks.py +++ b/t/mocks.py @@ -1,12 +1,34 @@ from itertools import count from unittest.mock import Mock -from case import ContextMock - from kombu.transport import base from kombu.utils import json +class _ContextMock(Mock): + """Dummy class implementing __enter__ and __exit__ + as the :keyword:`with` statement requires these to be implemented + in the class, not just the instance.""" + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + pass + + +def ContextMock(*args, **kwargs): + """Mock that mocks :keyword:`with` statement contexts.""" + obj = _ContextMock(*args, **kwargs) + obj.attach_mock(_ContextMock(), '__enter__') + obj.attach_mock(_ContextMock(), '__exit__') + obj.__enter__.return_value = obj + # if __exit__ return a value the exception is ignored, + # so it must return None here. + obj.__exit__.return_value = None + return obj + + def PromiseMock(*args, **kwargs): m = Mock(*args, **kwargs) diff --git a/t/unit/conftest.py b/t/unit/conftest.py index d9cc02444..b798e3e59 100644 --- a/t/unit/conftest.py +++ b/t/unit/conftest.py @@ -1,11 +1,19 @@ import atexit +import builtins +import io import os import sys +import types +from unittest.mock import MagicMock import pytest from kombu.exceptions import VersionMismatch +_SIO_write = io.StringIO.write +_SIO_init = io.StringIO.__init__ +sentinel = object() + @pytest.fixture(scope='session') def multiprocessing_workaround(request): @@ -88,3 +96,231 @@ def cover_all_modules(): # so coverage sees all our modules. if is_in_coverage(): import_all_modules() + + +class WhateverIO(io.StringIO): + + def __init__(self, v=None, *a, **kw): + _SIO_init(self, v.decode() if isinstance(v, bytes) else v, *a, **kw) + + def write(self, data): + _SIO_write(self, data.decode() if isinstance(data, bytes) else data) + + +def noop(*args, **kwargs): + pass + + +def module_name(s): + if isinstance(s, bytes): + return s.decode() + return s + + +class _patching: + + def __init__(self, monkeypatch, request): + self.monkeypatch = monkeypatch + self.request = request + + def __getattr__(self, name): + return getattr(self.monkeypatch, name) + + def __call__(self, path, value=sentinel, name=None, + new=MagicMock, **kwargs): + value = self._value_or_mock(value, new, name, path, **kwargs) + self.monkeypatch.setattr(path, value) + return value + + def _value_or_mock(self, value, new, name, path, **kwargs): + if value is sentinel: + value = new(name=name or path.rpartition('.')[2]) + for k, v in kwargs.items(): + setattr(value, k, v) + return value + + def setattr(self, target, name=sentinel, value=sentinel, **kwargs): + # alias to __call__ with the interface of pytest.monkeypatch.setattr + if value is sentinel: + value, name = name, None + return self(target, value, name=name) + + def setitem(self, dic, name, value=sentinel, new=MagicMock, **kwargs): + # same as pytest.monkeypatch.setattr but default value is MagicMock + value = self._value_or_mock(value, new, name, dic, **kwargs) + self.monkeypatch.setitem(dic, name, value) + return value + + +class _stdouts: + + def __init__(self, stdout, stderr): + self.stdout = stdout + self.stderr = stderr + + +@pytest.fixture +def stdouts(): + """Override `sys.stdout` and `sys.stderr` with `StringIO` + instances. + Decorator example:: + @mock.stdouts + def test_foo(self, stdout, stderr): + something() + self.assertIn('foo', stdout.getvalue()) + Context example:: + with mock.stdouts() as (stdout, stderr): + something() + self.assertIn('foo', stdout.getvalue()) + """ + prev_out, prev_err = sys.stdout, sys.stderr + prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__ + mystdout, mystderr = WhateverIO(), WhateverIO() + sys.stdout = sys.__stdout__ = mystdout + sys.stderr = sys.__stderr__ = mystderr + + try: + yield _stdouts(mystdout, mystderr) + finally: + sys.stdout = prev_out + sys.stderr = prev_err + sys.__stdout__ = prev_rout + sys.__stderr__ = prev_rerr + + +@pytest.fixture +def patching(monkeypatch, request): + """Monkeypath.setattr shortcut. + Example: + .. code-block:: python + def test_foo(patching): + # execv value here will be mock.MagicMock by default. + execv = patching('os.execv') + patching('sys.platform', 'darwin') # set concrete value + patching.setenv('DJANGO_SETTINGS_MODULE', 'x.settings') + # val will be of type mock.MagicMock by default + val = patching.setitem('path.to.dict', 'KEY') + """ + return _patching(monkeypatch, request) + + +@pytest.fixture +def sleepdeprived(request): + """Mock sleep method in patched module to do nothing. + + Example: + >>> import time + >>> @pytest.mark.sleepdeprived_patched_module(time) + >>> def test_foo(self, patched_module): + >>> pass + """ + module = request.node.get_closest_marker( + "sleepdeprived_patched_module").args[0] + old_sleep, module.sleep = module.sleep, noop + try: + yield + finally: + module.sleep = old_sleep + + +@pytest.fixture +def module_exists(request): + """Patch one or more modules to ensure they exist. + + A module name with multiple paths (e.g. gevent.monkey) will + ensure all parent modules are also patched (``gevent`` + + ``gevent.monkey``). + + Example: + >>> @pytest.mark.ensured_modules('gevent.monkey') + >>> def test_foo(self, module_exists): + ... pass + + """ + gen = [] + old_modules = [] + modules = request.node.get_closest_marker("ensured_modules").args + for module in modules: + if isinstance(module, str): + module = types.ModuleType(module_name(module)) + gen.append(module) + if module.__name__ in sys.modules: + old_modules.append(sys.modules[module.__name__]) + sys.modules[module.__name__] = module + name = module.__name__ + if '.' in name: + parent, _, attr = name.rpartition('.') + setattr(sys.modules[parent], attr, module) + try: + yield + finally: + for module in gen: + sys.modules.pop(module.__name__, None) + for module in old_modules: + sys.modules[module.__name__] = module + + +# Taken from +# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py +@pytest.fixture +def mask_modules(request): + """Ban some modules from being importable inside the context + + For example:: + + >>> @pytest.mark.masked_modules('gevent.monkey') + >>> def test_foo(self, mask_modules): + ... try: + ... import sys + ... except ImportError: + ... print('sys not found') + sys not found + """ + realimport = builtins.__import__ + modnames = request.node.get_closest_marker("masked_modules").args + + def myimp(name, *args, **kwargs): + if name in modnames: + raise ImportError('No module named %s' % name) + else: + return realimport(name, *args, **kwargs) + + builtins.__import__ = myimp + try: + yield + finally: + builtins.__import__ = realimport + + +@pytest.fixture +def replace_module_value(request): + """Mock module value, given a module, attribute name and value. + + Decorator example:: + + >>> @pytest.mark.replace_module_value(module, 'CONSTANT', 3.03) + >>> def test_foo(self, replace_module_value): + ... pass + """ + module = request.node.get_closest_marker("replace_module_value").args[0] + name = request.node.get_closest_marker("replace_module_value").args[1] + value = request.node.get_closest_marker("replace_module_value").args[2] + has_prev = hasattr(module, name) + prev = getattr(module, name, None) + if value: + setattr(module, name, value) + else: + try: + delattr(module, name) + except AttributeError: + pass + try: + yield + finally: + if prev is not None: + setattr(module, name, prev) + if not has_prev: + try: + delattr(module, name) + except AttributeError: + pass diff --git a/t/unit/test_common.py b/t/unit/test_common.py index 33e2571fa..c99bcb0df 100644 --- a/t/unit/test_common.py +++ b/t/unit/test_common.py @@ -3,13 +3,12 @@ import pytest from amqp import RecoverableConnectionError -from case import ContextMock from kombu import common from kombu.common import (PREFETCH_COUNT_MAX, Broadcast, QoS, collect_replies, declaration_cached, generate_oid, ignore_errors, maybe_declare, send_reply) -from t.mocks import MockPool +from t.mocks import ContextMock, MockPool def test_generate_oid(): diff --git a/t/unit/test_compression.py b/t/unit/test_compression.py index d62444af6..f1f426b74 100644 --- a/t/unit/test_compression.py +++ b/t/unit/test_compression.py @@ -1,7 +1,6 @@ import sys import pytest -from case import mock from kombu import compression @@ -71,8 +70,8 @@ def test_compress__decompress__zstd(self): d = compression.decompress(c, ctype) assert d == text - @mock.mask_modules('bz2') - def test_no_bz2(self): + @pytest.mark.masked_modules('bz2') + def test_no_bz2(self, mask_modules): c = sys.modules.pop('kombu.compression') try: import kombu.compression @@ -81,8 +80,8 @@ def test_no_bz2(self): if c is not None: sys.modules['kombu.compression'] = c - @mock.mask_modules('lzma') - def test_no_lzma(self): + @pytest.mark.masked_modules('lzma') + def test_no_lzma(self, mask_modules): c = sys.modules.pop('kombu.compression') try: import kombu.compression diff --git a/t/unit/test_mixins.py b/t/unit/test_mixins.py index a86c39787..04a56a6c0 100644 --- a/t/unit/test_mixins.py +++ b/t/unit/test_mixins.py @@ -2,9 +2,9 @@ from unittest.mock import Mock, patch import pytest -from case import ContextMock from kombu.mixins import ConsumerMixin +from t.mocks import ContextMock def Message(body, content_type='text/plain', content_encoding='utf-8'): diff --git a/t/unit/test_serialization.py b/t/unit/test_serialization.py index 29c9d8095..14952e5e9 100644 --- a/t/unit/test_serialization.py +++ b/t/unit/test_serialization.py @@ -4,7 +4,6 @@ from unittest.mock import call, patch import pytest -from case import mock import t.skip from kombu.exceptions import ContentDisallowed, DecodeError, EncodeError @@ -294,14 +293,14 @@ def test_raw_encode(self): 'application/data', 'binary', b'foo', ) - @mock.mask_modules('yaml') - def test_register_yaml__no_yaml(self): + @pytest.mark.masked_modules('yaml') + def test_register_yaml__no_yaml(self, mask_modules): register_yaml() with pytest.raises(SerializerNotInstalled): loads('foo', 'application/x-yaml', 'utf-8') - @mock.mask_modules('msgpack') - def test_register_msgpack__no_msgpack(self): + @pytest.mark.masked_modules('msgpack') + def test_register_msgpack__no_msgpack(self, mask_modules): register_msgpack() with pytest.raises(SerializerNotInstalled): loads('foo', 'application/x-msgpack', 'utf-8') diff --git a/t/unit/transport/test_pyamqp.py b/t/unit/transport/test_pyamqp.py index de5473332..d5f6d7e24 100644 --- a/t/unit/transport/test_pyamqp.py +++ b/t/unit/transport/test_pyamqp.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from case import mock from kombu import Connection from kombu.transport import pyamqp @@ -133,8 +132,8 @@ def test_close_connection(self): assert connection.client is None connection.close.assert_called_with() - @mock.mask_modules('ssl') - def test_import_no_ssl(self): + @pytest.mark.masked_modules('ssl') + def test_import_no_ssl(self, mask_modules): pm = sys.modules.pop('amqp.connection') try: from amqp.connection import SSLError diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index 0841c41aa..4222d2448 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -7,13 +7,47 @@ from unittest.mock import ANY, Mock, call, patch import pytest -from case import ContextMock, mock from kombu import Connection, Consumer, Exchange, Producer, Queue from kombu.exceptions import InconsistencyError, VersionMismatch from kombu.transport import virtual from kombu.utils import eventio # patch poll from kombu.utils.json import dumps +from t.mocks import ContextMock + + +def _redis_modules(): + + class ConnectionError(Exception): + pass + + class AuthenticationError(Exception): + pass + + class InvalidData(Exception): + pass + + class InvalidResponse(Exception): + pass + + class ResponseError(Exception): + pass + + exceptions = types.ModuleType('redis.exceptions') + exceptions.ConnectionError = ConnectionError + exceptions.AuthenticationError = AuthenticationError + exceptions.InvalidData = InvalidData + exceptions.InvalidResponse = InvalidResponse + exceptions.ResponseError = ResponseError + + class Redis: + pass + + myredis = types.ModuleType('redis') + myredis.exceptions = exceptions + myredis.Redis = Redis + + return myredis, exceptions class _poll(eventio._select): @@ -980,8 +1014,8 @@ def setup(self): def teardown(self): self.connection.close() - @mock.replace_module_value(redis.redis, 'VERSION', [3, 0, 0]) - def test_publish__get_redispyv3(self): + @pytest.mark.replace_module_value(redis.redis, 'VERSION', [3, 0, 0]) + def test_publish__get_redispyv3(self, replace_module_value): channel = self.connection.channel() producer = Producer(channel, self.exchange, routing_key='test_Redis') self.queue(channel).declare() @@ -993,8 +1027,8 @@ def test_publish__get_redispyv3(self): assert self.queue(channel).get() is None assert self.queue(channel).get() is None - @mock.replace_module_value(redis.redis, 'VERSION', [2, 5, 10]) - def test_publish__get_redispyv2(self): + @pytest.mark.replace_module_value(redis.redis, 'VERSION', [2, 5, 10]) + def test_publish__get_redispyv2(self, replace_module_value): channel = self.connection.channel() producer = Producer(channel, self.exchange, routing_key='test_Redis') self.queue(channel).declare() @@ -1089,14 +1123,15 @@ def test_get__Empty(self): channel._get('does-not-exist') channel.close() - def test_get_client(self): - with mock.module_exists(*_redis_modules()): - conn = Connection(transport=Transport) - chan = conn.channel() - assert chan.Client - assert chan.ResponseError - assert conn.transport.connection_errors - assert conn.transport.channel_errors + @pytest.mark.ensured_modules(*_redis_modules()) + def test_get_client(self, module_exists): + # with module_exists(*_redis_modules()): + conn = Connection(transport=Transport) + chan = conn.channel() + assert chan.Client + assert chan.ResponseError + assert conn.transport.connection_errors + assert conn.transport.channel_errors def test_check_at_least_we_try_to_connect_and_fail(self): import redis @@ -1107,40 +1142,6 @@ def test_check_at_least_we_try_to_connect_and_fail(self): chan._size('some_queue') -def _redis_modules(): - - class ConnectionError(Exception): - pass - - class AuthenticationError(Exception): - pass - - class InvalidData(Exception): - pass - - class InvalidResponse(Exception): - pass - - class ResponseError(Exception): - pass - - exceptions = types.ModuleType('redis.exceptions') - exceptions.ConnectionError = ConnectionError - exceptions.AuthenticationError = AuthenticationError - exceptions.InvalidData = InvalidData - exceptions.InvalidResponse = InvalidResponse - exceptions.ResponseError = ResponseError - - class Redis: - pass - - myredis = types.ModuleType('redis') - myredis.exceptions = exceptions - myredis.Redis = Redis - - return myredis, exceptions - - class test_MultiChannelPoller: def setup(self): diff --git a/t/unit/utils/test_amq_manager.py b/t/unit/utils/test_amq_manager.py index aa3d7574c..ca6adb6e7 100644 --- a/t/unit/utils/test_amq_manager.py +++ b/t/unit/utils/test_amq_manager.py @@ -1,20 +1,19 @@ from unittest.mock import patch import pytest -from case import mock from kombu import Connection class test_get_manager: - @mock.mask_modules('pyrabbit') - def test_without_pyrabbit(self): + @pytest.mark.masked_modules('pyrabbit') + def test_without_pyrabbit(self, mask_modules): with pytest.raises(ImportError): Connection('amqp://').get_manager() - @mock.module_exists('pyrabbit') - def test_with_pyrabbit(self): + @pytest.mark.ensured_modules('pyrabbit') + def test_with_pyrabbit(self, module_exists): with patch('pyrabbit.Client', create=True) as Client: manager = Connection('amqp://').get_manager() assert manager is not None @@ -22,8 +21,8 @@ def test_with_pyrabbit(self): 'localhost:15672', 'guest', 'guest', ) - @mock.module_exists('pyrabbit') - def test_transport_options(self): + @pytest.mark.ensured_modules('pyrabbit') + def test_transport_options(self, module_exists): with patch('pyrabbit.Client', create=True) as Client: manager = Connection('amqp://', transport_options={ 'manager_hostname': 'admin.mq.vandelay.com', diff --git a/t/unit/utils/test_compat.py b/t/unit/utils/test_compat.py index a0bf18b00..d3159b766 100644 --- a/t/unit/utils/test_compat.py +++ b/t/unit/utils/test_compat.py @@ -3,7 +3,7 @@ import types from unittest.mock import Mock, patch -from case import mock +import pytest from kombu.utils import compat from kombu.utils.compat import entrypoints, maybe_fileno @@ -42,8 +42,8 @@ def test_detect_environment(self): finally: compat._environment = None - @mock.module_exists('eventlet', 'eventlet.patcher') - def test_detect_environment_eventlet(self): + @pytest.mark.ensured_modules('eventlet', 'eventlet.patcher') + def test_detect_environment_eventlet(self, module_exists): with patch('eventlet.patcher.is_monkey_patched', create=True) as m: assert sys.modules['eventlet'] m.return_value = True @@ -51,8 +51,8 @@ def test_detect_environment_eventlet(self): m.assert_called_with(socket) assert env == 'eventlet' - @mock.module_exists('gevent') - def test_detect_environment_gevent(self): + @pytest.mark.ensured_modules('gevent') + def test_detect_environment_gevent(self, module_exists): with patch('gevent.socket', create=True) as m: prev, socket.socket = socket.socket, m.socket try: diff --git a/t/unit/utils/test_functional.py b/t/unit/utils/test_functional.py index 42c53caf4..73a98e520 100644 --- a/t/unit/utils/test_functional.py +++ b/t/unit/utils/test_functional.py @@ -3,7 +3,6 @@ from unittest.mock import Mock import pytest -from case import mock from kombu.utils import functional as utils from kombu.utils.functional import (ChannelPromise, LRUCache, accepts_argument, @@ -182,8 +181,8 @@ def errback(self, exc, intervals, retries): assert interval == sleepvals[self.index] return interval - @mock.sleepdeprived(module=utils) - def test_simple(self): + @pytest.mark.sleepdeprived_patched_module(utils) + def test_simple(self, sleepdeprived): prev_count, utils.count = utils.count, Mock() try: utils.count.return_value = list(range(1)) @@ -217,8 +216,8 @@ def test_retry_timeout(self): errback=None, timeout=1, ) - @mock.sleepdeprived(module=utils) - def test_retry_zero(self): + @pytest.mark.sleepdeprived_patched_module(utils) + def test_retry_zero(self, sleepdeprived): with pytest.raises(self.Predicate): retry_over_time( self.myfun, self.Predicate, @@ -232,8 +231,8 @@ def test_retry_zero(self): max_retries=0, errback=None, interval_max=14, ) - @mock.sleepdeprived(module=utils) - def test_retry_once(self): + @pytest.mark.sleepdeprived_patched_module(utils) + def test_retry_once(self, sleepdeprived): with pytest.raises(self.Predicate): retry_over_time( self.myfun, self.Predicate, @@ -247,8 +246,8 @@ def test_retry_once(self): max_retries=1, errback=None, interval_max=14, ) - @mock.sleepdeprived(module=utils) - def test_retry_always(self): + @pytest.mark.sleepdeprived_patched_module(utils) + def test_retry_always(self, sleepdeprived): Predicate = self.Predicate class Fun: