diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 93bd8efc67528..dd63042b24540 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -430,6 +430,20 @@ def _mx_roi_align(inputs, attrs): new_attrs["layout"] = "NCHW" return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs) +def _mx_upsampling(inputs, attrs): + scale_height = attrs.get_float("scale_height", None) + scale_width = attrs.get_float("scale_width", None) + if scale_height == None: + height = attrs.get_int("height", 1) + scale_height = float(height) // inputs[0].shape[2] + if scale_width == None: + width = attrs.get_int("width", 1) + scale_width = float(width) // inputs[0].shape[3] + assert scale_height == scale_width + scale = scale_width + layout = 'NCHW' + method = 'BILINEAR' + return _op.nn.upsampling(inputs[0], scale=scale, layout=layout, method=method) def _mx_proposal(inputs, attrs): new_attrs = {} @@ -616,6 +630,7 @@ def _mx_l2_normalize(inputs, attrs): "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, # vision + "_contrib_BilinearResize2D" : _mx_upsampling, "_contrib_MultiBoxPrior" : _mx_multibox_prior, "_contrib_MultiBoxDetection" : _mx_multibox_detection, "_contrib_ROIAlign" : _mx_roi_align,