diff --git a/src/bindings/python/flux/future.py b/src/bindings/python/flux/future.py index 5155174a00f2..7090f3aa01c2 100644 --- a/src/bindings/python/flux/future.py +++ b/src/bindings/python/flux/future.py @@ -68,9 +68,11 @@ def check_wrap(self, fun, name): func = super(Future.InnerWrapper, self).check_wrap(fun, name) return check_future_error(func) - def __init__(self, future_handle, prefixes=None): + def __init__(self, future_handle, prefixes=None, pimpl_t=None): super(Future, self).__init__() - self.pimpl = self.InnerWrapper(handle=future_handle, prefixes=prefixes) + if pimpl_t is None: + pimpl_t = self.InnerWrapper + self.pimpl = pimpl_t(handle=future_handle, prefixes=prefixes) self.then_cb = None self.then_arg = None self.cb_handle = None diff --git a/src/bindings/python/flux/rpc.py b/src/bindings/python/flux/rpc.py index b14c3cda590d..b7ca796a7d53 100644 --- a/src/bindings/python/flux/rpc.py +++ b/src/bindings/python/flux/rpc.py @@ -20,6 +20,9 @@ class RPC(Future): """An RPC state object""" + class RPCInnerWrapper(Future.InnerWrapper): + pass + def __init__( self, flux_handle, @@ -37,7 +40,11 @@ def __init__( payload = encode_payload(payload) future_handle = raw.flux_rpc(flux_handle, topic, payload, nodeid, flags) - super(RPC, self).__init__(future_handle, prefixes=["flux_rpc_", "flux_future_"]) + super(RPC, self).__init__( + future_handle, + prefixes=["flux_rpc_", "flux_future_"], + pimpl_t=self.RPCInnerWrapper, + ) def get_str(self): payload_str = ffi.new("char *[1]") diff --git a/src/bindings/python/flux/wrapper.py b/src/bindings/python/flux/wrapper.py index d981c2a87b8b..6ed568edfe16 100644 --- a/src/bindings/python/flux/wrapper.py +++ b/src/bindings/python/flux/wrapper.py @@ -17,7 +17,7 @@ import os import errno import inspect -import weakref +from types import MethodType import six @@ -250,6 +250,20 @@ def __init__( self.filter_match = filter_match self.prefixes = prefixes self.destructor = destructor + # this is an error-checking dance to ensure that the class-based caching of + # callables is safe by only allowing one set of prefixes, filter-matches, etc. + # per derived class of wrapper + signature = (match, filter_match, prefixes) + mytype = type(self) + if getattr(mytype, "signature", None) is None: + setattr(mytype, "signature", signature) + else: + assert signature == getattr( + mytype, "signature" + ), f""" +signatures do not match, create a new subclass to change matching parameters: +{mytype}: mysig: {getattr(mytype, "signature")} sig:{signature} + """ def check_handle(self, name, fun_type): if self.match is not None and self._handle is not None: @@ -322,12 +336,15 @@ def __getattr__(self, name): return fun new_fun = self.check_wrap(fun, name) - new_method = six.create_bound_method(new_fun, weakref.proxy(self)) + new_meth = MethodType(new_fun, self) + + def wrap_class(self_renamed, *args, **kwargs): + return new_fun(self_renamed, *args, **kwargs) - # Store the wrapper function into the instance + # Store the wrapper function into the class # to prevent a second lookup - setattr(self, name, new_method) - return new_method + setattr(type(self), name, wrap_class) + return new_meth def _clear(self): # avoid recursion