Skip to content

Commit

Permalink
Pre-trained model tutorial fixes. (apache#6453)
Browse files Browse the repository at this point in the history
Before the change on running the tutorial for the first time: "UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])". It also showed probability of >>1 due to incorrect usage of np.argsort().
  • Loading branch information
pracheer authored and piiswrong committed May 26, 2017
1 parent 7b15c56 commit 9fb9868
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/basic/ndarray.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ to `numpy.ndarray`. Like the corresponding NumPy data structure, MXNet's
So you might wonder, why not just use NumPy? MXNet offers two compelling
advantages. First, MXNet's `NDArray` supports fast execution on a wide range of
hardware configurations, including CPU, GPU, and multi-GPU machines. _MXNet_
also scales to distribute systems in the cloud. Second, MXNet's NDArray
also scales to distributed systems in the cloud. Second, MXNet's `NDArray`
executes code lazily, allowing it to automatically parallelize multiple
operations across the available hardware.

Expand Down
11 changes: 6 additions & 5 deletions docs/tutorials/python/predict_image.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ occurances of `mx.cpu()` with `mx.gpu()` to accelerate the computation.

```python
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
mod = mx.mod.Module(symbol=sym, context=mx.cpu())
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
mod.set_params(arg_params, aux_params)
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))],
label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
```
Expand Down Expand Up @@ -68,8 +69,8 @@ def predict(url):
prob = mod.get_outputs()[0].asnumpy()
# print the top-5
prob = np.squeeze(prob)
prob = np.argsort(prob)[::-1]
for i in prob[0:5]:
a = np.argsort(prob)[::-1]
for i in a[0:5]:
print('probability=%f, class=%s' %(prob[i], labels[i]))
```

Expand Down

0 comments on commit 9fb9868

Please sign in to comment.