From af78d4cd49abf8affbe1e64b49266f91df3b71b8 Mon Sep 17 00:00:00 2001 From: David Riazati Date: Tue, 23 Oct 2018 09:02:50 -0700 Subject: [PATCH] Add weak script modules (#12682) Summary: Adds support for weak script modules created that get compiled to `ScriptModule`s once added as a submodule of a `ScriptModule`: ```python weak_module class Test(torch.nn.Module): ... weak_script_method def forward(self, x): ... ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/12682 Differential Revision: D10458626 Pulled By: driazati fbshipit-source-id: 10ae23cb83cdafc4646cee58f399e14b2e60acd4 --- .../TestScript.test_weak_module-basic.expect | 9 + ...tScript.test_weak_module-scope_test.expect | 18 ++ .../TestScript.test_weak_module_nested.expect | 20 ++ ..._weak_module_parameters_and_buffers.expect | 19 ++ test/test_jit.py | 287 ++++++++++++++++++ torch/_jit_internal.py | 34 ++- torch/_six.py | 9 + torch/csrc/jit/script/init.cpp | 6 + torch/jit/__init__.py | 136 ++++++++- 9 files changed, 523 insertions(+), 15 deletions(-) create mode 100644 test/expect/TestScript.test_weak_module-basic.expect create mode 100644 test/expect/TestScript.test_weak_module-scope_test.expect create mode 100644 test/expect/TestScript.test_weak_module_nested.expect create mode 100644 test/expect/TestScript.test_weak_module_parameters_and_buffers.expect diff --git a/test/expect/TestScript.test_weak_module-basic.expect b/test/expect/TestScript.test_weak_module-basic.expect new file mode 100644 index 0000000000000..4a6372b57550b --- /dev/null +++ b/test/expect/TestScript.test_weak_module-basic.expect @@ -0,0 +1,9 @@ +graph(%x : Dynamic) { + %1 : int = prim::Constant[value=55]() + %2 : int = prim::Constant[value=199]() + %3 : int = prim::Constant[value=1]() + %4 : int = aten::add(%1, %2) + %5 : Dynamic = ^python_op_in_weak_module()(%x) + %6 : Dynamic = aten::add(%5, %4, %3) + return (%6); +} diff --git a/test/expect/TestScript.test_weak_module-scope_test.expect b/test/expect/TestScript.test_weak_module-scope_test.expect new file mode 100644 index 0000000000000..9376661a45adc --- /dev/null +++ b/test/expect/TestScript.test_weak_module-scope_test.expect @@ -0,0 +1,18 @@ +graph(%x : Dynamic) { + %1 : int = prim::Constant[value=357]() + %2 : int = prim::Constant[value=55]() + %3 : int = prim::Constant[value=199]() + %4 : int = prim::Constant[value=2]() + %5 : int = prim::Constant[value=1]() + %y : Dynamic = aten::mul(%x, %4) + %7 : Dynamic = aten::add(%y, %5, %5) + %8 : int = aten::add(%2, %3) + %9 : Dynamic = ^python_op_in_weak_module()(%y) + %10 : Dynamic = aten::add(%9, %8, %5) + %11 : Dynamic = aten::add(%7, %10, %5) + %12 : Dynamic = aten::add(%y, %1, %5) + %13 : Dynamic = ^python_op_in_strong_module()(%y) + %14 : Dynamic = aten::add(%12, %13, %5) + %15 : Dynamic = aten::add(%11, %14, %5) + return (%15); +} diff --git a/test/expect/TestScript.test_weak_module_nested.expect b/test/expect/TestScript.test_weak_module_nested.expect new file mode 100644 index 0000000000000..daf2936550246 --- /dev/null +++ b/test/expect/TestScript.test_weak_module_nested.expect @@ -0,0 +1,20 @@ +graph(%x : Dynamic + %1 : Dynamic + %2 : Dynamic + %3 : Dynamic + %4 : Dynamic) { + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=3]() + %7 : int = prim::Constant[value=27]() + %8 : Dynamic = aten::mul(%x, %x) + %9 : Dynamic = aten::add(%8, %6, %5) + %10 : Dynamic = aten::linear(%x, %1, %2) + %11 : Dynamic = aten::add(%9, %10, %5) + %12 : Dynamic = aten::add(%x, %11, %5) + %13 : Dynamic = aten::add(%x, %7, %5) + %14 : Dynamic = aten::add(%12, %13, %5) + %15 : Dynamic = aten::linear(%x, %3, %4) + %16 : Dynamic = aten::add(%14, %15, %5) + %17 : Dynamic = aten::add(%x, %16, %5) + return (%17); +} diff --git a/test/expect/TestScript.test_weak_module_parameters_and_buffers.expect b/test/expect/TestScript.test_weak_module_parameters_and_buffers.expect new file mode 100644 index 0000000000000..f982be6cd812b --- /dev/null +++ b/test/expect/TestScript.test_weak_module_parameters_and_buffers.expect @@ -0,0 +1,19 @@ +graph(%x : Dynamic + %1 : Dynamic + %2 : Dynamic + %3 : Dynamic + %4 : Dynamic + %5 : Dynamic + %6 : Dynamic) { + %7 : int = prim::Constant[value=1]() + %8 : Dynamic = aten::linear(%x, %1, %2) + %9 : Dynamic = aten::add(%8, %3, %7) + %10 : Dynamic = aten::add(%x, %9, %7) + %11 : Dynamic = aten::linear(%x, %1, %2) + %12 : Dynamic = aten::add(%11, %3, %7) + %13 : Dynamic = aten::add(%10, %12, %7) + %14 : Dynamic = aten::linear(%x, %4, %5) + %15 : Dynamic = aten::add(%14, %6, %7) + %16 : Dynamic = aten::add(%13, %15, %7) + return (%16); +} diff --git a/test/test_jit.py b/test/test_jit.py index 3246c5180b535..a1050d56633b9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7442,6 +7442,293 @@ def foo(x): traced = torch.jit.trace(foo, (x,)) self.assertExpectedGraph(traced.graph) + def test_weak_module(self): + + @torch._jit_internal.weak_module + class Weak(torch.nn.Module): + __constants__ = ['number'] + + def __init__(self): + super(Weak, self).__init__() + self.number = 199 + + def python_op_in_weak_module(self, x): + return x + 123 + + @torch._jit_internal.weak_script_method + def forward(self, x): + return 55 + self.number + self.python_op_in_weak_module(x) + + class OtherStrong(torch.jit.ScriptModule): + __constants__ = ['number'] + + def __init__(self): + super(OtherStrong, self).__init__() + self.number = 357 + + def python_op_in_strong_module(self, x): + return x + 456 + + @torch.jit.script_method + def forward(self, x): + return x + self.number + self.python_op_in_strong_module(x) + + class Passthrough(torch.jit.ScriptModule): + def __init__(self): + super(Passthrough, self).__init__() + self.weak = Weak() + + @torch.jit.script_method + def forward(self, x): + return self.weak(x) + + weak_mod = Weak() + x = torch.ones(1) + expected_result = 55 + 199 + (x + 123) + + # Ensure weak mod is running without the JIT by passing the wrong type + # (i.e. not a tensor) + weak_mod(2) + + python_result = weak_mod(x) + strong_mod = Passthrough() + script_result = strong_mod(x) + self.assertEqual(python_result, expected_result) + self.assertEqual(script_result, expected_result) + self.assertExpectedGraph(strong_mod.graph, "basic") + + class Strong(torch.jit.ScriptModule): + def __init__(self): + super(Strong, self).__init__() + self.weak = Weak() + self.strong = OtherStrong() + + @torch.jit.script_method + def forward(self, x): + y = 2 * x + return y + 1 + self.weak(y) + self.strong(y) + + strong_mod = Strong() + strong_mod2 = Strong() + x = torch.ones(1) + expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456) + script_result = strong_mod(x) + script_result2 = strong_mod2(x) + self.assertEqual(script_result, expected_result) + self.assertEqual(script_result, script_result2) + self.assertExpectedGraph(strong_mod.graph, "scope_test") + + def test_weak_module_parameters_and_buffers(self): + import math + weights = torch.randn(10, 10) + bias = torch.randn(10) + weights2 = torch.randn(10, 10) + bias2 = torch.randn(10) + + @torch._jit_internal.weak_module + class TestLinear(torch.nn.Module): + def __init__(self, in_features, out_features): + super(TestLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + self.bias = torch.nn.Parameter(torch.Tensor(out_features)) + self.register_buffer('counter', torch.ones(out_features)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + @torch._jit_internal.weak_script_method + def forward(self, input): + return F.linear(input, self.weight, self.bias) + self.counter + + # Initialize a ScriptModule that uses the weak module above multiple times + class Strong(torch.jit.ScriptModule): + def __init__(self): + super(Strong, self).__init__() + self.fc1 = TestLinear(10, 10) + self.fc1.weight = torch.nn.Parameter(weights) + self.fc1.bias = torch.nn.Parameter(bias) + self.fc2 = TestLinear(10, 10) + self.fc2.weight = torch.nn.Parameter(weights2) + self.fc2.bias = torch.nn.Parameter(bias2) + + @torch.jit.script_method + def forward(self, x): + return x + self.fc1(x) + self.fc1(x) + self.fc2(x) + + strong_mod = Strong() + self.assertExpectedGraph(strong_mod.graph) + + # Run same calculation as module + inp = torch.ones(10) + lin = torch.nn.Linear(10, 10) + lin.weight = torch.nn.Parameter(weights) + lin.bias = torch.nn.Parameter(bias) + lin2 = torch.nn.Linear(10, 10) + lin2.weight = torch.nn.Parameter(weights2) + lin2.bias = torch.nn.Parameter(bias2) + expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10) + + self.assertEqual(strong_mod(inp), expected_result) + + def test_weak_module_nested(self): + @torch._jit_internal.weak_module + class OtherWeak(torch.nn.Module): + __constants__ = ['constant'] + + def __init__(self, in_features, out_features): + super(OtherWeak, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.ones(out_features, in_features)) + self.bias = torch.nn.Parameter(torch.ones(out_features)) + self.constant = 3 + + @torch._jit_internal.weak_script_method + def forward(self, x): + return x * x + self.constant + F.linear(x, self.weight, self.bias) + + class OtherStrong(torch.jit.ScriptModule): + + def __init__(self): + super(OtherStrong, self).__init__() + + @torch.jit.script_method + def forward(self, x): + return x + 27 + + @torch._jit_internal.weak_module + class Weak(torch.nn.Module): + def __init__(self, in_features, out_features): + super(Weak, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features)) + self.bias = torch.nn.Parameter(2 * torch.ones(out_features)) + self.weak_submodule = OtherWeak(10, 10) + self.strong_submodule = OtherStrong() + + @torch._jit_internal.weak_script_method + def forward(self, x): + return x + self.weak_submodule(x) + self.strong_submodule(x) \ + + F.linear(x, self.weight, self.bias) + + class Strong(torch.jit.ScriptModule): + __constants__ = ['constant'] + + def __init__(self): + super(Strong, self).__init__() + self.weak = Weak(10, 10) + + @torch.jit.script_method + def forward(self, x): + return x + self.weak(x) + + strong_mod = Strong() + self.assertExpectedGraph(strong_mod.graph) + inp = torch.randn(10) + result = strong_mod(inp) + expected_result = inp + (inp + inp * inp + inp + 27) + 3 \ + + F.linear(inp, torch.ones(10, 10), torch.ones(10)) \ + + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10)) + self.assertEqual(result, expected_result) + + def test_weak_module_submodule(self): + @torch._jit_internal.weak_module + class Weak(torch.nn.Module): + def __init__(self): + super(Weak, self).__init__() + self.param = torch.nn.Parameter(100 * torch.ones(5)) + + @torch._jit_internal.weak_script_method + def forward(self, x): + return x + self.param + + weak = Weak() + + class OtherStrong(torch.jit.ScriptModule): + def __init__(self): + super(OtherStrong, self).__init__() + self.weak = weak + self.weak2 = weak + + @torch.jit.script_method + def forward(self, x): + return x + self.weak(x) + + class Strong(torch.jit.ScriptModule): + def __init__(self): + super(Strong, self).__init__() + self.weak = Weak() + + @torch.jit.script_method + def forward(self, x): + return self.weak(x) + weak(x) + + other_strong_mod = OtherStrong() + + self.assertIs(other_strong_mod.weak, other_strong_mod.weak2) + + with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"): + strong_mod = Strong() + + def test_weak_module_copying(self): + class Submodule(torch.nn.Module): + def __init__(self): + super(Submodule, self).__init__() + + def forward(self, x): + return x + 100 + + @torch._jit_internal.weak_module + class Weak(torch.nn.Module): + def __init__(self, in_features, out_features): + super(Weak, self).__init__() + self.weight = torch.nn.Parameter(torch.ones(out_features, in_features)) + self.register_buffer("buffer", torch.ones(out_features)) + self.submodule = Submodule() + + @torch._jit_internal.weak_script_method + def forward(self, x): + return F.linear(x, self.weight) + self.buffer + self.submodule(x) + + class Strong(torch.jit.ScriptModule): + def __init__(self, weak): + super(Strong, self).__init__() + self.weak = weak + + @torch.jit.script_method + def forward(self, x): + return self.weak(x) + + inp = torch.ones(5, 5) * 5 + weak_mod = Weak(5, 5) + strong_mod = Strong(weak_mod) + + self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule)) + self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule)) + + self.assertIs(strong_mod.weak.weight, weak_mod.weight) + self.assertIs(strong_mod.weak.buffer, weak_mod.buffer) + self.assertIs(strong_mod.weak.submodule, weak_mod.submodule) + + # Test lookup fallback + weak_mod.new_attribute = 10 + self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute) + + weak_mod.weight.data += torch.ones(5, 5) * 100 + self.assertTrue(strong_mod(inp).allclose(weak_mod(inp))) + + # Re-assignment is not tracked + weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100) + self.assertFalse(strong_mod(inp).allclose(weak_mod(inp))) + class MnistNet(nn.Module): def __init__(self): diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 6fc06f3b0c59c..e6e4fad1f0921 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -7,11 +7,22 @@ import weakref import inspect try: - import builtins # py3 + import builtins # PY3 except Exception: - import __builtin__ as builtins # py2 + import __builtin__ as builtins # PY2 + +# Tracks standalone weak script functions +_compiled_weak_fns = weakref.WeakKeyDictionary() + +# Tracks which methods should be converted to strong methods +_weak_script_methods = weakref.WeakKeyDictionary() + +# Converted modules and their corresponding WeakScriptModuleProxy objects +_weak_modules = weakref.WeakKeyDictionary() + +# Types that have been declared as weak modules +_weak_types = weakref.WeakKeyDictionary() -compiled_weak_fns = weakref.WeakKeyDictionary() COMPILATION_PENDING = object() COMPILED = object() @@ -67,9 +78,24 @@ def weak_script(fn, _frames_up=0): inlined in the graph. When not used in a script function, the weak script annotation has no effect. """ - compiled_weak_fns[fn] = { + _compiled_weak_fns[fn] = { "status": COMPILATION_PENDING, "compiled_fn": None, "rcb": createResolutionCallback(_frames_up + 1) } return fn + + +def weak_module(cls): + _weak_types[cls] = { + "method_stubs": None + } + return cls + + +def weak_script_method(fn): + _weak_script_methods[fn] = { + "rcb": createResolutionCallback(frames_up=2), + "original_method": fn + } + return fn diff --git a/torch/_six.py b/torch/_six.py index 84ba9a464891b..924e641f638f6 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -115,3 +115,12 @@ def raise_from(value, from_value): elif PY3: import collections.abc container_abcs = collections.abc + +# Gets a function from the name of a method on a type +if PY2: + def get_function_from_type(cls, name): + method = getattr(cls, name, None) + return getattr(method, "__func__", None) +elif PY3: + def get_function_from_type(cls, name): + return getattr(cls, name, None) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index df626baa69625..d530e64ee2b00 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -283,6 +283,12 @@ std::shared_ptr toSugaredValue( return std::make_shared(obj); } } + + auto weak_obj = + py::module::import("torch.jit").attr("_try_get_weak_module")(obj); + if (!weak_obj.is_none()) { + obj = weak_obj; + } if (py::isinstance(obj)) { auto mod = py::cast>(obj); // In the case that this Python object is not a submodule, inline *ONLY diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 2b7825288ee2a..6387139d4421d 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -4,8 +4,10 @@ from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential from torch.jit.frontend import get_jit_ast, get_default_args import torch.jit.annotations -from torch._six import raise_from, with_metaclass -from .._jit_internal import createResolutionCallback, compiled_weak_fns, COMPILED, COMPILATION_PENDING +from torch._six import raise_from, with_metaclass, get_function_from_type +from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \ + _weak_script_methods, _weak_modules, _weak_types, COMPILED, \ + COMPILATION_PENDING import torch.testing from collections import defaultdict, OrderedDict, namedtuple import sys @@ -592,13 +594,13 @@ def __getattr__(self, attr): def _try_compile_weak_script(fn): - entry = compiled_weak_fns.get(fn) + entry = _compiled_weak_fns.get(fn) if entry is None: return None if entry["status"] == COMPILATION_PENDING: compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"]) del entry["rcb"] - compiled_weak_fns[fn]["compiled_fn"] = compiled_fn + _compiled_weak_fns[fn]["compiled_fn"] = compiled_fn entry["status"] = COMPILED return compiled_fn else: @@ -627,7 +629,7 @@ def script(fn, optimize=True, _frames_up=0, _rcb=None): ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method')) -def script_method(fn): +def script_method(fn, _rcb=None): if not _enabled: return fn # NOTE: we need to traverse two frames here because the meta-class frame @@ -642,9 +644,26 @@ def script_method(fn): # # createResolutionCallback internally adds 1 to get us to the scope of this # function (the calling function). Adding 2 gets us to the proper surrounding scope. - rcb = createResolutionCallback(frames_up=2) + if _rcb is None: + _rcb = createResolutionCallback(frames_up=2) ast = get_jit_ast(fn, is_method=True) - return ScriptMethodStub(rcb, ast, fn) + return ScriptMethodStub(_rcb, ast, fn) + + +def _try_get_weak_module(mod): + """ + Get the WeakScriptModuleProxy corresponding to mod if it exists + """ + if not isinstance(mod, Module): + return None + return _weak_modules.get(mod) + + +def _is_weak_type(cls): + """ + Check if a type has been annotated with `weak_module` + """ + return cls in _weak_types def batch(batch_size=1, optimize=True, _frames_up=0): @@ -817,6 +836,13 @@ def _get_valid_constant(v): " 2. a value of type {{{}}}\n".format(constants) + " 3. a list or tuple of (2)\n") + +def _create_methods_from_stubs(self, stubs): + defs = [m.def_ for m in stubs] + rcbs = [m.resolution_callback for m in stubs] + defaults = [get_default_args(m.original_method) for m in stubs] + self._create_methods(defs, rcbs, defaults) + # For each user-defined class that subclasses ScriptModule this meta-class, # (1) finds all the methods annotated with @script_method # in a ScriptModule and removes them from the class attributes, and @@ -854,10 +880,7 @@ def init_then_register(self, *args, **kwargs): if cls is type(self): torch._C.ScriptModule.__init__(self) original_init(self, *args, **kwargs) - defs = [m.def_ for m in methods] - rcbs = [m.resolution_callback for m in methods] - defaults = [get_default_args(m.original_method) for m in methods] - self._create_methods(defs, rcbs, defaults) + _create_methods_from_stubs(self, methods) cls.__init__ = init_then_register return super(ScriptMeta, cls).__init__(name, bases, attrs) @@ -1011,6 +1034,9 @@ def __getattr__(self, attr): def __setattr__(self, attr, value): if attr not in self._constants_set: + if isinstance(value, Module) and _is_weak_type(type(value)): + # Compile weak script module + value = _make_strong(value) return super(ScriptModule, self).__setattr__(attr, value) if hasattr(self, attr): raise RuntimeError("attempting to re-assign constant '{}'".format(attr)) @@ -1039,10 +1065,98 @@ def define(self, lang): # we add 1 to get to the proper surrounding scope. rcb = createResolutionCallback(frames_up=1) self._define(lang, rcb, True) + + class WeakScriptModuleProxy(ScriptModule): + def __init__(self, original, stubs): + # Guards behavior of __setattr__ and __getattr__ so ScriptModule + # __init__ can run correctly + self.__dict__['_initialized'] = False + super(WeakScriptModuleProxy, self).__init__() + + # Copy constants + self.__dict__["_original"] = weakref.ref(original) + self.__dict__["_constants_set"] = set(getattr(original, "__constants__", [])) + + # Copy Parameters / Modules / Buffers + for name in dir(original): + item = getattr(original, name) + if isinstance(item, Parameter) or (isinstance(item, Module) and item is not self): + ScriptModule.__setattr__(self, name, item) + for name in original._buffers: + self.register_buffer(name, original._buffers[name]) + + self.__dict__["_initialized"] = True + _create_methods_from_stubs(self, stubs) + + def __getattr__(self, attr): + # Try to get the attribute directly, if that fails, fall back to the + # weak module itself + try: + return ScriptModule.__getattr__(self, attr) + except AttributeError: + if self.__dict__["_initialized"]: + return getattr(self.__dict__["_original"](), attr) + else: + # Only fall back to original once __init__() is done + raise AttributeError("Weak module has no attribute '{}'" + .format(attr)) + + def __setattr__(self, attr, value): + # Once constructed, no new properties can be set + + if not self.__dict__["_initialized"]: + # If constructing, don't fall back to original module + return ScriptModule.__setattr__(self, attr, value) + + if hasattr(self, attr): + return ScriptModule.__setattr__(self, attr, value) + else: + raise AttributeError("Cannot set new attribute '{}' on " + "weak script module once it has been " + "created".format(attr)) + else: ScriptModule = torch.nn.Module +def _get_weak_stubs(cls): + """ + Calls script_method for each method on the type of the object passed in and + returns the generated ScriptMethodStubs + """ + stubs = [] + for name in dir(cls): + func = get_function_from_type(cls, name) + if func in _weak_script_methods: + entry = _weak_script_methods[func] + stub = script_method(entry["original_method"], entry["rcb"]) + stubs.append(stub) + return stubs + + +def _make_strong(mod): + """ + Converts a weak module into a subclass of ScriptModule + """ + if mod in _weak_modules: + return _weak_modules[mod] + + stubs = _weak_types.get(type(mod))["method_stubs"] + + if stubs is None: + # Generate stubs and and store on _weak_types in case this type is + # used again + stubs = _get_weak_stubs(type(mod)) + _weak_types[type(mod)]["method_stubs"] = stubs + + # Create proxy with stubs + proxy = WeakScriptModuleProxy(mod, stubs) + + _weak_modules[mod] = proxy + + return proxy + + def _get_methods(cls): import inspect # In Python 3 unbound methods are functions, but in Python 2 they are methods