diff --git a/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb b/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb
index bebe8798a..2c5c66a40 100644
--- a/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb
+++ b/code/chapter03_DL-basics/3.16_kaggle-house-price.ipynb
@@ -298,7 +298,7 @@
" with torch.no_grad():\n",
" # 将小于1的值设成1,使得取对数时数值更稳定\n",
" clipped_preds = torch.max(net(features), torch.tensor(1.0))\n",
- " rmse = torch.sqrt(2 * loss(clipped_preds.log(), labels.log()).mean())\n",
+ " rmse = torch.sqrt(loss(clipped_preds.log(), labels.log()))\n",
" return rmse.item()"
]
},
@@ -405,12 +405,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "fold 0, train rmse 0.240939, valid rmse 0.221437\n",
- "fold 1, train rmse 0.229326, valid rmse 0.267492\n",
- "fold 2, train rmse 0.231815, valid rmse 0.237722\n",
- "fold 3, train rmse 0.237550, valid rmse 0.219035\n",
- "fold 4, train rmse 0.230578, valid rmse 0.258887\n",
- "5-fold validation: avg train rmse 0.234042, avg valid rmse 0.240915\n"
+ "fold 0, train rmse 0.170585, valid rmse 0.156860\n",
+ "fold 1, train rmse 0.162552, valid rmse 0.190944\n",
+ "fold 2, train rmse 0.164199, valid rmse 0.168767\n",
+ "fold 3, train rmse 0.168698, valid rmse 0.154873\n",
+ "fold 4, train rmse 0.163213, valid rmse 0.183080\n",
+ "5-fold validation: avg train rmse 0.165849, avg valid rmse 0.170905\n"
]
},
{
@@ -450,10 +450,10 @@
" \n",
" \n",
+ "\" id=\"m9cbda39ac0\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" \n",
" \n",
- " \n",
" \n",
" \n",
@@ -489,7 +489,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -529,7 +529,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -562,7 +562,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -608,7 +608,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -663,7 +663,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -852,15 +852,15 @@
" \n",
" \n",
+ "\" id=\"m0df7c1c40c\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -872,94 +872,80 @@
" \n",
" \n",
+ "\" id=\"m8231d37304\" style=\"stroke:#000000;stroke-width:0.6;\"/>\n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1021,210 +1007,210 @@
" \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
" \n",
- " \n",
- " \n",
+ " \n",
" \n",
" \n",
@@ -1261,12 +1247,12 @@
"z\n",
"\" style=\"fill:#ffffff;opacity:0.8;stroke:#cccccc;stroke-linejoin:miter;\"/>\n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1361,12 +1347,12 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1424,14 +1410,14 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
"\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -1469,7 +1455,7 @@
" preds = net(test_features).detach().numpy()\n",
" test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])\n",
" submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)\n",
- " submission.to_csv('./submission.csv', index=False)"
+ " # submission.to_csv('./submission.csv', index=False)"
]
},
{
@@ -1481,7 +1467,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "train rmse 0.230200\n"
+ "train rmse 0.162085\n"
]
},
{
@@ -1521,10 +1507,10 @@
" \n",
" \n",
+ "\" id=\"me383947859\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1560,7 +1546,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1600,7 +1586,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1633,7 +1619,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1679,7 +1665,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1734,7 +1720,7 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1923,15 +1909,15 @@
" \n",
" \n",
+ "\" id=\"mf4b47cc8b8\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1943,87 +1929,80 @@
" \n",
" \n",
+ "\" id=\"m5bb3ee9e0a\" style=\"stroke:#000000;stroke-width:0.6;\"/>\n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -2085,106 +2064,106 @@
" \n",
" \n",
" \n",
- " \n",
- " \n",
+ " \n",
" \n",
@@ -2211,14 +2190,14 @@
" \n",
" \n",
" \n",
- " \n",
+ " \n",
" \n",
" \n",
" \n",
"\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
diff --git a/docs/chapter03_DL-basics/3.16_kaggle-house-price.md b/docs/chapter03_DL-basics/3.16_kaggle-house-price.md
index bdb8715e7..751d6d770 100644
--- a/docs/chapter03_DL-basics/3.16_kaggle-house-price.md
+++ b/docs/chapter03_DL-basics/3.16_kaggle-house-price.md
@@ -131,7 +131,7 @@ def log_rmse(net, features, labels):
with torch.no_grad():
# 将小于1的值设成1,使得取对数时数值更稳定
clipped_preds = torch.max(net(features), torch.tensor(1.0))
- rmse = torch.sqrt(2 * loss(clipped_preds.log(), labels.log()).mean())
+ rmse = torch.sqrt(loss(clipped_preds.log(), labels.log()))
return rmse.item()
```
@@ -203,12 +203,12 @@ def k_fold(k, X_train, y_train, num_epochs,
```
输出:
```
-fold 0, train rmse 0.241054, valid rmse 0.221462
-fold 1, train rmse 0.229857, valid rmse 0.268489
-fold 2, train rmse 0.231413, valid rmse 0.238157
-fold 3, train rmse 0.237733, valid rmse 0.218747
-fold 4, train rmse 0.230720, valid rmse 0.258712
-5-fold validation: avg train rmse 0.234155, avg valid rmse 0.241113
+fold 0, train rmse 0.170585, valid rmse 0.156860
+fold 1, train rmse 0.162552, valid rmse 0.190944
+fold 2, train rmse 0.164199, valid rmse 0.168767
+fold 3, train rmse 0.168698, valid rmse 0.154873
+fold 4, train rmse 0.163213, valid rmse 0.183080
+5-fold validation: avg train rmse 0.165849, avg valid rmse 0.170905
```
@@ -250,7 +250,7 @@ train_and_pred(train_features, test_features, train_labels, test_data, num_epoch
```
输出:
```
-train rmse 0.229943
+train rmse 0.162085
```