diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 2be03c80c20b..f6f2d99e2ea5 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -334,6 +334,14 @@ def _convert_upsample(inexpr, keras_layer, _): raise TypeError("Unsupported upsampling type with different axes size : {}" .format(keras_layer.size)) params = {'scale': h} + + if hasattr(keras_layer, 'interpolation'): + interpolation = keras_layer.interpolation + if interpolation == 'nearest': + params['method'] = 'NEAREST_NEIGHBOR' + else: + params['method'] = 'BILINEAR' + elif upsample_type == 'UpSampling3D': h, w, d = keras_layer.size if h != w or w != d: diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 576c69a35523..baa2e4fc203f 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -133,9 +133,9 @@ def test_forward_conv(): verify_keras_frontend(keras_model) -def test_forward_upsample(): +def test_forward_upsample(interpolation='nearest'): data = keras.layers.Input(shape=(32,32,3)) - x = keras.layers.UpSampling2D(size=(3,3))(data) + x = keras.layers.UpSampling2D(size=(3,3), interpolation=interpolation)(data) keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model) @@ -246,7 +246,8 @@ def test_forward_mobilenet(): test_forward_dense() test_forward_pool() test_forward_conv() - test_forward_upsample() + test_forward_upsample(interpolation='nearest') + test_forward_upsample(interpolation='bilinear') test_forward_reshape() test_forward_crop() test_forward_multi_inputs()