From 5067c81dea74ba633aa952cc5a8cace7bf69da81 Mon Sep 17 00:00:00 2001 From: Joe Tsoi Date: Thu, 25 Feb 2021 11:05:33 +0000 Subject: [PATCH 1/2] Allow AF_UNIX sockets in GuardedSocket asyncio uses unix domain sockets, as we're mainly interested in blocking external requests, allowing AF_UNIX through should be fine. --- pytest_socket.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytest_socket.py b/pytest_socket.py index 0c8dd63..bd19573 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -77,7 +77,9 @@ def disable_socket(): class GuardedSocket(socket.socket): """ socket guard to disable socket creation (from pytest-socket) """ def __new__(cls, *args, **kwargs): - raise SocketBlockedError() + if args[0] != socket.AddressFamily.AF_UNIX: + raise SocketBlockedError() + return super().__new__(cls, *args, **kwargs) socket.socket = GuardedSocket From 197d3e35fef7b1ceb4eaa4309c60d074146b07ff Mon Sep 17 00:00:00 2001 From: Joe Tsoi Date: Fri, 5 Mar 2021 11:37:31 +0000 Subject: [PATCH 2/2] Add --allow-unix-socket config option --- pytest_socket.py | 31 +++++++++++++++++++++++-------- tests/test_socket.py | 27 +++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/pytest_socket.py b/pytest_socket.py index bd19573..5e4ca3c 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -34,6 +34,12 @@ def pytest_addoption(parser): metavar='ALLOWED_HOSTS_CSV', help='Only allow specified hosts through socket.socket.connect((host, port)).' ) + group.addoption( + '--allow-unix-socket', + action='store_true', + dest='allow_unix_socket', + help='Allow calls if they are to Unix domain sockets' + ) @pytest.fixture(autouse=True) @@ -55,31 +61,40 @@ def _socket_marker(request): @pytest.fixture -def socket_disabled(): +def socket_disabled(pytestconfig): """ disable socket.socket for duration of this test function """ - disable_socket() + allow_unix_socket = pytestconfig.getoption('--allow-unix-socket') + disable_socket(allow_unix_socket) yield enable_socket() @pytest.fixture -def socket_enabled(): +def socket_enabled(pytestconfig): """ enable socket.socket for duration of this test function """ enable_socket() yield - disable_socket() + allow_unix_socket = pytestconfig.getoption('--allow-unix-socket') + disable_socket(allow_unix_socket) -def disable_socket(): +def disable_socket(allow_unix_socket=False): """ disable socket.socket to disable the Internet. useful in testing. """ class GuardedSocket(socket.socket): """ socket guard to disable socket creation (from pytest-socket) """ def __new__(cls, *args, **kwargs): - if args[0] != socket.AddressFamily.AF_UNIX: - raise SocketBlockedError() - return super().__new__(cls, *args, **kwargs) + try: + is_unix_socket = args[0] == socket.AF_UNIX + except AttributeError: + # AF_UNIX not supported on Windows https://bugs.python.org/issue33408 + is_unix_socket = False + + if is_unix_socket and allow_unix_socket: + return super().__new__(cls, *args, **kwargs) + + raise SocketBlockedError() socket.socket = GuardedSocket diff --git a/tests/test_socket.py b/tests/test_socket.py index acccfce..83942d9 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import pytest +import socket from pytest_socket import enable_socket @@ -228,3 +229,29 @@ class MySocket(socket.socket): """) result = testdir.runpytest("--verbose") assert_socket_blocked(result) + + +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="Skip any platform that does not support AF_UNIX") +def test_unix_domain_sockets_blocked_with_disable_socket(testdir): + testdir.makepyfile(""" + import socket + + def test_unix_socket(): + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + """) + result = testdir.runpytest("--verbose", "--disable-socket") + assert_socket_blocked(result) + + +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="Skip any platform that does not support AF_UNIX") +def test_enabling_unix_domain_sockets_with_disable_socket(testdir): + testdir.makepyfile(""" + import socket + + def test_inet(): + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + def test_unix_socket(): + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + """) + result = testdir.runpytest("--verbose", "--disable-socket", "--allow-unix-socket") + result.assert_outcomes(passed=1, skipped=0, failed=1)