diff --git a/pytest_socket.py b/pytest_socket.py index bd19573..5e4ca3c 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -34,6 +34,12 @@ def pytest_addoption(parser): metavar='ALLOWED_HOSTS_CSV', help='Only allow specified hosts through socket.socket.connect((host, port)).' ) + group.addoption( + '--allow-unix-socket', + action='store_true', + dest='allow_unix_socket', + help='Allow calls if they are to Unix domain sockets' + ) @pytest.fixture(autouse=True) @@ -55,31 +61,40 @@ def _socket_marker(request): @pytest.fixture -def socket_disabled(): +def socket_disabled(pytestconfig): """ disable socket.socket for duration of this test function """ - disable_socket() + allow_unix_socket = pytestconfig.getoption('--allow-unix-socket') + disable_socket(allow_unix_socket) yield enable_socket() @pytest.fixture -def socket_enabled(): +def socket_enabled(pytestconfig): """ enable socket.socket for duration of this test function """ enable_socket() yield - disable_socket() + allow_unix_socket = pytestconfig.getoption('--allow-unix-socket') + disable_socket(allow_unix_socket) -def disable_socket(): +def disable_socket(allow_unix_socket=False): """ disable socket.socket to disable the Internet. useful in testing. """ class GuardedSocket(socket.socket): """ socket guard to disable socket creation (from pytest-socket) """ def __new__(cls, *args, **kwargs): - if args[0] != socket.AddressFamily.AF_UNIX: - raise SocketBlockedError() - return super().__new__(cls, *args, **kwargs) + try: + is_unix_socket = args[0] == socket.AF_UNIX + except AttributeError: + # AF_UNIX not supported on Windows https://bugs.python.org/issue33408 + is_unix_socket = False + + if is_unix_socket and allow_unix_socket: + return super().__new__(cls, *args, **kwargs) + + raise SocketBlockedError() socket.socket = GuardedSocket diff --git a/tests/test_socket.py b/tests/test_socket.py index acccfce..83942d9 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import pytest +import socket from pytest_socket import enable_socket @@ -228,3 +229,29 @@ class MySocket(socket.socket): """) result = testdir.runpytest("--verbose") assert_socket_blocked(result) + + +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="Skip any platform that does not support AF_UNIX") +def test_unix_domain_sockets_blocked_with_disable_socket(testdir): + testdir.makepyfile(""" + import socket + + def test_unix_socket(): + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + """) + result = testdir.runpytest("--verbose", "--disable-socket") + assert_socket_blocked(result) + + +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="Skip any platform that does not support AF_UNIX") +def test_enabling_unix_domain_sockets_with_disable_socket(testdir): + testdir.makepyfile(""" + import socket + + def test_inet(): + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + def test_unix_socket(): + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + """) + result = testdir.runpytest("--verbose", "--disable-socket", "--allow-unix-socket") + result.assert_outcomes(passed=1, skipped=0, failed=1)