diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 8310a0202c17b..404bb80120232 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -17,6 +17,7 @@ #pylint: disable=invalid-name """Utilities for testing and benchmarks""" from __future__ import absolute_import as _abs +import collections import numpy as np import tvm @@ -135,3 +136,18 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me def rand(dtype, *shape): return tvm.nd.array(np.random.rand(*shape).astype(dtype)) + + +def count_ops(expr): + """count number of times a given op is called in the graph""" + class OpCounter(tvm.relay.ExprVisitor): + def visit_call(self, call): + if hasattr(call, 'op'): + self.node_counter[call.op.name] += 1 + return super().visit_call(call) + def count(self, expr): + self.node_set = {} + self.node_counter = collections.Counter() + self.visit(expr) + return self.node_counter + return OpCounter().count(expr) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 66743c65ca98b..4838c6a4e7fce 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -25,24 +25,10 @@ from tvm.relay import create_executor, transform from tvm.relay.transform import gradient from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand +from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand, count_ops import tvm.relay.op as op -def count_ops(expr): - class OpCounter(tvm.relay.ExprVisitor): - def visit_call(self, expr): - if hasattr(expr, 'op'): - self.node_counter[expr.op.name] += 1 - return super().visit_call(expr) - def count(self, expr): - self.node_set = {} - self.node_counter = collections.Counter() - self.visit(expr) - return self.node_counter - return OpCounter().count(expr) - - def test_id(): shape = (10, 10) dtype = 'float32'