Skip to content

Commit

Permalink
[Relay][Frontend][darknet] Solve tvm parsing darknet resnext failure …
Browse files Browse the repository at this point in the history
…bug (apache#3778)

* test_darkent_bug

* test_darkent

* add resnext tests
  • Loading branch information
youluexx authored and MarisaKirisame committed Sep 4, 2019
1 parent 18077a5 commit 21deb82
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,4 @@ conda/pkg

# antlr files
*.tokens
*.interp
*.interp
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,11 @@ def _get_convolution_weights(self, layer, opname):
if layer.nweights == 0:
return None

if (layer.n * layer.c * layer.size * layer.size) != layer.nweights:
if (layer.n * layer.c // layer.groups * layer.size * layer.size) != layer.nweights:
raise RuntimeError("layer weights size not matching with n c h w")

params = {}
shape = (layer.n, layer.c, layer.size, layer.size)
shape = (layer.n, layer.c // layer.groups, layer.size, layer.size)
weights = self._read_memory_buffer(shape, layer.weights)

biases = self._read_memory_buffer((layer.n, ), layer.biases)
Expand Down
13 changes: 13 additions & 0 deletions tests/python/frontend/darknet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ def test_forward_resnet50():
verify_darknet_frontend(net)
LIB.free_network(net)

def test_forward_resnext50():
'''test resnet50 model'''
model_name = 'resnext50'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
verify_darknet_frontend(net)
LIB.free_network(net)


def test_forward_yolov2():
'''test yolov2 model'''
model_name = 'yolov2'
Expand Down Expand Up @@ -441,6 +453,7 @@ def test_forward_rnn():

if __name__ == '__main__':
test_forward_resnet50()
test_forward_resnext50()
test_forward_alexnet()
test_forward_extraction()
test_forward_yolov2()
Expand Down

0 comments on commit 21deb82

Please sign in to comment.