Skip to content

Commit

Permalink
[Relay][Frontend][keras] added interpolation method of Upsampling2D (a…
Browse files Browse the repository at this point in the history
…pache#2854)

* [Relay][Frontend][keras] added interpolation method of Upsampling2D.

* added testcase

* small fixes
  • Loading branch information
lhelontra authored and wweic committed Mar 20, 2019
1 parent b3a7b02 commit fc2f444
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ def _convert_upsample(inexpr, keras_layer, _):
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
params = {'scale': h}

if hasattr(keras_layer, 'interpolation'):
interpolation = keras_layer.interpolation
if interpolation == 'nearest':
params['method'] = 'NEAREST_NEIGHBOR'
else:
params['method'] = 'BILINEAR'

elif upsample_type == 'UpSampling3D':
h, w, d = keras_layer.size
if h != w or w != d:
Expand Down
7 changes: 4 additions & 3 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def test_forward_conv():
verify_keras_frontend(keras_model)


def test_forward_upsample():
def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.UpSampling2D(size=(3,3))(data)
x = keras.layers.UpSampling2D(size=(3,3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)

Expand Down Expand Up @@ -246,7 +246,8 @@ def test_forward_mobilenet():
test_forward_dense()
test_forward_pool()
test_forward_conv()
test_forward_upsample()
test_forward_upsample(interpolation='nearest')
test_forward_upsample(interpolation='bilinear')
test_forward_reshape()
test_forward_crop()
test_forward_multi_inputs()
Expand Down

0 comments on commit fc2f444

Please sign in to comment.