-
Notifications
You must be signed in to change notification settings - Fork 6
/
eval_model.py
316 lines (269 loc) · 11.3 KB
/
eval_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
'''A script for generating patches
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.io
import numpy as np
import pickle
from os import path as osp
import argparse
import json
import time
try:
from ompl import base as ob
from ompl import geometric as og
from ompl import util as ou
except ImportError:
raise ImportError("Container does not have OMPL installed")
from transformer import Models as tfModel
from unet import Models as unetModel
from utils import geom2pix, ValidityChecker
from dataLoader import get_encoder_input
res = 0.05
def pix2geom(pos, res=0.05, length=24):
"""
Converts pixel co-ordinates to geometrical positions.
:param pos: The (x,y) pixel co-ordinates.
:param res: The distance represented by each pixel.
:param length: The length of the map in meters.
:returns (float, float): The associated eucledian co-ordinates.
"""
return (pos[0]*res, length-pos[1]*res)
receptive_field = 32
def getHashTable(mapSize):
'''
Return the hashTable for the given map
NOTE: This hastable only works for the patch_embedding network defined in the
transformers/Models.py file.
:param mapSize: The size of the map
:returns list: the hashTable to convert 1D token index to 2D image positions
'''
H, W = mapSize
Hhat = np.floor((H-8)/4) - 1
What = np.floor((W-8)/4) - 1
tokenH = int((Hhat+6)//5)
tokenW = int((What+6)//5)
return [(20*r+4, 20*c+4) for c in range(tokenH) for r in range(tokenW)]
def getPathLengthObjective(cost, si):
'''
Return the threshold objective for early termination
:param cost: The cost of the original RRT* path
:param si: An object of class ob.SpaceInformation
:returns : An object of class ob.PathLengthOptimizationObjective
'''
obj = ob.PathLengthOptimizationObjective(si)
obj.setCostThreshold(ob.Cost(cost))
return obj
def get_path(start, goal, input_map, patch_map, plannerType, cost, exp=False):
'''
Plan a path given the start, goal and patch_map.
:param start:
:param goal:
:param patch_map:
:param plannerType: The planner type to use
:param cost: The cost of the path
:param exp: If exploration is enabled
returns bool: Returns True if a path was planned successfully.
'''
mapSize = input_map.shape
# Planning parametersf
space = ob.RealVectorStateSpace(2)
bounds = ob.RealVectorBounds(2)
bounds.setLow(0.0)
bounds.setHigh(0, mapSize[1]*res) # Set width bounds (x)
bounds.setHigh(1, mapSize[0]*res) # Set height bounds (y)
space.setBounds(bounds)
si = ob.SpaceInformation(space)
ValidityCheckerObj = ValidityChecker(si, input_map, patch_map)
si.setStateValidityChecker(ValidityCheckerObj)
StartState = ob.State(space)
StartState[0] = start[0]
StartState[1] = start[1]
GoalState = ob.State(space)
GoalState[0] = goal[0]
GoalState[1] = goal[1]
success = False
# Define planning problem
pdef = ob.ProblemDefinition(si)
pdef.setStartAndGoalStates(StartState, GoalState, 0.1)
# Set up the objective function
obj = getPathLengthObjective(cost, si)
pdef.setOptimizationObjective(obj)
if plannerType=='rrtstar':
planner = og.RRTstar(si)
elif plannerType=='informedrrtstar':
planner = og.InformedRRTstar(si)
else:
raise TypeError(f"Planner Type {plannerType} not found")
# Set the problem instance the planner has to solve
planner.setProblemDefinition(pdef)
planner.setup()
# Attempt to solve the planning problem in the given time
if exp:
startTime = time.time()
solved = planner.solve(1.0)
if not pdef.hasExactSolution():
NewValidityCheckerObj = ValidityChecker(si, input_map)
si.setStateValidityChecker(NewValidityCheckerObj)
solved = planner.solve(89.0)
planTime = time.time()-startTime
else:
startTime = time.time()
solved = planner.solve(90)
planTime = time.time() - startTime
plannerData = ob.PlannerData(si)
planner.getPlannerData(plannerData)
numVertices = plannerData.numVertices()
if pdef.hasExactSolution():
success = True
print("Found Solution")
path = [
[pdef.getSolutionPath().getState(i)[0], pdef.getSolutionPath().getState(i)[1]]
for i in range(pdef.getSolutionPath().getStateCount())
]
else:
path = [[start[0], start[1]], [goal[0], goal[1]]]
return path, planTime, numVertices, success
device='cuda' if torch.cuda.is_available() else 'cpu'
def get_patch(model, start_pos, goal_pos, input_map):
'''
Return the patch map for the given start and goal position, and the network
architecture.
:param model:
:param start:
:param goal:
:param input_map:
'''
# Identitfy Anchor points
encoder_input = get_encoder_input(input_map, goal_pos, start_pos)
hashTable = getHashTable(input_map.shape)
predVal = model(encoder_input[None,:].float().cuda())
predClass = predVal[0, :, :].max(1)[1]
predProb = F.softmax(predVal[0, :, :], dim=1)
possAnchor = [hashTable[i] for i, label in enumerate(predClass) if label==1]
# Generate Patch Maps
patch_map = np.zeros_like(input_map)
map_size = input_map.shape
for pos in possAnchor:
goal_start_x = max(0, pos[0]- receptive_field//2)
goal_start_y = max(0, pos[1]- receptive_field//2)
goal_end_x = min(map_size[1], pos[0]+ receptive_field//2)
goal_end_y = min(map_size[0], pos[1]+ receptive_field//2)
patch_map[goal_start_y:goal_end_y, goal_start_x:goal_end_x] = 1.0
return patch_map, predProb
def get_patch_unet(model, start_pos, goal_pos, input_map):
'''
Return the patch map for the given start and goal position, and the network
architecture.
:param model: A UNetModel
:param start:
:param goal:
:param input_map:
'''
# Identitfy Anchor points
encoder_input = get_encoder_input(input_map, goal_pos, start_pos)
predVal = model(encoder_input[None,:].float().cuda())
patch_map = torch.argmax(predVal.cpu(), dim=1).squeeze().numpy()
return patch_map
device='cuda' if torch.cuda.is_available() else 'cpu'
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--segmentType',
help='The underlying segmentation method to use',
required=True,
choices=['mpt', 'unet']
)
parser.add_argument(
'--plannerType',
help='The underlying sampler to use',
required=True,
choices=['rrtstar', 'informedrrtstar']
)
parser.add_argument('--modelFolder', help='Directory where model_params.json exists', required=True)
parser.add_argument('--valDataFolder', help='Directory where training data exists', required=True)
parser.add_argument('--start', help='Start of environment number', required=True, type=int)
parser.add_argument('--numEnv', help='Number of environments', required=True, type=int)
parser.add_argument('--epoch', help='Model epoch number to test', required=True, type=int)
parser.add_argument('--numPaths', help='Number of start and goal pairs for each env', default=1, type=int)
parser.add_argument('--explore', help='Explore the environment w/o the mask', dest='explore', action='store_true')
parser.add_argument('--mapSize', help='The size of the input map', default='')
args = parser.parse_args()
modelFolder = args.modelFolder
modelFile = osp.join(modelFolder, f'model_params.json')
assert osp.isfile(modelFile), f"Cannot find the model_params.json file in {modelFolder}"
start = args.start
model_param = json.load(open(modelFile))
if args.segmentType =='mpt':
model = tfModel.Transformer(
**model_param
)
elif args.segmentType == 'unet':
model = unetModel.UNet(
**model_param
)
model.to(device)
receptive_field=32
# Load model parameters
epoch = args.epoch
checkpoint = torch.load(osp.join(modelFolder, f'model_epoch_{epoch}.pkl'))
model.load_state_dict(checkpoint['state_dict'])
# valDataFolder
valDataFolder = args.valDataFolder
# Only do evaluation - Need this for the problem to work with maps of different sizes.
model.eval()
# Get path data
pathSuccess = []
pathTime = []
pathVertices = []
for env_num in range(start, start+args.numEnv):
temp_map = osp.join(valDataFolder, f'env{env_num:06d}/map_{env_num}.png')
small_map = skimage.io.imread(temp_map, as_gray=True)
mapSize = small_map.shape
hashTable = getHashTable(mapSize)
for pathNum in range(args.numPaths):
# pathNum = 0
pathFile = osp.join(valDataFolder, f'env{env_num:06d}/path_{pathNum}.p')
data = pickle.load(open(pathFile, 'rb'))
path = data['path_interpolated']
if data['success']:
goal_pos = geom2pix(path[0, :], size=mapSize)
start_pos = geom2pix(path[-1, :], size=mapSize)
if args.segmentType =='mpt':
# NOTE: THIS IS NEEDS TO BE TESTED!!
# NOTE: All earlier data was gathered using hard coded
patch_map, _ = get_patch(model, start_pos, goal_pos, small_map)
elif args.segmentType == 'unet':
patch_map = get_patch_unet(model, start_pos, goal_pos, small_map)
# # Identitfy Anchor points
# encoder_input = get_encoder_input(small_map, goal_pos, start_pos)
# # NOTE: Currently only valid for map sizes of certain multiples.
# predVal = model(encoder_input[None,:].float().cuda())
# predClass = predVal[0, :, :].max(1)[1]
# predProb = F.softmax(predVal[0, :, :], dim=1)
# possAnchor = [hashTable[i] for i, label in enumerate(predClass) if label==1]
# # Generate Patch Maps
# patch_map = np.zeros_like(small_map)
# map_size = small_map.shape
# for pos in possAnchor:
# goal_start_x = max(0, pos[0]- receptive_field//2)
# goal_start_y = max(0, pos[1]- receptive_field//2)
# goal_end_x = min(map_size[1], pos[0]+ receptive_field//2)
# goal_end_y = min(map_size[0], pos[1]+ receptive_field//2)
# patch_map[goal_start_y:goal_end_y, goal_start_x:goal_end_x] = 1.0
cost = np.linalg.norm(np.diff(path, axis=0), axis=1).sum()
_, t, v, s = get_path(path[0, :], path[-1, :], small_map, patch_map, args.plannerType, cost, exp=args.explore)
pathSuccess.append(s)
pathTime.append(t)
pathVertices.append(v)
else:
pathSuccess.append(False)
pathTime.append(0)
pathVertices.append(0)
pathData = {'Time':pathTime, 'Success':pathSuccess, 'Vertices':pathVertices}
if args.explore:
fileName = osp.join(modelFolder, f'eval_val{args.mapSize}_plan_exp_{args.segmentType}_{args.plannerType}_{start:06d}.p')
else:
fileName = osp.join(modelFolder, f'eval_val{args.mapSize}_plan_{args.segmentType}_{args.plannerType}_{start:06d}.p')
pickle.dump(pathData, open(fileName, 'wb'))