-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy patheval_model_mpnet.py
320 lines (281 loc) · 12.3 KB
/
eval_model_mpnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
'''A script for planning usng MPNet
'''
import skimage.io
import skimage.morphology as skim
import numpy as np
import pickle
import torch
try:
from ompl import base as ob
from ompl import geometric as og
from ompl import util as ou
except ImportError:
raise ImportError("Container does not have OMPL installed")
from os import path as osp
import argparse
import time
import json
from utils import ValidityChecker
from mpnet import Models
res = 0.05
def normalize_state(state, worldBounds):
'''
Returns a normalized state between -1 and 1 for the given state.
:param state: A np.array of the state to be normalized.
:param worldBounds: A np.array of the world bounds in.
:return np.array: A normalized state between -1 adn 1.
'''
return (state/worldBounds)*2 - 1
def scale_state(state, length=24):
'''
Scale up state
:param state: The state of the robot
'''
return length*(state+1)/2
def check_edge_collision(start, goal, checkerFn):
'''
Returns True if collision between the edge b/w start and goal in collision
:param start: The start state of the robot
:param goal: The goal state of the robot
:param checkerFn: A collision checking function that takes a 2d state
:returns bool: True if path is in collision else False
'''
alpha = np.linspace(0, 1, int(np.floor(np.linalg.norm(start-goal)/0.5)))[:, None]
interTraj = (1-alpha)*start + alpha*goal
for pos in interTraj:
if not checkerFn(pos):
return True
return False
def check_trajectory_collision(traj, collisionFn):
'''
Returns True if a planned path has collision
:param traj: A trajectory represented by a numpy array
:param collisionFn: A collision checking function
'''
for i, pos in enumerate(traj[:-1,:]):
if check_edge_collision(pos, traj[i+1, :], collisionFn):
return True
return False
def plan_path_mpnet(model, enc, start, goal, device, collisionFn, worldBounds):
'''
Plan a path using MPNet.
:param model: The NN model used for planning.
:param enc: The latent space representation of the map.
:param start: A numpy array of size 2 with the normalized start state.
:param goal: A numpy array of size 2 with the normalized goal state.
:param device: The device used for torch computations.
:param collisionFn: A function to check if a state is in collision
:param worldBounds: A numpy array of worldBounds
:returns np.array: A set of scaled trajectories
'''
goalS = torch.tensor(normalize_state(goal, worldBounds)[None,:], device=device, dtype=torch.float)
startS = torch.tensor(normalize_state(start, worldBounds)[None,:], device=device, dtype=torch.float)
reachedGoal = False
normPredTraj = [startS.cpu().numpy().squeeze()]
vertex = 0
for _ in range(10):
# Check if we can connect start and goal state
if not check_edge_collision(scale_state(startS.cpu().numpy().squeeze()), goal, collisionFn):
reachedGoal = True
break
inputs = torch.cat([enc, startS, goalS], dim=1)
for _ in range(10):
with torch.no_grad():
tempS = model(inputs)
temp = tempS.cpu().numpy().squeeze()
if collisionFn(scale_state(temp)):
vertex +=1
startS = tempS
normPredTraj.append(temp)
break
# Check if we can connect to the goal or reached near it
if torch.linalg.norm(startS-goalS)*24<0.1:
reachedGoal = True
break
if reachedGoal:
normPredTraj.append(normalize_state(goal, worldBounds))
return scale_state(np.array(normPredTraj)), reachedGoal, vertex
def plan_path_mpnet_bidirection(model, enc, start, goal, device, collisionFn, worldBounds):
'''
Return a path planned in a bi-direction manner.
:param model: The NN model used for planning.
:param enc: The latent space representation of the map.
:param start: A numpy array of size 2 with the normalized start state.
:param goal: A numpy array of size 2 with the normalized goal state.
:param device: The device used for torch computations.
:param collisionFn: A function to check if a state is in collision
:returns np.array: A set of scaled trajectories
'''
goalS = torch.tensor(normalize_state(goal, worldBounds)[None,:], device=device, dtype=torch.float)
startS = torch.tensor(normalize_state(start, worldBounds)[None,:], device=device, dtype=torch.float)
reachedGoal = False
normPredTrajF = [startS.cpu().numpy().squeeze()]
normPredTrajB = [goalS.cpu().numpy().squeeze()]
normPredTraj = [normalize_state(start, worldBounds), normalize_state(goal, worldBounds)]
forward = True
vertex = 0
for _ in range(10):
# Check if we can connect start and goal state
if not check_edge_collision(
scale_state(startS.cpu().numpy().squeeze()),
scale_state(goalS.cpu().numpy().squeeze()),
collisionFn
):
reachedGoal = True
break
inputs = torch.cat([enc, startS, goalS], dim=1)
with torch.no_grad():
tempS = model(inputs)
temp = tempS.cpu().numpy().squeeze()
if collisionFn(scale_state(temp)):
vertex += 1
if forward:
startS = goalS
goalS = tempS
normPredTrajF.append(temp)
else:
goalS = startS
startS = tempS
normPredTrajB.insert(0, temp)
# Check if we can connect to the goal or have reached near it
if torch.linalg.norm(startS-goalS)*24<0.1:
reachedGoal = True
break
if reachedGoal:
normPredTraj = normPredTrajF + normPredTrajB
return scale_state(np.array(normPredTraj)), reachedGoal, vertex
def plan_path_rrt(start, goal, space, si):
'''
Returns a planned path using RRT for given start and goal state.
:param start: The start state of the robot
:param goal: The goal state of the robot
:param space: A ompl.base.Space object
:param si: An ompl.base.SpaceInformation object
'''
startState = ob.State(space)
startState[0] = float(start[0])
startState[1] = float(start[1])
goalState = ob.State(space)
goalState[0] = float(goal[0])
goalState[1] = float(goal[1])
pdef = ob.ProblemDefinition(si)
pdef.setStartAndGoalStates(startState, goalState, 0.1)
planner = og.RRTstar(si)
planner.setProblemDefinition(pdef)
planner.setup()
for _ in range(20):
solved = planner.solve(0.25)
if pdef.hasExactSolution():
path = np.array( [[pdef.getSolutionPath().getState(i)[0], pdef.getSolutionPath().getState(i)[1]]
for i in range(pdef.getSolutionPath().getStateCount())])
plannerData = ob.PlannerData(si)
planner.getPlannerData(plannerData)
return path, True, plannerData.numVertices()
return [], False, 0
def simplify_path(traj, collisionFn):
'''
Simplify a given trajectory by removing un-necessary nodes.
:param traj: A numpy array, with the trajectory to be simplified.
:param collisionFn: A function that can be use to check the collision status of a state
:return np.array: A simplified trajectory
'''
if len(traj)==1:
return traj
for i, pos in enumerate(traj[:1:-1]):
if not check_edge_collision(traj[0, :], pos, collisionFn):
return np.r_[traj[0,:][None, :], pos[None, :], simplify_path(traj[-(i+1):], collisionFn)]
return np.r_[traj[0, :][None,:], simplify_path(traj[1:], collisionFn)]
def get_path(start, goal, small_map, model, device, worldBounds):
'''
Plan a path using MPNet
:param start: the start state of the robot.
:param goal: The goal state of the robot
:param small_map: The map used for planning
:param model: The NN model for MPNet
:param device: The device for torch computations
:param worldBounds: The bounds of the training data
'''
mapSize = small_map.shape
space = ob.RealVectorStateSpace(2)
bounds = ob.RealVectorBounds(2)
bounds.setLow(0.0)
bounds.setHigh(0, mapSize[1]*res) # Set width bounds (x)
bounds.setHigh(1, mapSize[0]*res) # Set height bounds (y)
space.setBounds(bounds)
si = ob.SpaceInformation(space)
ValidityCheckerObj = ValidityChecker(si, small_map)
si.setStateValidityChecker(ValidityCheckerObj)
v = 0
startTime = time.time()
enc = model.get_environment_encoding(torch.tensor(small_map[None, None, :, :], dtype=torch.float, device=device))
for _ in range(50):
predTraj, reachedGoal, vertex = plan_path_mpnet_bidirection(model, enc, start, goal, device, ValidityCheckerObj.isValid, worldBounds)
if reachedGoal:
v += vertex
break
validTraj = predTraj[0, :][None, :]
for i, pos in enumerate(predTraj[:-1]):
if check_edge_collision(pos, predTraj[i+1, :], ValidityCheckerObj.isValid):
for _ in range(10):
newTraj, success, tmpvertex = plan_path_mpnet_bidirection(model, enc, pos, predTraj[i+1, :], device, ValidityCheckerObj.isValid, worldBounds)
v += tmpvertex if success else 0
# Check if trajectory is successful, else replan
if check_trajectory_collision(newTraj, ValidityCheckerObj.isValid):
newTraj, success, tmpvertex = plan_path_rrt(pos, predTraj[i+1, :], space, si)
v += tmpvertex if success else 0
if success:
validTraj = np.r_[validTraj, newTraj[1:, :]]
else:
validTraj = np.r_[validTraj, predTraj[i+1, :][None, :]]
if reachedGoal:
validTraj = simplify_path(validTraj, ValidityCheckerObj.isValid)
planTime = time.time() - startTime
return validTraj, planTime, v, reachedGoal
device='cuda' if torch.cuda.is_available() else 'cpu'
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--modelFolder', help='Directory where model_params.json exists', required=True)
parser.add_argument('--valDataFolder', help='Directory where training data exists', required=True)
parser.add_argument('--start', help='Start of environment number', required=True, type=int)
parser.add_argument('--samples', help='Number of envs', required=True, type=int)
parser.add_argument('--numPaths', help='Number of start and goal pairs for each env', default=1, type=int)
parser.add_argument('--epoch', help='Model epoch number to test', required=True, type=int)
args = parser.parse_args()
modelFolder = args.modelFolder
modelFile = osp.join(modelFolder, f'model_params.json')
assert osp.isfile(modelFile), f"Cannot find the model_params.json file in {modelFolder}"
start = args.start
samples = args.samples
model_param = json.load(open(modelFile))
model = Models.MPNet(**model_param)
model.to(device)
# Load model parameters
epoch = args.epoch
checkpoint = torch.load(osp.join(modelFolder, f'model_epoch_{epoch}.pkl'))
model.load_state_dict(checkpoint['state_dict'])
# valDataFolder
valDataFolder = args.valDataFolder
pathSuccess = []
pathTime = []
pathVertices = []
for env_num in range(start, start+samples):
temp_map = osp.join(valDataFolder, f'env{env_num:06d}/map_{env_num}.png')
small_map = skimage.io.imread(temp_map, as_gray=True)
for pathNum in range(args.numPaths):
pathFile = osp.join(valDataFolder, f'env{env_num:06d}/path_{pathNum}.p')
data = pickle.load(open(pathFile, 'rb'))
path = data['path_interpolated']
print(f"Env Num: {env_num}: {data['success']}")
if data['success']:
cost = np.linalg.norm(np.diff(path, axis=0), axis=1).sum()
_, t, v, s = get_path(path[0, :], path[-1, :], small_map, model, device, np.array([24.0, 24.0]))
pathSuccess.append(s)
pathTime.append(t)
pathVertices.append(v)
else:
pathSuccess.append(False)
pathTime.append(0)
pathVertices.append(0)
pathData = {'Time':pathTime, 'Success':pathSuccess, 'Vertices':pathVertices}
pickle.dump(pathData, open(osp.join(modelFolder, f'eval_val_plan_mpnet_forest_{start:06d}.p'), 'wb'))
print(sum(pathSuccess))