forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge with upstream/master. resolve conflict in c_api_ndarray.cc
- Loading branch information
Showing
78 changed files
with
2,371 additions
and
471 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
Apache MXNet (incubating) is an effort undergoing incubation at The | ||
Apache Software Foundation (ASF), sponsored by the Apache Incubator PMC. | ||
|
||
Incubation is required of all newly accepted | ||
projects until a further review indicates that the | ||
infrastructure, communications, and decision making process have | ||
stabilized in a manner consistent with other successful ASF | ||
projects. | ||
|
||
While incubation status is not necessarily a reflection | ||
of the completeness or stability of the code, it does indicate | ||
that the project has yet to be fully endorsed by the ASF. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
This file contains the PGP keys of various developers. | ||
Please don't use them for email unless you have to. Their main | ||
purpose is code signing. | ||
|
||
Examples of importing this file in your keystore: | ||
gpg --import KEYS.txt | ||
(need pgp and other examples here) | ||
|
||
Examples of adding your key to this file: | ||
pgp -kxa <your name> and append it to this file. | ||
(pgpk -ll <your name> && pgpk -xa <your name>) >> this file. | ||
(gpg --list-sigs <your name> | ||
&& gpg --armor --export <your name>) >> this file. | ||
|
||
----------------------------------------------------------------------------------- | ||
pub 4096R/D3541808 2014-01-09 | ||
uid [ultimate] Suneel Marthi (CODE SIGNING KEY) <[email protected]> | ||
sig 3 D3541808 2014-01-09 Suneel Marthi (CODE SIGNING KEY) <[email protected]> | ||
sub 4096R/AF46E2DE 2014-01-09 | ||
sig D3541808 2014-01-09 Suneel Marthi (CODE SIGNING KEY) <[email protected]> | ||
|
||
-----BEGIN PGP PUBLIC KEY BLOCK----- | ||
Comment: GPGTools - https://gpgtools.org | ||
|
||
mQINBFLPJmEBEAC9d/dUZCXeyhB0fVGmJAjdjXfLebav4VqGdNZC+M1T9C3dcVsh | ||
X/JGme5bjJeIgVwiH5UsdNceYn1+hyxs8jXuRAWEWKP76gD+pNrp8Az0ZdBkJoAy | ||
zCywOPtJV2PCOz7+S5ri2nUA2+1Kgcu6IlSLMmYAGO0IAmRrjBEzxy9iGaxiNGTc | ||
LvQt/iVtIXWkKKI8yvpoJ8iFf3TGhpjgaC/h7cJP3zpy0SScmhJJASLXRsfocLv9 | ||
sle6ndN9IPbDtRW8cL7Fk3VQlzp1ToVjmnQTyZZ6S1WafsjzCZ9hLN+k++o8VbvY | ||
v3icY6Sy0BKz0J6KwaxTkuZ6w1K7oUkVOQboKaWFIEdO+jwrEmU+Puyd8Np8jLnF | ||
Q0Y5GPfyMlqM3S/zaDm1t4D1eb5FLciStkxfg5wPVK6TkqB325KVD3aio5C7E7kt | ||
aQechHxaJXCQOtCtVY4X+L4iClnMSuk+hcSc8W8MYRTSVansItK0vI9eQZXMnpan | ||
w9/jk5rS4Gts1rHB7+kdjT3QRJmkyk6fEFT0fz5tfMC7N8waeEUhCaRW6lAoiqDW | ||
NW1h+0UGxJw+9YcGxBC0kkt3iofNOWQWmuf/BS3DHPKT7XV/YtBHe44wW0sF5L5P | ||
nfQUHpnA3pcZ0En6bXAvepKVZTNdOWWJqMyHV+436DA+33h45QL6lWb/GwARAQAB | ||
tDVTdW5lZWwgTWFydGhpIChDT0RFIFNJR05JTkcgS0VZKSA8c21hcnRoaUBhcGFj | ||
aGUub3JnPokCNwQTAQoAIQUCUs8mYQIbAwULCQgHAwUVCgkICwUWAgMBAAIeAQIX | ||
gAAKCRC08czE01QYCOKKEAChRtHBoYNTX+RZbFO0Kl1GlN+i1Ik0shEm5ZJ56XHv | ||
AnFx/gRK7CfZzJswWo7kf2s/dvJiFfs+rrolYVuO6E8gNhAaTEomSuvWQAMHdPcR | ||
9G5APRKCSkbZYugElqplEbSphk78FKoFO+sml52M7Pr9jj88ApBjoFVVY8njdnNq | ||
6DVlaDsg8YninCD78Z7PNFnRGwxyZ8Qd4Dh0rG+MUTfAWopZu6/MxpQxU7QpeVeX | ||
SIMLg7ClFrGfXnZcszYF4dnav1aa0i7W88PAdYNPko7tC5qz5yv2ep7t2gRbcYKf | ||
RXhYC2FHQey3wPhMKjA8V436lAqmfYnY/YdmhEy9Xq/1EdX1nHsQ7OEkfgXK14WM | ||
F+rnqXRAl/0cwiyb41eocdg5kpZFIKgCYT02usLWxwNnd3jOCe109Ze3y3acN/G8 | ||
+xOf9YRfNVAe6pD8H6ieRbv9gRjBmsbz9bXQCmxFnDqxNri5Me6gBAQPNmYTJD0h | ||
jgJTK6o0vJ0pwjBLauasJsLu+1tR3Cb0dxPE+JVaTF26FCd7pM7W6KdVfod9ZfrN | ||
cSyJ/cECc2KvYVGmTjQNVo1dYG0awBachlWnYNt+0Qx4opLsczZOLtPKtFY4BJA7 | ||
aZoXT4Qf9yB8km7x2/cgNExVbFummToJ/IP3M39/EaryspsQQuM5Qu5Q5lZp8Qnn | ||
ybkCDQRSzyZhARAA7bAawFzbJaghYnm6mTZyGG5hQmfAynbF6cPAE+g2SnXcNQjP | ||
6kjYx3tSpb7rEzmjQqs46ztqdec6PIVBMhakON6z27Zz+IviAtO/TcaZHWNuCAjw | ||
FXVQZ+tYsSeiKInttfkrQc8jXAHWwSkSjLqNpvQpBdBEX80MYkFB6ZPOeON2+/Ta | ||
GC1H/HU2YngF0qQSmG33KKG6ezihBJdKxU6t2tsQfTlCmZW6R6MGpS9fVurYMKBk | ||
vR+7RGZ/H6dSjWPcpxhusGg92J9uz7r5SopN1wSdyPMUCMAFGeyoxcAuBDl38quU | ||
H/ENG3x5LDPq2aEH2AJ6yvZfIXbeJ1zmXf2cAHv+HbmvZaTSp0XIjq8Yxh8NkYEC | ||
ZdfRWmsGLIpU16TkBijpK3Dn9MDXjHGT3V8/qfdpURtMvIaL8WFrq9ejcy/vGRFn | ||
mCYqxIIPH+vLiMXKWtuMc61GN3ES21msKQH6IuQxxfQLyhK44L/pv7FpF4E+6LaE | ||
8uRwAex5HIDpR1v4aJq089rRtye9VXTJJLZ7lYs0HctdZ30QbBRWT4jS9d9rj3cr | ||
HgQ7mIGO9TAfK2kWc6AJN/EvxPWNbOwptsTUzAF/adiy9ax8C18iw7nKczC+2eN6 | ||
UcbxXiPdytuKYK7O9A8S9e1w89GwpxYN7Xfn2o6QfpSbL9cLKiinOeV+xikAEQEA | ||
AYkCHwQYAQoACQUCUs8mYQIbDAAKCRC08czE01QYCG7yD/471dmyOD+go8cZkdqR | ||
3CHhjH03odtI0EJNVy4VGEC0r9paz3BWYTy18LqWYkw3ygphOIU1r8/7QK3H5Ke3 | ||
c4yCSUxaMk5SlAJ+iVRek5TABkR8+zI+ZN5pQtqRH+ya5JxV4F/Sx5Q3KWMzpvgY | ||
n6AgSSc3hEfkgdI7SalIeyLaLDWv+RFdGZ5JU5gD28C0G8BeH8L62x6sixZcqoGT | ||
oy9rwkjs45/ZmmvBZhd1wLvC/au8l2Ecou6O8+8m26W8Z7vCuGKxuWn0KV3DLLWe | ||
66uchDVlakGoMJSPIK06JWYUlE+gL0CW+U2ekt/v2qb8hGgMVET3CBAMq+bFWuJ6 | ||
juX7hJd7wHtCFfjnFDDAkdp2IIIZAlBW6FZGv7pJ82xsW6pSAg0A7VrV6nTtMtDv | ||
T8esOfo/t4t0gaL7bivy9DVVdATbUBcJJFpoVoe5MxiyjptveqPzIRwzt04n52Ph | ||
ordVWAnX5AokXWTg+Glem/EWEuf7jUuZArfqCSl/sZoQdXGTjR7G4iFscispji4+ | ||
kNjVQsItqFbgDpuc6n+GcFxlKQ7YMCnu5MVtTV01U4lFs0qy0NTUqsuR35DM4z14 | ||
DkFmj1upWAayCoXTpKzsHBvJZPC+Wqf9Pl3O47apelg7KxU3S011YfXpVPvCTKBv | ||
kD2o/5GKWS5QkSUEUXXY1oDiLg== | ||
=f8kJ | ||
-----END PGP PUBLIC KEY BLOCK----- | ||
pub rsa4096 2017-07-12 [SC] | ||
406DCA257CD2BE237B79AE6BC9D353CA4AFF2E24 | ||
uid [ultimate] Ly Nguyen (CODE SIGNING KEY) <[email protected]> | ||
sig 3 C9D353CA4AFF2E24 2017-07-12 Ly Nguyen (CODE SIGNING KEY) <[email protected]> | ||
sub rsa4096 2017-07-12 [E] | ||
sig C9D353CA4AFF2E24 2017-07-12 Ly Nguyen (CODE SIGNING KEY) <[email protected]> | ||
|
||
-----BEGIN PGP PUBLIC KEY BLOCK----- | ||
|
||
mQINBFlmSIMBEADIr6FzNJ6o/owjqgqWdOtreIRuU47/uzNRZw8c2lEys2Fw+3CI | ||
iUitkWpb7jR0BGLk+8yUk+1VGdXPuJ+zj8XWcCnCJ7TUy3Hudp/BrX7y388m9hP9 | ||
3LP5yx+AUKbXRZiEr5EG2lyTmJBB5lmreVlRMs74Ie3uFtH6US/DVZMqULEtumcH | ||
yCL30kKugUjfftO1mbx901kB0WpB705od3Wrde0Jd9sniMz4HkXMsd93gExh/s1H | ||
3XApXes+yDIEILiUJRawgzgcPIuTyOq4bbafoiFd8ipZU0G7AQPtNUAnpTUtrUaJ | ||
5CDGzOiqGUgwi+M3zwsRcW2MjDi9MyNTmlW2P6Gifzn3EaJ0EVdz4fPmIokC5h+H | ||
6nMHqSPUEu0WA/qpirVrOiUku34lpkP0vZwb8UOyjgBCFTxDMPX70DuUmCbij1rr | ||
vGM0rKLV+LFclEQFpnXckUnza8f/Zbk9T3yWcPQykXyi7+1Z1WJSPVkF4l8ynpDy | ||
4DdUnLGdF8HZAGHdroi/jGVrH2NYy42XQqOZoLfk2BTGiFYpQem/Bfzo3OdEPBT7 | ||
zpZUVqixtXbnGseL1sdHao1BdinIbvSpPOPEbObINenk65NtXWc+9YbauGkJ5kwd | ||
opAkBmZC4IycFWkpmHecbGXJN61eYvARuXKAev7DeYH7g6Zuzp4n07rtIwARAQAB | ||
tC5MeSBOZ3V5ZW4gKENPREUgU0lHTklORyBLRVkpIDxseG4yQGFwYWNoZS5vcmc+ | ||
iQJOBBMBCgA4FiEEQG3KJXzSviN7ea5rydNTykr/LiQFAllmSIMCGwMFCwkIBwMF | ||
FQoJCAsFFgIDAQACHgECF4AACgkQydNTykr/LiT2/Q//aW1qOLX7msuJDqhlHFIM | ||
hCUZzWClljfCHMHZJooJY5YOcvzE5mVgwVdWjgAgZfgk/bFsNhuOb+jIqlatsNfI | ||
Eg7sm6VjfHRo3pP1W7NN+CQNu5JnEEZAIVLy2gn+Eq1rQc7g2pfylVh/HV14TGon | ||
OWbk7BfaZubGLtLJTIimHAPd+TrRsGsLnd9JiDZj0gsPPKV6HHXHgZoAeStIUPNX | ||
13mN/WMDAAqroPPUfMEMXPbmJgNf/ukIFxsS/y8MwU32BjVCBvvh8ojN3RIgUJnX | ||
chdjT9i/QVKi9TyoF20R7mR80x/P9CBwqKoN9+QuHjTPDuZkol4xD3jyzOsKHPwZ | ||
CpltwdhI2JCYJzEIFtrZ0R59fXJ+8NNXZzIOqnx83qarC+eSf8cunqPS/ZBIvEJ0 | ||
qM1adZlJiY96La10wXSjYnEc+XEw+dad3D3ChVsvDceJirelaAVrRS2Dz4ugNShy | ||
W0cZFFUL0aCTNNJnF9sHAfexbbg06BTzSSAeYrEWLmmpjEYHXAtFyToHzk0jTUr4 | ||
66SeIUVHIqBLk8yx1L9zQK38JS9usYj1PFJri9J6iYyqiIS7zRinoO8MIySZOOGp | ||
Z3Q5xJbnwzjwl4frGaXg2/zyD7rfQGG3P23WOselgNWMKuYtVAA+AHo/CxLIinKk | ||
JAMljesV3vfeawK5HHnfcgK5Ag0EWWZIgwEQAMsmr5lOFe4n9iGdTciYFXxZYSEX | ||
ZqmtWyxNsXkih2icfohygx/YLFBSkdXSfIywS7w7+Na4OYdhp3uaRdU+yA4ianY7 | ||
qH5guni98KtyZmsRnnjT1DgyR0pNNqAdAyfWeCglMx5SWLLtzKxHazqF0t6Jb6M/ | ||
sAew+KdoTXsYzKb9d/R81spvefJoBopaxKLF1tijaX98RiquKLlFBD+88XP6pxSB | ||
nwNxNybgJVlGT/RdxPiRiRj0CySuvx27i8w8Rc2HaT9CFumzdy6moz+RJbuuIjDN | ||
QzIOpNy4+LJKSysPGh8AwRu6xCl9gnfbJ9thiFwYGZ7S3lVvS23/poI1YzLZZY+5 | ||
XvpiiogF7j5Aj/zTTli8BI/CiNVrGKJuzeJJyLFfBMmrbysi9mV/fR8wC7xd5P9g | ||
LjElkA4j1Xv5I47AVsILAbHLhphpxNDoKBmr1EbP/CJitEYjRmdjn4Mo6sYwMlVN | ||
CA+rl/VMS3Nc0Iixu/Y070H3kE9IfitksiuXIJfeX5RW/uWegEO1e1dSpi+rreb8 | ||
lvVtQk4tMUHyM16qPqO08tPGSunt6J0HiPi7J+xDwbJjJS7gNDW4AYHG5q4/dZsx | ||
PtpcZC7zFOlFV0BwFftYnluccDhsWPc48mDmmhOe9p42irMAx6ms/Y42jgh4OmgD | ||
bjMzKIyYFI40URGnABEBAAGJAjYEGAEKACAWIQRAbcolfNK+I3t5rmvJ01PKSv8u | ||
JAUCWWZIgwIbDAAKCRDJ01PKSv8uJCAtD/97SuVGnCP3kbWfI/qfTTVKwuWTdbIg | ||
rPvOjGo5F57l1PAgARt8N1ccqREbR3JwhRdsU3ewz5eDQEyEZVffPgufhqZr8liI | ||
EP783m83VgRSMKYt6HzORX0os2BapsHHuejvlME9XpN0UG5AnvbzXDxP3wJufB1K | ||
GkmC+rlpqfyMu60xFXzym9QuePksbdf/xXZduvLGaB1u+AYtvHp3+NGV382vat7C | ||
xwRShVJTb8Zr9y5tA+JDqfhDDb5CepcPH6Uk2frU8aV7vZ3hmVmGcDcUddu3U9hg | ||
L7Lcpr1E0D7xOuQ4QMAFhcDO+aB8aPv+JRkH4Y6wDFPrEgcEJ1YK6hhW5KSdslyK | ||
QrKHKMSl+hwPmh9fKX4wC+FjMMXJ/PHtEG3N3f7/TyyO4iza5xDIJkYcyKkDXc0l | ||
VcHLJvtjsJziMJNV3lKAeTp/uzbaJHRhLmpPHukQPnlpjfhnmsYh3wydnd03pfzQ | ||
k6XJ4iGeSSQqtW6T14yqkCl5HDH2ms1ufhe4Os217CMXnaRbM/K6Zl4iGGozzXgd | ||
no02+jTN3NqmUw0hUBR/9ZEn+IKmZ6f0Azsgio0M9ez1T0CCDZvo19kJw9b3VdOF | ||
TZQhIRekaaV+bCQQxnwDOJ31bIUUpxaMdvygjq55Gri/5C75TsMNcgbhqYWLGKe2 | ||
kRsGTxyO+fQ6/Q== | ||
=FuXU | ||
-----END PGP PUBLIC KEY BLOCK----- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
--- | ||
title: "Customized loss function" | ||
output: | ||
md_document: | ||
variant: markdown_github | ||
--- | ||
|
||
```{r setup, include=FALSE} | ||
knitr::opts_chunk$set(echo = TRUE) | ||
``` | ||
|
||
This tutorial provides guidelines for using customized loss function in network construction. | ||
|
||
Model Training Example | ||
---------- | ||
|
||
Let's begin with a small regression example. We can build and train a regression model with the following code: | ||
|
||
```{r} | ||
data(BostonHousing, package = "mlbench") | ||
BostonHousing[, sapply(BostonHousing, is.factor)] <- | ||
as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)])) | ||
BostonHousing <- data.frame(scale(BostonHousing)) | ||
test.ind = seq(1, 506, 5) # 1 pt in 5 used for testing | ||
train.x = data.matrix(BostonHousing[-test.ind,-14]) | ||
train.y = BostonHousing[-test.ind, 14] | ||
test.x = data.matrix(BostonHousing[--test.ind,-14]) | ||
test.y = BostonHousing[--test.ind, 14] | ||
require(mxnet) | ||
data <- mx.symbol.Variable("data") | ||
label <- mx.symbol.Variable("label") | ||
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1") | ||
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1") | ||
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2") | ||
lro <- mx.symbol.LinearRegressionOutput(fc2, name = "lro") | ||
mx.set.seed(0) | ||
model <- mx.model.FeedForward.create(lro, X = train.x, y = train.y, | ||
ctx = mx.cpu(), | ||
num.round = 5, | ||
array.batch.size = 60, | ||
optimizer = "rmsprop", | ||
verbose = TRUE, | ||
array.layout = "rowmajor", | ||
batch.end.callback = NULL, | ||
epoch.end.callback = NULL) | ||
pred <- predict(model, test.x) | ||
sum((test.y - pred[1,])^2) / length(test.y) | ||
``` | ||
|
||
Besides the `LinearRegressionOutput`, we also provide `LogisticRegressionOutput` and `MAERegressionOutput`. | ||
However, this might not be enough for real-world models. You can provide your own loss function | ||
by using `mx.symbol.MakeLoss` when constructing the network. | ||
|
||
How to Use Your Own Loss Function | ||
--------- | ||
|
||
We still use our previous example, but this time we use `mx.symbol.MakeLoss` to minimize the `(pred-label)^2` | ||
|
||
```{r} | ||
data <- mx.symbol.Variable("data") | ||
label <- mx.symbol.Variable("label") | ||
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1") | ||
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1") | ||
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2") | ||
lro2 <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.Reshape(fc2, shape = 0) - label), name="lro2") | ||
``` | ||
|
||
Then we can train the network just as usual. | ||
|
||
```{r} | ||
mx.set.seed(0) | ||
model2 <- mx.model.FeedForward.create(lro2, X = train.x, y = train.y, | ||
ctx = mx.cpu(), | ||
num.round = 5, | ||
array.batch.size = 60, | ||
optimizer = "rmsprop", | ||
verbose = TRUE, | ||
array.layout = "rowmajor", | ||
batch.end.callback = NULL, | ||
epoch.end.callback = NULL) | ||
``` | ||
|
||
We should get very similar results because we are actually minimizing the same loss function. | ||
However, the result is quite different. | ||
|
||
```{r} | ||
pred2 <- predict(model2, test.x) | ||
sum((test.y - pred2)^2) / length(test.y) | ||
``` | ||
|
||
This is because output of `mx.symbol.MakeLoss` is the gradient of loss with respect to the input data. | ||
We can get the real prediction as below. | ||
|
||
```{r} | ||
internals = internals(model2$symbol) | ||
fc_symbol = internals[[match("fc2_output", outputs(internals))]] | ||
model3 <- list(symbol = fc_symbol, | ||
arg.params = model2$arg.params, | ||
aux.params = model2$aux.params) | ||
class(model3) <- "MXFeedForwardModel" | ||
pred3 <- predict(model3, test.x) | ||
sum((test.y - pred3[1,])^2) / length(test.y) | ||
``` | ||
|
||
We have provided many operations on the symbols. An example of `|pred-label|` can be found below. | ||
|
||
```{r} | ||
lro_abs <- mx.symbol.MakeLoss(mx.symbol.abs(mx.symbol.Reshape(fc2, shape = 0) - label)) | ||
mx.set.seed(0) | ||
model4 <- mx.model.FeedForward.create(lro_abs, X = train.x, y = train.y, | ||
ctx = mx.cpu(), | ||
num.round = 20, | ||
array.batch.size = 60, | ||
optimizer = "sgd", | ||
learning.rate = 0.001, | ||
verbose = TRUE, | ||
array.layout = "rowmajor", | ||
batch.end.callback = NULL, | ||
epoch.end.callback = NULL) | ||
internals = internals(model4$symbol) | ||
fc_symbol = internals[[match("fc2_output", outputs(internals))]] | ||
model5 <- list(symbol = fc_symbol, | ||
arg.params = model4$arg.params, | ||
aux.params = model4$aux.params) | ||
class(model5) <- "MXFeedForwardModel" | ||
pred5 <- predict(model5, test.x) | ||
sum(abs(test.y - pred5[1,])) / length(test.y) | ||
``` | ||
|
||
|
||
```{r} | ||
lro_mae <- mx.symbol.MAERegressionOutput(fc2, name = "lro") | ||
mx.set.seed(0) | ||
model6 <- mx.model.FeedForward.create(lro_mae, X = train.x, y = train.y, | ||
ctx = mx.cpu(), | ||
num.round = 20, | ||
array.batch.size = 60, | ||
optimizer = "sgd", | ||
learning.rate = 0.001, | ||
verbose = TRUE, | ||
array.layout = "rowmajor", | ||
batch.end.callback = NULL, | ||
epoch.end.callback = NULL) | ||
pred6 <- predict(model6, test.x) | ||
sum(abs(test.y - pred6[1,])) / length(test.y) | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.