diff --git a/pyproject.toml b/pyproject.toml index cbf07d3..c46036c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dev = [ "pytest>=6.1.1,<8.2.0", "tornado>=4.0,<7.0; python_version>='3.12'", "tornado>=4.0,<6.0; python_version<'3.12'", + "multiprocess>=0.70.12.2", ] tornado = [ diff --git a/tests/test_all_protocols_binary_field.py b/tests/test_all_protocols_binary_field.py index e75e141..df25df8 100644 --- a/tests/test_all_protocols_binary_field.py +++ b/tests/test_all_protocols_binary_field.py @@ -3,7 +3,7 @@ import sys import time import traceback -from multiprocessing import Process +from multiprocess import Process import pytest import six @@ -26,8 +26,6 @@ from thriftpy2.transport import TBufferedTransportFactory, TCyMemoryBuffer -if sys.platform == "win32": - pytest.skip("requires fork", allow_module_level=True) protocols = [TApacheJSONProtocolFactory, @@ -72,6 +70,11 @@ def test(t): trans_factory = TBufferedTransportFactory def run_server(): + import thriftpy2 + test_thrift = thriftpy2.load( + "apache_json_test.thrift", + module_name="test_thrift" + ) server = server_func[0]( test_thrift.TestService, handler=Handler(), @@ -161,11 +164,16 @@ def test_exceptions(server_func, proto_factory): ) TestException = test_thrift.TestException - class Handler(object): - def do_error(self, arg): - raise TestException(message=arg) - def do_server(): + import thriftpy2 + test_thrift = thriftpy2.load( + "apache_json_test.thrift", + module_name="test_thrift" + ) + TestException = test_thrift.TestException + class Handler(object): + def do_error(self, arg): + raise TestException(message=arg) server = server_func[0]( service=test_thrift.TestService, handler=Handler(), @@ -237,6 +245,8 @@ def test(t): trans_factory = TBufferedTransportFactory def run_server(): + import thriftpy2 + spec = thriftpy2.load("bin_test.thrift", module_name="bin_thrift") server = make_rpc_server( spec.BinService, handler=Handler(), diff --git a/tests/test_http.py b/tests/test_http.py index 6d7493a..a9bc7a2 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -3,9 +3,8 @@ from __future__ import absolute_import import os -import multiprocessing +import multiprocess import socket -import sys import time import uuid @@ -22,10 +21,6 @@ "addressbook.thrift")) -if sys.platform == "win32": - pytest.skip("requires fork", allow_module_level=True) - - class Dispatcher(): def __init__(self): self.ab = addressbook.AddressBook() @@ -79,9 +74,12 @@ def get_headers(self): @pytest.fixture(scope="module") def server(request): + import thriftpy2 + addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__), + "addressbook.thrift")) server = make_server(addressbook.AddressBookService, Dispatcher(), host="127.0.0.1", port=6080) - ps = multiprocessing.Process(target=server.serve) + ps = multiprocess.Process(target=server.serve) ps.start() time.sleep(0.1) diff --git a/tests/test_multiplexed.py b/tests/test_multiplexed.py index 3f38abf..7a2a8b0 100644 --- a/tests/test_multiplexed.py +++ b/tests/test_multiplexed.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import +import sys -import multiprocessing +import multiprocess import os -import sys import time import pytest @@ -20,10 +20,6 @@ from thriftpy2.transport import TBufferedTransportFactory, TServerSocket -if sys.platform == "win32": - pytest.skip("requires fork", allow_module_level=True) - - mux = thriftpy2.load(os.path.join(os.path.dirname(__file__), "multiplexed.thrift")) sock_path = "/tmp/thriftpy_test.sock" @@ -51,7 +47,7 @@ def server(request): _server = TThreadedServer(mux_proc, TServerSocket(unix_socket=sock_path), iprot_factory=TBinaryProtocolFactory(), itrans_factory=TBufferedTransportFactory()) - ps = multiprocessing.Process(target=_server.serve) + ps = multiprocess.Process(target=_server.serve) ps.start() time.sleep(0.1) @@ -82,7 +78,7 @@ def client_two(timeout=3000): socket_timeout=timeout, connect_timeout=timeout, proto_factory=multiplexing_factory) - +@pytest.mark.skipif(sys.platform == "win32", reason="Unix domain socket required") def test_multiplexed_server(server): with client_one() as c: assert c.doThingOne() is True diff --git a/tests/test_oneway.py b/tests/test_oneway.py index fcf88ba..c33658f 100644 --- a/tests/test_oneway.py +++ b/tests/test_oneway.py @@ -3,15 +3,11 @@ import pytest -import multiprocessing +import multiprocess import thriftpy2 from thriftpy2.rpc import make_client, make_server -if sys.platform == "win32": - pytest.skip("requires fork", allow_module_level=True) - - class Dispatcher(object): def Test(self, req): print("Get req msg: %s" % req) @@ -24,9 +20,8 @@ class TestOneway(object): oneway_thrift = thriftpy2.load("oneway.thrift") def setup_class(self): - ctx = multiprocessing.get_context("fork") server = make_server(self.oneway_thrift.echo, Dispatcher(), '127.0.0.1', 6000) - self.p = ctx.Process(target=server.serve) + self.p = multiprocess.Process(target=server.serve) self.p.start() time.sleep(1) # Wait a second for server to start. diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 9d71d4d..a971379 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -18,12 +18,11 @@ from __future__ import absolute_import import contextlib -import multiprocessing +import multiprocess import os import pickle import random import socket -import sys import tempfile import time @@ -56,11 +55,7 @@ except ImportError: pass else: - cleanup_on_sigterm() - - -if sys.platform == "win32": - pytest.skip("requires fork", allow_module_level=True) + cleanup_on_sigterm() addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__), @@ -191,7 +186,7 @@ def gen_server(port, tracker=tracker, processor=TTrackedProcessor): server = TSampleServer(processor, server_socket, prot_factory=TBinaryProtocolFactory(), trans_factory=TBufferedTransportFactory()) - ps = multiprocessing.Process(target=server.serve) + ps = multiprocess.Process(target=server.serve) ps.start() return ps, server diff --git a/tox.ini b/tox.ini index 1acd731..b24d3c6 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,7 @@ deps = tornado>=4.0,<6.0 cython py35,py36,py37,py38,py39,pypy3,coverage: pytest_asyncio + multiprocess [testenv:flake8] deps =