-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathprocess_cornell_data.py
94 lines (75 loc) · 3.16 KB
/
process_cornell_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import sys
"""
Processing of movie dialogue dataset from Cornell
http://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
"""
class CornellData:
"""
"""
def __init__(self, dirName):
"""
Args:
dirName (string): directory where to load the corpus
"""
self.lines = {}
self.dirName = dirName
self.queryTrainFile = 'gen_data/chitchat.train.query'
self.answerTrainFile = 'gen_data/chitchat.train.answer'
self.queryDevFile = 'gen_data/chitchat.dev.query'
self.answerDevFile = 'gen_data/chitchat.dev.answer'
self.MOVIE_LINES_FIELDS = ["lineID", "characterID",
"movieID", "character", "text"]
self.lines = self.loadLines(os.path.join(
self.dirName, 'cornell movie-dialogs corpus', "movie_lines.txt"))
# TODO: Cleaner program (merge copy-paste) !!
def loadLines(self, fileName):
"""
Args:
fileName (str): file to load
Return:
dict<dict<str>>: the extracted fields for each line
"""
lines = {}
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
lineObj = {}
for i, field in enumerate(self.MOVIE_LINES_FIELDS):
lineObj[field] = values[i]
lines[lineObj['lineID']] = lineObj
return lines
def spruceUpLine(self,line):
line = line.replace("'", " ' ")
line = line.replace(".", " . ")
line = line.replace("!", " !")
line = line.replace("?", " ?")
line = line.replace('"','')
line = line.replace(",",'')
line = line.replace("-", ' ')
return ' '.join(line.lower().split())
def writeToFile(self):
with open(os.path.join(self.dirName, self.queryTrainFile), 'w+') as querytrainfile:
with open(os.path.join(self.dirName, self.answerTrainFile), 'w+') as answertrainfile:
with open(os.path.join(self.dirName, self.queryDevFile), 'w+') as querydevfile:
with open(os.path.join(self.dirName, self.answerDevFile), 'w+') as answerdevfile:
line_numb = len(self.lines.keys())
keys = list(self.lines.keys())
i = 0
while i < line_numb-1:
first_phrase = self.spruceUpLine(self.lines[keys[i+1]]['text'])
second_phrase = self.spruceUpLine(self.lines[keys[i]]['text'])
if i%1000 == 0:
querydevfile.write(str(first_phrase)+'\n')
answerdevfile.write(str(second_phrase)+'\n')
else:
querytrainfile.write(str(first_phrase)+'\n')
answertrainfile.write(str(second_phrase)+'\n')
i = i+2
def main():
dirName = os.getcwd()
cornell = CornellData(dirName)
cornell.writeToFile()
if __name__ == '__main__':
main()