forked from OpenNMT/OpenNMT
-
Notifications
You must be signed in to change notification settings - Fork 2
/
preprocess.lua
124 lines (100 loc) · 3.63 KB
/
preprocess.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
require('onmt.init')
local cmd = onmt.utils.ExtendedCmdLine.new('preprocess.lua')
-- First argument define the dataType: bitext/monotext - default is bitext.
local dataType = cmd.getArgument(arg, '-data_type') or 'bitext'
-- Options declaration
local options = {
{
'-data_type', 'bitext',
[[Type of data to preprocess. Use 'monotext' for monolingual data.
This option impacts all options choices.]],
{
enum = {'bitext', 'monotext', 'feattext'},
depends = function(opt) return opt.data_type ~= 'feattext' or opt.idx_files end
}
},
{
'-dry_run', false,
[[If set, this will only prepare the preprocessor. Useful when using file sampling to
test distribution rules.]]
},
{
'-save_data', '',
[[Output file for the prepared data.]],
{
depends = function(opt)
return opt.dry_run or opt.save_data ~= '', "option `-save_data` is required"
end
}
}
}
cmd:setCmdLineOptions(options, 'Preprocess')
onmt.data.Preprocessor.declareOpts(cmd, dataType)
-- insert on the fly the option depending if there is a hook selected
onmt.utils.HookManager.updateOpt(arg, cmd)
-- expand options depending on source or target (tokenization, mpreprocessing)
onmt.data.Preprocessor.expandOpts(cmd, dataType)
onmt.utils.HookManager.declareOpts(cmd)
onmt.utils.Logger.declareOpts(cmd)
local otherOptions = {
{
'-seed', 3425,
[[Random seed.]],
{
valid = onmt.utils.ExtendedCmdLine.isUInt()
}
}
}
cmd:setCmdLineOptions(otherOptions, 'Other')
local opt = cmd:parse(arg)
local function main()
torch.manualSeed(opt.seed)
_G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level)
_G.hookManager = onmt.utils.HookManager.new(opt)
local Preprocessor = onmt.data.Preprocessor.new(opt, dataType)
if opt.dry_run then
_G.logger:shutDown()
return
end
local data = { dataType=dataType }
-- keep processing options in the structure for further traceability
data.opt = opt
_G.logger:info('Preparing vocabulary...')
data.dicts = Preprocessor:getVocabulary()
_G.logger:info('Preparing training data...')
data.train = Preprocessor:makeData('train', data.dicts)
_G.logger:info('')
_G.logger:info('Preparing validation data...')
data.valid = Preprocessor:makeData('valid', data.dicts)
_G.logger:info('')
if dataType == 'monotext' then
if opt.vocab:len() == 0 then
onmt.data.Vocabulary.save('source', data.dicts.src.words, opt.save_data .. '.dict')
end
if opt.features_vocabs_prefix:len() == 0 then
onmt.data.Vocabulary.saveFeatures('source', data.dicts.src.features, opt.save_data)
end
elseif dataType == 'feattext' then
if opt.tgt_vocab:len() == 0 then
onmt.data.Vocabulary.save('target', data.dicts.tgt.words, opt.save_data .. '.tgt.dict')
end
if opt.features_vocabs_prefix:len() == 0 then
onmt.data.Vocabulary.saveFeatures('target', data.dicts.tgt.features, opt.save_data)
end
else
if opt.src_vocab:len() == 0 then
onmt.data.Vocabulary.save('source', data.dicts.src.words, opt.save_data .. '.src.dict')
end
if opt.tgt_vocab:len() == 0 then
onmt.data.Vocabulary.save('target', data.dicts.tgt.words, opt.save_data .. '.tgt.dict')
end
if opt.features_vocabs_prefix:len() == 0 then
onmt.data.Vocabulary.saveFeatures('source', data.dicts.src.features, opt.save_data..'.source')
onmt.data.Vocabulary.saveFeatures('target', data.dicts.tgt.features, opt.save_data..'.target')
end
end
_G.logger:info('Saving data to \'' .. opt.save_data .. '-train.t7\'...')
torch.save(opt.save_data .. '-train.t7', data, 'binary', false)
_G.logger:shutDown()
end
main()