-
Notifications
You must be signed in to change notification settings - Fork 1
/
index.html
478 lines (403 loc) · 67.5 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>07_Seq2Seq</title>
<link rel="stylesheet" href="https://stackedit.io/style.css" />
</head>
<body class="stackedit">
<div class="stackedit__html"><h1 id="seq2seq">07 Seq2Seq</h1>
<h2 id="assignment">Assignment</h2>
<ol start="2">
<li>Assignment 2 (300 points):
<ol>
<li>Train <strong>model we wrote in the class</strong> on the following two datasets taken from <a href="https://kili-technology.com/blog/chatbot-training-datasets/">this link (Links to an external site.)</a>:
<ol>
<li><a href="http://www.cs.cmu.edu/~ark/QA-data/">http://www.cs.cmu.edu/~ark/QA-data/ (Links to an external site.)</a></li>
<li><a href="https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs">https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs (Links to an external site.)</a></li>
</ol>
</li>
<li>Once done, please upload the file to GitHub and proceed to answer these questions in the S7 - Assignment Solutions, where these questions are asked:
<ol>
<li>Share the link to your GitHub repo (100 pts for code quality/file structure/model accuracy) (100 pts)</li>
<li>Share the link to your readme file (100 points for proper readme file), this file can be the second part of your Part 1 Readme (basically you can have only 1 Readme, describing both assignments if you want) (100 pts)</li>
<li>Copy-paste the code related to your dataset preparation for both datasets. (100 pts)</li>
</ol>
</li>
</ol>
</li>
</ol>
<p>[Update]</p>
<p>400 Points for each successful attempt on any additional dataset available on this link: <a href="https://kili-technology.com/blog/chatbot-training-datasets">https://kili-technology.com/blog/chatbot-training-datasetsLinks to an external site.</a>/</p>
<p>Share the notebook on which you attempted additional datasets successfully as your response to the assignment-solution page. Please make sure that you have explained the task and dataset that you have used.</p>
<h2 id="solution">Solution</h2>
<table>
<thead>
<tr>
<th></th>
<th>NBViewer</th>
<th>Google Colab</th>
<th>TensorBoard Logs</th>
</tr>
</thead>
<tbody>
<tr>
<td>Wiki-QA Dataset</td>
<td><a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/WikiQA_Dataset.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a></td>
<td><a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/WikiQA_Dataset.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></td>
<td></td>
</tr>
<tr>
<td>Wiki-QA Model</td>
<td><a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/WikiQA_Model.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a></td>
<td><a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/WikiQA_Model.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></td>
<td><a href="https://tensorboard.dev/experiment/ilMEHBPqQv6Tqh5DWC2SKw/"><img src="https://img.shields.io/badge/logs-tensorboard-orange?logo=Tensorflow"></a></td>
</tr>
<tr>
<td>Quora-SQ Dataset</td>
<td><a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/13876eec3594235c4df8545cf14bfd6328e01b8e/07_Seq2Seq/Quora_Question_Dataset.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a></td>
<td><a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/Quora_Question_Dataset.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></td>
<td></td>
</tr>
<tr>
<td>Quora-SQ Model</td>
<td><a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/QuoraSQ_Model.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a></td>
<td><a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/QuoraSQ_Model.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></td>
<td><a href="https://tensorboard.dev/experiment/9zZYZKtLQQmnOk4hpvgzsA/"><img src="https://img.shields.io/badge/logs-tensorboard-orange?logo=Tensorflow"></a></td>
</tr>
</tbody>
</table><h3 id="wiki-qa-dataset">Wiki QA Dataset</h3>
<p><strong>Dataset Preparation</strong></p>
<pre class=" language-python"><code class="prism language-python">S08 <span class="token operator">=</span> pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span><span class="token string">'/content/Question_Answer_Dataset_v2.2/S10/question_answer_pairs.txt'</span><span class="token punctuation">,</span> sep<span class="token operator">=</span><span class="token string">'\t'</span><span class="token punctuation">,</span> encoding <span class="token operator">=</span> <span class="token string">"ISO-8859-1"</span><span class="token punctuation">)</span>
S09 <span class="token operator">=</span> pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span><span class="token string">'/content/Question_Answer_Dataset_v1.2/S09/question_answer_pairs.txt'</span><span class="token punctuation">,</span> sep<span class="token operator">=</span><span class="token string">'\t'</span><span class="token punctuation">,</span> encoding <span class="token operator">=</span> <span class="token string">"ISO-8859-1"</span><span class="token punctuation">)</span>
S10 <span class="token operator">=</span> pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span><span class="token string">'/content/Question_Answer_Dataset_v1.2/S08/question_answer_pairs.txt'</span><span class="token punctuation">,</span> sep<span class="token operator">=</span><span class="token string">'\t'</span><span class="token punctuation">,</span> encoding <span class="token operator">=</span> <span class="token string">"ISO-8859-1"</span><span class="token punctuation">)</span>
combined_qa <span class="token operator">=</span> pd<span class="token punctuation">.</span>concat<span class="token punctuation">(</span><span class="token punctuation">[</span>S08<span class="token punctuation">,</span> S09<span class="token punctuation">,</span> S10<span class="token punctuation">]</span><span class="token punctuation">)</span>
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/assets/wikiqa_df.png?raw=true" alt="wiki-qa"></p>
<p>The dataset is pretty straight forward, there are columns <code>Question</code> and <code>Answer</code> in each of the <code>S08</code>, <code>S09</code> and <code>S10</code> sub folder, you combine all of them. That’s it. In total we have <code>3998</code> Rows and after dropping <code>NA</code> values we are left with <code>3422</code> Rows.</p>
<p>Which results in</p>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/assets/wikiqa_cleaned.png?raw=true" alt="wiki-qa-cleaned"></p>
<p>Below is a the PyTorch <code>Dataset</code> implementation of the same.</p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">class</span> <span class="token class-name">WikiQADataset</span><span class="token punctuation">(</span>Dataset<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""
Wiki QA Dataset
"""</span>
URL <span class="token operator">=</span> <span class="token string">'https://drive.google.com/uc?id=1FFTtPmxu63Dljelg8YsRRn8Yz475MWyv'</span>
OUTPUT <span class="token operator">=</span> <span class="token string">'wikiqa_dataset.csv'</span>
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> root<span class="token punctuation">,</span> split<span class="token operator">=</span><span class="token string">'train'</span><span class="token punctuation">,</span> vocab<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> vectors<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> text_transforms<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> label_transforms<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> ngrams<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""Initiate dataset.
Args:
vocab: Vocabulary object used for dataset.
"""</span>
<span class="token builtin">super</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>__class__<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> vectors<span class="token punctuation">:</span>
<span class="token keyword">raise</span> NotImplementedError<span class="token punctuation">(</span>f<span class="token string">'vectors not supported for this dataset as of now'</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> split <span class="token operator">not</span> <span class="token keyword">in</span> <span class="token punctuation">[</span><span class="token string">'train'</span><span class="token punctuation">,</span> <span class="token string">'test'</span><span class="token punctuation">]</span><span class="token punctuation">:</span>
<span class="token keyword">raise</span> ValueError<span class="token punctuation">(</span>f<span class="token string">'split must be either "train" or "test" unknown split {split}'</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> vocab <span class="token operator">and</span> vectors<span class="token punctuation">:</span>
<span class="token keyword">raise</span> ValueError<span class="token punctuation">(</span>f<span class="token string">'both vocab and vectors cannot be provided'</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>vocab <span class="token operator">=</span> vocab
self<span class="token punctuation">.</span>vectors <span class="token operator">=</span> vectors
gdown<span class="token punctuation">.</span>cached_download<span class="token punctuation">(</span>self<span class="token punctuation">.</span>URL<span class="token punctuation">,</span> Path<span class="token punctuation">(</span>root<span class="token punctuation">)</span> <span class="token operator">/</span> self<span class="token punctuation">.</span>OUTPUT<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>generate_tweet_dataset<span class="token punctuation">(</span>Path<span class="token punctuation">(</span>root<span class="token punctuation">)</span> <span class="token operator">/</span> self<span class="token punctuation">.</span>OUTPUT<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>train_dset<span class="token punctuation">,</span> self<span class="token punctuation">.</span>test_dset <span class="token operator">=</span> train_test_split<span class="token punctuation">(</span>self<span class="token punctuation">.</span>full_dataset_<span class="token punctuation">,</span> test_size<span class="token operator">=</span><span class="token number">0.3</span><span class="token punctuation">,</span> random_state<span class="token operator">=</span><span class="token number">42</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> split <span class="token operator">==</span> <span class="token string">'train'</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>dataset <span class="token operator">=</span> self<span class="token punctuation">.</span>train_dset
<span class="token keyword">elif</span> split <span class="token operator">==</span> <span class="token string">'test'</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>dataset <span class="token operator">=</span> self<span class="token punctuation">.</span>test_dset
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">raise</span> ValueError<span class="token punctuation">(</span><span class="token string">"What did you do? you stupid potato?"</span><span class="token punctuation">)</span>
<span class="token comment"># create the tokenizer, here we use spacy</span>
tokenizer <span class="token operator">=</span> get_tokenizer<span class="token punctuation">(</span><span class="token string">"spacy"</span><span class="token punctuation">,</span> language<span class="token operator">=</span><span class="token string">"en_core_web_sm"</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>tokenizer <span class="token operator">=</span> tokenizer
<span class="token comment"># the text transform can only work at the sentence level</span>
<span class="token comment"># the rest of tokenization and vocab is done by this class</span>
self<span class="token punctuation">.</span>text_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>tokenizer<span class="token punctuation">,</span> text_f<span class="token punctuation">.</span>ngrams_func<span class="token punctuation">(</span>ngrams<span class="token punctuation">)</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>vocab_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>vector_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">build_vocab</span><span class="token punctuation">(</span>data<span class="token punctuation">,</span> transforms<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">apply_transforms</span><span class="token punctuation">(</span>data<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">for</span> line <span class="token keyword">in</span> data<span class="token punctuation">:</span>
<span class="token keyword">yield</span> transforms<span class="token punctuation">(</span>line<span class="token punctuation">)</span>
<span class="token keyword">return</span> build_vocab_from_iterator<span class="token punctuation">(</span>apply_transforms<span class="token punctuation">(</span>data<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>data<span class="token punctuation">)</span><span class="token punctuation">,</span> specials<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">'<unk>'</span><span class="token punctuation">,</span> <span class="token string">'<pad>'</span><span class="token punctuation">,</span> <span class="token string">'<bos>'</span><span class="token punctuation">,</span> <span class="token string">'<eos>'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>vectors<span class="token punctuation">:</span>
self<span class="token punctuation">.</span>vector_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
partial<span class="token punctuation">(</span>vectors<span class="token punctuation">.</span>get_vecs_by_tokens<span class="token punctuation">,</span> lower_case_backup<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
<span class="token punctuation">)</span>
<span class="token keyword">elif</span> self<span class="token punctuation">.</span>vocab <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>vocab <span class="token operator">=</span> build_vocab<span class="token punctuation">(</span>
pd<span class="token punctuation">.</span>concat<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>train_dset<span class="token punctuation">[</span><span class="token string">'Question'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>train_dset<span class="token punctuation">[</span><span class="token string">'Answer'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
self<span class="token punctuation">.</span>text_transform
<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>PAD_IDX <span class="token operator">=</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">[</span><span class="token string">'<pad>'</span><span class="token punctuation">]</span>
self<span class="token punctuation">.</span>BOS_IDX <span class="token operator">=</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">[</span><span class="token string">'<bos>'</span><span class="token punctuation">]</span>
self<span class="token punctuation">.</span>EOS_IDX <span class="token operator">=</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">[</span><span class="token string">'<eos>'</span><span class="token punctuation">]</span>
<span class="token comment"># if the user is using vocab, instead of vector</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">:</span>
self<span class="token punctuation">.</span>vocab_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
text_f<span class="token punctuation">.</span>vocab_func<span class="token punctuation">(</span>self<span class="token punctuation">.</span>vocab<span class="token punctuation">)</span><span class="token punctuation">,</span> text_f<span class="token punctuation">.</span>totensor<span class="token punctuation">(</span>dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">)</span>
<span class="token punctuation">)</span>
<span class="token comment"># label transform is similar to text_transform for this dataset except this does not have vectors</span>
self<span class="token punctuation">.</span>label_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>text_transform<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vocab_transforms
<span class="token punctuation">)</span>
<span class="token keyword">if</span> text_transforms <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>text_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>text_transform<span class="token punctuation">,</span> text_transforms<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vocab_transforms<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vector_transforms
<span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>text_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>text_transform<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vocab_transforms<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vector_transforms
<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">generate_tweet_dataset</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> dataset_file<span class="token punctuation">)</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>full_dataset_ <span class="token operator">=</span> pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span>dataset_file<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">__getitem__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> idx<span class="token punctuation">)</span><span class="token punctuation">:</span>
text <span class="token operator">=</span> self<span class="token punctuation">.</span>text_transform<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dataset<span class="token punctuation">[</span><span class="token string">'Question'</span><span class="token punctuation">]</span><span class="token punctuation">.</span>iloc<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">)</span>
label <span class="token operator">=</span> self<span class="token punctuation">.</span>label_transform<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dataset<span class="token punctuation">[</span><span class="token string">'Answer'</span><span class="token punctuation">]</span><span class="token punctuation">.</span>iloc<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> label<span class="token punctuation">,</span> text
<span class="token keyword">def</span> <span class="token function">__len__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>dataset<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">get_vocab</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>vocab
<span class="token keyword">def</span> <span class="token function">get_vectors</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>vectors
<span class="token keyword">def</span> <span class="token function">batch_sampler_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">batch_sampler</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
indices <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">(</span>i<span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>tokenizer<span class="token punctuation">(</span>s<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> i<span class="token punctuation">,</span> s <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>train_list<span class="token punctuation">)</span><span class="token punctuation">]</span>
random<span class="token punctuation">.</span>shuffle<span class="token punctuation">(</span>indices<span class="token punctuation">)</span>
pooled_indices <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
<span class="token comment"># create pool of indices with similar lengths </span>
<span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>indices<span class="token punctuation">)</span><span class="token punctuation">,</span> batch_size <span class="token operator">*</span> <span class="token number">100</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
pooled_indices<span class="token punctuation">.</span>extend<span class="token punctuation">(</span><span class="token builtin">sorted</span><span class="token punctuation">(</span>indices<span class="token punctuation">[</span>i<span class="token punctuation">:</span>i <span class="token operator">+</span> batch_size <span class="token operator">*</span> <span class="token number">100</span><span class="token punctuation">]</span><span class="token punctuation">,</span> key<span class="token operator">=</span><span class="token keyword">lambda</span> x<span class="token punctuation">:</span> x<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
pooled_indices <span class="token operator">=</span> <span class="token punctuation">[</span>x<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token keyword">for</span> x <span class="token keyword">in</span> pooled_indices<span class="token punctuation">]</span>
<span class="token comment"># yield indices for current batch</span>
<span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pooled_indices<span class="token punctuation">)</span><span class="token punctuation">,</span> batch_size<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">yield</span> pooled_indices<span class="token punctuation">[</span>i<span class="token punctuation">:</span>i <span class="token operator">+</span> batch_size<span class="token punctuation">]</span>
<span class="token keyword">return</span> batch_sampler<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">collator_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">collate_fn</span><span class="token punctuation">(</span>batch<span class="token punctuation">)</span><span class="token punctuation">:</span>
targets<span class="token punctuation">,</span> sequences <span class="token operator">=</span> <span class="token builtin">zip</span><span class="token punctuation">(</span><span class="token operator">*</span>batch<span class="token punctuation">)</span>
lengths <span class="token operator">=</span> torch<span class="token punctuation">.</span>LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token builtin">len</span><span class="token punctuation">(</span>sequence<span class="token punctuation">)</span> <span class="token keyword">for</span> sequence <span class="token keyword">in</span> sequences<span class="token punctuation">]</span><span class="token punctuation">)</span>
targets <span class="token operator">=</span> <span class="token punctuation">[</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>BOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> item<span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>EOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">for</span> item <span class="token keyword">in</span> targets<span class="token punctuation">]</span>
sequences <span class="token operator">=</span> <span class="token punctuation">[</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>BOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> item<span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>EOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">for</span> item <span class="token keyword">in</span> sequences<span class="token punctuation">]</span>
<span class="token keyword">if</span> <span class="token operator">not</span> self<span class="token punctuation">.</span>vectors<span class="token punctuation">:</span>
pad_idx <span class="token operator">=</span> self<span class="token punctuation">.</span>PAD_IDX
sequences <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>rnn<span class="token punctuation">.</span>pad_sequence<span class="token punctuation">(</span>sequences<span class="token punctuation">,</span>
padding_value <span class="token operator">=</span> pad_idx<span class="token punctuation">,</span>
batch_first<span class="token operator">=</span><span class="token boolean">True</span>
<span class="token punctuation">)</span>
targets <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>rnn<span class="token punctuation">.</span>pad_sequence<span class="token punctuation">(</span>targets<span class="token punctuation">,</span>
padding_value <span class="token operator">=</span> pad_idx<span class="token punctuation">,</span>
batch_first<span class="token operator">=</span><span class="token boolean">True</span>
<span class="token punctuation">)</span>
<span class="token keyword">return</span> targets<span class="token punctuation">,</span> sequences<span class="token punctuation">,</span> lengths
<span class="token keyword">return</span> collate_fn
</code></pre>
<h3 id="wikiqa-model---training-log">WikiQA Model - Training Log</h3>
<pre><code> | Name | Type | Params
---------------------------------------------
0 | encoder | Encoder | 5.1 M
1 | decoder | Decoder | 7.8 M
2 | loss | CrossEntropyLoss | 0
---------------------------------------------
12.9 M Trainable params
0 Non-trainable params
12.9 M Total params
51.496 Total estimated model params size (MB)
Validating: 100%
9/9 [00:01<00:00, 5.60it/s]
Epoch: 0, Test PPL: 329.2035827636719, Test Loss: 5.630014896392822
Validating: 100%
9/9 [00:01<00:00, 5.33it/s]
Epoch: 1, Test PPL: 323.3686828613281, Test Loss: 5.590897560119629
Validating: 100%
9/9 [00:01<00:00, 5.32it/s]
Epoch: 2, Test PPL: 320.0765380859375, Test Loss: 5.55886173248291
Validating: 100%
9/9 [00:01<00:00, 5.89it/s]
Epoch: 3, Test PPL: 308.7172546386719, Test Loss: 5.521171569824219
Validating: 100%
9/9 [00:01<00:00, 6.69it/s]
Epoch: 4, Test PPL: 328.9159851074219, Test Loss: 5.575353622436523
Validating: 100%
9/9 [00:01<00:00, 5.37it/s]
Epoch: 5, Test PPL: 323.0148620605469, Test Loss: 5.544531345367432
Validating: 100%
9/9 [00:01<00:00, 5.33it/s]
Epoch: 6, Test PPL: 321.3564147949219, Test Loss: 5.539650917053223
Validating: 100%
9/9 [00:01<00:00, 5.77it/s]
Epoch: 7, Test PPL: 338.40802001953125, Test Loss: 5.586311340332031
Validating: 100%
9/9 [00:01<00:00, 6.56it/s]
Epoch: 8, Test PPL: 344.0347595214844, Test Loss: 5.59952449798584
Validating: 100%
9/9 [00:01<00:00, 5.30it/s]
Epoch: 9, Test PPL: 352.4696960449219, Test Loss: 5.6175217628479
</code></pre>
<h3 id="quora-sq-dataset">Quora SQ Dataset</h3>
<p>This dataset is a <code>.tsv</code> file</p>
<pre><code>id qid1 qid2 question1 question2 is_duplicate
0 1 2 What is the step by step guide to invest in share market in india? What is the step by step guide to invest in share market? 0
1 3 4 What is the story of Kohinoor (Koh-i-Noor) Diamond? What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back? 0
2 5 6 How can I increase the speed of my internet connection while using a VPN? How can Internet speed be increased by hacking through DNS? 0
3 7 8 Why am I mentally very lonely? How can I solve it? Find the remainder when [math]23^{24}[/math] is divided by 24,23? 0
4 9 10 Which one dissolve in water quikly sugar, salt, methane and carbon di oxide? Which fish would survive in salt water? 0
5 11 12 Astrology: I am a Capricorn Sun Cap moon and cap rising...what does that say about me? I'm a triple Capricorn (Sun, Moon and ascendant in Capricorn) What does this say about me? 1
6 13 14 Should I buy tiago? What keeps childern active and far from phone and video games? 0
7 15 16 How can I be a good geologist? What should I do to be a great geologist? 1
8 17 18 When do you use シ instead of し? "When do you use ""&"" instead of ""and""?" 0
</code></pre>
<p>The only columns we are interested is in are <code>question1</code>, <code>question2</code> and <code>is_duplicate</code></p>
<p>Also we want only the questions that are duplicate, because our model will take in <code>question1</code> and try to generate <code>question</code></p>
<pre class=" language-python"><code class="prism language-python">duplicate_df <span class="token operator">=</span> quora_df<span class="token punctuation">[</span>quora_df<span class="token punctuation">[</span><span class="token string">'is_duplicate'</span><span class="token punctuation">]</span> <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">]</span>
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/07_Seq2Seq/assets/quorasq_df.png?raw=true" alt="quora_sq"></p>
<p>This dataset was huge <code>\(〇_o)/</code> with over <code>149263 rows</code>, also this was a pain to train, compared to the other model.</p>
<p>The PyTorch style <code>Dataset</code> of this is exactly same as before dataset, just the <code>.csv</code> file is different. It would have been better to create a new Class <code>CSVDataset</code> for these kind of datasets.</p>
<p>TorchText does have a <code>TabularDataset</code> implementation, but that is old, and i refuse to use it. I will find a way to create a normal pytorch <code>Dataset</code> for Tabular Data.</p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">class</span> <span class="token class-name">QuoraSQDataset</span><span class="token punctuation">(</span>Dataset<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""
Quora Similar Questions Dataset
"""</span>
URL <span class="token operator">=</span> <span class="token string">'https://drive.google.com/uc?id=1g2YqSPXBWdCU1SjkCb69ENUuoEuxPbLg'</span>
OUTPUT <span class="token operator">=</span> <span class="token string">'quora_duplicate_only_questions.csv'</span>
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> root<span class="token punctuation">,</span> split<span class="token operator">=</span><span class="token string">'train'</span><span class="token punctuation">,</span> vocab<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> vectors<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> text_transforms<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> label_transforms<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> ngrams<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""Initiate dataset.
Args:
vocab: Vocabulary object used for dataset.
"""</span>
<span class="token builtin">super</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>__class__<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> vectors<span class="token punctuation">:</span>
<span class="token keyword">raise</span> NotImplementedError<span class="token punctuation">(</span>f<span class="token string">'vectors not supported for this dataset as of now'</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> split <span class="token operator">not</span> <span class="token keyword">in</span> <span class="token punctuation">[</span><span class="token string">'train'</span><span class="token punctuation">,</span> <span class="token string">'test'</span><span class="token punctuation">]</span><span class="token punctuation">:</span>
<span class="token keyword">raise</span> ValueError<span class="token punctuation">(</span>f<span class="token string">'split must be either "train" or "test" unknown split {split}'</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> vocab <span class="token operator">and</span> vectors<span class="token punctuation">:</span>
<span class="token keyword">raise</span> ValueError<span class="token punctuation">(</span>f<span class="token string">'both vocab and vectors cannot be provided'</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>vocab <span class="token operator">=</span> vocab
self<span class="token punctuation">.</span>vectors <span class="token operator">=</span> vectors
gdown<span class="token punctuation">.</span>cached_download<span class="token punctuation">(</span>self<span class="token punctuation">.</span>URL<span class="token punctuation">,</span> Path<span class="token punctuation">(</span>root<span class="token punctuation">)</span> <span class="token operator">/</span> self<span class="token punctuation">.</span>OUTPUT<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>generate_tweet_dataset<span class="token punctuation">(</span>Path<span class="token punctuation">(</span>root<span class="token punctuation">)</span> <span class="token operator">/</span> self<span class="token punctuation">.</span>OUTPUT<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>train_dset<span class="token punctuation">,</span> self<span class="token punctuation">.</span>test_dset <span class="token operator">=</span> train_test_split<span class="token punctuation">(</span>self<span class="token punctuation">.</span>full_dataset_<span class="token punctuation">,</span> test_size<span class="token operator">=</span><span class="token number">0.3</span><span class="token punctuation">,</span> random_state<span class="token operator">=</span><span class="token number">42</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> split <span class="token operator">==</span> <span class="token string">'train'</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>dataset <span class="token operator">=</span> self<span class="token punctuation">.</span>train_dset
<span class="token keyword">elif</span> split <span class="token operator">==</span> <span class="token string">'test'</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>dataset <span class="token operator">=</span> self<span class="token punctuation">.</span>test_dset
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">raise</span> ValueError<span class="token punctuation">(</span><span class="token string">"What did you do? you stupid potato?"</span><span class="token punctuation">)</span>
<span class="token comment"># create the tokenizer, here we use spacy</span>
tokenizer <span class="token operator">=</span> get_tokenizer<span class="token punctuation">(</span><span class="token string">"spacy"</span><span class="token punctuation">,</span> language<span class="token operator">=</span><span class="token string">"en_core_web_sm"</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>tokenizer <span class="token operator">=</span> tokenizer
<span class="token comment"># the text transform can only work at the sentence level</span>
<span class="token comment"># the rest of tokenization and vocab is done by this class</span>
self<span class="token punctuation">.</span>text_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>tokenizer<span class="token punctuation">,</span> text_f<span class="token punctuation">.</span>ngrams_func<span class="token punctuation">(</span>ngrams<span class="token punctuation">)</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>vocab_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>vector_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">build_vocab</span><span class="token punctuation">(</span>data<span class="token punctuation">,</span> transforms<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">apply_transforms</span><span class="token punctuation">(</span>data<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">for</span> line <span class="token keyword">in</span> data<span class="token punctuation">:</span>
<span class="token keyword">yield</span> transforms<span class="token punctuation">(</span>line<span class="token punctuation">)</span>
<span class="token keyword">return</span> build_vocab_from_iterator<span class="token punctuation">(</span>apply_transforms<span class="token punctuation">(</span>data<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>data<span class="token punctuation">)</span><span class="token punctuation">,</span> specials<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">'<unk>'</span><span class="token punctuation">,</span> <span class="token string">'<pad>'</span><span class="token punctuation">,</span> <span class="token string">'<bos>'</span><span class="token punctuation">,</span> <span class="token string">'<eos>'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>vectors<span class="token punctuation">:</span>
self<span class="token punctuation">.</span>vector_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
partial<span class="token punctuation">(</span>vectors<span class="token punctuation">.</span>get_vecs_by_tokens<span class="token punctuation">,</span> lower_case_backup<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
<span class="token punctuation">)</span>
<span class="token keyword">elif</span> self<span class="token punctuation">.</span>vocab <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>vocab <span class="token operator">=</span> build_vocab<span class="token punctuation">(</span>
pd<span class="token punctuation">.</span>concat<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>train_dset<span class="token punctuation">[</span><span class="token string">'question1'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>train_dset<span class="token punctuation">[</span><span class="token string">'question2'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
self<span class="token punctuation">.</span>text_transform
<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>PAD_IDX <span class="token operator">=</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">[</span><span class="token string">'<pad>'</span><span class="token punctuation">]</span>
self<span class="token punctuation">.</span>BOS_IDX <span class="token operator">=</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">[</span><span class="token string">'<bos>'</span><span class="token punctuation">]</span>
self<span class="token punctuation">.</span>EOS_IDX <span class="token operator">=</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">[</span><span class="token string">'<eos>'</span><span class="token punctuation">]</span>
<span class="token comment"># if the user is using vocab, instead of vector</span>
<span class="token keyword">if</span> self<span class="token punctuation">.</span>vocab<span class="token punctuation">:</span>
self<span class="token punctuation">.</span>vocab_transforms <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
text_f<span class="token punctuation">.</span>vocab_func<span class="token punctuation">(</span>self<span class="token punctuation">.</span>vocab<span class="token punctuation">)</span><span class="token punctuation">,</span> text_f<span class="token punctuation">.</span>totensor<span class="token punctuation">(</span>dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">)</span>
<span class="token punctuation">)</span>
<span class="token comment"># label transform is similar to text_transform for this dataset except this does not have vectors</span>
self<span class="token punctuation">.</span>label_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>text_transform<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vocab_transforms
<span class="token punctuation">)</span>
<span class="token keyword">if</span> text_transforms <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>text_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>text_transform<span class="token punctuation">,</span> text_transforms<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vocab_transforms<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vector_transforms
<span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>text_transform <span class="token operator">=</span> text_f<span class="token punctuation">.</span>sequential_transforms<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>text_transform<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vocab_transforms<span class="token punctuation">,</span> self<span class="token punctuation">.</span>vector_transforms
<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">generate_tweet_dataset</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> dataset_file<span class="token punctuation">)</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>full_dataset_ <span class="token operator">=</span> pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span>dataset_file<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">__getitem__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> idx<span class="token punctuation">)</span><span class="token punctuation">:</span>
text <span class="token operator">=</span> self<span class="token punctuation">.</span>text_transform<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dataset<span class="token punctuation">[</span><span class="token string">'question1'</span><span class="token punctuation">]</span><span class="token punctuation">.</span>iloc<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">)</span>
label <span class="token operator">=</span> self<span class="token punctuation">.</span>label_transform<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dataset<span class="token punctuation">[</span><span class="token string">'question2'</span><span class="token punctuation">]</span><span class="token punctuation">.</span>iloc<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> label<span class="token punctuation">,</span> text
<span class="token keyword">def</span> <span class="token function">__len__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>dataset<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">get_vocab</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>vocab
<span class="token keyword">def</span> <span class="token function">get_vectors</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>vectors
<span class="token keyword">def</span> <span class="token function">batch_sampler_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">batch_sampler</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
indices <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">(</span>i<span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>tokenizer<span class="token punctuation">(</span>s<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> i<span class="token punctuation">,</span> s <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>train_list<span class="token punctuation">)</span><span class="token punctuation">]</span>
random<span class="token punctuation">.</span>shuffle<span class="token punctuation">(</span>indices<span class="token punctuation">)</span>
pooled_indices <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
<span class="token comment"># create pool of indices with similar lengths </span>
<span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>indices<span class="token punctuation">)</span><span class="token punctuation">,</span> batch_size <span class="token operator">*</span> <span class="token number">100</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
pooled_indices<span class="token punctuation">.</span>extend<span class="token punctuation">(</span><span class="token builtin">sorted</span><span class="token punctuation">(</span>indices<span class="token punctuation">[</span>i<span class="token punctuation">:</span>i <span class="token operator">+</span> batch_size <span class="token operator">*</span> <span class="token number">100</span><span class="token punctuation">]</span><span class="token punctuation">,</span> key<span class="token operator">=</span><span class="token keyword">lambda</span> x<span class="token punctuation">:</span> x<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
pooled_indices <span class="token operator">=</span> <span class="token punctuation">[</span>x<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token keyword">for</span> x <span class="token keyword">in</span> pooled_indices<span class="token punctuation">]</span>
<span class="token comment"># yield indices for current batch</span>
<span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pooled_indices<span class="token punctuation">)</span><span class="token punctuation">,</span> batch_size<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">yield</span> pooled_indices<span class="token punctuation">[</span>i<span class="token punctuation">:</span>i <span class="token operator">+</span> batch_size<span class="token punctuation">]</span>
<span class="token keyword">return</span> batch_sampler<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">collator_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">collate_fn</span><span class="token punctuation">(</span>batch<span class="token punctuation">)</span><span class="token punctuation">:</span>
targets<span class="token punctuation">,</span> sequences <span class="token operator">=</span> <span class="token builtin">zip</span><span class="token punctuation">(</span><span class="token operator">*</span>batch<span class="token punctuation">)</span>
lengths <span class="token operator">=</span> torch<span class="token punctuation">.</span>LongTensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token builtin">len</span><span class="token punctuation">(</span>sequence<span class="token punctuation">)</span> <span class="token keyword">for</span> sequence <span class="token keyword">in</span> sequences<span class="token punctuation">]</span><span class="token punctuation">)</span>
targets <span class="token operator">=</span> <span class="token punctuation">[</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>BOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> item<span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>EOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">for</span> item <span class="token keyword">in</span> targets<span class="token punctuation">]</span>
sequences <span class="token operator">=</span> <span class="token punctuation">[</span>torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>BOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> item<span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>EOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">for</span> item <span class="token keyword">in</span> sequences<span class="token punctuation">]</span>
<span class="token keyword">if</span> <span class="token operator">not</span> self<span class="token punctuation">.</span>vectors<span class="token punctuation">:</span>
pad_idx <span class="token operator">=</span> self<span class="token punctuation">.</span>PAD_IDX
sequences <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>rnn<span class="token punctuation">.</span>pad_sequence<span class="token punctuation">(</span>sequences<span class="token punctuation">,</span>
padding_value <span class="token operator">=</span> pad_idx<span class="token punctuation">,</span>
batch_first<span class="token operator">=</span><span class="token boolean">True</span>
<span class="token punctuation">)</span>
targets <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>rnn<span class="token punctuation">.</span>pad_sequence<span class="token punctuation">(</span>targets<span class="token punctuation">,</span>
padding_value <span class="token operator">=</span> pad_idx<span class="token punctuation">,</span>
batch_first<span class="token operator">=</span><span class="token boolean">True</span>
<span class="token punctuation">)</span>
<span class="token keyword">return</span> targets<span class="token punctuation">,</span> sequences<span class="token punctuation">,</span> lengths
<span class="token keyword">return</span> collate_fn
</code></pre>
<h3 id="quorasq-model---training-log">QuoraSQ Model - Training Log</h3>
<pre><code> | Name | Type | Params
---------------------------------------------
0 | encoder | Encoder | 2.2 M
1 | decoder | Decoder | 4.3 M
2 | loss | CrossEntropyLoss | 0
---------------------------------------------
6.5 M Trainable params
0 Non-trainable params
6.5 M Total params
26.076 Total estimated model params size (MB)
Epoch: 0, Test PPL: 251.0346221923828, Test Loss: 5.5212016105651855
Epoch: 1, Test PPL: 214.76112365722656, Test Loss: 5.36377477645874
Epoch: 2, Test PPL: 187.0980987548828, Test Loss: 5.2250542640686035
Epoch: 3, Test PPL: 152.59373474121094, Test Loss: 5.019686222076416
Epoch: 4, Test PPL: 133.257568359375, Test Loss: 4.8832526206970215
</code></pre>
<h2 id="takeaways">Takeaways</h2>
<ul>
<li>There was no code-directory structure followed. I was planning to do it, but my lazy a didn’t do it. The idea is to have a python package to store all the models and the datasets and then you can simply import them and train them. I’ve done it before for Vision based Models/Dataset, so this should also be simple, just that it’s too time taking and i feel lazy doing the same thing over <code>≧ ﹏ ≦</code></li>
<li>I need to fix the <code>batch_sampler</code> so it works just like the <code>BucketIterator</code></li>
</ul>
<hr>
<center>
<iframe src="https://giphy.com/embed/bOwOAey4MDO3ivBkgK" width="480" height="480" class="giphy-embed" allowfullscreen=""></iframe><p><a href="https://giphy.com/gifs/end-old-hollywood-fin-bOwOAey4MDO3ivBkgK"></a></p>
</center>
</div>
</body>
</html>