Skip to content

Commit

Permalink
Don't multiply by constant 1 uselessly in dense (#5911)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jun 24, 2020
1 parent 4c78c03 commit 11815b8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,11 +995,11 @@ def _impl(inputs, input_types):
beta = inputs[3]
alpha = inputs[4]

if not isinstance(alpha, _expr.Expr):
if not isinstance(alpha, _expr.Expr) and alpha != 1:
alpha = _create_typed_const(alpha, data_type)
data *= alpha

if not isinstance(beta, _expr.Expr):
if not isinstance(beta, _expr.Expr) and beta != 1:
beta = _create_typed_const(beta, data_type)
weight *= beta

Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@

sys.setrecursionlimit(10000)

def list_ops(expr):
class OpLister(tvm.relay.ExprVisitor):
def visit_op(self, expr):
if expr not in self.node_set:
self.node_list.append(expr)
return super().visit_op(expr)
def list_nodes(self, expr):
self.node_set = {}
self.node_list = []
self.visit(expr)
return self.node_list
return OpLister().list_nodes(expr)

def assert_shapes_match(tru, est):
if tru.shape != est.shape:
Expand Down Expand Up @@ -1047,6 +1059,13 @@ def forward(self, *args):
verify_model(Dense1().float().eval(), input_data=input_data)
verify_model(Dense2().float().eval(), input_data=input_data)

trace = torch.jit.trace(Dense1(), [input_data])
mod, params = relay.frontend.from_pytorch(
trace,
[('input', input_shape)],
)
assert not any([op.name == "multiply" for op in list_ops(mod['main'])])

def test_forward_dropout():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down

0 comments on commit 11815b8

Please sign in to comment.