diff --git a/thriftpy/rpc.py b/thriftpy/rpc.py index 389ad76..2d831c7 100644 --- a/thriftpy/rpc.py +++ b/thriftpy/rpc.py @@ -2,11 +2,12 @@ from __future__ import absolute_import +import signal import contextlib import warnings from thriftpy.protocol import TBinaryProtocolFactory -from thriftpy.server import TThreadedServer +from thriftpy.server import TThreadedServer, TProcessPoolServer from thriftpy.thrift import TProcessor, TClient from thriftpy.transport import ( TBufferedTransportFactory, @@ -43,11 +44,18 @@ def make_client(service, host="localhost", port=9090, unix_socket=None, return TClient(service, protocol) +def _init_handler(): + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGHUP, signal.SIG_DFL) + + def make_server(service, handler, host="localhost", port=9090, unix_socket=None, proto_factory=TBinaryProtocolFactory(), trans_factory=TBufferedTransportFactory(), - client_timeout=3000, certfile=None): + client_timeout=3000, certfile=None, + num_workers=None): processor = TProcessor(service, handler) if unix_socket: @@ -65,9 +73,16 @@ def make_server(service, handler, else: raise ValueError("Either host/port or unix_socket must be provided.") - server = TThreadedServer(processor, server_socket, - iprot_factory=proto_factory, - itrans_factory=trans_factory) + if num_workers is None: + server = TThreadedServer(processor, server_socket, + iprot_factory=proto_factory, + itrans_factory=trans_factory) + else: + server = TProcessPoolServer(processor, server_socket, + iprot_factory=proto_factory, + itrans_factory=trans_factory) + server.setNumWorkers(num_workers) + server.setPostForkCallback(_init_handler) return server diff --git a/thriftpy/server.py b/thriftpy/server.py index 664e4ec..bb023a5 100644 --- a/thriftpy/server.py +++ b/thriftpy/server.py @@ -11,6 +11,7 @@ TTransportException ) +from multiprocessing import Process, Value, Condition, reduction logger = logging.getLogger(__name__) @@ -103,3 +104,105 @@ def handle(self, client): def close(self): self.closed = True + + +class TProcessPoolServer(TServer): + """Server with a fixed size pool of worker subprocesses to service requests + + Note that if you need shared state between the handlers - it's up to you! + Written by Dvir Volk, doat.com + """ + def __init__(self, *args, **kwargs): + self.daemon = kwargs.pop("daemon", False) + TServer.__init__(self, *args, **kwargs) + self.closed = False + + self.numWorkers = 1 + self.workers = [] + self.isRunning = Value('b', False) + self.stopCondition = Condition() + self.postForkCallback = None + + def setPostForkCallback(self, callback): + if not callable(callback): + raise TypeError("This is not a callback!") + self.postForkCallback = callback + + def setNumWorkers(self, num): + """Set the number of worker threads that should be created""" + self.numWorkers = num + + def workerProcess(self): + """Loop getting clients from the shared queue and process them""" + if self.postForkCallback: + self.postForkCallback() + + while self.isRunning.value: + try: + client = self.trans.accept() + if not client: + continue + self.serveClient(client) + except (KeyboardInterrupt, SystemExit): + return 0 + except Exception as x: + logger.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + + itrans = self.itrans_factory.get_transport(client) + otrans = self.otrans_factory.get_transport(client) + iprot = self.iprot_factory.get_protocol(itrans) + oprot = self.oprot_factory.get_protocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransportException as tx: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + otrans.close() + + def serve(self): + """Start workers and put into queue""" + # this is a shared state that can tell the workers to exit when False + self.isRunning.value = True + + # first bind and listen to the port + self.trans.listen() + + # fork the children + for i in range(self.numWorkers): + try: + w = Process(target=self.workerProcess) + w.daemon = True + w.start() + self.workers.append(w) + except Exception as x: + logger.exception(x) + + # wait until the condition is set by stop() + while True: + self.stopCondition.acquire() + try: + self.stopCondition.wait() + break + except (SystemExit, KeyboardInterrupt): + break + except Exception as x: + logger.exception(x) + + self.isRunning.value = False + + def stop(self): + self.isRunning.value = False + self.stopCondition.acquire() + self.stopCondition.notify() + self.stopCondition.release() + + def close(self): + self.closed = True