Skip to content

Commit

Permalink
move count_ops to common tvm.relay.testing
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Jun 29, 2020
1 parent 0fbedf5 commit e344cfb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pylint: disable=invalid-name
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
import collections
import numpy as np

import tvm
Expand Down Expand Up @@ -135,3 +136,18 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me

def rand(dtype, *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))


def count_ops(expr):
"""count number of times a given op is called in the graph"""
class OpCounter(tvm.relay.ExprVisitor):
def visit_call(self, call):
if hasattr(call, 'op'):
self.node_counter[call.op.name] += 1
return super().visit_call(call)
def count(self, expr):
self.node_set = {}
self.node_counter = collections.Counter()
self.visit(expr)
return self.node_counter
return OpCounter().count(expr)
16 changes: 1 addition & 15 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,10 @@
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand, count_ops
import tvm.relay.op as op


def count_ops(expr):
class OpCounter(tvm.relay.ExprVisitor):
def visit_call(self, expr):
if hasattr(expr, 'op'):
self.node_counter[expr.op.name] += 1
return super().visit_call(expr)
def count(self, expr):
self.node_set = {}
self.node_counter = collections.Counter()
self.visit(expr)
return self.node_counter
return OpCounter().count(expr)


def test_id():
shape = (10, 10)
dtype = 'float32'
Expand Down

0 comments on commit e344cfb

Please sign in to comment.