Skip to content

Commit

Permalink
Add weak script modules (pytorch#12682)
Browse files Browse the repository at this point in the history
Summary:
Adds support for weak script modules created that get compiled to `ScriptModule`s once added as a submodule of a `ScriptModule`:

```python
weak_module
class Test(torch.nn.Module):
	...
	weak_script_method
	def forward(self, x):
		...
```
Pull Request resolved: pytorch#12682

Differential Revision: D10458626

Pulled By: driazati

fbshipit-source-id: 10ae23cb83cdafc4646cee58f399e14b2e60acd4
  • Loading branch information
David Riazati authored and facebook-github-bot committed Oct 23, 2018
1 parent 3fb3a07 commit af78d4c
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 15 deletions.
9 changes: 9 additions & 0 deletions test/expect/TestScript.test_weak_module-basic.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=55]()
%2 : int = prim::Constant[value=199]()
%3 : int = prim::Constant[value=1]()
%4 : int = aten::add(%1, %2)
%5 : Dynamic = ^python_op_in_weak_module()(%x)
%6 : Dynamic = aten::add(%5, %4, %3)
return (%6);
}
18 changes: 18 additions & 0 deletions test/expect/TestScript.test_weak_module-scope_test.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=357]()
%2 : int = prim::Constant[value=55]()
%3 : int = prim::Constant[value=199]()
%4 : int = prim::Constant[value=2]()
%5 : int = prim::Constant[value=1]()
%y : Dynamic = aten::mul(%x, %4)
%7 : Dynamic = aten::add(%y, %5, %5)
%8 : int = aten::add(%2, %3)
%9 : Dynamic = ^python_op_in_weak_module()(%y)
%10 : Dynamic = aten::add(%9, %8, %5)
%11 : Dynamic = aten::add(%7, %10, %5)
%12 : Dynamic = aten::add(%y, %1, %5)
%13 : Dynamic = ^python_op_in_strong_module()(%y)
%14 : Dynamic = aten::add(%12, %13, %5)
%15 : Dynamic = aten::add(%11, %14, %5)
return (%15);
}
20 changes: 20 additions & 0 deletions test/expect/TestScript.test_weak_module_nested.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
graph(%x : Dynamic
%1 : Dynamic
%2 : Dynamic
%3 : Dynamic
%4 : Dynamic) {
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=3]()
%7 : int = prim::Constant[value=27]()
%8 : Dynamic = aten::mul(%x, %x)
%9 : Dynamic = aten::add(%8, %6, %5)
%10 : Dynamic = aten::linear(%x, %1, %2)
%11 : Dynamic = aten::add(%9, %10, %5)
%12 : Dynamic = aten::add(%x, %11, %5)
%13 : Dynamic = aten::add(%x, %7, %5)
%14 : Dynamic = aten::add(%12, %13, %5)
%15 : Dynamic = aten::linear(%x, %3, %4)
%16 : Dynamic = aten::add(%14, %15, %5)
%17 : Dynamic = aten::add(%x, %16, %5)
return (%17);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
graph(%x : Dynamic
%1 : Dynamic
%2 : Dynamic
%3 : Dynamic
%4 : Dynamic
%5 : Dynamic
%6 : Dynamic) {
%7 : int = prim::Constant[value=1]()
%8 : Dynamic = aten::linear(%x, %1, %2)
%9 : Dynamic = aten::add(%8, %3, %7)
%10 : Dynamic = aten::add(%x, %9, %7)
%11 : Dynamic = aten::linear(%x, %1, %2)
%12 : Dynamic = aten::add(%11, %3, %7)
%13 : Dynamic = aten::add(%10, %12, %7)
%14 : Dynamic = aten::linear(%x, %4, %5)
%15 : Dynamic = aten::add(%14, %6, %7)
%16 : Dynamic = aten::add(%13, %15, %7)
return (%16);
}
287 changes: 287 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7442,6 +7442,293 @@ def foo(x):
traced = torch.jit.trace(foo, (x,))
self.assertExpectedGraph(traced.graph)

def test_weak_module(self):

@torch._jit_internal.weak_module
class Weak(torch.nn.Module):
__constants__ = ['number']

def __init__(self):
super(Weak, self).__init__()
self.number = 199

def python_op_in_weak_module(self, x):
return x + 123

@torch._jit_internal.weak_script_method
def forward(self, x):
return 55 + self.number + self.python_op_in_weak_module(x)

class OtherStrong(torch.jit.ScriptModule):
__constants__ = ['number']

def __init__(self):
super(OtherStrong, self).__init__()
self.number = 357

def python_op_in_strong_module(self, x):
return x + 456

@torch.jit.script_method
def forward(self, x):
return x + self.number + self.python_op_in_strong_module(x)

class Passthrough(torch.jit.ScriptModule):
def __init__(self):
super(Passthrough, self).__init__()
self.weak = Weak()

@torch.jit.script_method
def forward(self, x):
return self.weak(x)

weak_mod = Weak()
x = torch.ones(1)
expected_result = 55 + 199 + (x + 123)

# Ensure weak mod is running without the JIT by passing the wrong type
# (i.e. not a tensor)
weak_mod(2)

python_result = weak_mod(x)
strong_mod = Passthrough()
script_result = strong_mod(x)
self.assertEqual(python_result, expected_result)
self.assertEqual(script_result, expected_result)
self.assertExpectedGraph(strong_mod.graph, "basic")

class Strong(torch.jit.ScriptModule):
def __init__(self):
super(Strong, self).__init__()
self.weak = Weak()
self.strong = OtherStrong()

@torch.jit.script_method
def forward(self, x):
y = 2 * x
return y + 1 + self.weak(y) + self.strong(y)

