Skip to content

Commit

Permalink
fix bug ShusenTang#84
Browse files Browse the repository at this point in the history
  • Loading branch information
ShusenTang committed Jan 2, 2020
1 parent b3401dd commit 3932834
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1.0.0 cuda\n"
"1.1.0 cuda\n"
]
}
],
Expand All @@ -39,10 +39,10 @@
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"DATA_ROOT = \"/S1/CSCL/tangss/Datasets\"\n",
"DATA_ROOT = \"/data1/tangss/Datasets\"\n",
"\n",
"print(torch.__version__, device)"
]
Expand Down Expand Up @@ -88,10 +88,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 12500/12500 [00:04<00:00, 2930.03it/s]\n",
"100%|██████████| 12500/12500 [00:04<00:00, 3008.48it/s]\n",
"100%|██████████| 12500/12500 [00:03<00:00, 3365.08it/s]\n",
"100%|██████████| 12500/12500 [00:03<00:00, 3305.63it/s]\n"
"100%|██████████| 12500/12500 [00:00<00:00, 34211.42it/s]\n",
"100%|██████████| 12500/12500 [00:00<00:00, 38506.48it/s]\n",
"100%|██████████| 12500/12500 [00:00<00:00, 31316.61it/s]\n",
"100%|██████████| 12500/12500 [00:00<00:00, 29664.72it/s]\n"
]
}
],
Expand All @@ -108,7 +108,8 @@
" random.shuffle(data)\n",
" return data\n",
"\n",
"train_data, test_data = read_imdb('train'), read_imdb('test')"
"data_root = os.path.join(DATA_ROOT, \"aclImdb\")\n",
"train_data, test_data = read_imdb('train', data_root), read_imdb('test', data_root)"
]
},
{
Expand Down Expand Up @@ -152,7 +153,7 @@
{
"data": {
"text/plain": [
"('# words in vocab:', 46151)"
"('# words in vocab:', 46152)"
]
},
"execution_count": 5,
Expand Down Expand Up @@ -330,8 +331,7 @@
"ExecuteTime": {
"end_time": "2019-07-03T04:26:47.895604Z",
"start_time": "2019-07-03T04:26:47.685801Z"
},
"collapsed": true
}
},
"outputs": [],
"source": [
Expand All @@ -345,10 +345,17 @@
"ExecuteTime": {
"end_time": "2019-07-03T04:26:48.102388Z",
"start_time": "2019-07-03T04:26:47.897582Z"
},
"collapsed": true
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 21202 oov words.\n"
]
}
],
"source": [
"def load_pretrained_embedding(words, pretrained_vocab):\n",
" \"\"\"从预训练好的vocab中提取出words对应的词向量\"\"\"\n",
Expand All @@ -359,9 +366,9 @@
" idx = pretrained_vocab.stoi[word]\n",
" embed[i, :] = pretrained_vocab.vectors[idx]\n",
" except KeyError:\n",
" oov_count += 0\n",
" oov_count += 1\n",
" if oov_count > 0:\n",
" print(\"There are %d oov words.\")\n",
" print(\"There are %d oov words.\" % oov_count)\n",
" return embed\n",
"\n",
"net.embedding.weight.data.copy_(load_pretrained_embedding(vocab.itos, glove_vocab))\n",
Expand Down Expand Up @@ -390,11 +397,11 @@
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 0.5759, train acc 0.666, test acc 0.832, time 250.8 sec\n",
"epoch 2, loss 0.1785, train acc 0.842, test acc 0.852, time 253.3 sec\n",
"epoch 3, loss 0.1042, train acc 0.866, test acc 0.856, time 253.7 sec\n",
"epoch 4, loss 0.0682, train acc 0.888, test acc 0.868, time 254.2 sec\n",
"epoch 5, loss 0.0483, train acc 0.901, test acc 0.862, time 251.4 sec\n"
"epoch 1, loss 0.5415, train acc 0.719, test acc 0.819, time 48.7 sec\n",
"epoch 2, loss 0.1897, train acc 0.837, test acc 0.852, time 53.0 sec\n",
"epoch 3, loss 0.1105, train acc 0.857, test acc 0.844, time 51.6 sec\n",
"epoch 4, loss 0.0719, train acc 0.881, test acc 0.865, time 52.1 sec\n",
"epoch 5, loss 0.0519, train acc 0.894, test acc 0.852, time 51.2 sec\n"
]
}
],
Expand Down Expand Up @@ -488,9 +495,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:py36]",
"display_name": "Python [conda env:py36_pytorch]",
"language": "python",
"name": "conda-env-py36-py"
"name": "conda-env-py36_pytorch-py"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -502,7 +509,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
"version": "3.6.2"
},
"varInspector": {
"cols": {
Expand Down
4 changes: 2 additions & 2 deletions code/d2lzh_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,9 +1203,9 @@ def load_pretrained_embedding(words, pretrained_vocab):
idx = pretrained_vocab.stoi[word]
embed[i, :] = pretrained_vocab.vectors[idx]
except KeyError:
oov_count += 0
oov_count += 1
if oov_count > 0:
print("There are %d oov words.")
print("There are %d oov words." % oov_count)
return embed

def predict_sentiment(net, vocab, sentence):
Expand Down

0 comments on commit 3932834

Please sign in to comment.