Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[NDArray] add clip op #149

Merged
merged 1 commit into from
Sep 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
579 changes: 0 additions & 579 deletions example/notebooks/alexnet.ipynb

This file was deleted.

57 changes: 30 additions & 27 deletions example/notebooks/cifar-recipe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {
"collapsed": false
},
Expand All @@ -247,16 +247,16 @@
"output_type": "stream",
"text": [
"INFO:root:Start training with [gpu(0)]\n",
"INFO:root:Batch [50]\tSpeed: 1091.84 samples/sec\n",
"INFO:root:Batch [100]\tSpeed: 1084.80 samples/sec\n",
"INFO:root:Batch [150]\tSpeed: 1084.55 samples/sec\n",
"INFO:root:Batch [200]\tSpeed: 1077.30 samples/sec\n",
"INFO:root:Batch [250]\tSpeed: 1074.73 samples/sec\n",
"INFO:root:Batch [300]\tSpeed: 1075.67 samples/sec\n",
"INFO:root:Batch [350]\tSpeed: 1067.09 samples/sec\n",
"INFO:root:Iteration[0] Train-accuracy=0.525695\n",
"INFO:root:Iteration[0] Time cost=47.012\n",
"INFO:root:Iteration[0] Validation-accuracy=0.660008\n"
"INFO:root:Batch [50]\tSpeed: 1003.50 samples/sec\n",
"INFO:root:Batch [100]\tSpeed: 976.31 samples/sec\n",
"INFO:root:Batch [150]\tSpeed: 975.57 samples/sec\n",
"INFO:root:Batch [200]\tSpeed: 964.21 samples/sec\n",
"INFO:root:Batch [250]\tSpeed: 963.53 samples/sec\n",
"INFO:root:Batch [300]\tSpeed: 963.95 samples/sec\n",
"INFO:root:Batch [350]\tSpeed: 963.71 samples/sec\n",
"INFO:root:Iteration[0] Train-accuracy=0.520520\n",
"INFO:root:Iteration[0] Time cost=52.424\n",
"INFO:root:Iteration[0] Validation-accuracy=0.652393\n"
]
}
],
Expand All @@ -272,14 +272,14 @@
"# eval_data=test_dataiter,\n",
"# eval_metric=\"accuracy\",\n",
"# epoch_end_callback=mx.helper.Speedometer(batch_size),\n",
"# iter_end_callback=mx.model.do_checkpoint(model_prefix))\n"
"# iter_end_callback=mx.callback.do_checkpoint(model_prefix))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After only 1 epoch, our model is able to acheive about 66% accuracy on testset.\n",
"After only 1 epoch, our model is able to acheive about 65% accuracy on testset.\n",
"We can save our model by calling either ```save``` or using ```pickle```.\n"
]
},
Expand Down Expand Up @@ -348,7 +348,7 @@
"output_type": "stream",
"text": [
"INFO:root:Finish predict...\n",
"INFO:root:final accuracy = 0.651000\n"
"INFO:root:final accuracy = 0.652600\n"
]
}
],
Expand Down Expand Up @@ -385,33 +385,36 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"ename": "TypeError",
"evalue": "Symbol only support integer index to fetch i-th output",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-17-0e3d13f4a151>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0minternals\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msoftmax\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_internals\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mfea_symbol\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minternals\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m\"global_avg_output\"\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m feature_extractor = mx.model.FeedForward(ctx=mx.gpu(), symbol=group, \n",
"\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/symbol.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m 156\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 157\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 158\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Symbol only support integer index to fetch i-th output'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 159\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mSymbolHandle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 160\u001b[0m check_call(_LIB.MXSymbolGetOutput(\n",
"\u001b[1;31mTypeError\u001b[0m: Symbol only support integer index to fetch i-th output"
"name": "stdout",
"output_type": "stream",
"text": [
"(10000, 336, 1, 1)\n"
]
}
],
"source": [
"# predict internal featuremaps\n",
"# Predict internal featuremaps\n",
"# From a symbol, we are able to get all internals. Note it is still a symbol\n",
"internals = softmax.get_internals()\n",
"\n",
"# We get get an internal symbol for the feature.\n",
"# By default, the symbol is named as \"symbol_name + _output\"\n",
"# in this case we'd like to get global_avg\" layer's output as feature, so its \"global_avg_output\"\n",
"# You may call ```internals.list_outputs()``` to find the target\n",
"# but we strongly suggests set a special name for special symbol \n",
"fea_symbol = internals[\"global_avg_output\"]\n",
"\n",
"feature_extractor = mx.model.FeedForward(ctx=mx.gpu(), symbol=group, \n",
"# Make a new model by using an internal symbol. We can reuse all parameters from model we trained before\n",
"# In this case, we must set ```allow_extra_params``` to True\n",
"feature_extractor = mx.model.FeedForward(ctx=mx.gpu(), symbol=fea_symbol, \n",
" arg_params=model.arg_params, aux_params=model.aux_params,\n",
" allow_extra_params=True)\n",
"# Predict as normal\n",
"global_pooling_feature = feature_extractor.predict(test_dataiter)\n",
"print(global_pooling_feature.shape)"
]
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from . import optimizer
from . import model
from . import initializer
# use mx.init as short for mx.initializer
from . import initializer as init
from . import visualization
# use viz as short for mx.ndarray
from . import visualization as viz
Expand Down
13 changes: 13 additions & 0 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ def _slice(self, start, stop):
self.handle, start, stop, ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)

