-
Notifications
You must be signed in to change notification settings - Fork 0
/
intersect.py
executable file
·84 lines (67 loc) · 3.02 KB
/
intersect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import box_utils
import ray_utils
def apply_mask_to_tensors(mask, tensors):
"""Applies mask to a list of tensors.
Args:
mask: [R, ...]. Mask to apply.
tensors: List of [R, ...]. List of tensors.
Returns:
intersect_tensors: List of shape [R?, ...]. Masked tensors.
"""
intersect_tensors = []
for t in tensors:
intersect_t = t[mask]
intersect_tensors.append(intersect_t)
return intersect_tensors
def get_full_intersection_tensors(ray_batch):
"""Test case that selects all rays."""
mask = torch.ones_like(ray_batch, dtype=torch.bool)[:, 0]
n_intersect = mask.size()[0] # R?
indices = torch.arange(n_intersect, dtype=torch.int).unsqueeze(1) # [R?, 1]
indices = apply_mask_to_tensors( # [R?, M]
mask=mask, # [R,]
tensors=[indices])[0] # [R, M]
bounds = ray_batch[:, 6:8] # [R?,]
bounds = apply_mask_to_tensors( # [R?, M]
mask=mask, # [R,]
tensors=[bounds])[0] # [R, M]
return mask, indices, bounds
def compute_object_intersect_tensors(ray_batch, box_center, box_dims):
"""Compute rays that intersect with bounding boxes.
Args:
ray_batch: [R, M] float tensor. Batch of rays.
box_center: List of 3 floats containing the (x, y, z) center of bbox.
box_dims: List of 3 floats containing the x, y, z dimensions of the bbox.
Returns:
intersect_ray_batch: [R?, M] float tensor. Batch of intersecting rays.
indices: [R?, 1] float tensor. Indices of intersecting rays.
"""
# Check that bbox params are properly formed.
for lst in [box_center, box_dims]:
assert type(lst) == list
assert len(lst) == 3
assert all((isinstance(x, int) or isinstance(x, float)) for x in lst)
# For now, we assume bbox has no rotation.
num_rays = ray_batch.size()[0] # R
box_center = torch.tile(torch.tensor(box_center), (num_rays, 1)).float() # [R, 3]
box_dims = torch.tile(torch.tensor(box_dims), (num_rays, 1)).float() # [R, 3]
box_rotation = torch.tile(torch.eye(3).unsqueeze(0), (num_rays, 1, 1)).float() # [R, 3, 3]
# Compute ray-bbox intersections.
bounds, indices, mask = box_utils.compute_ray_bbox_bounds_pairwise( # [R', 2], [R',], [R,]
rays_o=ray_batch[:, 0:3], # [R, 3]
rays_d=ray_batch[:, 3:6], # [R, 3]
box_length=box_dims[:, 0], # [R,]
box_width=box_dims[:, 1], # [R,]
box_height=box_dims[:, 2], # [R,]
box_center=box_center, # [R, 3]
box_rotation=box_rotation) # [R, 3, 3]
# Apply the intersection mask to the ray batch.
intersect_ray_batch = apply_mask_to_tensors( # [R?, M]
mask=mask, # [R,]
tensors=[ray_batch])[0] # [R, M]
# Update the near and far bounds of the ray batch with the intersect bounds.
intersect_ray_batch = ray_utils.update_ray_batch_bounds( # [R?, M]
ray_batch=intersect_ray_batch, # [R?, M]
bounds=bounds) # [R?, 2]
return intersect_ray_batch, indices, mask # [R?, M], [R?, 1]