Skip to content

Commit

Permalink
Fix broadcasting in pytorch
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561648133
  • Loading branch information
Conchylicultor authored and The visu3d Authors committed Aug 31, 2023
1 parent ac1685f commit c3655de
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions visu3d/dc_arrays/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def from_look_at(
"""
return cls(
t=pos,
R=_get_r_look_at_(pos=pos, target=target),
R=_get_r_look_at(pos=pos, target=target),
)

@dca.vectorize_method
Expand All @@ -148,7 +148,7 @@ def look_at(self, target: FloatArray['*shape 3']) -> Transform:
)
# TODO(epot): Rather than overwriting R, should only apply the rotation
# to the existing R.
return self.replace(R=_get_r_look_at_(pos=self.t, target=target))
return self.replace(R=_get_r_look_at(pos=self.t, target=target))

@property
@dca.vectorize_method
Expand Down Expand Up @@ -536,7 +536,7 @@ def _apply_to(self: ComposedTransform, other: _T) -> _T:
return self.left_tr @ (self.right_tr @ other)


def _get_r_look_at_(
def _get_r_look_at(
*,
pos: FloatArray['*shape 3'],
target: FloatArray['*shape 3'],
Expand All @@ -554,6 +554,10 @@ def _get_r_look_at_(

# In world coordinates, `z` is pointing up
world_up = xnp.asarray([0, 0, 1.0], dtype=xnp.float32)

# `torch.cross` do not support broadcasting
world_up = xnp.broadcast_to(world_up, cam_forward.shape)

# The width of the cam is parallel to the ground (prependicular to z), so
# use cross-product.
cam_w = xnp.cross(cam_forward, world_up)
Expand Down

0 comments on commit c3655de

Please sign in to comment.