-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
162 lines (142 loc) · 5.97 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import torch
import torch.utils.data as Data
import numpy as np
import unicodedata
from utils import device
SOS = '<s>'
EOS = '</s>'
UNK = '<unk>'
NUM = '<num>'
class Vocabulary(object):
def __init__ (self, vocfile, use_num=True):
self.use_num = use_num
self.word2idx = {}
self.idx2word = {}
self.word2idx[SOS] = 0
self.idx2word[0] = SOS
self.word2idx[EOS] = 1
self.idx2word[1] = EOS
words = open(vocfile, 'r').read().strip().split('\n')
for word in words:
if self.use_num and self.is_number(word):
word = NUM
if word not in self.word2idx and word!=UNK:
idx = len(self.word2idx)
self.word2idx[word] = idx
self.idx2word[idx] = word
self.word2idx[UNK] = len(self.word2idx)
self.idx2word[len(self.word2idx)] = UNK
self.vocsize = len(self.word2idx)
def word2id (self, word):
if self.use_num and self.is_number(word):
word = NUM
if word in self.word2idx:
return self.word2idx[word]
else:
return self.word2idx[UNK]
def id2word(self, idx):
if idx in self.idx2word:
return self.idx2word[idx]
else:
return UNK
def is_number(self, word):
word = word.replace(',', '') # 10,000 -> 10000
word = word.replace(':', '') # 5:30 -> 530
word = word.replace('-', '') # 17-08 -> 1708
word = word.replace('/', '') # 17/08/1992 -> 17081992
word = word.replace('th', '') # 20th -> 20
word = word.replace('rd', '') # 93rd -> 20
word = word.replace('nd', '') # 22nd -> 20
word = word.replace('m', '') # 20m -> 20
word = word.replace('s', '') # 20s -> 20
try:
float(word)
return True
except ValueError:
pass
try:
unicodedata.numeric(word)
return True
except (TypeError, ValueError):
pass
return False
def __len__(self):
return self.vocsize
# using Dataset to load txt data
class TextDataset(Data.Dataset):
def __init__(self, txtfile, voc):
self.words, self.ids = self.tokenize(txtfile, voc)
self.nline = len(self.ids)
self.n_sents = len(self.ids)
self.n_words = sum([len(ids) for ids in self.ids])
self.n_unks = len([id for ids in self.ids for id in ids if id == voc.word2id(UNK)])
def tokenize(self, txtfile, voc):
assert os.path.exists(txtfile)
lines = open(txtfile, 'r').readlines()
words, ids = [], []
for i, line in enumerate(lines):
tokens = line.strip().split()
if(len(tokens) == 0):
continue
words.append([SOS])
ids.append([voc.word2id(SOS)])
for token in tokens:
if voc.word2id(token) < len(voc.idx2word):
words[-1].append(token)
ids[-1].append(voc.word2id(token))
else:
words[-1].append(UNK)
ids[-1].append(voc.word2id(UNK))
words[-1].append(EOS)
ids[-1].append(voc.word2id(EOS))
return words, ids
def __len__(self):
return self.n_sents
def __repr__(self):
return '#Sents=%d, #Words=%d, #UNKs=%d'%(self.n_sents, self.n_words, self.n_unks)
def __getitem__ (self, index):
return self.ids[index]
class Corpus(object):
def __init__(self, data_dir, train_batch_size, valid_batch_size, test_batch_size):
self.voc = Vocabulary(os.path.join(data_dir, 'voc.txt'))
self.train_data = TextDataset(os.path.join(data_dir, 'train.txt'), self.voc)
self.valid_data = TextDataset(os.path.join(data_dir, 'valid.txt'), self.voc)
self.test_data = TextDataset(os.path.join(data_dir, 'test.txt'), self.voc)
self.train_loader = Data.DataLoader(self.train_data, batch_size=train_batch_size,
shuffle=True, num_workers=0, collate_fn=collate_fn, drop_last=False)
self.valid_loader = Data.DataLoader(self.valid_data, batch_size=valid_batch_size,
shuffle=False, num_workers=0, collate_fn=collate_fn, drop_last=False)
self.test_loader = Data.DataLoader(self.test_data, batch_size=test_batch_size,
shuffle=False, num_workers=0, collate_fn=collate_fn, drop_last=False)
def __repr__(self):
return 'Train: %s\n'%(self.train_data)+\
'Valid: %s\n'%(self.valid_data)+\
'Test: %s\n'%(self.test_data)
def collate_fn(batch):
sent_lens = torch.LongTensor(list(map(len, batch)))
max_len = sent_lens.max()
batchsize = len(batch)
sent_batch = sent_lens.new_zeros(batchsize, max_len)
for idx, (sent, sent_len) in enumerate(zip(batch, sent_lens)):
sent_batch[idx, :sent_len] = torch.LongTensor(sent)
sent_lens, perm_idx = sent_lens.sort(0, descending=True)
sent_batch = sent_batch[perm_idx]
sent_batch = sent_batch.t().contiguous()
inputs = sent_batch[0:max_len-1]
targets = sent_batch[1:max_len]
sent_lens.sub_(1)
return inputs.to(device), targets.to(device), sent_lens.to(device)
if __name__ == '__main__':
corpus = Corpus('data/ami', 8, 16, 1)
print(len(corpus.voc))
print(corpus.voc.id2word(len(corpus.voc)), corpus.voc.word2id('<unk>'),corpus.voc.word2id('<s>'), corpus.voc.word2id('</s>'))
for i, (inputs, targets, sent_lens) in enumerate(corpus.train_loader):
print(i, inputs, targets, sent_lens)
break
for i, (inputs, targets, sent_lens) in enumerate(corpus.valid_loader):
print(i, inputs.size(), targets.size(), sent_lens)
break
for i, (inputs, targets, sent_lens) in enumerate(corpus.test_loader):
print(i, inputs.size(), targets.size(), sent_lens)
break