Skip to content

Commit

Permalink
clean duplicates of infer_type and infer_shape in frontends
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Aug 10, 2021
1 parent b7488ef commit f6ffa3c
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 192 deletions.
14 changes: 7 additions & 7 deletions python/tvm/relay/frontend/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import infer_shape

__all__ = ["from_caffe"]

Expand Down Expand Up @@ -84,8 +84,8 @@ def convert_eltwise(self, op):
lhs_expr = self.exp_tab.get_expr(inputs[0])
rhs_expr = self.exp_tab.get_expr(inputs[1])

lhs_shape = _infer_shape(lhs_expr)
rhs_shape = _infer_shape(rhs_expr)
lhs_shape = infer_shape(lhs_expr)
rhs_shape = infer_shape(rhs_expr)

assert lhs_shape == rhs_shape, "input tensors shape should be equal"

Expand Down Expand Up @@ -163,7 +163,7 @@ def convert_batch_norm(self, op):
"""Convert BatchNorm layer"""
inputs = op.bottom
in_expr = self.exp_tab.get_expr(inputs[0])
n, c, h, w = _infer_shape(in_expr)
n, c, h, w = infer_shape(in_expr)

if op.name in self.new_bn:
mean, var, eps, gamma, beta = self.new_bn[op.name]
Expand Down Expand Up @@ -234,7 +234,7 @@ def convert_scale(self, op):
np.zeros(gamma.shape, dtype=np.float32), dtype="float32"
)

_, c, _, _ = _infer_shape(in_expr)
_, c, _, _ = infer_shape(in_expr)
gamma_expr = _op.reshape(gamma_expr, newshape=(1, c, 1, 1))
beta_expr = _op.reshape(beta_expr, newshape=(1, c, 1, 1))
out = _op.multiply(in_expr, gamma_expr)
Expand Down Expand Up @@ -262,7 +262,7 @@ def convert_reshape(self, op):
dims = list(reshape_param.shape.dim)

in_expr = self.exp_tab.get_expr(input_name)
input_shape = list(_infer_shape(in_expr))
input_shape = list(infer_shape(in_expr))

start_axis = int(reshape_param.axis)
if start_axis < 0:
Expand Down Expand Up @@ -571,7 +571,7 @@ def convert_crop(self, op):
offset = list(getattr(crop_params, "offset", 0))

# expand offset to (offset1, offset2, ...)
in_a_shape = _infer_shape(in_expr_a)
in_a_shape = infer_shape(in_expr_a)
num_to_crop = len(in_a_shape) - axis
if not offset:
offset = [0] * num_to_crop
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ def get_name(node):

def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
if isinstance(node, tvm.relay.Var):
return node.type_annotation

if isinstance(mod, IRModule):
mod["main"] = _function.Function(tvm.relay.analysis.free_vars(node), node)
mod = _transform.InferType()(mod)
Expand All @@ -484,11 +487,16 @@ def infer_type(node, mod=None):
if mod is not None:
new_mod.update(mod)

new_mod = _transform.RemoveUnusedFunctions()(new_mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body

return ret
return ret.checked_type


def infer_type_with_prelude(val, prelude):
return infer_type(val, prelude.mod)


def fold_constant(node, mod=None):
Expand All @@ -502,15 +510,14 @@ def infer_channels(inputs, transpose=False):
these attributes. We check the shape of weights provided to get the number.
"""
out_type = infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
out_shapes = [get_const_tuple(out_type.shape)]
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels


def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
checked_type = infer_type(inputs, mod=mod)
if hasattr(checked_type, "shape"):
# Regular operator that outputs tensors
return get_const_tuple(checked_type.shape)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ... import nd as _nd
from ..._ffi import base as _base
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import infer_shape

__all__ = ["from_coreml"]

Expand Down Expand Up @@ -67,7 +67,7 @@ def _ConvolutionLayerParams(op, inexpr, etab):
dilation = list(op.dilationFactor)
if not dilation:
dilation = [1, 1]
N, C, H, W = _infer_shape(inexpr)
N, C, H, W = infer_shape(inexpr)
params = {
"channels": op.outputChannels,
"kernel_size": list(op.kernelSize),
Expand Down
Loading

0 comments on commit f6ffa3c

Please sign in to comment.