diff --git a/LICENSE b/LICENSE index c3d32e5..55e1d36 100644 --- a/LICENSE +++ b/LICENSE @@ -84,7 +84,14 @@ https://github.com/waymo-research/waymax/blob/main/LICENSE, and your access and use of the Waymx Licensed Materials are governed by the terms and conditions contained therein. -@inproceedings{waymax, title={Waymax: An Accelerated, Data-Driven Simulator for Large-Scale Autonomous Driving Research}, author={Cole Gulino and Justin Fu and Wenjie Luo and George Tucker and Eli Bronstein and Yiren Lu and Jean Harb and Xinlei Pan and Yan Wang and Xiangyu Chen and John D. Co-Reyes and Rishabh Agarwal and Rebecca Roelofs and Yao Lu and Nico Montali and Paul Mougin and Zoey Yang and Brandyn White and Aleksandra Faust, and Rowan McAllister and Dragomir Anguelov and Benjamin Sapp}, booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},year={2023}} +@inproceedings{waymax, title={Waymax: An Accelerated, Data-Driven Simulator for +Large-Scale Autonomous Driving Research}, author={Cole Gulino and Justin Fu and +Wenjie Luo and George Tucker and Eli Bronstein and Yiren Lu and Jean Harb and +Xinlei Pan and Yan Wang and Xiangyu Chen and John D. Co-Reyes and Rishabh +Agarwal and Rebecca Roelofs and Yao Lu and Nico Montali and Paul Mougin and +Zoey Yang and Brandyn White and Aleksandra Faust, and Rowan McAllister and +Dragomir Anguelov and Benjamin Sapp}, booktitle={Proceedings of the Neural +Information Processing Systems Track on Datasets and Benchmarks},year={2023}} ii. In any license granting or any agreement governing use or access to Your Derivative IP, You must, and must require recipients of Your Derivative IP to diff --git a/README.md b/README.md index e82c351..d5ef5d0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,9 @@ distill behavior research into its simplest form. As all components are entirely written in JAX, Waymax is easily distributed and deployed on hardware accelerators, such as GPUs and -[TPUs](https://cloud.google.com/tpu). +[TPUs](https://cloud.google.com/tpu). Waymax is provided free of charge under +the terms of the [Waymax License Agreement for Non-Commercial Use](https://github.com/waymo-research/waymax/blob/main/LICENSE). + ## Installation @@ -36,7 +38,7 @@ instructions on how to setup JAX with GPU/CUDA support if needed. Waymax is designed to work with the Waymo Open Motion dataset out of the box. -A simple way to configure access is the following: +A simple way to configure access via command line is the following: 1. Apply for [Waymo Open Dataset](https://waymo.com/open) access. @@ -46,6 +48,13 @@ A simple way to configure access is the following: 4. Run `gcloud auth application-default login`. +If you are using [colab](https://colab.google), run the following inside of the colab after registering in step 1: + +```python +from google.colab import auth +auth.authenticate_user() +``` + Please reference [TF Datasets](https://www.tensorflow.org/datasets/gcs#authentication) for alternative methods to authentication. @@ -151,3 +160,7 @@ Brandyn White and Aleksandra Faust, and Rowan McAllister and Dragomir Anguelov a booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},year={2023}} ``` + +## Contact + +Please email any questions to [waymax@google.com](mailto:waymax@google.com), or raise an issue on Github. diff --git a/waymax/agents/actor_core.py b/waymax/agents/actor_core.py index 32bc878..8e74f6d 100644 --- a/waymax/agents/actor_core.py +++ b/waymax/agents/actor_core.py @@ -14,12 +14,11 @@ """Abstract definition of a Waymax actor for use at inference-time.""" import abc -from typing import Callable, TypeVar, Sequence +from typing import Callable, Sequence, TypeVar import chex import jax import jax.numpy as jnp - from waymax import datatypes # This is the internal state for whatever the agent needs to keep as its state. @@ -28,6 +27,7 @@ # This is the dictionary of parameters passed into the model which represents # the parameters to run the network. Params = datatypes.PyTree +Action = datatypes.PyTree @chex.dataclass(frozen=True) @@ -45,7 +45,7 @@ class WaymaxActorOutput: """ actor_state: ActorState - action: datatypes.Action + action: Action is_controlled: jax.Array def validate(self): diff --git a/waymax/agents/constant_speed.py b/waymax/agents/constant_speed.py index 6393dcd..0554015 100644 --- a/waymax/agents/constant_speed.py +++ b/waymax/agents/constant_speed.py @@ -108,7 +108,7 @@ def __init__(self, speed: float = 0.0): Args: speed: Speed in m/s to set as the speed for all agents. """ - super().__init__() + super().__init__(invalidate_on_end=True) self._speed = speed def update_speed( diff --git a/waymax/agents/waypoint_following_agent.py b/waymax/agents/waypoint_following_agent.py index 835c490..9ae090a 100644 --- a/waymax/agents/waypoint_following_agent.py +++ b/waymax/agents/waypoint_following_agent.py @@ -60,8 +60,10 @@ def __init__( is_controlled_func: Optional[ Callable[[datatypes.SimulatorState], jax.Array] ] = None, + invalidate_on_end: bool = False, ): super().__init__(is_controlled_func=is_controlled_func) + self.invalidate_on_end = invalidate_on_end def update_trajectory( self, state: datatypes.SimulatorState @@ -129,19 +131,25 @@ def _get_next_trajectory_by_projection( next_xy, next_yaw, reached_last_waypoint = _project_to_a_trajectory( jnp.stack([next_x, next_y], axis=-1), log_traj, - extrapolate_traj=False, + extrapolate_traj=not self.invalidate_on_end, ) # Freeze the speed for agents that have reached the last waypoint to # prevent drift. + if self.invalidate_on_end: + default_x_vel = jnp.zeros_like(cur_sim_traj.vel_x) + default_y_vel = jnp.zeros_like(cur_sim_traj.vel_y) + else: + default_x_vel = cur_sim_traj.vel_x + default_y_vel = cur_sim_traj.vel_y new_vel_x = jnp.where( reached_last_waypoint, - cur_sim_traj.vel_x, + default_x_vel, new_speed * jnp.cos(cur_sim_traj.yaw), ) new_vel_y = jnp.where( reached_last_waypoint, - cur_sim_traj.vel_y, + default_y_vel, new_speed * jnp.sin(cur_sim_traj.yaw), ) @@ -149,10 +157,11 @@ def _get_next_trajectory_by_projection( # This is to avoid invalidating parked cars. Use a threshold velocity # since some sim agents will tell the parked cars to move forward since # nothing is in front (e.g. IDM). - moving_after_last_waypoint = reached_last_waypoint & ( - new_speed > _STATIC_SPEED_THRESHOLD - ) - valid = valid & ~moving_after_last_waypoint + if self.invalidate_on_end: + moving_after_last_waypoint = reached_last_waypoint & ( + new_speed > _STATIC_SPEED_THRESHOLD + ) + valid = valid & ~moving_after_last_waypoint next_traj = cur_sim_traj.replace( x=next_xy[..., 0], @@ -204,23 +213,28 @@ def __init__( min_spacing: float = 2.0, safe_time_headway: float = 2.0, max_accel: float = 2.0, - max_deccel: float = 4.0, + max_decel: float = 4.0, delta: float = 4.0, max_lookahead: int = 10, + lookahead_from_current_position: bool = True, additional_lookahead_points: int = 10, additional_lookahead_distance: float = 10.0, + invalidate_on_end: bool = False, ): - super().__init__(is_controlled_func=is_controlled_func) + super().__init__( + is_controlled_func=is_controlled_func, + invalidate_on_end=invalidate_on_end, + ) self.desired_vel = desired_vel self.min_spacing_s0 = min_spacing self.safe_time_headway = safe_time_headway self.max_accel = max_accel - self.max_deccel = max_deccel + self.max_decel = max_decel self.delta = delta self.max_lookahead = max_lookahead + self.lookahead_from_current_position = lookahead_from_current_position self.additional_lookahead_distance = additional_lookahead_distance self.additional_headway_points = additional_lookahead_points - self.total_lookahead = max_lookahead + additional_lookahead_points def update_speed( self, state: datatypes.SimulatorState, dt: float = _DEFAULT_TIME_DELTA @@ -287,9 +301,18 @@ def _get_accel( log_waypoints.validate() obj_curr_traj.validate() # 1. Find the closest waypoint and slice the future from that waypoint. - traj = _find_reference_traj_from_log_traj( - cur_position, log_waypoints, self.max_lookahead - ) + if self.lookahead_from_current_position: + traj = _find_reference_traj_from_log_traj(cur_position, obj_curr_traj, 1) + chex.assert_shape(traj.xyz, prefix_shape + (num_obj, 1, 3)) + total_lookahead = 1 + self.additional_headway_points + else: + traj = _find_reference_traj_from_log_traj( + cur_position, log_waypoints, self.max_lookahead + ) + chex.assert_shape( + traj.xyz, prefix_shape + (num_obj, self.max_lookahead, 3) + ) + total_lookahead = self.max_lookahead + self.additional_headway_points if self.additional_headway_points > 0: @@ -303,7 +326,7 @@ def _get_accel( # max_lookahead) between traj (..., num_objects, max_lookahead) and # obj_curr_traj (..., num_objects, 1). Make common shape for bboxes: # (..., num_objects, num_objects, max_lookahead, 5). - broadcast_shape = prefix_shape + (num_obj, num_obj, self.total_lookahead, 5) + broadcast_shape = prefix_shape + (num_obj, num_obj, total_lookahead, 5) traj_5dof = traj.stack_fields(['x', 'y', 'length', 'width', 'yaw']) traj_bbox = jnp.broadcast_to( jnp.expand_dims(traj_5dof, axis=-3), broadcast_shape @@ -356,7 +379,7 @@ def _get_accel( cur_speed * self.safe_time_headway + cur_speed * (cur_speed - lead_vel) - / (2 * jnp.sqrt(self.max_accel * self.max_deccel)), + / (2 * jnp.sqrt(self.max_accel * self.max_decel)), ) # Set 0 for free-road behaviour. s_star = jnp.where( @@ -521,21 +544,21 @@ def project_point_to_traj( src_yaw = traj.yaw[idx] src_dir = jnp.stack([jnp.cos(src_yaw), jnp.sin(src_yaw)], axis=-1) + last_valid_idx = jnp.where(traj.valid, jnp.arange(traj.shape[0]), 0) + last_valid_idx = jnp.argmax(last_valid_idx, axis=-1) + last_point = traj.xy[last_valid_idx, :] + reached_last_point = ( + jnp.linalg.norm(last_point - src_xy, axis=-1) + < _REACHED_END_OF_TRAJECTORY_THRESHOLD + ) + # Secondary detection: If a vehicle strays too far from the traj, + # also mark it as reaching the end. + reached_last_point = jnp.logical_or( + reached_last_point, dist[idx] > _DISTANCE_TO_REF_THRESHOLD + ) + # Prevent points from extrapolating beyond traj. - reached_last_point = jnp.zeros_like(idx, dtype=jnp.bool_) if not extrapolate_traj: - last_valid_idx = jnp.where(traj.valid, jnp.arange(traj.shape[0]), 0) - last_valid_idx = jnp.argmax(last_valid_idx, axis=-1) - last_point = traj.xy[last_valid_idx, :] - reached_last_point = ( - jnp.linalg.norm(last_point - src_xy, axis=-1) - < _REACHED_END_OF_TRAJECTORY_THRESHOLD - ) - # Secondary detection: If a vehicle strays too far from the traj, - # also mark it as reaching the end. - reached_last_point = jnp.logical_or( - reached_last_point, dist[idx] > _DISTANCE_TO_REF_THRESHOLD - ) src_dir = jnp.where(reached_last_point, jnp.zeros_like(src_dir), src_dir) # Shape: (2). diff --git a/waymax/agents/waypoint_following_agent_test.py b/waymax/agents/waypoint_following_agent_test.py index 112d8e0..ca04ab9 100644 --- a/waymax/agents/waypoint_following_agent_test.py +++ b/waymax/agents/waypoint_following_agent_test.py @@ -285,17 +285,19 @@ def test_decelerates_near_collision(self): cur_speed = jnp.array([10.0, 10.0]) cur_position = jax.tree_util.tree_map(lambda x: x[..., :1], objects) max_accel = 1.13 - max_deccel = 1.78 + max_decel = 1.78 delta = 4.0 desired_vel = 30.0 result = waypoint_following_agent.IDMRoutePolicy( max_accel=max_accel, - max_deccel=max_deccel, + max_decel=max_decel, desired_vel=desired_vel, min_spacing=1.0, safe_time_headway=1.0, max_lookahead=6, delta=delta, + lookahead_from_current_position=False, + invalidate_on_end=True, )._get_accel(objects, objects.xyz[:, 0, :], cur_speed, cur_position) # First agent should yield to second agent. # Second agent for free-road behavior. @@ -313,7 +315,7 @@ def test_free_road_behavior(self, max_accel, cur_speed, desired_speed): cur_position = jax.tree_util.tree_map(lambda x: x[..., :1], waypoints) result = waypoint_following_agent.IDMRoutePolicy( max_accel=max_accel, - max_deccel=max_accel, + max_decel=max_accel, desired_vel=desired_speed, max_lookahead=6, delta=delta, diff --git a/waymax/config.py b/waymax/config.py index 357934c..6a34d3c 100644 --- a/waymax/config.py +++ b/waymax/config.py @@ -120,32 +120,12 @@ class MetricsConfig: """Config for the built-in Waymax Metrics functions. Attributes: - run_log_divergence: Whether log_divergence metric will be computed in the - `step` function. - run_overlap: Whether overlap metric will be computed in the `step` function. - run_offroad: Whether offroad metric will be computed in the `step` function. - run_sdc_wrongway: Whether wrong-way metric will be computed for SDC in the - `step` function. Note this is only for single-agent env currently since - there is no route for sim-agents in data. - run_sdc_progression: Whether progression metric will be computed for SDC in - the `step` function. Note this is only for single-agent env currently - since there is no route for sim-agents in data. - run_sdc_off_route: Whether the off-route metric will be computed for SDC in - the `step` function. Note this is only for single-agent env currently - since there is no route for sim-agents in data. - run_sdc_kinematic_infeasibility: Whether the kinematics infeasibility metric - will be computed for SDC in the `step` function. Note this is only for - single-agent env currently since other agents may have different dynamics - and cannot be evaluated using the current kinematics infeasibility metrics + metrics_to_run: A list of metric names to run. Available metrics are: + log_divergence, overlap, offroad, sdc_wrongway, sdc_off_route, + sdc_progression, kinematic_infeasibility. Additional custom metrics can be + registered with `metric_factory.register_metric`. """ - - run_log_divergence: bool = True - run_overlap: bool = True - run_offroad: bool = True - run_sdc_wrongway: bool = False - run_sdc_progression: bool = False - run_sdc_off_route: bool = False - run_sdc_kinematic_infeasibility: bool = False + metrics_to_run: tuple[str, ...] = ('log_divergence', 'overlap', 'offroad') @dataclasses.dataclass(frozen=True) @@ -154,11 +134,7 @@ class LinearCombinationRewardConfig: Attributes: rewards: Dictionary of metric names to floats indicating the weight of each - metric to create a reward of a linear combination. Valid metric names are - taken from the MetricConfig and removing 'run_'. For example, to create a - reward using the progression metric, the name would have to be - 'sdc_progression', since 'run_sdc_progression' is used in the config - above. + metric to create a reward of a linear combination. """ rewards: dict[str, float] @@ -263,9 +239,9 @@ class WaymaxConfig: def __post_init__(self): if not self.data_config.include_sdc_paths and ( - self.env_config.metrics.run_sdc_wrongway - | self.env_config.metrics.run_sdc_progression - | self.env_config.metrics.run_sdc_off_route + ('sdc_wrongway' in self.env_config.metrics.metrics_to_run) + | ('sdc_progression' in self.env_config.metrics.metrics_to_run) + | ('sdc_off_route' in self.env_config.metrics.metrics_to_run) ): raise ValueError( 'Need to set data_config.include_sdc_paths True in ' diff --git a/waymax/datatypes/object_state_test.py b/waymax/datatypes/object_state_test.py index 77fc971..747b486 100644 --- a/waymax/datatypes/object_state_test.py +++ b/waymax/datatypes/object_state_test.py @@ -279,88 +279,40 @@ def test_zeros_returns_valid_datastructure(self): self.assertAllEqual(traj, zeros_traj) def test_trajectory_validate_asserts_if_improperly_created(self): - error_prefix = ( - '[Chex] Assertion assert_type failed: Error in type ' - 'compatibility check:' - ) with self.subTest('IdsWrongType'): - error = ( - f'{error_prefix} input 0 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(x=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 1 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(y=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 2 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(z=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 3 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(vel_x=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 4 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(vel_y=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 5 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(yaw=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 6 has type float32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(valid=jnp.zeros((1), dtype=jnp.float32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 7 has type float32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace( timestamp_micros=jnp.zeros((1), dtype=jnp.float32) ).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 8 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(length=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 9 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(width=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 10 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.traj.replace(height=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ShapesNotTheSame'): diff --git a/waymax/datatypes/roadgraph_test.py b/waymax/datatypes/roadgraph_test.py index 4d65575..6c8b839 100644 --- a/waymax/datatypes/roadgraph_test.py +++ b/waymax/datatypes/roadgraph_test.py @@ -157,72 +157,32 @@ def test_roadgraph_equality_returns_correctly(self): self.assertNotEqual(self.rg, self.rg.replace(x=jnp.array([1]))) def test_roadgraph_validate_asserts_if_improperly_created(self): - error_prefix = ( - '[Chex] Assertion assert_type failed: Error in type ' - 'compatibility check:' - ) with self.subTest('IdsWrongType'): - error = ( - f'{error_prefix} input 0 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(x=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 1 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(y=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 2 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(z=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 3 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(dir_x=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 4 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(dir_y=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 5 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(dir_z=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 6 has type float32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(types=jnp.zeros((1), dtype=jnp.float32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 7 has type float32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(ids=jnp.zeros((1), dtype=jnp.float32)).validate() with self.subTest('ObjectTypesWrongType'): - error = ( - f'{error_prefix} input 8 has type float32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.rg.replace(valid=jnp.zeros((1), dtype=jnp.float32)).validate() with self.subTest('ShapesNotTheSame'): diff --git a/waymax/datatypes/route_test.py b/waymax/datatypes/route_test.py index 6d5ba15..d3e2427 100644 --- a/waymax/datatypes/route_test.py +++ b/waymax/datatypes/route_test.py @@ -48,33 +48,33 @@ def test_route_equality_works_correctly(self): def test_route_paths_validation_raises_when_necessary(self): with self.subTest('XWrongType'): - error = "input 0 has type int32 but expected ." + error = 'input 0 has type int32 but expected .*float32.*.' with self.assertRaisesRegex(AssertionError, error): self.routes.replace(x=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('YWrongType'): - error = "input 1 has type int32 but expected ." + error = 'input 1 has type int32 but expected .*float32.*.' with self.assertRaisesRegex(AssertionError, error): self.routes.replace(y=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ZWrongType'): - error = "input 2 has type int32 but expected ." + error = 'input 2 has type int32 but expected .*float32.*.' with self.assertRaisesRegex(AssertionError, error): self.routes.replace(z=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('IdsWrongType'): - error = "input 3 has type float32 but expected ." + error = 'input 3 has type float32 but expected .*int32.*.' with self.assertRaisesRegex(AssertionError, error): self.routes.replace(ids=jnp.zeros((1), dtype=jnp.float32)).validate() with self.subTest('ValidWrongType'): - error = "input 4 has type int32 but expected ." + error = 'input 4 has type int32 but expected .*bool.*.' with self.assertRaisesRegex(AssertionError, error): self.routes.replace(valid=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ArcLengthWrongType'): - error = "input 5 has type int32 but expected ." + error = 'input 5 has type int32 but expected .*float32.*.' with self.assertRaisesRegex(AssertionError, error): self.routes.replace( arc_length=jnp.zeros((1), dtype=jnp.int32) ).validate() with self.subTest('OnRouteWrongType'): - error = "input 6 has type int32 but expected ." + error = 'input 6 has type int32 but expected .*bool.*.' with self.assertRaisesRegex(AssertionError, error): self.routes.replace(on_route=jnp.zeros((1), dtype=jnp.int32)).validate() diff --git a/waymax/datatypes/traffic_lights_test.py b/waymax/datatypes/traffic_lights_test.py index 1287d4f..f97a0a9 100644 --- a/waymax/datatypes/traffic_lights_test.py +++ b/waymax/datatypes/traffic_lights_test.py @@ -46,51 +46,23 @@ def test_traffic_lights_equality_works_properly(self): self.assertNotEqual(self.tls, self.tls.replace(x=jnp.array([1.0]))) def test_traffic_lights_validity_works_properly(self): - error_prefix = ( - '[Chex] Assertion assert_type failed: Error in type ' - 'compatibility check:' - ) with self.subTest('XWrongType'): - error = ( - f'{error_prefix} input 0 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.tls.replace(x=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('YWrongType'): - error = ( - f'{error_prefix} input 1 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.tls.replace(y=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ZWrongType'): - error = ( - f'{error_prefix} input 2 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.tls.replace(z=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('StateWrongType'): - error = ( - f'{error_prefix} input 3 has type float32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.tls.replace(state=jnp.zeros((1), dtype=jnp.float32)).validate() with self.subTest('LaneIdsWrongType'): - error = ( - f'{error_prefix} input 4 has type float32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.tls.replace(lane_ids=jnp.zeros((1), dtype=jnp.float32)).validate() with self.subTest('ValidWrongType'): - error = ( - f'{error_prefix} input 5 has type int32 but expected ." - ) - with self.assertRaisesWithLiteralMatch(AssertionError, error): + with self.assertRaises(AssertionError): self.tls.replace(valid=jnp.zeros((1), dtype=jnp.int32)).validate() with self.subTest('ShapesNotTheSame'): diff --git a/waymax/dynamics/abstract_dynamics.py b/waymax/dynamics/abstract_dynamics.py index 7ef57d9..10a1b81 100644 --- a/waymax/dynamics/abstract_dynamics.py +++ b/waymax/dynamics/abstract_dynamics.py @@ -140,6 +140,7 @@ def apply_trajectory_update_to_state( is_controlled: jax.Array, timestep: int, allow_object_injection: bool = False, + use_fallback: bool = False, ) -> datatypes.Trajectory: """Applies a TrajectoryUpdate to the sim trajectory at the next timestep. @@ -150,7 +151,7 @@ def apply_trajectory_update_to_state( For objects not in is_controlled, reference_trajectory is used. For objects in is_controlled, but not valid in trajectory_update, fall back to - constant speed behaviour. + constant speed behaviour if the use_fallback flag is on. Args: trajectory_update: Updated trajectory fields for all objects after the @@ -168,6 +169,8 @@ def apply_trajectory_update_to_state( allow_object_injection: Whether to allow new objects to enter the scene. If this is set to False, all objects that are not valid at the current timestep will not be valid at the next timestep and visa versa. + use_fallback: Whether to fall back to constant speed if a controlled agent + is given an invalid action. Otherwise, the agent will be invalidated. Returns: Updated trajectory given update from a dynamics model at `timestep` + 1. @@ -201,16 +204,22 @@ def apply_trajectory_update_to_state( # TODO: Update z using the (x, y) coordinates of the vehicle. replacement_dict = {} for field in CONTROLLABLE_FIELDS: - # Use fallback trajectory if user doesn't not provide valid action. - new_value = jnp.where( - trajectory_update.valid, - trajectory_update[field], - fallback_trajectory[field], - ) - # Only update for is_controlled objects from users. - replacement_dict[field] = jnp.where( - is_controlled, new_value, default_next_traj[field] - ) + if use_fallback: + # Use fallback trajectory if user doesn't not provide valid action. + new_value = jnp.where( + trajectory_update.valid, + trajectory_update[field], + fallback_trajectory[field], + ) + # Only update for is_controlled objects from users. + replacement_dict[field] = jnp.where( + is_controlled, new_value, default_next_traj[field] + ) + else: + new_value = jnp.where( + is_controlled, trajectory_update[field], default_next_traj[field] + ) + replacement_dict[field] = new_value exist_and_controlled = is_controlled & current_traj.valid # For exist_and_controlled objects, valid flags should remain the same as diff --git a/waymax/dynamics/abstract_dynamics_test.py b/waymax/dynamics/abstract_dynamics_test.py index ecbec2b..28ef094 100644 --- a/waymax/dynamics/abstract_dynamics_test.py +++ b/waymax/dynamics/abstract_dynamics_test.py @@ -281,6 +281,66 @@ def test_update_state_with_dynamics_trajectory_handles_valid( updated_sim_traj.valid[:, sim_state.timestep + 1], expected_valid ) + @parameterized.named_parameters( + ('UseFallback', True), + ('DontUseFallback', False), + ) + def test_apply_trajectory_update_with_fallback(self, use_fallback): + data_config = _config.DatasetConfig(path=TEST_DATA_PATH, max_num_objects=5) + sim_state = test_utils.make_zeros_state(data_config) + sim_state = datatypes.update_state_by_log(sim_state, num_steps=10) + + current_valids = jnp.array([True, False, True, True, False]) + next_valids = jnp.array([True, True, False, False, False]) + is_controlled = jnp.array([True, True, True, False, True]) + action_valid = jnp.array([True, False, True, False, False]) + + sim_current_valids = sim_state.sim_trajectory.valid.at[ + ..., sim_state.timestep + ].set(current_valids) + log_next_valids = sim_state.log_trajectory.valid.at[ + ..., sim_state.timestep + 1 + ].set(next_valids) + sim_state = sim_state.replace( + sim_trajectory=sim_state.sim_trajectory.replace( + valid=sim_current_valids + ), + log_trajectory=sim_state.log_trajectory.replace(valid=log_next_valids), + ) + current_traj = sim_state.current_sim_trajectory + + trajectory_update = datatypes.TrajectoryUpdate( + x=jnp.ones_like(current_traj.x), + y=jnp.ones_like(current_traj.y), + vel_x=jnp.ones_like(current_traj.vel_x), + vel_y=jnp.ones_like(current_traj.vel_y), + yaw=jnp.ones_like(current_traj.yaw), + valid=action_valid[..., jnp.newaxis], + ) + updated_sim_traj = abstract_dynamics.apply_trajectory_update_to_state( + trajectory_update, + sim_state.sim_trajectory, + sim_state.log_trajectory, + is_controlled=is_controlled, + timestep=int(sim_state.timestep), + use_fallback=use_fallback, + allow_object_injection=False, + ) + + base_valid = (is_controlled & current_valids) | ( + ~is_controlled & next_valids + ) + if use_fallback: + # With fallback, agents are not invalidated when an action is invalid. + expected_valid = base_valid + else: + # Without fallback, agents are invalidated when the action is invalid. + expected_valid = base_valid & action_valid + + self.assertAllEqual( + updated_sim_traj.valid[:, sim_state.timestep + 1], expected_valid + ) + if __name__ == '__main__': tf.test.main() diff --git a/waymax/env/__init__.py b/waymax/env/__init__.py index f7665e7..44e37fc 100644 --- a/waymax/env/__init__.py +++ b/waymax/env/__init__.py @@ -20,6 +20,7 @@ from waymax.env.errors import SimulationNotInitializedError from waymax.env.planning_agent_environment import PlanningAgentDynamics from waymax.env.planning_agent_environment import PlanningAgentEnvironment +from waymax.env.planning_agent_environment import PlanningAgentSimulatorState from waymax.env.rollout import rollout from waymax.env.rollout import rollout_log_by_expert_sdc from waymax.env.rollout import RolloutOutput diff --git a/waymax/env/abstract_environment.py b/waymax/env/abstract_environment.py index 3ef9374..b98c847 100644 --- a/waymax/env/abstract_environment.py +++ b/waymax/env/abstract_environment.py @@ -25,7 +25,9 @@ class AbstractEnvironment(abc.ABC): """A stateless environment interface for Waymax.""" @abc.abstractmethod - def reset(self, scenario: types.GenericScenario) -> types.GenericState: + def reset( + self, scenario: types.GenericScenario, rng: jax.Array | None = None + ) -> types.GenericState: """Initializes a simulation state. This method allows the environment to perform optional postprocessing @@ -34,6 +36,7 @@ def reset(self, scenario: types.GenericScenario) -> types.GenericState: Args: scenario: Scenario used to generate the initial state. + rng: Optional random number generator for stochastic environments. Returns: The initialized simulation state. @@ -41,7 +44,10 @@ def reset(self, scenario: types.GenericScenario) -> types.GenericState: @abc.abstractmethod def step( - self, state: types.GenericState, actions: types.GenericAction + self, + state: types.GenericState, + actions: types.GenericAction, + rng: jax.Array | None = None, ) -> types.GenericState: """Advances the simulation by one timestep. @@ -49,6 +55,7 @@ def step( state: The current state of the simulator. actions: Action to apply to the state to produce the updated simulator state. + rng: Optional random number generator for stochastic environments. Returns: The next simulation state after taking an action. diff --git a/waymax/env/base_environment.py b/waymax/env/base_environment.py index c7dd5cb..65162b0 100644 --- a/waymax/env/base_environment.py +++ b/waymax/env/base_environment.py @@ -61,7 +61,9 @@ def metrics(self, state: datatypes.SimulatorState) -> types.Metrics: simulator_state=state, metrics_config=self.config.metrics ) - def reset(self, state: datatypes.SimulatorState) -> datatypes.SimulatorState: + def reset( + self, state: datatypes.SimulatorState, rng: jax.Array | None = None + ) -> datatypes.SimulatorState: """Initializes the simulation state. This initializer sets the initial timestep and fills the initial simulation @@ -69,6 +71,7 @@ def reset(self, state: datatypes.SimulatorState) -> datatypes.SimulatorState: Args: state: An uninitialized state of shape (...). + rng: Optional random number generator for stochastic environments. Returns: The initialized simulation state of shape (...). @@ -109,7 +112,10 @@ def observe(self, state: datatypes.SimulatorState) -> types.Observation: @jax.named_scope('BaseEnvironment.step') def step( - self, state: datatypes.SimulatorState, action: datatypes.Action + self, + state: datatypes.SimulatorState, + action: datatypes.Action, + rng: jax.Array | None = None, ) -> datatypes.SimulatorState: """Advances simulation by one timestep using the dynamics model. @@ -119,6 +125,7 @@ def step( actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics. + rng: Optional random number generator for stochastic environments. Returns: The next simulation state after taking an action of shape (...). diff --git a/waymax/env/planning_agent_environment.py b/waymax/env/planning_agent_environment.py index 144863d..9b21a46 100644 --- a/waymax/env/planning_agent_environment.py +++ b/waymax/env/planning_agent_environment.py @@ -16,7 +16,7 @@ from typing import Sequence -from absl import logging +import chex from dm_env import specs import jax import jax.numpy as jnp @@ -24,7 +24,9 @@ from waymax import datatypes from waymax import dynamics as _dynamics from waymax import metrics +from waymax import rewards from waymax.agents import actor_core +from waymax.env import abstract_environment from waymax.env import base_environment as _env from waymax.env import typedefs as types from waymax.utils import geometry @@ -143,7 +145,21 @@ def inverse( ) -class PlanningAgentEnvironment(_env.BaseEnvironment): +@chex.dataclass +class PlanningAgentSimulatorState(datatypes.SimulatorState): + """Simulator state for the planning agent environment. + + Attributes: + sim_agent_actor_states: State of the sim agents that are being run inside of + the environment `step` function. If sim agents state is provided, this + will be updated. The list of sim agent states should be as long as and in + the same order as the number of sim agents run in the environment. + """ + + sim_agent_actor_states: Sequence[actor_core.ActorState] = () + + +class PlanningAgentEnvironment(abstract_environment.AbstractEnvironment): """An environment wrapper allowing for controlling a single agent. The PlanningAgentEnvironment inherits from a multi-agent BaseEnvironment @@ -161,6 +177,7 @@ def __init__( dynamics_model: _dynamics.DynamicsModel, config: _config.EnvironmentConfig, sim_agent_actors: Sequence[actor_core.WaymaxActorCore] = (), + sim_agent_params: Sequence[actor_core.Params] = (), ) -> None: """Constructs the single agent wrapper. @@ -171,21 +188,91 @@ def __init__( sim_agent_actors: Sim agents as Waymax actors used to update other agents in the scene besides the ADV. Note the actions generated by the sim agents correspond to abstract_dynamics.TrajectoryUpdate. + sim_agent_params: Parameters for the sim agents corresponding to the + `sim_agent_actors` which are added in the step function. """ self._planning_agent_dynamics = PlanningAgentDynamics(dynamics_model) self._state_dynamics = _dynamics.StateDynamics() + self._reward_function = rewards.LinearCombinationReward(config.rewards) + self.config = config if config.controlled_object != _config.ObjectType.SDC: raise ValueError( f'controlled_object {config.controlled_object} must be SDC for' ' planning agent environment.' ) self._sim_agent_actors = sim_agent_actors - super().__init__( - dynamics_model=self._planning_agent_dynamics, config=config + self._sim_agent_params = sim_agent_params + if len(self._sim_agent_actors) != len(self._sim_agent_params): + raise ValueError( + 'Number of sim agents must match number of sim agent params.' + ) + + @property + def dynamics(self) -> _dynamics.DynamicsModel: + return self._planning_agent_dynamics + + def reset( + self, state: datatypes.SimulatorState, rng: jax.Array | None = None + ) -> PlanningAgentSimulatorState: + """Initializes the simulation state. + + This initializer sets the initial timestep and fills the initial simulation + trajectory with invalid values. + + Args: + state: An uninitialized state of shape (...). + rng: Optional random number generator for stochastic environments. + + Returns: + The initialized simulation state of shape (...). + """ + chex.assert_equal( + self.config.max_num_objects, state.log_trajectory.num_objects ) + # Fills with invalid values (i.e. -1.) and False. + sim_traj_uninitialized = datatypes.fill_invalid_trajectory( + state.log_trajectory + ) + state_uninitialized = state.replace( + timestep=jnp.array(-1), sim_trajectory=sim_traj_uninitialized + ) + state = datatypes.update_state_by_log( + state_uninitialized, self.config.init_steps + ) + state = PlanningAgentSimulatorState(**state) + if rng is not None: + keys = jax.random.split(rng, len(self._sim_agent_actors)) + else: + keys = [None] * len(self._sim_agent_actors) + init_actor_states = [ + actor_core.init(key, state) + for key, actor_core in zip(keys, self._sim_agent_actors) + ] + state = state.replace(sim_agent_actor_states=init_actor_states) + return state + + def observe(self, state: PlanningAgentSimulatorState) -> types.Observation: + """Computes the observation for the given simulation state. + + Here we assume that the default observation is just the simulator state. We + leave this for the user to override in order to provide a user-specific + observation function. A user can use this to move some of their model + specific post-processing into the environment rollout in the actor nodes. If + they want this post-processing on the accelerator, they can keep this the + same and implement it on the learner side. We provide some helper functions + at datatypes.observation.py to help write your own observation functions. + + Args: + state: Current state of the simulator of shape (...). + + Returns: + Simulator state as an observation without modifications of shape (...). + """ + return state + @jax.named_scope('PlanningAgentEnvironment.metrics') - def metrics(self, state: datatypes.SimulatorState) -> types.Metrics: + def metrics(self, state: PlanningAgentSimulatorState) -> types.Metrics: """Computes the metrics for the single agent wrapper. The metrics to be computed are based on those specified by the configuration @@ -215,7 +302,7 @@ def metrics(self, state: datatypes.SimulatorState) -> types.Metrics: ) metric_dict[metric_name] = one_hot_metric[metric_name] - if self.config.metrics.run_sdc_kinematic_infeasibility: + if 'kinematic_infeasibility' in self.config.metrics.metrics_to_run: # Since initially the first state has a time step of # self.config.init_steps - 1, and the transition from # self.config.init_steps - 2 to self.config.init_steps - 1 is not @@ -223,19 +310,19 @@ def metrics(self, state: datatypes.SimulatorState) -> types.Metrics: # state's sdc_kim value and set it to 0 (kinematically feasible) because # the action is not chosen by the actor and is thus not clipped. kim_metric_valid = state.timestep > self.config.init_steps - 1 - kim_metric = metric_dict['sdc_kinematic_infeasibility'] + kim_metric = metric_dict['kinematic_infeasibility'] kim_metric = kim_metric.replace( value=kim_metric.value * kim_metric_valid, valid=kim_metric.valid & kim_metric_valid, ) - metric_dict['sdc_kinematic_infeasibility'] = datatypes.select_by_onehot( + metric_dict['kinematic_infeasibility'] = datatypes.select_by_onehot( kim_metric, state.object_metadata.is_sdc, keepdims=False ) return metric_dict @jax.named_scope('PlanningAgentEnvironment.reward') def reward( - self, state: datatypes.SimulatorState, action: datatypes.Action + self, state: PlanningAgentSimulatorState, action: datatypes.Action ) -> jax.Array: """Computes the reward for a transition. @@ -249,37 +336,33 @@ def reward( A float (...) tensor of rewards for the single agent. """ # Shape: (..., num_objects). - multi_agent_reward = super().reward(state, action) - # After onehot, shape: (...) - return datatypes.select_by_onehot( - multi_agent_reward, state.object_metadata.is_sdc, keepdims=False - ) + if self.config.compute_reward: + agent_mask = datatypes.get_control_mask( + state.object_metadata, self.config.controlled_object + ) + multi_agent_reward = self._reward_function.compute( + state, action, agent_mask + ) + # After onehot, shape: (...) + return datatypes.select_by_onehot( + multi_agent_reward, state.object_metadata.is_sdc, keepdims=False + ) + else: + reward_spec = specs.Array(shape=(), dtype=jnp.float32) + return jnp.zeros(state.shape + reward_spec.shape, dtype=reward_spec.dtype) def action_spec(self) -> datatypes.Action: - # Shape: (..., num_objects, ndim). - multi_agent_spec = super().action_spec() - prefix_shape = multi_agent_spec.data.shape[:-2] - - # Remove the object dimension from the action spec. - data_spec = specs.BoundedArray( - shape=prefix_shape + multi_agent_spec.data.shape[-1:], - dtype=multi_agent_spec.data.dtype, - # Note that `multi_agent_spec` are same for all objects, thus we take - # the first one's value for min/max. - minimum=multi_agent_spec.data.minimum[..., 0, :], # pytype: disable=attribute-error # jax-ndarray - maximum=multi_agent_spec.data.maximum[..., 0, :], # pytype: disable=attribute-error # jax-ndarray - ) - # Shape: (..., num_objects, 1). - valid_spec = specs.Array( - shape=prefix_shape + (1,), - dtype=multi_agent_spec.valid.dtype, - ) + data_spec = self.dynamics.action_spec() # rank 1 + valid_spec = specs.Array(shape=(1,), dtype=jnp.bool_) return datatypes.Action(data=data_spec, valid=valid_spec) # pytype: disable=wrong-arg-types # jax-ndarray @jax.named_scope('PlanningAgentEnvironment.step') def step( - self, state: datatypes.SimulatorState, action: datatypes.Action - ) -> datatypes.SimulatorState: + self, + state: PlanningAgentSimulatorState, + action: datatypes.Action, + rng: jax.Array | None = None, + ) -> PlanningAgentSimulatorState: """Advances simulation by one timestep using the dynamics model. Args: @@ -288,6 +371,7 @@ def step( actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics. + rng: Optional random number generator for stochastic environments. Returns: The next simulation state after taking an action of shape (...). @@ -303,15 +387,25 @@ def step( # (likely an articulated bus). is_controllable = ~_initialized_overlap(state.log_trajectory) - for agent in self._sim_agent_actors: - agent_output = agent.select_action({}, state, None, None) # pytype: disable=wrong-arg-types - if agent_output.actor_state: - logging.log_first_n( - logging.WARNING, - 'Agent output returned actor_state but using actor_state is not' - ' currently implemented.', - 1, - ) + if len(self._sim_agent_actors) != len(state.sim_agent_actor_states): + raise ValueError( + f'The number of sim agents ({len(self._sim_agent_actors)}) must' + ' match the number of sim actor states' + f' ({len(state.sim_agent_actor_states)}).' + ) + updated_sim_agent_actor_states = [] + if rng is not None: + keys = jax.random.split(rng, len(self._sim_agent_actors)) + else: + keys = [None] * len(self._sim_agent_actors) + for agent, actor_state, params, key in zip( + self._sim_agent_actors, + state.sim_agent_actor_states, + self._sim_agent_params, + keys, + ): + agent_output = agent.select_action(params, state, actor_state, key) # pytype: disable=wrong-arg-types + updated_sim_agent_actor_states.append(agent_output.actor_state) action = agent_output.action controlled_by_sim = agent_output.is_controlled & is_controllable merged_action_data = jnp.where( @@ -333,12 +427,24 @@ def step( timestep=state.timestep, allow_object_injection=self.config.allow_new_objects_after_warmup, ) - return state.replace(sim_trajectory=new_traj, timestep=state.timestep + 1) + return state.replace( + sim_trajectory=new_traj, + timestep=state.timestep + 1, + sim_agent_actor_states=updated_sim_agent_actor_states, + ) def reward_spec(self) -> specs.Array: """Specify the reward spec as just for one object.""" return specs.Array(shape=(), dtype=jnp.float32) + def discount_spec(self) -> specs.BoundedArray: + return specs.BoundedArray( + shape=tuple(), minimum=0.0, maximum=1.0, dtype=jnp.float32 + ) + + def observation_spec(self) -> types.Observation: + raise NotImplementedError() + def _initialized_overlap(log_trajectory: datatypes.Trajectory) -> jax.Array: """Return a mask for objects initialized in a overlap state. diff --git a/waymax/env/planning_agent_environment_test.py b/waymax/env/planning_agent_environment_test.py index 7b656cb..973e5af 100644 --- a/waymax/env/planning_agent_environment_test.py +++ b/waymax/env/planning_agent_environment_test.py @@ -14,11 +14,10 @@ from typing import Optional +import chex import jax from jax import numpy as jnp import tensorflow as tf - -from absl.testing import parameterized from waymax import config as _config from waymax import dataloader from waymax import datatypes @@ -28,6 +27,8 @@ from waymax.env import planning_agent_environment from waymax.utils import test_utils +from absl.testing import parameterized + TEST_DATA_PATH = test_utils.ROUTE_DATA_PATH ROUTE_NUM_PATHS = test_utils.ROUTE_NUM_PATHS ROUTE_NUM_POINTS_PER_PATH = test_utils.ROUTE_NUM_POINTS_PER_PATH @@ -85,7 +86,8 @@ def test_reward_has_correct_shape(self, compute_reward: bool): env = planning_agent_environment.PlanningAgentEnvironment( dynamics_model=dynamics.DeltaGlobal(), config=env_config ) - reward = env.reward(self.sample_state, self.sample_action) + sample_state = env.reset(self.sample_state) + reward = env.reward(sample_state, self.sample_action) self.assertAllEqual(reward.shape, ()) self.assertAllEqual(reward.dtype, jnp.float32) @@ -93,16 +95,22 @@ def test_metric_has_correct_shape(self): env_config = _config.EnvironmentConfig( init_steps=10, metrics=_config.MetricsConfig( - run_sdc_wrongway=True, - run_sdc_progression=True, - run_sdc_off_route=True, - run_sdc_kinematic_infeasibility=True, + metrics_to_run=( + 'log_divergence', + 'overlap', + 'offroad', + 'sdc_wrongway', + 'sdc_progression', + 'sdc_off_route', + 'kinematic_infeasibility', + ), ), ) env = planning_agent_environment.PlanningAgentEnvironment( dynamics_model=dynamics.DeltaGlobal(), config=env_config ) - metrics_dict = env.metrics(self.sample_state) + sample_state = env.reset(self.sample_state) + metrics_dict = env.metrics(sample_state) num_metrics = 7 metric_shape_targets = tuple([() for _ in range(num_metrics)]) metric_shapes = [v.shape for _, v in metrics_dict.items()] @@ -111,20 +119,25 @@ def test_metric_has_correct_shape(self): with self.subTest('all_metrics_are_populated'): self.assertAllEqual(metric_shapes, metric_shape_targets) - def test_planning_agent_environment_with_sim_agents_works(self): + @parameterized.named_parameters(('without_keys', None), ('with_keys', 100)) + def test_planning_agent_environment_with_sim_agents_works(self, key): state = test_utils.make_zeros_state(self.dataset_config) + state = planning_agent_environment.PlanningAgentSimulatorState(**state) state = state.replace(timestep=0) env = planning_agent_environment.PlanningAgentEnvironment( dynamics_model=dynamics.DeltaGlobal(), config=self.env_config, sim_agent_actors=[constant_velocity_actor()], + sim_agent_params=[{}], ) + key = jax.random.PRNGKey(key) if key is not None else key + state = env.reset(state, rng=key) action_spec = self.env.action_spec() action = datatypes.Action( data=jnp.array([0.7, 0.8, 0.05]), valid=jnp.ones(action_spec.valid.shape, dtype=jnp.bool_), ) - state = env.step(state, action) + state = env.step(state, action, rng=key) traj = state.current_sim_trajectory is_sdc = state.object_metadata.is_sdc @@ -140,6 +153,21 @@ def test_planning_agent_environment_with_sim_agents_works(self): self.assertAllClose( traj.yaw[~is_sdc], jnp.ones_like(traj.x[~is_sdc]) * 0.1 ) + self.assertEqual( + state.sim_agent_actor_states, [ConstantSimAgentState(state_num=1)] + ) + + def test_planning_agent_environment_raises_with_sim_actor_params(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Number of sim agents must match number of sim agent params.', + ): + planning_agent_environment.PlanningAgentEnvironment( + dynamics_model=dynamics.DeltaGlobal(), + config=self.env_config, + sim_agent_actors=[constant_velocity_actor()], + sim_agent_params=[{}, {}], + ) def test_initialized_overlap_mask(self): log_traj = datatypes.Trajectory.zeros(shape=(3, 2)) @@ -178,27 +206,36 @@ def update_trajectory( ) +@chex.dataclass(frozen=True) +class ConstantSimAgentState: + state_num: int = 0 + + def constant_velocity_actor() -> actor_core.WaymaxActorCore: agent = ConstantSimAgentActor() + def init(rng, state: datatypes.SimulatorState) -> ConstantSimAgentState: + del rng, state + return ConstantSimAgentState() + def select_action( params: Optional[actor_core.Params], state: datatypes.SimulatorState, - actor_state=None, + actor_state: ConstantSimAgentState, rng: Optional[jax.Array] = None, ) -> actor_core.WaymaxActorOutput: - del params, actor_state, rng + del params, rng action = agent.update_trajectory(state).as_action() output = actor_core.WaymaxActorOutput( action=action, - actor_state=None, + actor_state=actor_state.replace(state_num=actor_state.state_num + 1), is_controlled=~state.object_metadata.is_sdc, ) output.validate() return output return actor_core.actor_core_factory( - init=lambda rng, state: None, + init=init, select_action=select_action, name='constant_vel', ) diff --git a/waymax/env/wrappers/brax_wrapper.py b/waymax/env/wrappers/brax_wrapper.py index 7add769..2b92739 100644 --- a/waymax/env/wrappers/brax_wrapper.py +++ b/waymax/env/wrappers/brax_wrapper.py @@ -28,9 +28,10 @@ from flax import struct import jax from jax import numpy as jnp +from waymax import config as _config from waymax import datatypes from waymax import dynamics -from waymax.env import base_environment +from waymax.env import abstract_environment from waymax.env import typedefs as types @@ -69,18 +70,23 @@ def __eq__(self, other: Any) -> bool: class BraxWrapper: """Brax-like interface wrapper for the Waymax environment.""" - def __init__(self, wrapped_env: base_environment.BaseEnvironment) -> None: + def __init__( + self, + wrapped_env: abstract_environment.AbstractEnvironment, + dynamics_model: dynamics.DynamicsModel, + config: _config.EnvironmentConfig, + ) -> None: """Constracts the Brax wrapper over a Waymax environment. Args: wrapped_env: Waymax environment to wrap with the Brax interface. + dynamics_model: Dynamics model to use which transitions the simulator + state to the next timestep given an action. + config: Waymax environment configs. """ self._wrapped_env = wrapped_env - self.config = self._wrapped_env.config - - @property - def dynamics(self) -> dynamics.DynamicsModel: - return self._wrapped_env.dynamics + self.dynamics = dynamics_model + self.config = config def metrics(self, state: datatypes.SimulatorState) -> types.Metrics: """Computes metrics (lower is better) from state.""" diff --git a/waymax/env/wrappers/brax_wrapper_test.py b/waymax/env/wrappers/brax_wrapper_test.py index 1e9b8a1..342f08d 100644 --- a/waymax/env/wrappers/brax_wrapper_test.py +++ b/waymax/env/wrappers/brax_wrapper_test.py @@ -46,11 +46,19 @@ def setUp(self): multi_stateless_env = base_environment.MultiAgentEnvironment( dynamics_model=dynamics.DeltaGlobal(), config=self.env_config ) - self.multi_env = brax_wrapper.BraxWrapper(multi_stateless_env) + self.multi_env = brax_wrapper.BraxWrapper( + multi_stateless_env, + multi_stateless_env.dynamics, + multi_stateless_env.config, + ) single_stateless_env = planning_agent_environment.PlanningAgentEnvironment( dynamics_model=dynamics.DeltaGlobal(), config=self.env_config ) - self.single_env = brax_wrapper.BraxWrapper(single_stateless_env) + self.single_env = brax_wrapper.BraxWrapper( + single_stateless_env, + single_stateless_env.dynamics, + single_stateless_env.config, + ) @parameterized.parameters(True, False) def test_reset_returns_first_timestep(self, multi=False): diff --git a/waymax/metrics/__init__.py b/waymax/metrics/__init__.py index 2b21bdd..86952e1 100644 --- a/waymax/metrics/__init__.py +++ b/waymax/metrics/__init__.py @@ -13,9 +13,12 @@ # limitations under the License. """Metrics for agent evaluation.""" + from waymax.metrics.abstract_metric import AbstractMetric from waymax.metrics.abstract_metric import MetricResult from waymax.metrics.imitation import LogDivergenceMetric +from waymax.metrics.metric_factory import get_metric_names +from waymax.metrics.metric_factory import register_metric from waymax.metrics.metric_factory import run_metrics from waymax.metrics.overlap import OverlapMetric from waymax.metrics.roadgraph import OffroadMetric diff --git a/waymax/metrics/metric_factory.py b/waymax/metrics/metric_factory.py index ad8f613..b6fb89a 100644 --- a/waymax/metrics/metric_factory.py +++ b/waymax/metrics/metric_factory.py @@ -13,6 +13,8 @@ # limitations under the License. """Utility function that runs all metrics according to an environment config.""" +from collections.abc import Iterable + from waymax import config as _config from waymax import datatypes from waymax.metrics import abstract_metric @@ -23,12 +25,25 @@ from waymax.metrics import route +_METRICS_REGISTRY: dict[str, abstract_metric.AbstractMetric] = { + 'log_divergence': imitation.LogDivergenceMetric(), + 'overlap': overlap.OverlapMetric(), + 'offroad': roadgraph.OffroadMetric(), + 'kinematic_infeasibility': comfort.KinematicsInfeasibilityMetric(), + 'sdc_wrongway': roadgraph.WrongWayMetric(), + 'sdc_progression': route.ProgressionMetric(), + 'sdc_off_route': route.OffRouteMetric(), +} + + def run_metrics( simulator_state: datatypes.SimulatorState, metrics_config: _config.MetricsConfig, ) -> dict[str, abstract_metric.MetricResult]: """Runs all metrics with config flags set to True. + User-defined metrics must be registered using the `register_metric` function. + Args: simulator_state: The current simulator state of shape (...). metrics_config: Waymax metrics config. @@ -37,20 +52,33 @@ def run_metrics( A dictionary of metric names mapping to metric result arrays where each metric is of shape (..., num_objects). """ - name_to_metric = { - 'log_divergence': imitation.LogDivergenceMetric, - 'overlap': overlap.OverlapMetric, - 'offroad': roadgraph.OffroadMetric, - 'sdc_wrongway': roadgraph.WrongWayMetric, - 'sdc_progression': route.ProgressionMetric, - 'sdc_off_route': route.OffRouteMetric, - 'sdc_kinematic_infeasibility': comfort.KinematicsInfeasibilityMetric, - } - results = {} - for metric_name, metric_fn in name_to_metric.items(): - # If flag set to True, compute and store metric. - if getattr(metrics_config, f'run_{metric_name}'): - results[metric_name] = metric_fn().compute(simulator_state) + for metric_name in metrics_config.metrics_to_run: + if metric_name in _METRICS_REGISTRY: + results[metric_name] = _METRICS_REGISTRY[metric_name].compute( + simulator_state + ) + else: + raise ValueError(f'Metric {metric_name} not registered.') return results + + +def register_metric(metric_name: str, metric: abstract_metric.AbstractMetric): + """Register a metric. + + This function registers a metric so that it can be included in a MetricsConfig + and computed by `run_metrics`. + + Args: + metric_name: String name to register the metric with. + metric: The metric to register. + """ + if metric_name in _METRICS_REGISTRY: + raise ValueError(f'Metric {metric_name} has already been registered.') + _METRICS_REGISTRY[metric_name] = metric + + +def get_metric_names() -> Iterable[str]: + """Returns the names of all registered metrics.""" + return _METRICS_REGISTRY.keys() diff --git a/waymax/metrics/metric_factory_test.py b/waymax/metrics/metric_factory_test.py index f407e9d..77059a8 100644 --- a/waymax/metrics/metric_factory_test.py +++ b/waymax/metrics/metric_factory_test.py @@ -16,13 +16,15 @@ from jax import numpy as jnp import tensorflow as tf - -from absl.testing import parameterized from waymax import config as _config from waymax import dataloader +from waymax.metrics import abstract_metric from waymax.metrics import metric_factory from waymax.utils import test_utils +from absl.testing import parameterized + + TEST_DATA_PATH = test_utils.ROUTE_DATA_PATH @@ -31,9 +33,7 @@ class MetricFactoryTest(tf.test.TestCase, parameterized.TestCase): @parameterized.parameters(((),), ((2, 1),)) def test_all_false_flags_results_in_empty_results_dict(self, batch_dims): config = _config.EnvironmentConfig( - metrics=_config.MetricsConfig( - run_log_divergence=False, run_overlap=False, run_offroad=False - ) + metrics=_config.MetricsConfig(metrics_to_run=tuple()) ) dataset = test_utils.make_test_dataset(batch_dims=batch_dims) @@ -50,10 +50,9 @@ def test_all_false_flags_results_in_empty_results_dict(self, batch_dims): @parameterized.parameters(((),), ((2, 1),)) def test_true_flags_results_in_correct_number_of_results(self, batch_dims): + metric_names = ('log_divergence', 'overlap', 'offroad') config = _config.EnvironmentConfig( - metrics=_config.MetricsConfig( - run_log_divergence=True, run_overlap=True, run_offroad=True - ) + metrics=_config.MetricsConfig(metrics_to_run=metric_names) ) dataset = test_utils.make_test_dataset(batch_dims=batch_dims) @@ -65,7 +64,6 @@ def test_true_flags_results_in_correct_number_of_results(self, batch_dims): metric_results = metric_factory.run_metrics( simulator_state=sim_state, metrics_config=config.metrics ) - metric_names = ['log_divergence', 'overlap', 'offroad'] with self.subTest('check_correct_number_of_elements'): self.assertLen(metric_results, 3) @@ -80,9 +78,7 @@ def test_true_flags_results_in_correct_number_of_results(self, batch_dims): def test_offroad_is_detected_in_metric(self): config = _config.EnvironmentConfig( - metrics=_config.MetricsConfig( - run_log_divergence=False, run_overlap=False, run_offroad=True - ) + metrics=_config.MetricsConfig(metrics_to_run=('offroad',)) ) sim_state = test_utils.simulator_state_with_offroad() @@ -112,9 +108,7 @@ def test_offroad_is_detected_in_metric(self): ) def test_overlap_is_detected_in_metric(self, valid, expected): config = _config.EnvironmentConfig( - metrics=_config.MetricsConfig( - run_log_divergence=False, run_overlap=True, run_offroad=False - ) + metrics=_config.MetricsConfig(metrics_to_run=('overlap',)) ) sim_state = test_utils.simulator_state_with_overlap() @@ -137,6 +131,26 @@ def test_overlap_is_detected_in_metric(self, valid, expected): jnp.full(sim_traj.num_objects, valid, dtype=jnp.bool_), ) + def test_user_defined_metric_detected(self): + config = _config.EnvironmentConfig( + metrics=_config.MetricsConfig(metrics_to_run=('custom_metric',)) + ) + + class CustomMetric(abstract_metric.AbstractMetric): + + def compute(self, _) -> abstract_metric.MetricResult: + return abstract_metric.MetricResult( + value=jnp.array([123]), valid=jnp.array([True]) + ) + + metric_factory.register_metric('custom_metric', CustomMetric()) + + sim_state = test_utils.simulator_state_with_overlap() + metric_results = metric_factory.run_metrics( + simulator_state=sim_state, metrics_config=config.metrics + ) + custom_metric_result = metric_results['custom_metric'] + self.assertAllEqual(custom_metric_result.value, jnp.array([123])) if __name__ == '__main__': diff --git a/waymax/metrics/roadgraph.py b/waymax/metrics/roadgraph.py index b92deb9..ea8a44c 100644 --- a/waymax/metrics/roadgraph.py +++ b/waymax/metrics/roadgraph.py @@ -16,7 +16,6 @@ import jax from jax import numpy as jnp - from waymax import datatypes from waymax.metrics import abstract_metric @@ -114,9 +113,9 @@ def is_offroad( Args: trajectory: Agent trajectories to test to see if they are on or off road of - shape (..., num_objects). The bounding boxes derived from center and shape - of the trajectory will be used to determine if any point in the box is - offroad. + shape (..., num_objects, num_timesteps). The bounding boxes derived from + center and shape of the trajectory will be used to determine if any point + in the box is offroad. The num_timesteps dimension size should be 1. roadgraph_points: All of the roadgraph points in the run segment of shape (..., num_points). Roadgraph points of type `ROAD_EDGE_BOUNDARY` and `ROAD_EDGE_MEDIAN` are used to do the check. @@ -126,7 +125,7 @@ def is_offroad( True if the bbox is offroad. """ # Shape: (..., num_objects, num_corners=4, 2). - bbox_corners = trajectory.bbox_corners[..., 0, :] + bbox_corners = jnp.squeeze(trajectory.bbox_corners, axis=-3) # Add in the Z dimension from the current center. This assumption will help # disambiguate between different levels of the roadgraph (i.e. under and over # passes). @@ -165,7 +164,7 @@ def compute_signed_distance_to_nearest_road_edge_point( Args: query_points: A set of query points for the metric of shape - (num_query_points, 3). + (..., num_query_points, 3). roadgraph_points: A set of roadgraph points of shape (num_points). z_stretch: Tolerance in the z dimension which determines how close to associate points in the roadgraph. This is used to fix problems with @@ -179,40 +178,46 @@ def compute_signed_distance_to_nearest_road_edge_point( """ # Shape: (..., num_points, 3). sampled_points = roadgraph_points.xyz - # Shape: (num_query_points, num_points, 3). - differences = sampled_points - query_points[:, jnp.newaxis] + # Shape: (..., num_query_points, num_points, 3). + differences = sampled_points - jnp.expand_dims(query_points, axis=-2) # Stretch difference in altitude to avoid over/underpasses. + # Shape: (..., num_query_points, num_points, 3). z_stretched_differences = differences * jnp.array([[[1.0, 1.0, z_stretch]]]) + # Shape: (..., num_query_points, num_points). square_distances = jnp.sum(z_stretched_differences**2, axis=-1) # Do not consider invalid points. # Shape: (num_points). is_road_edge = datatypes.is_road_edge(roadgraph_points.types) + # Shape: (..., num_query_points, num_points). square_distances = jnp.where( roadgraph_points.valid & is_road_edge, square_distances, float('inf') ) - # Shape: (num_query_points). + # Shape: (..., num_query_points). nearest_indices = jnp.argmin(square_distances, axis=-1) + # Shape: (..., num_query_points). prior_indices = jnp.maximum( jnp.zeros_like(nearest_indices), nearest_indices - 1 ) + # Shape: (..., num_query_points, 2). nearest_xys = sampled_points[nearest_indices, :2] # Direction of the road edge at the nearest points. Should be normed and # tangent to the road edge. - # Shape: (num_points, 2). + # Shape: (..., num_query_points, 2). nearest_vector_xys = roadgraph_points.dir_xyz[nearest_indices, :2] # Direction of the road edge at the points that precede the nearest points. - # Shape: (num_points, 2). + # Shape: (..., num_query_points, 2). prior_vector_xys = roadgraph_points.dir_xyz[prior_indices, :2] - # Shape: (num_query_points, 2). + # Shape: (..., num_query_points, 2). points_to_edge = query_points[..., :2] - nearest_xys # Get the signed distance to the half-plane boundary with a cross product. cross_product = jnp.cross(points_to_edge, nearest_vector_xys) cross_product_prior = jnp.cross(points_to_edge, prior_vector_xys) # If the prior point is contiguous, consider both half-plane distances. - # Shape: (num_points). + # Shape: (..., num_query_points). prior_point_in_same_curve = jnp.equal( roadgraph_points.ids[nearest_indices], roadgraph_points.ids[prior_indices] ) + # Shape: (..., num_query_points). offroad_sign = jnp.sign( jnp.where( jnp.logical_and( @@ -222,7 +227,7 @@ def compute_signed_distance_to_nearest_road_edge_point( cross_product, ) ) - # Shape: (num_query_points). + # Shape: (..., num_query_points). return ( jnp.linalg.norm(nearest_xys - query_points[:, :2], axis=-1) * offroad_sign ) diff --git a/waymax/rewards/linear_combination_reward.py b/waymax/rewards/linear_combination_reward.py index 94434fe..2bf1d7e 100644 --- a/waymax/rewards/linear_combination_reward.py +++ b/waymax/rewards/linear_combination_reward.py @@ -15,7 +15,6 @@ """Reward functions for the Waymax environment.""" import jax import jax.numpy as jnp - from waymax import config as _config from waymax import datatypes from waymax import metrics @@ -64,9 +63,7 @@ def compute( def _validate_reward_metrics(config: _config.LinearCombinationRewardConfig): """Checks that all metrics in the RewardConfigs are valid.""" - metrics_config = _config.MetricsConfig() - metric_names_with_run = metrics_config.__dict__.keys() - metric_names = set(name[4:] for name in metric_names_with_run) + metric_names = metrics.get_metric_names() for reward_metric_name in config.rewards.keys(): if reward_metric_name not in metric_names: raise ValueError( @@ -80,13 +77,4 @@ def _linear_config_to_metric_config( config: _config.LinearCombinationRewardConfig, ) -> _config.MetricsConfig: """Converts a LinearCombinationRewardConfig into a MetricsConfig.""" - reward_metric_names = config.rewards.keys() - temp_metrics_configs = _config.MetricsConfig() - metric_flags = {} - for metric_name in temp_metrics_configs.__dict__.keys(): - # MetricsConfig attributes are stored as f'run_{metric}'. The following line - # removes 'run_' from the name and checks if the metric is present in the - # reward config. If so, the metric is stored in the dictionary as True, - # otherwise False. - metric_flags[metric_name] = metric_name[4:] in reward_metric_names - return _config.MetricsConfig(**metric_flags) + return _config.MetricsConfig(metrics_to_run=tuple(config.rewards.keys())) diff --git a/waymax/rewards/linear_combination_reward_test.py b/waymax/rewards/linear_combination_reward_test.py index 50bcce8..d74b71c 100644 --- a/waymax/rewards/linear_combination_reward_test.py +++ b/waymax/rewards/linear_combination_reward_test.py @@ -63,13 +63,13 @@ def test_linear_combination_results_correct(self): self.assertAllClose(combination_reward, manual_combination_reward) with self.subTest('combination_equals_manual_entry'): - self.assertAllClose(combination_reward, [0, 0, -1, 0, 0, -1, -1]) + self.assertAllClose(combination_reward, [0, 0, -1, 0, 0, -0, -0]) with self.subTest('overlap_reward_is_correct'): self.assertAllClose(negative_overlap_reward, [-1] * 7) with self.subTest('offroad_reward_is_correct'): - self.assertAllClose(offroad_reward, [1, 1, 0, 1, 1, 0, 0]) + self.assertAllClose(offroad_reward, [1, 1, 0, 1, 1, 1, 1]) def test_returns_zero_for_invalid_trajectory(self): combination_config = _config.LinearCombinationRewardConfig({ diff --git a/waymax/visualization/utils.py b/waymax/visualization/utils.py index d80c0b2..23a9cf1 100644 --- a/waymax/visualization/utils.py +++ b/waymax/visualization/utils.py @@ -118,6 +118,7 @@ def plot_numpy_bounding_boxes( color: np.ndarray, alpha: Optional[float] = 1.0, as_center_pts: bool = False, + label: Optional[str] = None, ) -> None: """Plots multiple bounding boxes. @@ -129,6 +130,7 @@ def plot_numpy_bounding_boxes( alpha: Alpha value for drawing, i.e. 0 means fully transparent. as_center_pts: If set to True, bboxes will be drawn as center points, instead of full bboxes. + label: String, represents the meaning of the color for different boxes. """ if bboxes.ndim != 2 or bboxes.shape[1] != 5 or color.shape != (3,): raise ValueError( @@ -139,7 +141,15 @@ def plot_numpy_bounding_boxes( ) if as_center_pts: - ax.plot(bboxes[:, 0], bboxes[:, 1], 'o', color=color, ms=2, alpha=alpha) + ax.plot( + bboxes[:, 0], + bboxes[:, 1], + 'o', + color=color, + ms=2, + alpha=alpha, + label=label, + ) else: c = np.cos(bboxes[:, 4]) s = np.sin(bboxes[:, 4]) @@ -166,6 +176,7 @@ def plot_numpy_bounding_boxes( color=color, zorder=4, alpha=alpha, + label=label, ) # Draw heading arrow. @@ -175,4 +186,5 @@ def plot_numpy_bounding_boxes( color=color, zorder=4, alpha=alpha, + label=label, ) diff --git a/waymax/visualization/viz.py b/waymax/visualization/viz.py index edc6871..33f902f 100644 --- a/waymax/visualization/viz.py +++ b/waymax/visualization/viz.py @@ -39,6 +39,7 @@ def _plot_bounding_boxes( time_idx: int, is_controlled: np.ndarray, valid: np.ndarray, + add_label: bool = False, ) -> None: """Helper function to plot multiple bounding boxes across time.""" # Plots bounding boxes (traj_5dof) with shape: (A, T) @@ -57,6 +58,7 @@ def _plot_bounding_boxes( ax=ax, bboxes=traj_5dof[(time_indices >= time_idx) & valid_controlled], color=color.COLOR_DICT['controlled'], + label='controlled' if add_label else None, ) utils.plot_numpy_bounding_boxes( @@ -64,12 +66,14 @@ def _plot_bounding_boxes( bboxes=traj_5dof[(time_indices < time_idx) & valid], color=color.COLOR_DICT['history'], as_center_pts=True, + label='history' if add_label else None, ) utils.plot_numpy_bounding_boxes( ax=ax, bboxes=traj_5dof[(time_indices >= time_idx) & valid_context], color=color.COLOR_DICT['context'], + label='context' if add_label else None, ) # Shows current overlap @@ -87,6 +91,7 @@ def _plot_bounding_boxes( ax=ax, bboxes=traj_5dof[:, time_idx][overlap_mask & valid[:, time_idx]], color=color.COLOR_DICT['overlap'], + label='overlap' if add_label else None, ) @@ -108,6 +113,7 @@ def plot_trajectory( is_controlled: np.ndarray, time_idx: Optional[int] = None, indices: Optional[np.ndarray] = None, + add_label: bool = False, ) -> None: """Plots a Trajectory with different color for controlled and context. @@ -124,6 +130,9 @@ def plot_trajectory( time_idx: step index to highlight bbox, -1 for last step. Default(None) for not showing bbox. indices: ids to show for each agents if not None, shape (A,). + add_label: a boolean that indicates whether or not to plot labels that + indicates different agent types, including 'controlled', 'overlap', + 'history', 'context'. """ if len(traj.shape) != 2: raise ValueError('traj should have shape (A, T)') @@ -150,7 +159,14 @@ def plot_trajectory( f'{indices[i]}', zorder=10, ) - _plot_bounding_boxes(ax, traj_5dof, time_idx, is_controlled, traj.valid) # pytype: disable=wrong-arg-types # jax-ndarray + _plot_bounding_boxes( + ax=ax, + traj_5dof=traj_5dof, + time_idx=time_idx, + is_controlled=is_controlled, + valid=traj.valid, + add_label=add_label, + ) # pytype: disable=wrong-arg-types # jax-ndarray def plot_roadgraph_points(