Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update code to use gluon model zoo; update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
mseth10 committed May 29, 2020
1 parent 709d372 commit 568b35d
Showing 1 changed file with 37 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,29 @@

# Image Classication using pretrained ResNet-50 model on Jetson module

This tutorial shows how to install latest MXNet v1.6 with Jetson support and use it to deploy a pre-trained MXNet model for image classification on a Jetson module.
This tutorial shows how to install MXNet v1.6 with Jetson support and use it to deploy a pre-trained MXNet model for image classification on a Jetson module.

## What's in this tutorial?

This tutorial shows how to:

1. Install MXNet v1.6 with Jetson support along with its dependencies
1. Install MXNet v1.6 along with its dependencies on a Jetson module (This tutorial has been tested on Jetson Xavier AGX and Jetson Nano modules)

2. Deploy a pre-trained MXNet model for image classifcation on a Jetson module
2. Deploy a pre-trained MXNet model for image classifcation on the module

### Who's this tutorial for?
## Who's this tutorial for?

This tutorial would benefit developers working on any Jetson module implementing a deep learning application. It assumes that readers have a Jetson module setup, are familiar with the Jetson working environment and are somewhat familiar with deep learning using MXNet.
This tutorial would benefit developers working on Jetson modules implementing deep learning applications. It assumes that readers have a Jetson module setup with Jetpack installed, are familiar with the Jetson working environment and are somewhat familiar with deep learning using MXNet.

### How to use this tutorial?

To follow this tutorial, you need to setup a [Jetson module](https://developer.nvidia.com/embedded/develop/hardware) and install latest [Jetpack 4.4](https://docs.nvidia.com/jetson/jetpack/release-notes/) using NVIDIA [SDK manager](https://developer.nvidia.com/nvidia-sdk-manager).
## Prerequisites

All instructions described in this tutorial can be executed on the any Jetson module directly or via SSH.
To complete this tutorial, you need:

## Prerequisites
* A [Jetson module](https://developer.nvidia.com/embedded/develop/hardware) setup with [Jetpack 4.4](https://docs.nvidia.com/jetson/jetpack/release-notes/) installed using NVIDIA [SDK Manager](https://developer.nvidia.com/nvidia-sdk-manager)

To complete this tutorial, you will need:
* An SSH connection to the module OR display and keyboard setup to directly open shell on the module

* A Jetson module with Jetpack 4.4 installed
* [Swapfile](https://help.ubuntu.com/community/SwapFaq) installed (in case of Jetson Nano) for additional memory
* [Swapfile](https://help.ubuntu.com/community/SwapFaq) installed, especially on Jetson Nano for additional memory (increase memory if the inference script terminates with a `Killed` message)

## Installing MXNet v1.6 with Jetson support

Expand All @@ -69,62 +66,44 @@ And we are done. You can test the installation now by importing mxnet from pytho
We are now ready to run a pre-trained model and run inference on a Jetson module. In this tutorial we are using ResNet-50 model trained on Imagenet dataset. We run the following classification script with either cpu/gpu context using python3.

```python
from mxnet.gluon import nn
from mxnet import gluon
import mxnet as mx
import numpy as np
import urllib.request
import cv2

# set context
ctx = mx.gpu()
dtype = 'float32'
bsize = 1

# download model files
path = 'http://data.mxnet.io/models/imagenet/'
symbol,_ = urllib.request.urlretrieve(path+'resnet/50-layers/resnet-50-symbol.json')
params,_ = urllib.request.urlretrieve(path+'resnet/50-layers/resnet-50-0000.params')
label_file,_ = urllib.request.urlretrieve(path+'synset.txt')

# load model
input_names = ['data', 'softmax_label']
net = nn.SymbolBlock.imports(symbol, input_names, params, ctx)
net.cast(dtype)

# load pre-trained model
net = gluon.model_zoo.vision.resnet50_v1(pretrained=True, ctx=ctx)
net.hybridize(static_alloc=True, static_shape=True)

# load labels
with open(label_file, 'r') as f:
lbl_path = gluon.utils.download('http://data.mxnet.io/models/imagenet/synset.txt')
with open(lbl_path, 'r') as f:
labels = [l.rstrip() for l in f]

# load image
img_file,_ = urllib.request.urlretrieve('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true')
img = cv2.imread(img_file)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224,))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)

# format input
batch = mx.nd.zeros((bsize,) + img.shape)
for i in range(bsize):
batch[i] = img
inputs = batch.astype(dtype)
mx_img = [mx.nd.array(inputs,ctx), mx.nd.zeros((bsize,),ctx)]

# infer
results = net(*mx_img)
prob = results[0].asnumpy()
prob = np.squeeze(prob)
a = np.argsort(prob)[::-1]
for i in a[0:5]:
print('probability=%f, class=%s' %(prob[i], labels[i]))
# download and format image as (batch, RGB, width, height)
img_path = gluon.utils.download('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true')
img = mx.image.imread(img_path)
img = mx.image.imresize(img, 224, 224) # resize
img = mx.image.color_normalize(img.astype(dtype='float32')/255,
mean=mx.nd.array([0.485, 0.456, 0.406]),
std=mx.nd.array([0.229, 0.224, 0.225])) # normalize
img = img.transpose((2, 0, 1)) # channel first
img = img.expand_dims(axis=0) # batchify
img = img.as_in_context(ctx)

prob = net(img).softmax() # predict and normalize output
idx = prob.topk(k=5)[0] # get top 5 result
for i in idx:
i = int(i.asscalar())
print('With prob = %.5f, it contains %s' % (prob[0,i].asscalar(), labels[i]))
```

After running the above script, you should get the following output showing the five classes that the image most relates to with probability:
```bash
probability=0.418679, class=n02119789 kit fox, Vulpes macrotis
probability=0.293494, class=n02119022 red fox, Vulpes vulpes
probability=0.029321, class=n02120505 grey fox, gray fox, Urocyon cinereoargenteus
probability=0.026230, class=n02124075 Egyptian cat
probability=0.022557, class=n02085620 Chihuahua
With prob = 0.41940, it contains n02119789 kit fox, Vulpes macrotis
With prob = 0.28096, it contains n02119022 red fox, Vulpes vulpes
With prob = 0.06857, it contains n02124075 Egyptian cat
With prob = 0.03046, it contains n02120505 grey fox, gray fox, Urocyon cinereoargenteus
With prob = 0.02770, it contains n02441942 weasel
```

0 comments on commit 568b35d

Please sign in to comment.