-
Notifications
You must be signed in to change notification settings - Fork 141
/
utils.lua
65 lines (57 loc) · 1.6 KB
/
utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
function loadData(file_name, config)
print(string.format("Loading data from %s", file_name))
local output = {}
local file = io.open(file_name, "r")
for line in file:lines() do
local data = {}
-- define data here, will be used in BatchIterator
data.input_file = config.data_root .. line .. '_mlt.png'
data.input_valid = config.data_root .. line .. '_valid.png'
data.boundary = config.data_root .. line .. '_3d_boundary.png'
data.name = config.data_root .. line
-- end
table.insert(output, data)
end
print(string.format("%d data loaded.",#output))
return output
end
function round(num, idp)
local mult = 10^(idp or 0)
return math.floor(num * mult + 0.5) / mult
end
function split(str, pat)
local t = {} -- NOTE: use {n = 0} in Lua-5.0
local fpat = "(.-)" .. pat
local last_end = 1
local s, e, cap = str:find(fpat, 1)
while s do
if s ~= 1 or cap ~= "" then
table.insert(t,cap)
end
last_end = e+1
s, e, cap = str:find(fpat, last_end)
end
if last_end <= #str then
cap = str:sub(last_end)
table.insert(t, cap)
end
return t
end
function file_exists(name)
local f=io.open(name,"r")
if f~=nil then io.close(f) return true else return false end
end
function bool2num(var)
return var and 1 or 0
end
function read_weight(file_name)
print(string.format("Loading weight from %s", file_name))
local file = io.open(file_name, "r")
a = file:read("*all")
b = string.split(a, '\n')
w = torch.Tensor(table.getn(b))
for i = 1,table.getn(b),1 do
w[i] = b[i]
end
return w
end