Skip to content

Commit

Permalink
[PYTORCH]Reduce_ops support added
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Apr 13, 2020
1 parent 0145cd5 commit 430f3b3
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 1 deletion.
40 changes: 39 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,41 @@ def _impl(inputs, input_types):
def _reduce(name):
def _impl(inputs, input_types):
data = inputs[0]
return get_relay_op(name)(data)
axis = None
keepdims = False
if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False
if isinstance(inputs[1], int):
axis = int(inputs[1])
else:
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[2])
return get_relay_op(name)(data, axis=axis, keepdims=keepdims)
return _impl

def _std():
def _impl(inputs, input_types):
data = inputs[0]
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])
if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
"estimator. Pytorch's Bessel's correction is not supported."
raise NotImplementedError(msg)
return _op.reduce.std(data, axis=axis, keepdims=keepdims)
return _impl

def _variance():
def _impl(inputs, input_types):
data = inputs[0]
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])
if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
"estimator. Pytorch's Bessel's correction is not supported."
raise NotImplementedError(msg)
return _op.reduce.variance(data, axis=axis, keepdims=keepdims)
return _impl

def _mean():
Expand Down Expand Up @@ -1381,6 +1415,10 @@ def _get_convert_map(prelude):
"aten::permute" : _transpose(prelude),
"aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"),
"aten::argmin" : _reduce("argmin"),
"aten::argmax" : _reduce("argmax"),
"aten::std" : _std(),
"aten::var" : _variance(),
"aten::sqrt" : _sqrt(),
'aten::floor' : _floor(),
"aten::detach" : _identity(),
Expand Down
138 changes: 138 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,138 @@ def forward(self, xs):

verify_script_model(RNNLoop().eval(), [(10, 10, 4)])

def test_forward_reduce_sum():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class ReduceSum1(Module):
def forward(self, *args):
return args[0].sum(1)
class ReduceSum2(Module):
def forward(self, *args):
return args[0].sum(dim=1, keepdim=False)
class ReduceSum3(Module):
def forward(self, *args):
return args[0].sum(dim=2, keepdim=True)
class ReduceSum4(Module):
def forward(self, *args):
return args[0].sum(dim=(2,3), keepdim=True)
class ReduceSum5(Module):
def forward(self, *args):
return args[0].sum(dim=(2,3), keepdim=False)
input_data = torch.rand(input_shape).float()
verify_model(ReduceSum1().float().eval(), input_data=input_data)
verify_model(ReduceSum2().float().eval(), input_data=input_data)
verify_model(ReduceSum3().float().eval(), input_data=input_data)
verify_model(ReduceSum4().float().eval(), input_data=input_data)
verify_model(ReduceSum5().float().eval(), input_data=input_data)

def test_forward_reduce_prod():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class ReduceProd1(Module):
def forward(self, *args):
return args[0].prod(1)
class ReduceProd2(Module):
def forward(self, *args):
return args[0].prod(dim=1, keepdim=False)
class ReduceProd3(Module):
def forward(self, *args):
return args[0].prod(dim=2, keepdim=True)
input_data = torch.rand(input_shape).float()
verify_model(ReduceProd1().float().eval(), input_data=input_data)
verify_model(ReduceProd2().float().eval(), input_data=input_data)
verify_model(ReduceProd3().float().eval(), input_data=input_data)

def test_forward_argmin():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class ArgMin1(Module):
def forward(self, *args):
return args[0].argmin(1)
class ArgMin2(Module):
def forward(self, *args):
return args[0].argmin(dim=1, keepdim=False)
class ArgMin3(Module):
def forward(self, *args):
return args[0].argmin(dim=2, keepdim=True)
input_data = torch.rand(input_shape).float()
verify_model(ArgMin1().float().eval(), input_data=input_data)
verify_model(ArgMin2().float().eval(), input_data=input_data)
verify_model(ArgMin3().float().eval(), input_data=input_data)

def test_forward_argmax():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class ArgMax1(Module):
def forward(self, *args):
return args[0].argmax(1)
class ArgMax2(Module):
def forward(self, *args):
return args[0].argmax(dim=1, keepdim=False)
class ArgMax3(Module):
def forward(self, *args):
return args[0].argmax(dim=2, keepdim=True)
input_data = torch.rand(input_shape).float()
verify_model(ArgMax1().float().eval(), input_data=input_data)
verify_model(ArgMax2().float().eval(), input_data=input_data)
verify_model(ArgMax3().float().eval(), input_data=input_data)

def test_forward_std():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Std1(Module):
def forward(self, *args):
return args[0].std(1, unbiased=False)
class Std2(Module):
def forward(self, *args):
return args[0].std(dim=1, keepdim=False, unbiased=False)
class Std3(Module):
def forward(self, *args):
return args[0].std(dim=2, keepdim=True, unbiased=False)
class Std4(Module):
def forward(self, *args):
return args[0].std(dim=(2,3), keepdim=True, unbiased=False)
class Std5(Module):
def forward(self, *args):
return args[0].std(dim=(2,3), keepdim=False, unbiased=False)
input_data = torch.rand(input_shape).float()
verify_model(Std1().float().eval(), input_data=input_data)
verify_model(Std2().float().eval(), input_data=input_data)
verify_model(Std3().float().eval(), input_data=input_data)
verify_model(Std4().float().eval(), input_data=input_data)
verify_model(Std5().float().eval(), input_data=input_data)

def test_forward_var():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Var1(Module):
def forward(self, *args):
return args[0].var(1, unbiased=False)
class Var2(Module):
def forward(self, *args):
return args[0].var(dim=1, keepdim=False, unbiased=False)
class Var3(Module):
def forward(self, *args):
return args[0].var(dim=2, keepdim=True, unbiased=False)
class Var4(Module):
def forward(self, *args):
return args[0].var(dim=(2,3), keepdim=True, unbiased=False)
class Var5(Module):
def forward(self, *args):
return args[0].var(dim=(2,3), keepdim=False, unbiased=False)
input_data = torch.rand(input_shape).float()
verify_model(Var1().float().eval(), input_data=input_data)
verify_model(Var2().float().eval(), input_data=input_data)
verify_model(Var3().float().eval(), input_data=input_data)
verify_model(Var4().float().eval(), input_data=input_data)
verify_model(Var5().float().eval(), input_data=input_data)


if __name__ == "__main__":
# Single operator tests
Expand All @@ -1291,6 +1423,12 @@ def forward(self, xs):
test_forward_squeeze()
test_forward_unsqueeze()
test_forward_concatenate()
test_forward_reduce_sum()
test_forward_reduce_prod()
test_forward_argmin()
test_forward_argmax()
test_forward_std()
test_forward_var()
test_forward_relu()
test_forward_prelu()
test_forward_leakyrelu()
Expand Down

0 comments on commit 430f3b3

Please sign in to comment.