diff --git a/setup.py b/setup.py index 3cc673b..12d9195 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ setup( name="tesse_gym", - version="0.1.3", + version="0.1.4", description="TESSE OpenAI Gym python interface", packages=find_packages("src"), # tell setuptools that all packages will be under the 'src' directory diff --git a/src/tesse_gym/core/continuous_control.py b/src/tesse_gym/core/continuous_control.py index 64027c1..df48384 100644 --- a/src/tesse_gym/core/continuous_control.py +++ b/src/tesse_gym/core/continuous_control.py @@ -29,7 +29,6 @@ from tesse.msgs import * from tesse.utils import UdpListener - # gains 1: 150, 35, 1.6, 0.27 # gains 2: 200, 35, 1.6, 0.27 @@ -211,13 +210,12 @@ def transform( data = self.get_data() self.set_goal(data, translate_x, translate_z, rotate_y) - # track movement across steps to find collisions last_z_err, last_z_rate_err = 0, 0 collision_count = 0 + n_steps = 0 # Apply controls until at goal point, a collision occurs, or max steps reached - i = 0 - while not self.at_goal(data) and i < self.max_steps: + while not self.at_goal(data) and n_steps < self.max_steps: force_z, z_error = self.control(data) data = self.get_data() @@ -228,6 +226,7 @@ def transform( break last_z_err = z_error + n_steps += 1 self.set_goal(data) @@ -264,7 +263,7 @@ def get_data(self) -> AgentState: if self.last_metadata is None: response = self.env.request(MetadataRequest()).metadata else: - response = self.last_metadata + response = self.get_broadcast_metadata() return parse_metadata(response) def set_goal( @@ -367,6 +366,17 @@ def control(self, data: AgentState) -> Tuple[float, float]: self.env.send(StepWithForce(force_z, torque_y, force_x)) return force_z, z_error_body + def get_current_time(self) -> float: + """ Get current sim time. """ + if self.last_metadata is None: + raise ValueError("Cannot get TESSE time, metadata is `NoneType`") + else: + return float(ET.fromstring(self.last_metadata).find("time").text) + + def get_broadcast_metadata(self) -> str: + """ Get metadata provided by TESSE UDP broadcasts. """ + return self.last_metadata + def close(self): """ Called upon destruction, join UDP listener. """ self.udp_listener.join() diff --git a/src/tesse_gym/core/tesse_gym.py b/src/tesse_gym/core/tesse_gym.py index 3f4d198..442656c 100644 --- a/src/tesse_gym/core/tesse_gym.py +++ b/src/tesse_gym/core/tesse_gym.py @@ -35,7 +35,13 @@ from tesse.msgs import * from .continuous_control import ContinuousController -from .utils import NetworkConfig, get_network_config, set_all_camera_params +from .utils import ( + NetworkConfig, + TesseConnectionError, + get_network_config, + response_nonetype_check, + set_all_camera_params, +) class TesseGym(GymEnv): @@ -210,7 +216,7 @@ def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: def observe(self) -> DataResponse: """ Observe state. """ cameras = [(Camera.RGB_LEFT, Compression.OFF, Channels.THREE)] - return self.env.request(DataRequest(metadata=True, cameras=cameras)) + return self._data_request(DataRequest(metadata=True, cameras=cameras)) def reset( self, scene_id: Optional[int] = None, random_seed: Optional[int] = None @@ -338,11 +344,39 @@ def get_pose(self) -> np.ndarray: """ return self.relative_pose + def _data_request(self, request_type: DataRequest, n_attempts: int = 20): + """ Make a data request while handling potential network limitations. + + If during the request, a `TesseConnectionError` is throw, this assumes + there is a spurrious bandwidth issue and re-requests `n_attempts` times. + If, after `n_attempts`, data cannot be recieved, a `TesseConnectionError` + is thrown. + + Args: + request_type (DataRequest): Data request type. + n_attempts (int): Number of times to request data from TESSE. + Default is 20. + + Returns: + DataResponse: Response from TESSE. + + Raises: + TesseConnectionError: Raised if data cannot be read from TESSE. + """ + for _ in range(n_attempts): + try: + return response_nonetype_check(self.env.request(request_type)) + except TesseConnectionError: + pass + + raise TesseConnectionError() + def _init_pose(self): """ Initialize agent's starting pose """ - metadata = self.env.request(MetadataRequest()).metadata - position = self._get_agent_position(metadata) - rotation = self._get_agent_rotation(metadata) + metadata_response = self._data_request(MetadataRequest()) + + position = self._get_agent_position(metadata_response.metadata) + rotation = self._get_agent_rotation(metadata_response.metadata) # initialize position in in agent frame initial_yaw = rotation[2] diff --git a/src/tesse_gym/core/utils.py b/src/tesse_gym/core/utils.py index db28ec7..2fda9be 100644 --- a/src/tesse_gym/core/utils.py +++ b/src/tesse_gym/core/utils.py @@ -20,8 +20,14 @@ ################################################################################################### from collections import namedtuple +from typing import Union -from tesse.msgs import Camera, SetCameraParametersRequest, SetCameraPositionRequest +from tesse.msgs import ( + Camera, + DataResponse, + SetCameraParametersRequest, + SetCameraPositionRequest, +) NetworkConfig = namedtuple( "NetworkConfig", @@ -174,3 +180,36 @@ def _adjust_camera_position(tesse_gym, camera, x=-0.05, y=0, z=0): z (int): z position. """ tesse_gym.env.request(SetCameraPositionRequest(camera=camera, x=x, y=y, z=z)) + + +def response_nonetype_check(obs: Union[DataResponse, None]) -> DataResponse: + """ Check that data from the sim is not `NoneType`. + + `obs` being `NoneType` indicates that data could + not be read from TESSE. Raise an exception if this + is the case. + + Args: + obs (Union[DataResponse, None]): Response from the simulator. + + Returns: + DataResponse: `obs` if `obs` is not `None`. + + Raises: + TesseConnectionError + """ + if obs is None: + raise TesseConnectionError() + return obs + + +class TesseConnectionError(Exception): + def __init__(self): + """ Indicates data cannot be read from TESSE. """ + self.message = ( + "Cannot receive data from the simulator. " + "The connection is blocked or the simulator is not running. " + ) + + def __str__(self): + return self.message diff --git a/src/tesse_gym/tasks/goseek/goseek.py b/src/tesse_gym/tasks/goseek/goseek.py index 761f97e..c6f967b 100644 --- a/src/tesse_gym/tasks/goseek/goseek.py +++ b/src/tesse_gym/tasks/goseek/goseek.py @@ -119,8 +119,7 @@ def observe(self) -> DataResponse: Returns: DataResponse: The `DataResponse` object. """ cameras = [(Camera.RGB_LEFT, Compression.OFF, Channels.THREE)] - agent_data = self.env.request(DataRequest(metadata=True, cameras=cameras)) - return agent_data + return self._data_request(DataRequest(metadata=True, cameras=cameras)) def reset( self, scene_id: Optional[int] = None, random_seed: Optional[int] = None diff --git a/src/tesse_gym/tasks/goseek/goseek_benchmark.py b/src/tesse_gym/tasks/goseek/goseek_benchmark.py index bdfbeb1..41eb2ee 100644 --- a/src/tesse_gym/tasks/goseek/goseek_benchmark.py +++ b/src/tesse_gym/tasks/goseek/goseek_benchmark.py @@ -82,7 +82,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]: """ results = {} for episode in range(len(self.scenes)): - print(f"Evaluation episode on episode {episode}, scene {self.scenes[episode]}") + print( + f"Evaluation episode on episode {episode}, scene {self.scenes[episode]}" + ) n_found_targets = 0 n_predictions = 0 n_successful_predictions = 0 @@ -91,7 +93,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]: self.env.n_targets = self.n_targets[episode] agent.reset() - obs = self.env.reset(scene_id=self.scenes[episode], random_seed=self.random_seeds[episode]) + obs = self.env.reset( + scene_id=self.scenes[episode], random_seed=self.random_seeds[episode] + ) for step in tqdm.tqdm(range(self.episode_length[episode])): action = agent.act(obs) @@ -106,7 +110,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]: if done: break - precision = 1 if n_predictions == 0 else n_successful_predictions / n_predictions + precision = ( + 1 if n_predictions == 0 else n_successful_predictions / n_predictions + ) recall = n_found_targets / self.env.n_targets results[str(episode)] = { "found_targets": n_found_targets, diff --git a/src/tesse_gym/tasks/goseek/goseek_full_perception.py b/src/tesse_gym/tasks/goseek/goseek_full_perception.py index ccdeb38..4a9d264 100644 --- a/src/tesse_gym/tasks/goseek/goseek_full_perception.py +++ b/src/tesse_gym/tasks/goseek/goseek_full_perception.py @@ -96,8 +96,7 @@ def observe(self) -> DataResponse: (Camera.SEGMENTATION, Compression.OFF, Channels.THREE), (Camera.DEPTH, Compression.OFF, Channels.THREE), ] - agent_data = self.env.request(DataRequest(metadata=True, cameras=cameras)) - return agent_data + return self._data_request(DataRequest(metadata=True, cameras=cameras)) def decode_observations(