Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Predict from pretrained example #171

Merged
merged 3 commits into from
Sep 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dmlc-core
9 changes: 8 additions & 1 deletion example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ MXNet Examples
==============
This folder contains examples of MXNet.

Notebooks
--------
* [composite symbol](composite_symbol.ipynb) gives you a demo of how to composite a symbolic Inception-BatchNorm Network
* [cifar-10 recipe](cifar-recipe.ipynb) gives you a step by step demo of how to use MXNet
* [cifar-100](cifar-100.ipynb) gives you a demo of how to train a 75.68% accuracy CIFAR-100 model
* [predict with pretained model](predict-with-pretrained-model.ipynb) gives you a demo of use a pretrained Inception-BN Network


Contents
--------
* [mnist](mnist) gives examples on training mnist.
Expand All @@ -13,4 +21,3 @@ Python Howto
[Python Howto](python-howto) is a folder containing short snippet of code
introducing a certain feature of mxnet.


7 changes: 5 additions & 2 deletions example/notebooks/cifar-recipe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,12 @@
"fea_symbol = internals[\"global_avg_output\"]\n",
"\n",
"# Make a new model by using an internal symbol. We can reuse all parameters from model we trained before\n",
"# In this case, we must set ```allow_extra_params``` to True\n",
"# In this case, we must set ```allow_extra_params``` to True \n",
"# Because we don't need params of FullyConnected Layer\n",
"\n",
"feature_extractor = mx.model.FeedForward(ctx=mx.gpu(), symbol=fea_symbol, \n",
" arg_params=model.arg_params, aux_params=model.aux_params,\n",
" arg_params=model.arg_params,\n",
" aux_params=model.aux_params,\n",
" allow_extra_params=True)\n",
"# Predict as normal\n",
"global_pooling_feature = feature_extractor.predict(test_dataiter)\n",
Expand Down
241 changes: 241 additions & 0 deletions example/notebooks/predict-with-pretrained-model.ipynb

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ def update(self, label, pred):
self.sum_metric += numpy.sum(py == label)
self.num_inst += label.size

class LogLoss(EvalMetric):
"""Calculate logloss"""
def __init__(self):
self.eps = 1e-15
super(LogLoss, self).__init__('logloss')

def update(self, label, pred):
# pylint: disable=invalid-name
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
for i in range(label.size):
p = pred[i][label[i]]
assert(numpy.isnan(p) == False)
p = max(min(p, 1 - self.eps), self.eps)
self.sum_metric += -numpy.log(p)
self.num_inst += label.size

class CustomMetric(EvalMetric):
"""Custom evaluation metric that takes a NDArray function.
Expand Down Expand Up @@ -110,5 +126,7 @@ def create(metric):
raise TypeError('metric should either be callable or str')
if metric == 'acc' or metric == 'accuracy':
return Accuracy()
elif metric == 'logloss':
return LogLoss()
else:
raise ValueError('Cannot find metric %s' % metric)
3 changes: 2 additions & 1 deletion python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ def _init_predictor(self, input_shape):

for name, value in list(zip(self.symbol.list_arguments(), pred_exec.arg_arrays)):
if not self._is_data_arg(name):
assert name in self.arg_params
if not name in self.arg_params:
raise ValueError("%s not exist in arg_params" % name)
self.arg_params[name].copyto(value)
for name, value in list(zip(self.symbol.list_auxiliary_states(), pred_exec.aux_arrays)):
assert name in self.aux_params
Expand Down