This repository has been archived by the owner on Jul 12, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathskipgram_dataset.py
137 lines (119 loc) · 6.06 KB
/
skipgram_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
This module contains the skipgram/word2vec Dataset and Dataloader.
"""
from typing import Tuple, Any, List
import numpy as np
from datasets import load_dataset # hugging_faces
from torch.utils.data import DataLoader
from nlpmodels.utils.elt.dataset import AbstractNLPDataset
from nlpmodels.utils.tokenizer import tokenize_corpus_basic
from nlpmodels.utils.vocabulary import NLPVocabulary
class SkipGramDataset(AbstractNLPDataset):
"""
SkipGramDataset class for transforming and storing dataset for use in Skip-gram model.
"""
def __init__(self, data: List):
"""
Args:
data (List): List of (input_word_token, context_word_token) tuples.
"""
self._data = data
def __len__(self) -> int:
"""
Returns:
size of dataset.
"""
return len(self._data)
def __getitem__(self, idx: int) -> Tuple[int, int]:
"""
Args:
idx (int): index of dataset slice to grab.
Returns:
Tuple of tensors (target, text) for that index.
"""
input_word_token, context_word_token = self._data[idx]
return input_word_token, context_word_token
@staticmethod
def get_skipgram_context(sentence_tokens: List, context_size: int, word_probas: np.array, train=True) -> List:
"""
Class method to take list of tokenized text and convert into sub-sampled (input,context) pairs.
Note that sub-sampling only happens on the training dataset (see Mikolov et al. for details).
Args:
sentence_tokens (list): list of tokenized data to be used to derive (input,context) pairs.
context_size (int): the window around each input word to derive context pairings.
word_probas (np.array): array of word probabilities found in corpus.
train (bool): a "train" flag to indicate we want to sub-sample the training set.
Returns:
list of (input_idx, context_idx) pairs to be used for negative sampling loss problem.
"""
data_partitions = []
sentence_len = len(sentence_tokens)
# calculate prob(w) per paper for common word_sampling, section 2.3 of paper
# designed to randomly drop high occurring words -> higher prob, higher chance to discard
# take input_word and provide the context for left and right
for input_idx, input_word in enumerate(sentence_tokens):
for context_idx in range(max(input_idx - context_size, 0), min(input_idx + context_size, sentence_len - 1)):
if context_idx != input_idx:
if train:
# sub-sampling methodology for training
if np.random.rand() < word_probas[context_idx] or np.random.rand() < word_probas[input_idx]:
continue
data_partitions.append((input_word, sentence_tokens[context_idx]))
return data_partitions
@classmethod
def get_target_context_data(cls, train_text: List, dictionary: NLPVocabulary, context_size: int,
train: bool) -> List:
"""
Class method to take list of tokenized text and convert into sub-sampled (input,context) pairs.
Note that sub-sampling only happens on the training dataset (see Mikolov et al. for details).
Args:
train_text (list): list of tokenized data to be used to derive (input,context) pairs.
dictionary (NLPVocabulary): a dictionary built off of the training data to map tokens <-> idxs.
context_size (int): the window around each input word to derive context pairings.
train (bool): a "train" flag to indicate we want to sub-sample the training set.
Returns:
list of (input_idx, context_idx) pairs to be used for negative sampling loss problem.
"""
train_data = []
word_probas = dictionary.get_word_discard_probas()
for tokens in train_text:
tokens = [dictionary.lookup_token(x) for x in tokens]
train_data.extend(cls.get_skipgram_context(tokens, context_size, word_probas, train))
return train_data
@classmethod
def get_training_dataloader(cls, *args: Any) -> Tuple[DataLoader, NLPVocabulary]:
"""
Class method to take transformed dataset and package in torch.Dataloader
so that random batches could be used in training.
Args:
context_size (int): size of the window to derive context words
thresh (float): a hyper-parameter to be used in frequent word sub-sampling
Returns:
(torch.Dataloader, NLPVocabulary) tuple to be used downstream in training.
"""
context, thresh, batch_size = args
train_data, vocab = cls.get_training_data(context, thresh)
train_loader = DataLoader(train_data,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True)
return train_loader, vocab
@classmethod
def get_training_data(cls, *args: Any) -> Tuple[AbstractNLPDataset, NLPVocabulary]:
"""
Class method to generate the training dataset (derived from hugging faces "ag_news").
This method grabs the raw text, tokenizes and cleans up the data, generates a dictionary,
and generates a sub-sampled (input,context) pair for training.
Returns:
(NLPDataset,NLPVocabulary) tuple to be used downstream in training.
"""
context_size, thresh = args
# Using the Ag News data via Hugging Faces
train_text = load_dataset("ag_news")['train']['text']
train_text = tokenize_corpus_basic(train_text)
dictionary = NLPVocabulary.build_vocabulary(train_text)
# for sub-sampling
dictionary.set_proba_thresh(thresh)
train_data = cls.get_target_context_data(train_text, dictionary, context_size, train=True)
return cls(train_data), dictionary