Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] clean duplicates of infer_type and infer_shape in frontends #8709

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions python/tvm/relay/frontend/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import infer_shape

__all__ = ["from_caffe"]

Expand Down Expand Up @@ -84,8 +84,8 @@ def convert_eltwise(self, op):
lhs_expr = self.exp_tab.get_expr(inputs[0])
rhs_expr = self.exp_tab.get_expr(inputs[1])

lhs_shape = _infer_shape(lhs_expr)
rhs_shape = _infer_shape(rhs_expr)
lhs_shape = infer_shape(lhs_expr)
rhs_shape = infer_shape(rhs_expr)

assert lhs_shape == rhs_shape, "input tensors shape should be equal"

Expand Down Expand Up @@ -163,7 +163,7 @@ def convert_batch_norm(self, op):
"""Convert BatchNorm layer"""
inputs = op.bottom
in_expr = self.exp_tab.get_expr(inputs[0])
n, c, h, w = _infer_shape(in_expr)
n, c, h, w = infer_shape(in_expr)

if op.name in self.new_bn:
mean, var, eps, gamma, beta = self.new_bn[op.name]
Expand Down Expand Up @@ -234,7 +234,7 @@ def convert_scale(self, op):
np.zeros(gamma.shape, dtype=np.float32), dtype="float32"
)

_, c, _, _ = _infer_shape(in_expr)
_, c, _, _ = infer_shape(in_expr)
gamma_expr = _op.reshape(gamma_expr, newshape=(1, c, 1, 1))
beta_expr = _op.reshape(beta_expr, newshape=(1, c, 1, 1))
out = _op.multiply(in_expr, gamma_expr)
Expand Down Expand Up @@ -262,7 +262,7 @@ def convert_reshape(self, op):
dims = list(reshape_param.shape.dim)

in_expr = self.exp_tab.get_expr(input_name)
input_shape = list(_infer_shape(in_expr))
input_shape = list(infer_shape(in_expr))

start_axis = int(reshape_param.axis)
if start_axis < 0:
Expand Down Expand Up @@ -571,7 +571,7 @@ def convert_crop(self, op):
offset = list(getattr(crop_params, "offset", 0))

# expand offset to (offset1, offset2, ...)
in_a_shape = _infer_shape(in_expr_a)
in_a_shape = infer_shape(in_expr_a)
num_to_crop = len(in_a_shape) - axis
if not offset:
offset = [0] * num_to_crop
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ def get_name(node):

def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
if isinstance(node, tvm.relay.Var):
return node.type_annotation

if isinstance(mod, IRModule):
mod["main"] = _function.Function(tvm.relay.analysis.free_vars(node), node)
mod = _transform.InferType()(mod)
Expand All @@ -484,11 +487,16 @@ def infer_type(node, mod=None):
if mod is not None:
new_mod.update(mod)

new_mod = _transform.RemoveUnusedFunctions()(new_mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body

return ret
return ret.checked_type


def infer_type_with_prelude(val, prelude):
return infer_type(val, prelude.mod)


def fold_constant(node, mod=None):
Expand All @@ -502,15 +510,14 @@ def infer_channels(inputs, transpose=False):
these attributes. We check the shape of weights provided to get the number.
"""
out_type = infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
out_shapes = [get_const_tuple(out_type.shape)]
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels


def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
checked_type = infer_type(inputs, mod=mod)
if hasattr(checked_type, "shape"):
# Regular operator that outputs tensors
return get_const_tuple(checked_type.shape)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ... import nd as _nd
from ..._ffi import base as _base
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import infer_shape

__all__ = ["from_coreml"]

Expand Down Expand Up @@ -67,7 +67,7 @@ def _ConvolutionLayerParams(op, inexpr, etab):
dilation = list(op.dilationFactor)
if not dilation:
dilation = [1, 1]
N, C, H, W = _infer_shape(inexpr)
N, C, H, W = infer_shape(inexpr)
params = {
"channels": op.outputChannels,
"kernel_size": list(op.kernelSize),
Expand Down
Loading