Skip to content

Commit

Permalink
Add --allow-unix-socket config option
Browse files Browse the repository at this point in the history
  • Loading branch information
joetsoi committed Mar 23, 2021
1 parent 5067c81 commit 01e04fc
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
31 changes: 23 additions & 8 deletions pytest_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def pytest_addoption(parser):
metavar='ALLOWED_HOSTS_CSV',
help='Only allow specified hosts through socket.socket.connect((host, port)).'
)
group.addoption(
'--allow-unix-socket',
action='store_true',
dest='allow_unix_socket',
help='Allow calls if they are to Unix domain sockets'
)


@pytest.fixture(autouse=True)
Expand All @@ -55,31 +61,40 @@ def _socket_marker(request):


@pytest.fixture
def socket_disabled():
def socket_disabled(pytestconfig):
""" disable socket.socket for duration of this test function """
disable_socket()
allow_unix_socket = pytestconfig.getoption('--allow-unix-socket')
disable_socket(allow_unix_socket)
yield
enable_socket()


@pytest.fixture
def socket_enabled():
def socket_enabled(pytestconfig):
""" enable socket.socket for duration of this test function """
enable_socket()
yield
disable_socket()
allow_unix_socket = pytestconfig.getoption('--allow-unix-socket')
disable_socket(allow_unix_socket)


def disable_socket():
def disable_socket(allow_unix_socket=False):
""" disable socket.socket to disable the Internet. useful in testing.
"""

class GuardedSocket(socket.socket):
""" socket guard to disable socket creation (from pytest-socket) """
def __new__(cls, *args, **kwargs):
if args[0] != socket.AddressFamily.AF_UNIX:
raise SocketBlockedError()
return super().__new__(cls, *args, **kwargs)
try:
is_unix_socket = args[0] == socket.AF_UNIX
except AttributeError:
# AF_UNIX not supported on Windows https://bugs.python.org/issue33408
is_unix_socket = False

if is_unix_socket and allow_unix_socket:
return super().__new__(cls, *args, **kwargs)

raise SocketBlockedError()

socket.socket = GuardedSocket

Expand Down
41 changes: 41 additions & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import pytest
import socket

from pytest_socket import enable_socket

Expand Down Expand Up @@ -228,3 +229,43 @@ class MySocket(socket.socket):
""")
result = testdir.runpytest("--verbose")
assert_socket_blocked(result)


@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="Skip any platform that does not support AF_UNIX")
def test_unix_domain_sockets_blocked_with_disable_socket(testdir):
testdir.makepyfile("""
import socket
def test_unix_socket():
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
""")
result = testdir.runpytest("--verbose", "--disable-socket")
assert_socket_blocked(result)


@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="Skip any platform that does not support AF_UNIX")
def test_enabling_unix_domain_sockets_with_disable_socket(testdir):
testdir.makepyfile("""
import socket
def test_inet():
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
def test_unix_socket():
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
""")
result = testdir.runpytest("--verbose", "--disable-socket", "--allow-unix-socket")
result.assert_outcomes(passed=1, skipped=0, failed=1)


def test_disable_socket_with_allowed_hosts(testdir):
testdir.makepyfile("""
import socket
import pytest
@pytest.mark.allowed_hosts("localhost")
def test_allowed_hosts():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.conect(("localhost", 80))
""")
result = testdir.runpytest("--verbose", "--disable-socket")
result.assert_outcomes(passed=1, skipped=0, failed=0)

0 comments on commit 01e04fc

Please sign in to comment.