Skip to content

Commit

Permalink
merge with upstream/master. resolve conflict in c_api_ndarray.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Aug 5, 2017
2 parents 1f07771 + 0d8d27e commit d511938
Show file tree
Hide file tree
Showing 78 changed files with 2,371 additions and 471 deletions.
12 changes: 12 additions & 0 deletions DISCLAIMER
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Apache MXNet (incubating) is an effort undergoing incubation at The
Apache Software Foundation (ASF), sponsored by the Apache Incubator PMC.

Incubation is required of all newly accepted
projects until a further review indicates that the
infrastructure, communications, and decision making process have
stabilized in a manner consistent with other successful ASF
projects.

While incubation status is not necessarily a reflection
of the completeness or stability of the code, it does indicate
that the project has yet to be fully endorsed by the ASF.
22 changes: 22 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,28 @@ try {
}
}
},
'Perl: CPU': {
node('mxnetlinux') {
ws('workspace/ut-perl-cpu') {
init_git()
unpack_lib('cpu')
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} cpu ./perl-package/test.sh /workspace/ut-perl-cpu /workspace/ut-perl-cpu"
}
}
}
},
'Perl: GPU': {
node('mxnetlinux') {
ws('workspace/ut-perl-gpu') {
init_git()
unpack_lib('gpu')
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} gpu ./perl-package/test.sh /workspace/ut-perl-gpu /workspace/ut-perl-gpu"
}
}
}
},
'R: CPU': {
node('mxnetlinux') {
ws('workspace/ut-r-cpu') {
Expand Down
132 changes: 132 additions & 0 deletions KEYS
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
This file contains the PGP keys of various developers.
Please don't use them for email unless you have to. Their main
purpose is code signing.

Examples of importing this file in your keystore:
gpg --import KEYS.txt
(need pgp and other examples here)

Examples of adding your key to this file:
pgp -kxa <your name> and append it to this file.
(pgpk -ll <your name> && pgpk -xa <your name>) >> this file.
(gpg --list-sigs <your name>
&& gpg --armor --export <your name>) >> this file.

-----------------------------------------------------------------------------------
pub 4096R/D3541808 2014-01-09
uid [ultimate] Suneel Marthi (CODE SIGNING KEY) <[email protected]>
sig 3 D3541808 2014-01-09 Suneel Marthi (CODE SIGNING KEY) <[email protected]>
sub 4096R/AF46E2DE 2014-01-09
sig D3541808 2014-01-09 Suneel Marthi (CODE SIGNING KEY) <[email protected]>

-----BEGIN PGP PUBLIC KEY BLOCK-----
Comment: GPGTools - https://gpgtools.org

mQINBFLPJmEBEAC9d/dUZCXeyhB0fVGmJAjdjXfLebav4VqGdNZC+M1T9C3dcVsh
X/JGme5bjJeIgVwiH5UsdNceYn1+hyxs8jXuRAWEWKP76gD+pNrp8Az0ZdBkJoAy
zCywOPtJV2PCOz7+S5ri2nUA2+1Kgcu6IlSLMmYAGO0IAmRrjBEzxy9iGaxiNGTc
LvQt/iVtIXWkKKI8yvpoJ8iFf3TGhpjgaC/h7cJP3zpy0SScmhJJASLXRsfocLv9
sle6ndN9IPbDtRW8cL7Fk3VQlzp1ToVjmnQTyZZ6S1WafsjzCZ9hLN+k++o8VbvY
v3icY6Sy0BKz0J6KwaxTkuZ6w1K7oUkVOQboKaWFIEdO+jwrEmU+Puyd8Np8jLnF
Q0Y5GPfyMlqM3S/zaDm1t4D1eb5FLciStkxfg5wPVK6TkqB325KVD3aio5C7E7kt
aQechHxaJXCQOtCtVY4X+L4iClnMSuk+hcSc8W8MYRTSVansItK0vI9eQZXMnpan
w9/jk5rS4Gts1rHB7+kdjT3QRJmkyk6fEFT0fz5tfMC7N8waeEUhCaRW6lAoiqDW
NW1h+0UGxJw+9YcGxBC0kkt3iofNOWQWmuf/BS3DHPKT7XV/YtBHe44wW0sF5L5P
nfQUHpnA3pcZ0En6bXAvepKVZTNdOWWJqMyHV+436DA+33h45QL6lWb/GwARAQAB
tDVTdW5lZWwgTWFydGhpIChDT0RFIFNJR05JTkcgS0VZKSA8c21hcnRoaUBhcGFj
aGUub3JnPokCNwQTAQoAIQUCUs8mYQIbAwULCQgHAwUVCgkICwUWAgMBAAIeAQIX
gAAKCRC08czE01QYCOKKEAChRtHBoYNTX+RZbFO0Kl1GlN+i1Ik0shEm5ZJ56XHv
AnFx/gRK7CfZzJswWo7kf2s/dvJiFfs+rrolYVuO6E8gNhAaTEomSuvWQAMHdPcR
9G5APRKCSkbZYugElqplEbSphk78FKoFO+sml52M7Pr9jj88ApBjoFVVY8njdnNq
6DVlaDsg8YninCD78Z7PNFnRGwxyZ8Qd4Dh0rG+MUTfAWopZu6/MxpQxU7QpeVeX
SIMLg7ClFrGfXnZcszYF4dnav1aa0i7W88PAdYNPko7tC5qz5yv2ep7t2gRbcYKf
RXhYC2FHQey3wPhMKjA8V436lAqmfYnY/YdmhEy9Xq/1EdX1nHsQ7OEkfgXK14WM
F+rnqXRAl/0cwiyb41eocdg5kpZFIKgCYT02usLWxwNnd3jOCe109Ze3y3acN/G8
+xOf9YRfNVAe6pD8H6ieRbv9gRjBmsbz9bXQCmxFnDqxNri5Me6gBAQPNmYTJD0h
jgJTK6o0vJ0pwjBLauasJsLu+1tR3Cb0dxPE+JVaTF26FCd7pM7W6KdVfod9ZfrN
cSyJ/cECc2KvYVGmTjQNVo1dYG0awBachlWnYNt+0Qx4opLsczZOLtPKtFY4BJA7
aZoXT4Qf9yB8km7x2/cgNExVbFummToJ/IP3M39/EaryspsQQuM5Qu5Q5lZp8Qnn
ybkCDQRSzyZhARAA7bAawFzbJaghYnm6mTZyGG5hQmfAynbF6cPAE+g2SnXcNQjP
6kjYx3tSpb7rEzmjQqs46ztqdec6PIVBMhakON6z27Zz+IviAtO/TcaZHWNuCAjw
FXVQZ+tYsSeiKInttfkrQc8jXAHWwSkSjLqNpvQpBdBEX80MYkFB6ZPOeON2+/Ta
GC1H/HU2YngF0qQSmG33KKG6ezihBJdKxU6t2tsQfTlCmZW6R6MGpS9fVurYMKBk
vR+7RGZ/H6dSjWPcpxhusGg92J9uz7r5SopN1wSdyPMUCMAFGeyoxcAuBDl38quU
H/ENG3x5LDPq2aEH2AJ6yvZfIXbeJ1zmXf2cAHv+HbmvZaTSp0XIjq8Yxh8NkYEC
ZdfRWmsGLIpU16TkBijpK3Dn9MDXjHGT3V8/qfdpURtMvIaL8WFrq9ejcy/vGRFn
mCYqxIIPH+vLiMXKWtuMc61GN3ES21msKQH6IuQxxfQLyhK44L/pv7FpF4E+6LaE
8uRwAex5HIDpR1v4aJq089rRtye9VXTJJLZ7lYs0HctdZ30QbBRWT4jS9d9rj3cr
HgQ7mIGO9TAfK2kWc6AJN/EvxPWNbOwptsTUzAF/adiy9ax8C18iw7nKczC+2eN6
UcbxXiPdytuKYK7O9A8S9e1w89GwpxYN7Xfn2o6QfpSbL9cLKiinOeV+xikAEQEA
AYkCHwQYAQoACQUCUs8mYQIbDAAKCRC08czE01QYCG7yD/471dmyOD+go8cZkdqR
3CHhjH03odtI0EJNVy4VGEC0r9paz3BWYTy18LqWYkw3ygphOIU1r8/7QK3H5Ke3
c4yCSUxaMk5SlAJ+iVRek5TABkR8+zI+ZN5pQtqRH+ya5JxV4F/Sx5Q3KWMzpvgY
n6AgSSc3hEfkgdI7SalIeyLaLDWv+RFdGZ5JU5gD28C0G8BeH8L62x6sixZcqoGT
oy9rwkjs45/ZmmvBZhd1wLvC/au8l2Ecou6O8+8m26W8Z7vCuGKxuWn0KV3DLLWe
66uchDVlakGoMJSPIK06JWYUlE+gL0CW+U2ekt/v2qb8hGgMVET3CBAMq+bFWuJ6
juX7hJd7wHtCFfjnFDDAkdp2IIIZAlBW6FZGv7pJ82xsW6pSAg0A7VrV6nTtMtDv
T8esOfo/t4t0gaL7bivy9DVVdATbUBcJJFpoVoe5MxiyjptveqPzIRwzt04n52Ph
ordVWAnX5AokXWTg+Glem/EWEuf7jUuZArfqCSl/sZoQdXGTjR7G4iFscispji4+
kNjVQsItqFbgDpuc6n+GcFxlKQ7YMCnu5MVtTV01U4lFs0qy0NTUqsuR35DM4z14
DkFmj1upWAayCoXTpKzsHBvJZPC+Wqf9Pl3O47apelg7KxU3S011YfXpVPvCTKBv
kD2o/5GKWS5QkSUEUXXY1oDiLg==
=f8kJ
-----END PGP PUBLIC KEY BLOCK-----
pub rsa4096 2017-07-12 [SC]
406DCA257CD2BE237B79AE6BC9D353CA4AFF2E24
uid [ultimate] Ly Nguyen (CODE SIGNING KEY) <[email protected]>
sig 3 C9D353CA4AFF2E24 2017-07-12 Ly Nguyen (CODE SIGNING KEY) <[email protected]>
sub rsa4096 2017-07-12 [E]
sig C9D353CA4AFF2E24 2017-07-12 Ly Nguyen (CODE SIGNING KEY) <[email protected]>

-----BEGIN PGP PUBLIC KEY BLOCK-----

mQINBFlmSIMBEADIr6FzNJ6o/owjqgqWdOtreIRuU47/uzNRZw8c2lEys2Fw+3CI
iUitkWpb7jR0BGLk+8yUk+1VGdXPuJ+zj8XWcCnCJ7TUy3Hudp/BrX7y388m9hP9
3LP5yx+AUKbXRZiEr5EG2lyTmJBB5lmreVlRMs74Ie3uFtH6US/DVZMqULEtumcH
yCL30kKugUjfftO1mbx901kB0WpB705od3Wrde0Jd9sniMz4HkXMsd93gExh/s1H
3XApXes+yDIEILiUJRawgzgcPIuTyOq4bbafoiFd8ipZU0G7AQPtNUAnpTUtrUaJ
5CDGzOiqGUgwi+M3zwsRcW2MjDi9MyNTmlW2P6Gifzn3EaJ0EVdz4fPmIokC5h+H
6nMHqSPUEu0WA/qpirVrOiUku34lpkP0vZwb8UOyjgBCFTxDMPX70DuUmCbij1rr
vGM0rKLV+LFclEQFpnXckUnza8f/Zbk9T3yWcPQykXyi7+1Z1WJSPVkF4l8ynpDy
4DdUnLGdF8HZAGHdroi/jGVrH2NYy42XQqOZoLfk2BTGiFYpQem/Bfzo3OdEPBT7
zpZUVqixtXbnGseL1sdHao1BdinIbvSpPOPEbObINenk65NtXWc+9YbauGkJ5kwd
opAkBmZC4IycFWkpmHecbGXJN61eYvARuXKAev7DeYH7g6Zuzp4n07rtIwARAQAB
tC5MeSBOZ3V5ZW4gKENPREUgU0lHTklORyBLRVkpIDxseG4yQGFwYWNoZS5vcmc+
iQJOBBMBCgA4FiEEQG3KJXzSviN7ea5rydNTykr/LiQFAllmSIMCGwMFCwkIBwMF
FQoJCAsFFgIDAQACHgECF4AACgkQydNTykr/LiT2/Q//aW1qOLX7msuJDqhlHFIM
hCUZzWClljfCHMHZJooJY5YOcvzE5mVgwVdWjgAgZfgk/bFsNhuOb+jIqlatsNfI
Eg7sm6VjfHRo3pP1W7NN+CQNu5JnEEZAIVLy2gn+Eq1rQc7g2pfylVh/HV14TGon
OWbk7BfaZubGLtLJTIimHAPd+TrRsGsLnd9JiDZj0gsPPKV6HHXHgZoAeStIUPNX
13mN/WMDAAqroPPUfMEMXPbmJgNf/ukIFxsS/y8MwU32BjVCBvvh8ojN3RIgUJnX
chdjT9i/QVKi9TyoF20R7mR80x/P9CBwqKoN9+QuHjTPDuZkol4xD3jyzOsKHPwZ
CpltwdhI2JCYJzEIFtrZ0R59fXJ+8NNXZzIOqnx83qarC+eSf8cunqPS/ZBIvEJ0
qM1adZlJiY96La10wXSjYnEc+XEw+dad3D3ChVsvDceJirelaAVrRS2Dz4ugNShy
W0cZFFUL0aCTNNJnF9sHAfexbbg06BTzSSAeYrEWLmmpjEYHXAtFyToHzk0jTUr4
66SeIUVHIqBLk8yx1L9zQK38JS9usYj1PFJri9J6iYyqiIS7zRinoO8MIySZOOGp
Z3Q5xJbnwzjwl4frGaXg2/zyD7rfQGG3P23WOselgNWMKuYtVAA+AHo/CxLIinKk
JAMljesV3vfeawK5HHnfcgK5Ag0EWWZIgwEQAMsmr5lOFe4n9iGdTciYFXxZYSEX
ZqmtWyxNsXkih2icfohygx/YLFBSkdXSfIywS7w7+Na4OYdhp3uaRdU+yA4ianY7
qH5guni98KtyZmsRnnjT1DgyR0pNNqAdAyfWeCglMx5SWLLtzKxHazqF0t6Jb6M/
sAew+KdoTXsYzKb9d/R81spvefJoBopaxKLF1tijaX98RiquKLlFBD+88XP6pxSB
nwNxNybgJVlGT/RdxPiRiRj0CySuvx27i8w8Rc2HaT9CFumzdy6moz+RJbuuIjDN
QzIOpNy4+LJKSysPGh8AwRu6xCl9gnfbJ9thiFwYGZ7S3lVvS23/poI1YzLZZY+5
XvpiiogF7j5Aj/zTTli8BI/CiNVrGKJuzeJJyLFfBMmrbysi9mV/fR8wC7xd5P9g
LjElkA4j1Xv5I47AVsILAbHLhphpxNDoKBmr1EbP/CJitEYjRmdjn4Mo6sYwMlVN
CA+rl/VMS3Nc0Iixu/Y070H3kE9IfitksiuXIJfeX5RW/uWegEO1e1dSpi+rreb8
lvVtQk4tMUHyM16qPqO08tPGSunt6J0HiPi7J+xDwbJjJS7gNDW4AYHG5q4/dZsx
PtpcZC7zFOlFV0BwFftYnluccDhsWPc48mDmmhOe9p42irMAx6ms/Y42jgh4OmgD
bjMzKIyYFI40URGnABEBAAGJAjYEGAEKACAWIQRAbcolfNK+I3t5rmvJ01PKSv8u
JAUCWWZIgwIbDAAKCRDJ01PKSv8uJCAtD/97SuVGnCP3kbWfI/qfTTVKwuWTdbIg
rPvOjGo5F57l1PAgARt8N1ccqREbR3JwhRdsU3ewz5eDQEyEZVffPgufhqZr8liI
EP783m83VgRSMKYt6HzORX0os2BapsHHuejvlME9XpN0UG5AnvbzXDxP3wJufB1K
GkmC+rlpqfyMu60xFXzym9QuePksbdf/xXZduvLGaB1u+AYtvHp3+NGV382vat7C
xwRShVJTb8Zr9y5tA+JDqfhDDb5CepcPH6Uk2frU8aV7vZ3hmVmGcDcUddu3U9hg
L7Lcpr1E0D7xOuQ4QMAFhcDO+aB8aPv+JRkH4Y6wDFPrEgcEJ1YK6hhW5KSdslyK
QrKHKMSl+hwPmh9fKX4wC+FjMMXJ/PHtEG3N3f7/TyyO4iza5xDIJkYcyKkDXc0l
VcHLJvtjsJziMJNV3lKAeTp/uzbaJHRhLmpPHukQPnlpjfhnmsYh3wydnd03pfzQ
k6XJ4iGeSSQqtW6T14yqkCl5HDH2ms1ufhe4Os217CMXnaRbM/K6Zl4iGGozzXgd
no02+jTN3NqmUw0hUBR/9ZEn+IKmZ6f0Azsgio0M9ez1T0CCDZvo19kJw9b3VdOF
TZQhIRekaaV+bCQQxnwDOJ31bIUUpxaMdvygjq55Gri/5C75TsMNcgbhqYWLGKe2
kRsGTxyO+fQ6/Q==
=FuXU
-----END PGP PUBLIC KEY BLOCK-----
159 changes: 159 additions & 0 deletions R-package/vignettes/CustomLossFunction.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
---
title: "Customized loss function"
output:
md_document:
variant: markdown_github
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```

This tutorial provides guidelines for using customized loss function in network construction.

Model Training Example
----------

Let's begin with a small regression example. We can build and train a regression model with the following code:

```{r}
data(BostonHousing, package = "mlbench")
BostonHousing[, sapply(BostonHousing, is.factor)] <-
as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)]))
BostonHousing <- data.frame(scale(BostonHousing))
test.ind = seq(1, 506, 5) # 1 pt in 5 used for testing
train.x = data.matrix(BostonHousing[-test.ind,-14])
train.y = BostonHousing[-test.ind, 14]
test.x = data.matrix(BostonHousing[--test.ind,-14])
test.y = BostonHousing[--test.ind, 14]
require(mxnet)
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2")
lro <- mx.symbol.LinearRegressionOutput(fc2, name = "lro")
mx.set.seed(0)
model <- mx.model.FeedForward.create(lro, X = train.x, y = train.y,
ctx = mx.cpu(),
num.round = 5,
array.batch.size = 60,
optimizer = "rmsprop",
verbose = TRUE,
array.layout = "rowmajor",
batch.end.callback = NULL,
epoch.end.callback = NULL)
pred <- predict(model, test.x)
sum((test.y - pred[1,])^2) / length(test.y)
```

Besides the `LinearRegressionOutput`, we also provide `LogisticRegressionOutput` and `MAERegressionOutput`.
However, this might not be enough for real-world models. You can provide your own loss function
by using `mx.symbol.MakeLoss` when constructing the network.

How to Use Your Own Loss Function
---------

We still use our previous example, but this time we use `mx.symbol.MakeLoss` to minimize the `(pred-label)^2`

```{r}
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2")
lro2 <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.Reshape(fc2, shape = 0) - label), name="lro2")
```

Then we can train the network just as usual.

```{r}
mx.set.seed(0)
model2 <- mx.model.FeedForward.create(lro2, X = train.x, y = train.y,
ctx = mx.cpu(),
num.round = 5,
array.batch.size = 60,
optimizer = "rmsprop",
verbose = TRUE,
array.layout = "rowmajor",
batch.end.callback = NULL,
epoch.end.callback = NULL)
```

We should get very similar results because we are actually minimizing the same loss function.
However, the result is quite different.

```{r}
pred2 <- predict(model2, test.x)
sum((test.y - pred2)^2) / length(test.y)
```

This is because output of `mx.symbol.MakeLoss` is the gradient of loss with respect to the input data.
We can get the real prediction as below.

```{r}
internals = internals(model2$symbol)
fc_symbol = internals[[match("fc2_output", outputs(internals))]]
model3 <- list(symbol = fc_symbol,
arg.params = model2$arg.params,
aux.params = model2$aux.params)
class(model3) <- "MXFeedForwardModel"
pred3 <- predict(model3, test.x)
sum((test.y - pred3[1,])^2) / length(test.y)
```

We have provided many operations on the symbols. An example of `|pred-label|` can be found below.

```{r}
lro_abs <- mx.symbol.MakeLoss(mx.symbol.abs(mx.symbol.Reshape(fc2, shape = 0) - label))
mx.set.seed(0)
model4 <- mx.model.FeedForward.create(lro_abs, X = train.x, y = train.y,
ctx = mx.cpu(),
num.round = 20,
array.batch.size = 60,
optimizer = "sgd",
learning.rate = 0.001,
verbose = TRUE,
array.layout = "rowmajor",
batch.end.callback = NULL,
epoch.end.callback = NULL)
internals = internals(model4$symbol)
fc_symbol = internals[[match("fc2_output", outputs(internals))]]
model5 <- list(symbol = fc_symbol,
arg.params = model4$arg.params,
aux.params = model4$aux.params)
class(model5) <- "MXFeedForwardModel"
pred5 <- predict(model5, test.x)
sum(abs(test.y - pred5[1,])) / length(test.y)
```


```{r}
lro_mae <- mx.symbol.MAERegressionOutput(fc2, name = "lro")
mx.set.seed(0)
model6 <- mx.model.FeedForward.create(lro_mae, X = train.x, y = train.y,
ctx = mx.cpu(),
num.round = 20,
array.batch.size = 60,
optimizer = "sgd",
learning.rate = 0.001,
verbose = TRUE,
array.layout = "rowmajor",
batch.end.callback = NULL,
epoch.end.callback = NULL)
pred6 <- predict(model6, test.x)
sum(abs(test.y - pred6[1,])) / length(test.y)
```

11 changes: 9 additions & 2 deletions cpp-package/example/alexnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ int main(int argc, char const *argv[]) {

/*with data and label, executor can be generated automatically*/
auto *exec = Net.SimpleBind(ctx, args_map);
auto arg_names = Net.ListArguments();
aux_map = exec->aux_dict();
args_map = exec->arg_dict();

Expand Down Expand Up @@ -240,7 +241,9 @@ int main(int argc, char const *argv[]) {
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0 / batch_size)
->SetParam("clip_gradient", 10);
->SetParam("clip_gradient", 10)
->SetParam("lr", learning_rate)
->SetParam("wd", weight_decay);

Accuracy acu_train, acu_val;
LogLoss logloss_val;
Expand All @@ -258,7 +261,11 @@ int main(int argc, char const *argv[]) {
batch.label.CopyTo(&args_map["label"]);
exec->Forward(true);
exec->Backward();
exec->UpdateAll(opt, learning_rate, weight_decay);
for (size_t i = 0; i < arg_names.size(); ++i) {
if (arg_names[i] == "data" || arg_names[i] == "label") continue;
opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
}

NDArray::WaitAll();
acu_train.Update(batch.label, exec->outputs[0]);
}
Expand Down
11 changes: 9 additions & 2 deletions cpp-package/example/charRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ void train(const string file, int batch_size, int max_epoch, int start_epoch) {
mx_float learning_rate = 0.0002;
mx_float weight_decay = 0.000002;
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("lr", learning_rate)
->SetParam("wd", weight_decay);
// opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
// ->SetParam("clip_gradient", 10);

Expand All @@ -470,7 +472,10 @@ void train(const string file, int batch_size, int max_epoch, int start_epoch) {

exe->Forward(true);
exe->Backward();
exe->UpdateAll(opt, learning_rate, weight_decay);
for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
}

NDArray::WaitAll();
}
auto toc = chrono::system_clock::now();
Expand Down Expand Up @@ -547,7 +552,9 @@ void trainWithBuiltInRNNOp(const string file, int batch_size, int max_epoch, int

exe->Forward(true);
exe->Backward();
exe->UpdateAll(opt, learning_rate, weight_decay);
for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
}
NDArray::WaitAll();
}
auto toc = chrono::system_clock::now();
Expand Down
Loading

0 comments on commit d511938

Please sign in to comment.