strong_mod = Strong()
strong_mod2 = Strong()
x = torch.ones(1)
expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
script_result = strong_mod(x)
script_result2 = strong_mod2(x)
self.assertEqual(script_result, expected_result)
self.assertEqual(script_result, script_result2)
self.assertExpectedGraph(strong_mod.graph, "scope_test")

def test_weak_module_parameters_and_buffers(self):
import math
weights = torch.randn(10, 10)
bias = torch.randn(10)
weights2 = torch.randn(10, 10)
bias2 = torch.randn(10)

@torch._jit_internal.weak_module
class TestLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super(TestLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.bias = torch.nn.Parameter(torch.Tensor(out_features))
self.register_buffer('counter', torch.ones(out_features))
self.reset_parameters()

def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)

@torch._jit_internal.weak_script_method
def forward(self, input):
return F.linear(input, self.weight, self.bias) + self.counter

# Initialize a ScriptModule that uses the weak module above multiple times
class Strong(torch.jit.ScriptModule):
def __init__(self):
super(Strong, self).__init__()
self.fc1 = TestLinear(10, 10)
self.fc1.weight = torch.nn.Parameter(weights)
self.fc1.bias = torch.nn.Parameter(bias)
self.fc2 = TestLinear(10, 10)
self.fc2.weight = torch.nn.Parameter(weights2)
self.fc2.bias = torch.nn.Parameter(bias2)

@torch.jit.script_method
def forward(self, x):
return x + self.fc1(x) + self.fc1(x) + self.fc2(x)

strong_mod = Strong()
self.assertExpectedGraph(strong_mod.graph)

# Run same calculation as module
inp = torch.ones(10)
lin = torch.nn.Linear(10, 10)
lin.weight = torch.nn.Parameter(weights)
lin.bias = torch.nn.Parameter(bias)
lin2 = torch.nn.Linear(10, 10)
lin2.weight = torch.nn.Parameter(weights2)
lin2.bias = torch.nn.Parameter(bias2)
expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)

self.assertEqual(strong_mod(inp), expected_result)

def test_weak_module_nested(self):
@torch._jit_internal.weak_module
class OtherWeak(torch.nn.Module):
__constants__ = ['constant']

def __init__(self, in_features, out_features):
super(OtherWeak, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
self.bias = torch.nn.Parameter(torch.ones(out_features))
self.constant = 3

@torch._jit_internal.weak_script_method
def forward(self, x):
return x * x + self.constant + F.linear(x, self.weight, self.bias)

class OtherStrong(torch.jit.ScriptModule):

def __init__(self):
super(OtherStrong, self).__init__()

@torch.jit.script_method
def forward(self, x):
return x + 27

@torch._jit_internal.weak_module
class Weak(torch.nn.Module):
def __init__(self, in_features, out_features):
super(Weak, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
self.bias = torch.nn.Parameter(2 * torch.ones(out_features))
self.weak_submodule = OtherWeak(10, 10)
self.strong_submodule = OtherStrong()

@torch._jit_internal.weak_script_method
def forward(self, x):
return x + self.weak_submodule(x) + self.strong_submodule(x) \
+ F.linear(x, self.weight, self.bias)

class Strong(torch.jit.ScriptModule):
__constants__ = ['constant']

def __init__(self):
super(Strong, self).__init__()
self.weak = Weak(10, 10)

@torch.jit.script_method
def forward(self, x):
return x + self.weak(x)

strong_mod = Strong()
self.assertExpectedGraph(strong_mod.graph)
inp = torch.randn(10)
result = strong_mod(inp)
expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
+ F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
self.assertEqual(result, expected_result)

def test_weak_module_submodule(self):
@torch._jit_internal.weak_module
class Weak(torch.nn.Module):
def __init__(self):
super(Weak, self).__init__()
self.param = torch.nn.Parameter(100 * torch.ones(5))

@torch._jit_internal.weak_script_method
def forward(self, x):
return x + self.param

weak = Weak()

class OtherStrong(torch.jit.ScriptModule):
def __init__(self):
super(OtherStrong, self).__init__()
self.weak = weak
self.weak2 = weak

@torch.jit.script_method
def forward(self, x):
return x + self.weak(x)

class Strong(torch.jit.ScriptModule):
def __init__(self):
super(Strong, self).__init__()
self.weak = Weak()

@torch.jit.script_method
def forward(self, x):
return self.weak(x) + weak(x)

other_strong_mod = OtherStrong()

self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)

with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
strong_mod = Strong()

def test_weak_module_copying(self):
class Submodule(torch.nn.Module):
def __init__(self):
super(Submodule, self).__init__()

def forward(self, x):
return x + 100

@torch._jit_internal.weak_module
class Weak(torch.nn.Module):
def __init__(self, in_features, out_features):
super(Weak, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
self.register_buffer("buffer", torch.ones(out_features))
self.submodule = Submodule()

@torch._jit_internal.weak_script_method
def forward(self, x):
return F.linear(x, self.weight) + self.buffer + self.submodule(x)

class Strong(torch.jit.ScriptModule):
def __init__(self, weak):
super(Strong, self).__init__()
self.weak = weak

@torch.jit.script_method
def forward(self, x):
return self.weak(x)

inp = torch.ones(5, 5) * 5
weak_mod = Weak(5, 5)
strong_mod = Strong(weak_mod)

self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))

self.assertIs(strong_mod.weak.weight, weak_mod.weight)
self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)

# Test lookup fallback
weak_mod.new_attribute = 10
self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)

weak_mod.weight.data += torch.ones(5, 5) * 100
self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))

# Re-assignment is not tracked
weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))


class MnistNet(nn.Module):
def __init__(self):
Expand Down
Loading

0 comments on commit af78d4c

Please sign in to comment.