def clip(self, value):
"""Clip NDArray to range [-value, value] and remove NaN

Parameters
----------
value: float
cliped range
"""
return NDArray._clip_scalar(self, float(value))

def wait_to_read(self):
"""Block until all pending writes operations on current NDArray are finished.

Expand Down Expand Up @@ -636,6 +646,9 @@ def generic_ndarray_function(*args, **kwargs):
ret_function.__name__ = func_name
ret_function.__doc__ = doc_str
return ret_function



# pylint: enable=too-many-locals, invalid-name

def _init_ndarray_module():
Expand Down
13 changes: 11 additions & 2 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,19 @@ class SGD(Optimizer):

rescale_grad : float, optional
rescaling factor of gradient.

clip_gradient : float, optional
clip gradient in range [-clip_gradient, clip_gradient]
"""
def __init__(self, learning_rate=0.01, momentum=0.0,
wd=0.0001, rescale_grad=1, lr_scheduler=None):
wd=0.0001, rescale_grad=1, clip_gradient=None,
lr_scheduler=None):
super(SGD, self).__init__()
self.lr = learning_rate
self.momentum = momentum
self.wd = wd
self.rescale_grad = rescale_grad
self.clip_gradient = clip_gradient
self.lr_scheduler = lr_scheduler
if lr_scheduler != None:
self.lr_scheduler.base_lr = learning_rate
Expand Down Expand Up @@ -89,7 +94,11 @@ def update(self, index, weight, grad, state):
if state:
mom = state
mom[:] *= self.momentum
mom[:] += -lr * (grad * self.rescale_grad + self.wd * weight)
if self.clip_gradient == None:
mom[:] += -lr * (grad * self.rescale_grad + self.wd * weight)
else:
mom[:] += -lr * (grad.clip(self.clip_gradient) * self.rescale_grad +
self.wd * weight)
weight[:] += mom
else:
assert self.momentum == 0.0
Expand Down
3 changes: 2 additions & 1 deletion src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ NDArray operator*(const NDArray &lhs, const real_t &rhs) {
NDArray operator/(const NDArray &lhs, const real_t &rhs) {
return ScalarOpRet<ndarray::Div, false>(lhs, rhs);
}

// Binary
NDArray &NDArray::operator=(real_t scalar) {
SetValueOp(scalar, this);
Expand Down Expand Up @@ -510,7 +511,7 @@ MXNET_REGISTER_NDARRAY_FUN(_plus_scalar).set_function(ScalarOp<ndarray::Plus, fa
MXNET_REGISTER_NDARRAY_FUN(_minus_scalar).set_function(ScalarOp<ndarray::Minus, false>);
MXNET_REGISTER_NDARRAY_FUN(_mul_scalar).set_function(ScalarOp<ndarray::Mul, false>);
MXNET_REGISTER_NDARRAY_FUN(_div_scalar).set_function(ScalarOp<ndarray::Div, false>);

MXNET_REGISTER_NDARRAY_FUN(_clip_scalar).set_function(ScalarOp<ndarray::Clip, false>);
// register API function
// scalar, reverse scalar
MXNET_REGISTER_NDARRAY_FUN(_rminus_scalar).set_function(ScalarOp<ndarray::Minus, true>);
Expand Down
2 changes: 2 additions & 0 deletions src/ndarray/ndarray_function-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,13 @@ DECL_SCALAR(DEVICE, Plus, EvalScalar_, true)
DECL_SCALAR(DEVICE, Minus, EvalScalar_, true)
DECL_SCALAR(DEVICE, Mul, EvalScalar_, true)
DECL_SCALAR(DEVICE, Div, EvalScalar_, true)
DECL_SCALAR(DEVICE, Clip, EvalScalar_, true)
// for reverse seq
DECL_SCALAR(DEVICE, Plus, EvalScalar_, false)
DECL_SCALAR(DEVICE, Minus, EvalScalar_, false)
DECL_SCALAR(DEVICE, Mul, EvalScalar_, false)
DECL_SCALAR(DEVICE, Div, EvalScalar_, false)
DECL_SCALAR(DEVICE, Clip, EvalScalar_, false)
} // namespace ndarray
} // namespace mxnet

Expand Down
10 changes: 10 additions & 0 deletions src/ndarray/ndarray_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ struct Div : public BinaryBase {
typedef mshadow::op::div mshadow_op;
};

struct Clip : public BinaryBase {
struct mshadow_op {
MSHADOW_XINLINE static real_t Map(real_t a, real_t b) {
if (isnan(a)) return 0.0f;
if (a < -b) return -b;
if (a > b) return b;
return a;
}
};
};
// type holder for random number generators
struct UniformDistribution {};

Expand Down