diff --git a/capnp/helpers/capabilityHelper.cpp b/capnp/helpers/capabilityHelper.cpp index fda7560..98cf950 100644 --- a/capnp/helpers/capabilityHelper.cpp +++ b/capnp/helpers/capabilityHelper.cpp @@ -79,10 +79,14 @@ ::kj::Promise> then(kj::Promise> pro } kj::Promise PythonInterfaceDynamicImpl::call(capnp::InterfaceSchema::Method method, - capnp::CallContext< capnp::DynamicStruct, capnp::DynamicStruct> context) { + capnp::CallContext< capnp::DynamicStruct, + capnp::DynamicStruct> context) { auto methodName = method.getProto().getName(); - kj::Promise * promise = call_server_method(py_server, const_cast(methodName.cStr()), context); + kj::Promise * promise = call_server_method(this->py_server->obj, + const_cast(methodName.cStr()), + context, + this->kj_loop->obj); check_py_error(); diff --git a/capnp/helpers/capabilityHelper.h b/capnp/helpers/capabilityHelper.h index d1d270a..3aabbba 100644 --- a/capnp/helpers/capabilityHelper.h +++ b/capnp/helpers/capabilityHelper.h @@ -72,17 +72,15 @@ ::kj::Promise> then(kj::Promise> pro class PythonInterfaceDynamicImpl final: public capnp::DynamicCapability::Server { public: - PyObject * py_server; + kj::Own py_server; + kj::Own kj_loop; - PythonInterfaceDynamicImpl(capnp::InterfaceSchema & schema, PyObject * _py_server) - : capnp::DynamicCapability::Server(schema), py_server(_py_server) { - GILAcquire gil; - Py_INCREF(_py_server); - } + PythonInterfaceDynamicImpl(capnp::InterfaceSchema & schema, + kj::Own _py_server, + kj::Own kj_loop) + : capnp::DynamicCapability::Server(schema), py_server(kj::mv(_py_server)), kj_loop(kj::mv(kj_loop)) { } ~PythonInterfaceDynamicImpl() { - GILAcquire gil; - Py_DECREF(py_server); } kj::Promise call(capnp::InterfaceSchema::Method method, diff --git a/capnp/includes/capnp_cpp.pxd b/capnp/includes/capnp_cpp.pxd index 8428987..da3b7c7 100644 --- a/capnp/includes/capnp_cpp.pxd +++ b/capnp/includes/capnp_cpp.pxd @@ -499,7 +499,7 @@ cdef extern from "capnp/helpers/capabilityHelper.h": Exception makeException(StringPtr message) PyPromise tryReadMessage(AsyncIoStream& stream, ReaderOptions opts) cppclass PythonInterfaceDynamicImpl: - PythonInterfaceDynamicImpl(InterfaceSchema&, PyObject *) + PythonInterfaceDynamicImpl(InterfaceSchema&, Own[PyRefCounter] server, Own[PyRefCounter] kj_loop) cdef extern from "capnp/serialize-async.h" namespace " ::capnp": VoidPromise writeMessage(AsyncIoStream& output, MessageBuilder& builder) diff --git a/capnp/lib/capnp.pxd b/capnp/lib/capnp.pxd index aed34d1..e3938be 100644 --- a/capnp/lib/capnp.pxd +++ b/capnp/lib/capnp.pxd @@ -159,7 +159,7 @@ cdef _setDynamicFieldStatic(DynamicStruct_Builder thisptr, field, value, parent) cdef api object wrap_dynamic_struct_reader(Response & r) with gil cdef api Promise[void] * call_server_method( - object server, char * _method_name, CallContext & _context) except * with gil + object server, char * _method_name, CallContext & _context, object kj_loop) except * with gil cdef api object wrap_kj_exception(capnp.Exception & exception) with gil cdef api object wrap_kj_exception_for_reraise(capnp.Exception & exception) with gil cdef api object get_exception_info(object exc_type, object exc_obj, object exc_tb) with gil diff --git a/capnp/lib/capnp.pyx b/capnp/lib/capnp.pyx index 9016f82..e639f66 100644 --- a/capnp/lib/capnp.pyx +++ b/capnp/lib/capnp.pyx @@ -41,6 +41,7 @@ import traceback as _traceback from types import ModuleType as _ModuleType from operator import attrgetter as _attrgetter from functools import partial as _partial +from contextlib import asynccontextmanager as _asynccontextmanager _CAPNP_VERSION_MAJOR = capnp.CAPNP_VERSION_MAJOR _CAPNP_VERSION_MINOR = capnp.CAPNP_VERSION_MINOR @@ -73,6 +74,9 @@ cdef class _VoidPromiseFulfiller: return self def void_task_done_callback(method_name, _VoidPromiseFulfiller fulfiller, task): + if fulfiller.fulfiller == NULL: + return + if task.cancelled(): fulfiller.fulfiller.reject(makeException(capnp.StringPtr( f"Server task for method {method_name} was cancelled"))) @@ -92,9 +96,12 @@ def void_task_done_callback(method_name, _VoidPromiseFulfiller fulfiller, task): fulfiller.fulfiller.fulfill() cdef api void promise_task_add_done_callback(object task, object callback, VoidPromiseFulfiller& fulfiller): - task.add_done_callback(_partial(callback, _VoidPromiseFulfiller()._init(&fulfiller))) + wrapper = _VoidPromiseFulfiller()._init(&fulfiller) + task.add_done_callback(_partial(callback, wrapper)) + task._fulfiller = wrapper cdef api void promise_task_cancel(object task): + (<_VoidPromiseFulfiller>task._fulfiller).fulfiller = NULL task.cancel() def fill_context(method_name, context, returned_data): @@ -113,42 +120,41 @@ def fill_context(method_name, context, returned_data): setattr(results, arg_name, arg_val) cdef api VoidPromise * call_server_method(object server, - char * _method_name, CallContext & _context) except * with gil: + char * _method_name, + CallContext & _context, + object _kj_loop) except * with gil: method_name = _method_name + kj_loop = <_EventLoop>_kj_loop + kj_loop.check() context = _CallContext()._init(_context) # TODO:MEMORY: invalidate this with promise chain func = getattr(server, method_name+'_context', None) if func is not None: ret = func(context) - if asyncio.iscoroutine(ret): - task = asyncio.create_task(ret) - callback = _partial(void_task_done_callback, method_name) - return new VoidPromise(helpers.taskToPromise( - capnp.heap[PyRefCounter](task), - callback)) - else: + if not asyncio.iscoroutine(ret): raise ValueError( "Server function ({}) is not a coroutine" .format(method_name, str(ret))) + task = asyncio.create_task(ret) else: - func = getattr(server, method_name) # will raise if no function found - params = context.params - params_dict = {name: getattr(params, name) for name in params.schema.fieldnames} - params_dict['_context'] = context - ret = func(**params_dict) - - if asyncio.iscoroutine(ret): - async def finalize(): - fill_context(method_name, context, await ret) - task = asyncio.create_task(finalize()) - callback = _partial(void_task_done_callback, method_name) - return new VoidPromise(helpers.taskToPromise( - capnp.heap[PyRefCounter](task), - callback)) - else: - raise ValueError( - "Server function ({}) is not a coroutine" - .format(method_name, str(ret))) + async def finalize(): + params = context.params + params_dict = {name: getattr(params, name) for name in params.schema.fieldnames} + params_dict['_context'] = context + func = getattr(server, method_name) # will raise if no function found + ret = func(**params_dict) + if not asyncio.iscoroutine(ret): + raise ValueError( + "Server function ({}) is not a coroutine" + .format(method_name, str(ret))) + fill_context(method_name, context, await ret) + task = asyncio.create_task(finalize()) + + kj_loop.active_tasks.add(task) + callback = _partial(void_task_done_callback, method_name) + return new VoidPromise(helpers.taskToPromise( + capnp.heap[PyRefCounter](task), + callback)) cdef extern from "" namespace " ::kj": @@ -702,7 +708,11 @@ cdef C_DynamicValue.Reader _extract_dynamic_client(_DynamicCapabilityClient valu cdef C_DynamicValue.Reader _extract_dynamic_server(object value): cdef _InterfaceSchema schema = value.schema - return C_DynamicValue.Reader(capnp.heap[PythonInterfaceDynamicImpl](schema.thisptr, value)) + kj_loop = C_DEFAULT_EVENT_LOOP_GETTER() + return C_DynamicValue.Reader(capnp.heap[PythonInterfaceDynamicImpl]( + schema.thisptr, + capnp.heap[PyRefCounter](value), + capnp.heap[PyRefCounter](kj_loop))) cdef C_DynamicValue.Reader _extract_dynamic_enum(_DynamicEnum value): @@ -1773,6 +1783,8 @@ cdef cppclass AsyncIoEventPort(EventPort): this.asyncioLoop = asyncioLoop __dealloc__(): + if this.runHandle is not None: + this.runHandle.cancel() del this.kjLoop cbool wait() except* with gil: @@ -1797,57 +1809,78 @@ cdef cppclass AsyncIoEventPort(EventPort): EventLoop *getKjLoop(): return this.kjLoop -def _asyncio_close_patch(loop, oldclose, _EventLoop kjloop): - # The purpose of patching the asyncio close() function is to set up the kj-loop to be closed as well. - # We replace the event loop getter with a weakref, such that it can be destroyed when all other - # references to it are gone. Then, if a new asyncio loop ever gets started, a new kj-loop can also be - # started. - _C_DEFAULT_EVENT_LOOP_LOCAL.loop = _weakref.ref(kjloop) - loop.close = oldclose - return oldclose() - cdef class _EventLoop: - cdef object __weakref__ # Needed to make this class weak-referenceable - cdef WaitScope* waitScope - cdef AsyncIoEventPort* customPort + cdef Own[WaitScope] wait_scope + cdef Own[AsyncIoEventPort] event_port + cdef object active_streams + cdef object active_rpcs + cdef object active_tasks + cdef cbool closed + + cdef _init(self, asyncio_loop) except +reraise_kj_exception: + self.event_port = capnp.heap[AsyncIoEventPort](asyncio_loop) + kj_loop = deref(self.event_port).getKjLoop() + self.wait_scope = capnp.heap[WaitScope](deref(kj_loop)) + self.active_streams = _weakref.WeakSet() + self.active_rpcs = _weakref.WeakSet() + self.active_tasks = _weakref.WeakSet() + self.closed = False + return self - def __init__(self): - self._init() + def __dealloc__(self): + self.close() - cdef _init(self) except +reraise_kj_exception: - loop = asyncio.get_running_loop() - self.customPort = new AsyncIoEventPort(loop) - kjLoop = self.customPort.getKjLoop() - self.waitScope = new WaitScope(deref(kjLoop)) - loop.close = _partial(_asyncio_close_patch, loop, loop.close, self) + cdef close(self): + if not self.closed: + self.closed = True + deref(self.event_port).kjLoop.run() + self.wait_scope = Own[WaitScope]() + self.event_port = Own[AsyncIoEventPort]() - def __dealloc__(self): - del self.waitScope - del self.customPort + cdef check(self): + if self.closed: + raise RuntimeError( + "The KJ event-loop is not running (on this thread). Please start it through 'capnp.kj_loop()'") + +@_asynccontextmanager +async def kj_loop(): + asyncio_loop = asyncio.get_running_loop() + if hasattr(asyncio_loop, '_kj_loop'): + raise RuntimeError("The KJ event-loop is already running (on this thread).") + cdef _EventLoop kj_loop = _EventLoop()._init(asyncio_loop) + asyncio_loop._kj_loop = kj_loop + try: + yield + finally: + # Close any asynciostream that has not been closed + for stream in list(kj_loop.active_streams): stream.close() + + # Shut down all the RPC clients and servers + for rpc in list(kj_loop.active_rpcs): rpc.close() + # Cancel any pending task that is a RPC call + # TODO: What if the cancellation is inhibited? + tasks = list(kj_loop.active_tasks) + for task in tasks: task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) -_C_DEFAULT_EVENT_LOOP_LOCAL = _threading.local() + try: + del asyncio_loop._kj_loop + except AttributeError: pass + kj_loop.close() +async def run(coro): + async with kj_loop(): + return await coro cdef _EventLoop C_DEFAULT_EVENT_LOOP_GETTER(): - global C_DEFAULT_EVENT_LOOP_LOCAL - loop = getattr(_C_DEFAULT_EVENT_LOOP_LOCAL, 'loop', None) - if type(loop) is _EventLoop: - return loop - elif type(loop) is _weakref.ref: - loop = loop() - if loop is not None: - raise RuntimeError( - "The capnproto event loop associated to an already closed Python asyncio event loop is " + - "still running, because not all I/O events associated to it have terminated. If you wish " + - " to start a new loop, make sure that all previous events are cleaned up.") - else: - _C_DEFAULT_EVENT_LOOP_LOCAL.loop = _EventLoop() - return _C_DEFAULT_EVENT_LOOP_LOCAL.loop - else: - assert loop is None - _C_DEFAULT_EVENT_LOOP_LOCAL.loop = _EventLoop() - return _C_DEFAULT_EVENT_LOOP_LOCAL.loop + asyncio_loop = asyncio.get_running_loop() + kj_loop = getattr(asyncio_loop, '_kj_loop', None) + if kj_loop is None: + raise RuntimeError( + "The KJ event-loop is not running (on this thread). Please start it through 'capnp.kj_loop()'") + elif type(kj_loop) is _EventLoop: return kj_loop + else: raise RuntimeError("Someone meddled with the KJ event loop!") cdef class _CallContext: @@ -1882,6 +1915,7 @@ cdef class _CallContext: cdef _promise_to_asyncio(PyPromise promise): + C_DEFAULT_EVENT_LOOP_GETTER() # Make sure the event loop is running fut = asyncio.get_running_loop().create_future() def success(res): return fut.set_result(res) if not fut.cancelled() else None def exception(err): return fut.set_exception(err) if not fut.cancelled() else None @@ -1924,21 +1958,13 @@ cdef class _RemotePromise: raise KjException( "Promise was already used in a consuming operation. You can no longer use this Promise object") - async def a_wait(self): - """ - Asyncio version of wait(). - Required when using asyncio for socket communication. - - Will still work with non-asyncio socket communication, but requires async handling of the function call. - """ - self._check_consumed() - cdef Own[RemotePromise] thisptr = move(self.thisptr) - return await _promise_to_asyncio(helpers.convert_to_pypromise(move(deref(thisptr)))) - def __await__(self): self._check_consumed() cdef Own[RemotePromise] thisptr = move(self.thisptr) - return _promise_to_asyncio(helpers.convert_to_pypromise(move(deref(thisptr)))).__await__() + return _promise_to_asyncio( + helpers.convert_to_pypromise(move(deref(thisptr))) + .attach(capnp.heap[PyRefCounter](self._parent)) + ).__await__() cpdef _get(self, field) except +reraise_kj_exception: self._check_consumed() @@ -1992,6 +2018,7 @@ cdef class _Request(_DynamicStructBuilder): del self.thisptr_child cpdef send(self): + C_DEFAULT_EVENT_LOOP_GETTER() # Make sure the event loop is running if self.is_consumed: raise KjException('Request has already been sent. You can only send a request once.') self.is_consumed = True @@ -2038,8 +2065,12 @@ cdef class _DynamicCapabilityClient: else: s = schema + kj_loop = C_DEFAULT_EVENT_LOOP_GETTER() self.thisptr = C_DynamicCapability.Client( - capnp.heap[PythonInterfaceDynamicImpl](s.thisptr, server)) + capnp.heap[PythonInterfaceDynamicImpl]( + s.thisptr, + capnp.heap[PyRefCounter](server), + capnp.heap[PyRefCounter](kj_loop))) self._parent = server return self @@ -2074,6 +2105,7 @@ cdef class _DynamicCapabilityClient: cpdef _send_helper(self, name, word_count, args, kwargs) except +reraise_kj_exception: # if word_count is None: # word_count = 0 + C_DEFAULT_EVENT_LOOP_GETTER() # Make sure the event loop is running cdef Request * request = new Request(self.thisptr.newRequest(name)) # TODO: pass word_count self._set_fields(request, name, args, kwargs) @@ -2162,6 +2194,9 @@ cdef class _TwoPartyVatNetwork: cdef Own[C_TwoPartyVatNetwork] thisptr cdef _AsyncIoStream stream + def close(self): + self.thisptr = Own[C_TwoPartyVatNetwork]() + cdef _init(self, _AsyncIoStream stream, Side side, schema_cpp.ReaderOptions opts): self.stream = stream self.thisptr = capnp.heap[C_TwoPartyVatNetwork](deref(stream.thisptr), side, opts) @@ -2179,16 +2214,24 @@ cdef class TwoPartyClient: :param traversal_limit_in_words: Pointer derefence limit (see https://capnproto.org/cxx.html). :param nesting_limit: Recursive limit when reading types (see https://capnproto.org/cxx.html). """ + cdef object __weakref__ # Needed to make this class weak-referenceable cdef Own[RpcSystem] thisptr cdef _TwoPartyVatNetwork _network + cdef cbool closed def __dealloc__(self): # Needed to make Python 3.7 happy, which seems to have trouble deallocating stack objects # appropriately self.thisptr = Own[RpcSystem]() - def __init__(self, socket=None, traversal_limit_in_words=None, nesting_limit=None): + def close(self): + self.closed = True + self.thisptr = Own[RpcSystem]() + self._network.close() + def __init__(self, socket=None, traversal_limit_in_words=None, nesting_limit=None): + cdef _EventLoop loop = C_DEFAULT_EVENT_LOOP_GETTER() + loop.active_rpcs.add(self) cdef schema_cpp.ReaderOptions opts = make_reader_opts(traversal_limit_in_words, nesting_limit) if isinstance(socket, _AsyncIoStream): @@ -2199,9 +2242,13 @@ cdef class TwoPartyClient: self.thisptr = capnp.heap[RpcSystem](makeRpcClient(deref(self._network.thisptr))) cpdef bootstrap(self) except +reraise_kj_exception: + if self.closed: + raise RuntimeError("This client is closed") return _CapabilityClient()._init(helpers.bootstrapHelper(deref(self.thisptr)), self) cpdef on_disconnect(self) except +reraise_kj_exception: + if self.closed: + raise RuntimeError("This client is closed") return self._network.on_disconnect() @@ -2214,15 +2261,24 @@ cdef class TwoPartyServer: :param traversal_limit_in_words: Pointer derefence limit (see https://capnproto.org/cxx.html). :param nesting_limit: Recursive limit when reading types (see https://capnproto.org/cxx.html). """ + cdef object __weakref__ # Needed to make this class weak-referenceable cdef Own[RpcSystem] thisptr cdef _TwoPartyVatNetwork _network + cdef cbool closed def __dealloc__(self): # Needed to make Python 3.7 happy, which seems to have trouble deallocating stack objects # appropriately self.thisptr = Own[RpcSystem]() + def close(self): + self.closed = True + self.thisptr = Own[RpcSystem]() + self._network.close() + def __init__(self, socket=None, bootstrap=None, traversal_limit_in_words=None, nesting_limit=None): + cdef _EventLoop loop = C_DEFAULT_EVENT_LOOP_GETTER() + loop.active_rpcs.add(self) if not bootstrap: raise KjException("You must provide a bootstrap interface to a server constructor.") @@ -2236,25 +2292,57 @@ cdef class TwoPartyServer: self.thisptr = capnp.heap[RpcSystem](makeRpcServer( deref(self._network.thisptr), C_DynamicCapability.Client(capnp.heap[PythonInterfaceDynamicImpl]( - schema.thisptr, bootstrap)))) + schema.thisptr, + capnp.heap[PyRefCounter](bootstrap), + capnp.heap[PyRefCounter](loop))))) cpdef bootstrap(self) except +reraise_kj_exception: + if self.closed: + raise RuntimeError("This server is closed") return _CapabilityClient()._init(helpers.bootstrapHelperServer(deref(self.thisptr)), self) cpdef on_disconnect(self) except +reraise_kj_exception: + if self.closed: + raise RuntimeError("This server is closed") return _voidpromise_to_asyncio(deref(self._network.thisptr).onDisconnect() .attach(capnp.heap[PyRefCounter](self))) cdef class _AsyncIoStream: + cdef object __weakref__ # Needed to make this class weak-referenceable cdef Own[AsyncIoStream] thisptr - cdef _EventLoop _event_loop # We hold a pointer to the event loop here, to ensure it remains alive + cdef cbool close_called + cdef object protocol + + def __init__(self): + cdef _EventLoop loop = C_DEFAULT_EVENT_LOOP_GETTER() + loop.active_streams.add(self) + self.close_called = False + + def _post_init(self, protocol): + if not self.close_called: + self.thisptr = capnp.heap[PyAsyncIoStream]( + capnp.heap[PyRefCounter](protocol)) + self.protocol = protocol + else: + protocol.transport.close() def __dealloc__(self): # Needed to make Python 3.7 happy, which seems to have trouble deallocating stack objects # appropriately self.thisptr = Own[AsyncIoStream]() + def close(self): + if self.protocol is None: # _post_init wasn't called yet + self.close_called = True + elif self.protocol.transport is not None and hasattr(self.protocol.transport, "close"): + self.protocol.transport.close() + # Call connection_lost immediately, instead of waiting for the transport to do it. + self.protocol.connection_lost("Stream is closing") + + async def wait_closed(self): + return await self.protocol.closed_future + @staticmethod async def create_connection(host = None, port = None, **kwargs): """Create a TCP connection. @@ -2263,11 +2351,10 @@ cdef class _AsyncIoStream: See that function for documentation on the possible arguments. """ cdef _AsyncIoStream self = _AsyncIoStream() - self._event_loop = C_DEFAULT_EVENT_LOOP_GETTER() loop = asyncio.get_running_loop() transport, protocol = await loop.create_connection( lambda: _PyAsyncIoStreamProtocol(), host, port, **kwargs) - self.thisptr = capnp.heap[PyAsyncIoStream](capnp.heap[PyRefCounter](protocol)) + self._post_init(protocol) return self @staticmethod @@ -2278,20 +2365,18 @@ cdef class _AsyncIoStream: See that function for documentation on the possible arguments. """ cdef _AsyncIoStream self = _AsyncIoStream() - self._event_loop = C_DEFAULT_EVENT_LOOP_GETTER() loop = asyncio.get_running_loop() transport, protocol = await loop.create_unix_connection( lambda: _PyAsyncIoStreamProtocol(), path, **kwargs) - self.thisptr = capnp.heap[PyAsyncIoStream](capnp.heap[PyRefCounter](protocol)) + self._post_init(protocol) return self @staticmethod def _connect(callback): cdef _AsyncIoStream self = _AsyncIoStream() - self._event_loop = C_DEFAULT_EVENT_LOOP_GETTER() loop = asyncio.get_running_loop() protocol = _PyAsyncIoStreamProtocol(callback, self) - self.thisptr = capnp.heap[PyAsyncIoStream](capnp.heap[PyRefCounter](protocol)) + self._post_init(protocol) return protocol @staticmethod @@ -2330,7 +2415,7 @@ cdef class _PyAsyncIoStreamProtocol(DummyBaseClass, asyncio.BufferedProtocol): # See https://github.com/python/cpython/issues/79575. Can be removed once Python 3.7 is unsupported. cdef dict __dict__ - cdef object transport + cdef public object transport cdef object connected_callback cdef object callback_arg @@ -2382,7 +2467,7 @@ cdef class _PyAsyncIoStreamProtocol(DummyBaseClass, asyncio.BufferedProtocol): if self.connected_callback is not None: callback_res = self.connected_callback(self.callback_arg) if asyncio.iscoroutine(callback_res): - asyncio.get_running_loop().create_task(callback_res) + self._task = asyncio.get_running_loop().create_task(callback_res) self.connected_callback = None self.callback_arg = None @@ -2396,6 +2481,7 @@ cdef class _PyAsyncIoStreamProtocol(DummyBaseClass, asyncio.BufferedProtocol): self.write_reset() self.write_paused = True self.transport = None + self._task = None def get_buffer(self, size_hint): if self.read_buffer == NULL: # Should not happen, but for SSL it does, see comment above @@ -3017,6 +3103,7 @@ class _StructModule(object): :param nesting_limit: Limits how many total words of data are allowed to be traversed. Default is 64. :rtype: :class:`_DynamicStructReader`""" + C_DEFAULT_EVENT_LOOP_GETTER() # Make sure the event loop is running cdef schema_cpp.ReaderOptions opts = make_reader_opts(traversal_limit_in_words, nesting_limit) reader = await _promise_to_asyncio(tryReadMessage(deref(stream.thisptr), opts)) if reader is None: @@ -3207,7 +3294,6 @@ class _InterfaceModule(object): self.Server = type(name + '.Server', (_DynamicCapabilityServer,), {'__init__': server_init, 'schema':schema}) def _new_client(self, server): - C_DEFAULT_EVENT_LOOP_GETTER() # Make sure that the event loop has been initialized return _DynamicCapabilityClient()._init_vals(self.schema, server) diff --git a/examples/async_calculator_client.py b/examples/async_calculator_client.py index c3e3e52..7cc8377 100755 --- a/examples/async_calculator_client.py +++ b/examples/async_calculator_client.py @@ -306,4 +306,4 @@ async def cmd_main(host): if __name__ == "__main__": - asyncio.run(cmd_main(parse_args().host)) + asyncio.run(capnp.run(cmd_main(parse_args().host))) diff --git a/examples/async_calculator_server.py b/examples/async_calculator_server.py index f8302c3..d9014b2 100755 --- a/examples/async_calculator_server.py +++ b/examples/async_calculator_server.py @@ -131,4 +131,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(capnp.run(main())) diff --git a/examples/async_client.py b/examples/async_client.py index 47492d3..ab4286f 100755 --- a/examples/async_client.py +++ b/examples/async_client.py @@ -24,11 +24,6 @@ async def status(self, value, **kwargs): print("status: {}".format(time.time())) -async def background(cap): - subscriber = StatusSubscriber() - await cap.subscribeStatus(subscriber) - - async def main(host): host, port = host.split(":") connection = await capnp.AsyncIoStream.create_connection(host=host, port=port) @@ -36,7 +31,7 @@ async def main(host): cap = client.bootstrap().cast_as(thread_capnp.Example) # Start background task for subscriber - asyncio.create_task(background(cap)) + task = asyncio.ensure_future(cap.subscribeStatus(StatusSubscriber())) # Run blocking tasks print("main: {}".format(time.time())) @@ -47,12 +42,14 @@ async def main(host): await cap.longRunning() print("main: {}".format(time.time())) + task.cancel() + if __name__ == "__main__": args = parse_args() - asyncio.run(main(args.host)) + asyncio.run(capnp.run(main(args.host))) # Test that we can run multiple asyncio loops in sequence. This is particularly tricky, because # main contains a background task that we never cancel. The entire loop gets cleaned up anyways, # and we can start a new loop. - asyncio.run(main(args.host)) + asyncio.run(capnp.run(main(args.host))) diff --git a/examples/async_reconnecting_ssl_client.py b/examples/async_reconnecting_ssl_client.py index 3d3acf7..62d9b6f 100755 --- a/examples/async_reconnecting_ssl_client.py +++ b/examples/async_reconnecting_ssl_client.py @@ -41,11 +41,6 @@ async def watch_connection(cap): return False -async def background(cap): - subscriber = StatusSubscriber() - await cap.subscribeStatus(subscriber) - - async def main(host): addr, port = host.split(":") @@ -71,7 +66,9 @@ async def main(host): # Start watcher to restart socket connection if it is lost and subscriber background task background_tasks = asyncio.gather( - background(cap), watch_connection(cap), return_exceptions=True + cap.subscribeStatus(StatusSubscriber()), + watch_connection(cap), + return_exceptions=True, ) # Run blocking tasks @@ -96,7 +93,7 @@ async def main(host): while retry: loop = asyncio.new_event_loop() try: - retry = not loop.run_until_complete(main(parse_args().host)) + retry = not loop.run_until_complete(capnp.run(main(parse_args().host))) except RuntimeError: # If an IO is hung, the event loop will be stopped # and will throw RuntimeError exception diff --git a/examples/async_server.py b/examples/async_server.py index 5d5b63d..7fa734c 100755 --- a/examples/async_server.py +++ b/examples/async_server.py @@ -46,4 +46,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(capnp.run(main())) diff --git a/examples/async_socket_message_client.py b/examples/async_socket_message_client.py index bb8397e..e656851 100644 --- a/examples/async_socket_message_client.py +++ b/examples/async_socket_message_client.py @@ -50,4 +50,4 @@ async def main(host): if __name__ == "__main__": args = parse_args() - asyncio.run(main(args.host)) + asyncio.run(capnp.run(main(args.host))) diff --git a/examples/async_socket_message_server.py b/examples/async_socket_message_server.py index 8896ff0..e7c9852 100644 --- a/examples/async_socket_message_server.py +++ b/examples/async_socket_message_server.py @@ -59,4 +59,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(capnp.run(main())) diff --git a/examples/async_ssl_calculator_client.py b/examples/async_ssl_calculator_client.py index 63e3d8d..118cd57 100755 --- a/examples/async_ssl_calculator_client.py +++ b/examples/async_ssl_calculator_client.py @@ -330,4 +330,4 @@ async def main(host): # https://bugs.python.org/issue36709 # asyncio.run(main(parse_args().host), loop=loop, debug=True) loop = asyncio.get_event_loop() - loop.run_until_complete(main(parse_args().host)) + loop.run_until_complete(capnp.run(main(parse_args().host))) diff --git a/examples/async_ssl_calculator_server.py b/examples/async_ssl_calculator_server.py index 8657d72..17cb749 100755 --- a/examples/async_ssl_calculator_server.py +++ b/examples/async_ssl_calculator_server.py @@ -155,4 +155,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(capnp.run(main())) diff --git a/examples/async_ssl_client.py b/examples/async_ssl_client.py index fa46062..4847b8b 100755 --- a/examples/async_ssl_client.py +++ b/examples/async_ssl_client.py @@ -15,8 +15,7 @@ def parse_args(): parser = argparse.ArgumentParser( - usage="Connects to the Example thread server \ -at the given address and does some RPCs" + usage="Connects to the Example thread server at the given address and does some RPCs" ) parser.add_argument("host", help="HOST:PORT") @@ -26,15 +25,10 @@ def parse_args(): class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server): """An implementation of the StatusSubscriber interface""" - def status(self, value, **kwargs): + async def status(self, value, **kwargs): print("status: {}".format(time.time())) -async def background(cap): - subscriber = StatusSubscriber() - await cap.subscribeStatus(subscriber) - - async def main(host): addr, port = host.split(":") @@ -59,7 +53,7 @@ async def main(host): cap = client.bootstrap().cast_as(thread_capnp.Example) # Start background task for subscriber - asyncio.create_task(background(cap)) + task = asyncio.ensure_future(cap.subscribeStatus(StatusSubscriber())) # Run blocking tasks print("main: {}".format(time.time())) @@ -70,10 +64,12 @@ async def main(host): await cap.longRunning() print("main: {}".format(time.time())) + task.cancel() + if __name__ == "__main__": # Using asyncio.run hits an asyncio ssl bug # https://bugs.python.org/issue36709 # asyncio.run(main(parse_args().host), loop=loop, debug=True) loop = asyncio.get_event_loop() - loop.run_until_complete(main(parse_args().host)) + loop.run_until_complete(capnp.run(main(parse_args().host))) diff --git a/examples/async_ssl_server.py b/examples/async_ssl_server.py index 9d629a5..aafa9bd 100755 --- a/examples/async_ssl_server.py +++ b/examples/async_ssl_server.py @@ -71,4 +71,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(capnp.run(main())) diff --git a/test/test_capability.py b/test/test_capability.py index 158ec7e..da732b2 100644 --- a/test/test_capability.py +++ b/test/test_capability.py @@ -4,6 +4,12 @@ import test_capability_capnp as capability +@pytest.fixture(autouse=True) +async def kj_loop(): + async with capnp.kj_loop(): + yield + + class Server(capability.TestInterface.Server): def __init__(self, val=1): self.val = val diff --git a/test/test_capability_context.py b/test/test_capability_context.py index 70c03c5..30cf2c0 100644 --- a/test/test_capability_context.py +++ b/test/test_capability_context.py @@ -4,6 +4,12 @@ import test_capability_capnp as capability +@pytest.fixture(autouse=True) +async def kj_loop(): + async with capnp.kj_loop(): + yield + + class Server(capability.TestInterface.Server): def __init__(self, val=1): self.val = val diff --git a/test/test_context_manager.py b/test/test_context_manager.py new file mode 100644 index 0000000..596f6cb --- /dev/null +++ b/test/test_context_manager.py @@ -0,0 +1,241 @@ +import pytest +import asyncio +import socket + +import capnp +import test_capability +import test_capability_capnp as capability + + +async def test_two_kj_one_asyncio(): + async with capnp.kj_loop(): + pass + async with capnp.kj_loop(): + pass + + +def test_two_kj_two_asyncio(): + async def do(): + async with capnp.kj_loop(): + pass + + asyncio.run(do()) + asyncio.run(do()) + + +async def test_nested_kj(): + with pytest.raises(RuntimeError) as exninfo: + async with capnp.kj_loop(): + async with capnp.kj_loop(): + pass + assert "The KJ event-loop is already running" in str(exninfo) + + +async def test_kj_loop_leak_new_client(): + async with capnp.kj_loop(): + client = capability.TestInterface._new_client(test_capability.Server()) + with pytest.raises(RuntimeError) as exninfo: + await client.foo(5, True) + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_kj_loop_leak_client(): + read, write = socket.socketpair() + async with capnp.kj_loop(): + read = await capnp.AsyncIoStream.create_connection(sock=read) + write = await capnp.AsyncIoStream.create_connection(sock=write) + _ = capnp.TwoPartyServer(write, bootstrap=test_capability.Server()) + client = capnp.TwoPartyClient(read) + cap = client.bootstrap().cast_as(capability.TestInterface) + with pytest.raises(RuntimeError) as exninfo: + await cap.foo(5, True) + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_kj_loop_leak_client2(): + read, write = socket.socketpair() + async with capnp.kj_loop(): + read = await capnp.AsyncIoStream.create_connection(sock=read) + write = await capnp.AsyncIoStream.create_connection(sock=write) + _ = capnp.TwoPartyServer(write, bootstrap=test_capability.Server()) + client = capnp.TwoPartyClient(read) + with pytest.raises(RuntimeError) as exninfo: + client.bootstrap().cast_as(capability.TestInterface) + assert "This client is closed" in str(exninfo) + + +async def test_kj_loop_leak_client3(): + read, write = socket.socketpair() + async with capnp.kj_loop(): + read = await capnp.AsyncIoStream.create_connection(sock=read) + write = await capnp.AsyncIoStream.create_connection(sock=write) + _ = capnp.TwoPartyServer(write, bootstrap=test_capability.Server()) + client = capnp.TwoPartyClient(read).bootstrap() + with pytest.raises(RuntimeError) as exninfo: + cap = client.cast_as(capability.TestInterface) + await cap.foo(5, True) + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_no_kj_loop(): + read, write = socket.socketpair() + with pytest.raises(RuntimeError) as exninfo: + await capnp.AsyncIoStream.create_connection(sock=read) + assert "The KJ event-loop is not running" in str(exninfo) + with pytest.raises(RuntimeError) as exninfo: + await capnp.AsyncIoStream.create_connection(sock=write) + assert "The KJ event-loop is not running" in str(exninfo) + with pytest.raises(RuntimeError) as exninfo: + capability.TestPipeline._new_client(test_capability.PipelineServer()) + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_promise_leaking1(): + async with capnp.kj_loop(): + client = capability.TestInterface._new_client(test_capability.Server()) + remote = client.foo(5, True) + task = asyncio.ensure_future(remote) + await asyncio.sleep(0) + with pytest.raises(capnp.KjException): + await task + + +async def test_promise_leaking2(): + async with capnp.kj_loop(): + client = capability.TestInterface._new_client(test_capability.Server()) + remote = client.foo(5, True) + task = asyncio.ensure_future(remote) + with pytest.raises(RuntimeError) as exninfo: + await task + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_promise_leaking3(): + async with capnp.kj_loop(): + client = capability.TestInterface._new_client(test_capability.Server()) + remote = client.foo(5, True) + with pytest.raises(RuntimeError) as exninfo: + await remote + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_promise_leaking4(): + read, _ = socket.socketpair() + async with capnp.kj_loop(): + connection = await capnp.AsyncIoStream.create_connection(sock=read) + client = capnp.TwoPartyClient(connection) + cap = client.bootstrap().cast_as(capability.TestInterface) + res = asyncio.ensure_future(cap.foo(5, True)) + await asyncio.sleep(0) + with pytest.raises(capnp.KjException): + await res + + +async def test_promise_leaking5(): + read, _ = socket.socketpair() + async with capnp.kj_loop(): + connection = await capnp.AsyncIoStream.create_connection(sock=read) + client = capnp.TwoPartyClient(connection) + cap = client.bootstrap().cast_as(capability.TestInterface) + res = asyncio.ensure_future(cap.foo(5, True)) + with pytest.raises(RuntimeError) as exninfo: + await res + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_promise_leaking6(): + read, _ = socket.socketpair() + async with capnp.kj_loop(): + connection = await capnp.AsyncIoStream.create_connection(sock=read) + client = capnp.TwoPartyClient(connection) + cap = client.bootstrap().cast_as(capability.TestInterface) + res = cap.foo(5, True) + with pytest.raises(RuntimeError) as exninfo: + await res + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_kj_loop_read_message_after_close(): + read, _ = socket.socketpair() + async with capnp.kj_loop(): + read = await capnp.AsyncIoStream.create_connection(sock=read) + with pytest.raises(RuntimeError) as exninfo: + await capability.TestSturdyRefHostId.read_async(read) + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_kj_loop_partial_read_message_after_close(): + read, _ = socket.socketpair() + async with capnp.kj_loop(): + read = await capnp.AsyncIoStream.create_connection(sock=read) + message = capability.TestSturdyRefHostId.read_async(read) + with pytest.raises(RuntimeError) as exninfo: + await message + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_kj_loop_write_message_after_close(): + _, write = socket.socketpair() + async with capnp.kj_loop(): + write = await capnp.AsyncIoStream.create_connection(sock=write) + message = capability.TestSturdyRefHostId.new_message() + with pytest.raises(RuntimeError) as exninfo: + await message.write_async(write) + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_kj_loop_partial_write_message_after_close(): + _, write = socket.socketpair() + async with capnp.kj_loop(): + write = await capnp.AsyncIoStream.create_connection(sock=write) + message = capability.TestSturdyRefHostId.new_message() + send = message.write_async(write) + with pytest.raises(RuntimeError) as exninfo: + await send + assert "The KJ event-loop is not running" in str(exninfo) + + +async def test_client_on_disconnect_memory(): + read, _ = socket.socketpair() + async with capnp.kj_loop(): + read = await capnp.AsyncIoStream.create_connection(sock=read) + client = capnp.TwoPartyClient(read) + with pytest.raises(RuntimeError) as exninfo: + await client.on_disconnect() + assert "This client is closed" in str(exninfo) + + +async def test_server_on_disconnect_memory(): + _, write = socket.socketpair() + async with capnp.kj_loop(): + write = await capnp.AsyncIoStream.create_connection(sock=write) + server = capnp.TwoPartyServer(write, bootstrap=test_capability.Server()) + with pytest.raises(RuntimeError) as exninfo: + await server.on_disconnect() + assert "This server is closed" in str(exninfo) + + +@pytest.mark.xfail( + strict=True, + reason="Fails because the promisefulfiller got destroyed. Possibly a bug in the C++ library.", +) +async def test_client_on_disconnect_memory2(): + """ + E capnp.lib.capnp.KjException: kj/async.c++:2813: failed: + PromiseFulfiller was destroyed without fulfilling the promise. + """ + read, _ = socket.socketpair() + async with capnp.kj_loop(): + read = await capnp.AsyncIoStream.create_connection(sock=read) + client = capnp.TwoPartyClient(read) + disc = client.on_disconnect() + await disc + + +async def test_server_on_disconnect_memory2(): + _, write = socket.socketpair() + async with capnp.kj_loop(): + write = await capnp.AsyncIoStream.create_connection(sock=write) + server = capnp.TwoPartyServer(write, bootstrap=test_capability.Server()) + disc = server.on_disconnect() + await disc diff --git a/test/test_memory_handling.py b/test/test_memory_handling.py new file mode 100644 index 0000000..ea3b6d8 --- /dev/null +++ b/test/test_memory_handling.py @@ -0,0 +1,32 @@ +from types import coroutine +import pytest +import socket +import gc + +import capnp +import test_capability +import test_capability_capnp as capability + + +@pytest.fixture(autouse=True) +async def kj_loop(): + async with capnp.kj_loop(): + yield + + +@coroutine +def wrap(p): + return (yield from p) + + +async def test_kj_loop_await_attach(): + read, write = socket.socketpair() + read = await capnp.AsyncIoStream.create_connection(sock=read) + write = await capnp.AsyncIoStream.create_connection(sock=write) + _ = capnp.TwoPartyServer(write, bootstrap=test_capability.Server()) + client = capnp.TwoPartyClient(read).bootstrap().cast_as(capability.TestInterface) + t = wrap(client.foo(5, True).__await__()) + del client + del read + gc.collect() + await t diff --git a/test/test_response.py b/test/test_response.py index d1f1c49..56b24f5 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -1,6 +1,15 @@ +import pytest + +import capnp import test_response_capnp +@pytest.fixture(autouse=True) +async def kj_loop(): + async with capnp.kj_loop(): + yield + + class FooServer(test_response_capnp.Foo.Server): def __init__(self, val=1): self.val = val diff --git a/test/test_rpc.py b/test/test_rpc.py index c6ae055..f340f02 100644 --- a/test/test_rpc.py +++ b/test/test_rpc.py @@ -9,6 +9,12 @@ import test_capability_capnp +@pytest.fixture(autouse=True) +async def kj_loop(): + async with capnp.kj_loop(): + yield + + class Server(test_capability_capnp.TestInterface.Server): def __init__(self, val=100): self.val = val diff --git a/test/test_rpc_calculator.py b/test/test_rpc_calculator.py index 19317c9..f0b1c7f 100644 --- a/test/test_rpc_calculator.py +++ b/test/test_rpc_calculator.py @@ -2,6 +2,7 @@ import os import socket import sys # add examples dir to sys.path +import pytest import capnp @@ -12,6 +13,12 @@ import async_calculator_server # noqa: E402 +@pytest.fixture(autouse=True) +async def kj_loop(): + async with capnp.kj_loop(): + yield + + async def test_calculator(): read, write = socket.socketpair() read = await capnp.AsyncIoStream.create_connection(sock=read)