Skip to content

Commit

Permalink
Merge pull request #2098 from StFroese/tilted_frame_transformation
Browse files Browse the repository at this point in the history
Speed up rotations, closes #2097
  • Loading branch information
kosack authored Nov 17, 2022
2 parents e4787a9 + fdd14a6 commit 248dfd0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 34 deletions.
71 changes: 43 additions & 28 deletions ctapipe/coordinates/ground_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
frame_transform_graph,
)
from astropy.units.quantity import Quantity
from numpy import cos, sin

__all__ = [
"GroundFrame",
Expand Down Expand Up @@ -91,28 +90,27 @@ class TiltedGroundFrame(BaseCoordinateFrame):
pointing_direction = CoordinateAttribute(default=None, frame=AltAz)


def get_shower_trans_matrix(azimuth, altitude):
def _get_shower_trans_matrix(azimuth, altitude, inverse=False):
"""Get Transformation matrix for conversion from the ground system to
the Tilted system and back again (This function is directly lifted
from read_hess, probably could be streamlined using python
functionality)
Parameters
----------
azimuth: float
Azimuth angle of the tilted system used
altitude: float
Altitude angle of the tilted system used
azimuth: float or ndarray
Azimuth angle in radians of the tilted system used
altitude: float or ndarray
Altitude angle in radiuan of the tilted system used
Returns
-------
trans: 3x3 ndarray transformation matrix
"""

cos_z = sin(altitude)
sin_z = cos(altitude)
cos_az = cos(azimuth)
sin_az = sin(azimuth)
cos_z = np.sin(altitude) # this is the same as np.cos(zenith) but faster
sin_z = np.cos(altitude)
cos_az = np.cos(azimuth)
sin_az = np.sin(azimuth)

trans = np.array(
[
Expand All @@ -123,9 +121,28 @@ def get_shower_trans_matrix(azimuth, altitude):
dtype=np.float64,
)

if inverse:
return np.swapaxes(trans, 0, 1)

return trans


def _get_xyz(coord):
"""
Essentially the same as coord.cartesian.xyz, but much faster by
avoiding some astropy bottlenecks.
"""
# this is a speed optimization. Much faster to use data if already a
# Cartesian object
if isinstance(coord.data, CartesianRepresentation):
cart = coord.data
else:
cart = coord.cartesian

# this is ~5x faster then cart.xyz
return u.Quantity([cart.x, cart.y, cart.z])


@frame_transform_graph.transform(FunctionTransform, GroundFrame, TiltedGroundFrame)
def ground_to_tilted(ground_coord, tilted_frame):
"""
Expand All @@ -142,18 +159,17 @@ def ground_to_tilted(ground_coord, tilted_frame):
-------
SkyCoordinate transformed to `tilted_frame` coordinates
"""
x_grd, y_grd, z_grd = ground_coord.cartesian.xyz
xyz_grd = _get_xyz(ground_coord)

# convert to rad first and substract. Faster than .zen
altitude = tilted_frame.pointing_direction.alt.to_value(u.rad)
azimuth = tilted_frame.pointing_direction.az.to_value(u.rad)

trans = get_shower_trans_matrix(azimuth, altitude)
rotation_matrix = _get_shower_trans_matrix(azimuth, altitude)

x_tilt = trans[0, 0] * x_grd + trans[0, 1] * y_grd + trans[0, 2] * z_grd
y_tilt = trans[1, 0] * x_grd + trans[1, 1] * y_grd + trans[1, 2] * z_grd
z_tilt = trans[2, 0] * x_grd + trans[2, 1] * y_grd + trans[2, 2] * z_grd
vec = np.einsum("ij...,j...->i...", rotation_matrix, xyz_grd)

representation = CartesianRepresentation(x_tilt, y_tilt, z_tilt)
representation = CartesianRepresentation(*vec)

return tilted_frame.realize_frame(representation)

Expand All @@ -174,20 +190,19 @@ def tilted_to_ground(tilted_coord, ground_frame):
-------
GroundFrame coordinates
"""
x_tilt, y_tilt, z_tilt = tilted_coord.cartesian.xyz
xyz_tilt = _get_xyz(tilted_coord)

altitude = tilted_coord.pointing_direction.alt.to(u.rad)
azimuth = tilted_coord.pointing_direction.az.to(u.rad)
altitude = tilted_coord.pointing_direction.alt.to_value(u.rad)
azimuth = tilted_coord.pointing_direction.az.to_value(u.rad)

trans = get_shower_trans_matrix(azimuth, altitude)
rotation_matrix = _get_shower_trans_matrix(azimuth, altitude, inverse=True)

x_grd = trans[0][0] * x_tilt + trans[1][0] * y_tilt + trans[2][0] * z_tilt
y_grd = trans[0][1] * x_tilt + trans[1][1] * y_tilt + trans[2][1] * z_tilt
z_grd = trans[0][2] * x_tilt + trans[1][2] * y_tilt + trans[2][2] * z_tilt
vec = np.einsum("ij...,j...->i...", rotation_matrix, xyz_tilt)

representation = CartesianRepresentation(x_grd, y_grd, z_grd)
representation = CartesianRepresentation(*vec)

grd = ground_frame.realize_frame(representation)

return grd


Expand Down Expand Up @@ -215,9 +230,9 @@ def project_to_ground(tilt_system):
y_initial = ground_system.y.value
z_initial = ground_system.z.value

trans = get_shower_trans_matrix(
tilt_system.pointing_direction.az,
tilt_system.pointing_direction.alt,
trans = _get_shower_trans_matrix(
tilt_system.pointing_direction.az.to_value(u.rad),
tilt_system.pointing_direction.alt.to_value(u.rad),
)

x_projected = x_initial - trans[2][0] * z_initial / trans[2][2]
Expand Down
37 changes: 31 additions & 6 deletions ctapipe/coordinates/tests/test_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,19 @@ def test_ground_to_tilt_many_to_one():
def test_ground_to_tilt_many_to_many():
from ctapipe.coordinates import GroundFrame, TiltedGroundFrame

# define ground coordinate
grd_coord = GroundFrame(x=[1, 1] * u.m, y=[2, 2] * u.m, z=[0, 0] * u.m)
ground = GroundFrame(x=[1, 2] * u.m, y=[2, 1] * u.m, z=[3, 3] * u.m)
pointing_direction = SkyCoord(
alt=[90, 90, 90], az=[0, 0, 90], frame=AltAz(), unit=u.deg
alt=[90, 90, 90],
az=[0, 90, 180],
frame=AltAz(),
unit=u.deg,
)

with raises(ValueError):
# there will be a shape mismatch in matrix multiplication
grd_coord.transform_to(TiltedGroundFrame(pointing_direction=pointing_direction))
tilted = ground[:, np.newaxis].transform_to(
TiltedGroundFrame(pointing_direction=pointing_direction)
)

assert tilted.shape == (2, 3)


def test_camera_missing_focal_length():
Expand Down Expand Up @@ -282,6 +286,27 @@ def test_ground_frame_roundtrip():
assert u.isclose(coord.z, back.z, atol=1e-12 * u.m)


def test_ground_to_tilt_many_to_many_roundtrip():
from ctapipe.coordinates import GroundFrame, TiltedGroundFrame

ground = GroundFrame(x=[1, 2] * u.m, y=[2, 1] * u.m, z=[3, 3] * u.m)
pointing_direction = SkyCoord(
alt=[90, 90, 90],
az=[0, 0, 180],
frame=AltAz(),
unit=u.deg,
)

tilted = ground[:, np.newaxis].transform_to(
TiltedGroundFrame(pointing_direction=pointing_direction)
)
back = tilted[:, 0].transform_to(GroundFrame())

assert u.isclose(ground.x, back.x, atol=1e-12 * u.m).all()
assert u.isclose(ground.y, back.y, atol=1e-12 * u.m).all()
assert u.isclose(ground.z, back.z, atol=1e-12 * u.m).all()


def test_ground_to_eastnorth_roundtrip():
"""Check Ground to EastingNorthing and the round-trip"""
from ctapipe.coordinates import EastingNorthingFrame, GroundFrame
Expand Down

0 comments on commit 248dfd0

Please sign in to comment.