Skip to content

Commit

Permalink
Merge pull request #4 from MIT-TESSE/0.1.4-SNAPSHOT
Browse files Browse the repository at this point in the history
0.1.4 snapshot
  • Loading branch information
ZacRavichandran authored Mar 26, 2020
2 parents 1943c0b + 911aaac commit 2a9ad6e
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 19 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

setup(
name="tesse_gym",
version="0.1.3",
version="0.1.4",
description="TESSE OpenAI Gym python interface",
packages=find_packages("src"),
# tell setuptools that all packages will be under the 'src' directory
Expand Down
20 changes: 15 additions & 5 deletions src/tesse_gym/core/continuous_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from tesse.msgs import *
from tesse.utils import UdpListener


# gains 1: 150, 35, 1.6, 0.27
# gains 2: 200, 35, 1.6, 0.27

Expand Down Expand Up @@ -211,13 +210,12 @@ def transform(
data = self.get_data()
self.set_goal(data, translate_x, translate_z, rotate_y)

# track movement across steps to find collisions
last_z_err, last_z_rate_err = 0, 0
collision_count = 0
n_steps = 0

# Apply controls until at goal point, a collision occurs, or max steps reached
i = 0
while not self.at_goal(data) and i < self.max_steps:
while not self.at_goal(data) and n_steps < self.max_steps:
force_z, z_error = self.control(data)
data = self.get_data()

Expand All @@ -228,6 +226,7 @@ def transform(
break

last_z_err = z_error
n_steps += 1

self.set_goal(data)

Expand Down Expand Up @@ -264,7 +263,7 @@ def get_data(self) -> AgentState:
if self.last_metadata is None:
response = self.env.request(MetadataRequest()).metadata
else:
response = self.last_metadata
response = self.get_broadcast_metadata()
return parse_metadata(response)

def set_goal(
Expand Down Expand Up @@ -367,6 +366,17 @@ def control(self, data: AgentState) -> Tuple[float, float]:
self.env.send(StepWithForce(force_z, torque_y, force_x))
return force_z, z_error_body

def get_current_time(self) -> float:
""" Get current sim time. """
if self.last_metadata is None:
raise ValueError("Cannot get TESSE time, metadata is `NoneType`")
else:
return float(ET.fromstring(self.last_metadata).find("time").text)

def get_broadcast_metadata(self) -> str:
""" Get metadata provided by TESSE UDP broadcasts. """
return self.last_metadata

def close(self):
""" Called upon destruction, join UDP listener. """
self.udp_listener.join()
44 changes: 39 additions & 5 deletions src/tesse_gym/core/tesse_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from tesse.msgs import *

from .continuous_control import ContinuousController
from .utils import NetworkConfig, get_network_config, set_all_camera_params
from .utils import (
NetworkConfig,
TesseConnectionError,
get_network_config,
response_nonetype_check,
set_all_camera_params,
)


class TesseGym(GymEnv):
Expand Down Expand Up @@ -210,7 +216,7 @@ def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
def observe(self) -> DataResponse:
""" Observe state. """
cameras = [(Camera.RGB_LEFT, Compression.OFF, Channels.THREE)]
return self.env.request(DataRequest(metadata=True, cameras=cameras))
return self._data_request(DataRequest(metadata=True, cameras=cameras))

def reset(
self, scene_id: Optional[int] = None, random_seed: Optional[int] = None
Expand Down Expand Up @@ -338,11 +344,39 @@ def get_pose(self) -> np.ndarray:
"""
return self.relative_pose

def _data_request(self, request_type: DataRequest, n_attempts: int = 20):
""" Make a data request while handling potential network limitations.
If during the request, a `TesseConnectionError` is throw, this assumes
there is a spurrious bandwidth issue and re-requests `n_attempts` times.
If, after `n_attempts`, data cannot be recieved, a `TesseConnectionError`
is thrown.
Args:
request_type (DataRequest): Data request type.
n_attempts (int): Number of times to request data from TESSE.
Default is 20.
Returns:
DataResponse: Response from TESSE.
Raises:
TesseConnectionError: Raised if data cannot be read from TESSE.
"""
for _ in range(n_attempts):
try:
return response_nonetype_check(self.env.request(request_type))
except TesseConnectionError:
pass

raise TesseConnectionError()

def _init_pose(self):
""" Initialize agent's starting pose """
metadata = self.env.request(MetadataRequest()).metadata
position = self._get_agent_position(metadata)
rotation = self._get_agent_rotation(metadata)
metadata_response = self._data_request(MetadataRequest())

position = self._get_agent_position(metadata_response.metadata)
rotation = self._get_agent_rotation(metadata_response.metadata)

# initialize position in in agent frame
initial_yaw = rotation[2]
Expand Down
41 changes: 40 additions & 1 deletion src/tesse_gym/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
###################################################################################################

from collections import namedtuple
from typing import Union

from tesse.msgs import Camera, SetCameraParametersRequest, SetCameraPositionRequest
from tesse.msgs import (
Camera,
DataResponse,
SetCameraParametersRequest,
SetCameraPositionRequest,
)

NetworkConfig = namedtuple(
"NetworkConfig",
Expand Down Expand Up @@ -174,3 +180,36 @@ def _adjust_camera_position(tesse_gym, camera, x=-0.05, y=0, z=0):
z (int): z position.
"""
tesse_gym.env.request(SetCameraPositionRequest(camera=camera, x=x, y=y, z=z))


def response_nonetype_check(obs: Union[DataResponse, None]) -> DataResponse:
""" Check that data from the sim is not `NoneType`.
`obs` being `NoneType` indicates that data could
not be read from TESSE. Raise an exception if this
is the case.
Args:
obs (Union[DataResponse, None]): Response from the simulator.
Returns:
DataResponse: `obs` if `obs` is not `None`.
Raises:
TesseConnectionError
"""
if obs is None:
raise TesseConnectionError()
return obs


class TesseConnectionError(Exception):
def __init__(self):
""" Indicates data cannot be read from TESSE. """
self.message = (
"Cannot receive data from the simulator. "
"The connection is blocked or the simulator is not running. "
)

def __str__(self):
return self.message
3 changes: 1 addition & 2 deletions src/tesse_gym/tasks/goseek/goseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def observe(self) -> DataResponse:
Returns:
DataResponse: The `DataResponse` object. """
cameras = [(Camera.RGB_LEFT, Compression.OFF, Channels.THREE)]
agent_data = self.env.request(DataRequest(metadata=True, cameras=cameras))
return agent_data
return self._data_request(DataRequest(metadata=True, cameras=cameras))

def reset(
self, scene_id: Optional[int] = None, random_seed: Optional[int] = None
Expand Down
12 changes: 9 additions & 3 deletions src/tesse_gym/tasks/goseek/goseek_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]:
"""
results = {}
for episode in range(len(self.scenes)):
print(f"Evaluation episode on episode {episode}, scene {self.scenes[episode]}")
print(
f"Evaluation episode on episode {episode}, scene {self.scenes[episode]}"
)
n_found_targets = 0
n_predictions = 0
n_successful_predictions = 0
Expand All @@ -91,7 +93,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]:

self.env.n_targets = self.n_targets[episode]
agent.reset()
obs = self.env.reset(scene_id=self.scenes[episode], random_seed=self.random_seeds[episode])
obs = self.env.reset(
scene_id=self.scenes[episode], random_seed=self.random_seeds[episode]
)

for step in tqdm.tqdm(range(self.episode_length[episode])):
action = agent.act(obs)
Expand All @@ -106,7 +110,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]:
if done:
break

precision = 1 if n_predictions == 0 else n_successful_predictions / n_predictions
precision = (
1 if n_predictions == 0 else n_successful_predictions / n_predictions
)
recall = n_found_targets / self.env.n_targets
results[str(episode)] = {
"found_targets": n_found_targets,
Expand Down
3 changes: 1 addition & 2 deletions src/tesse_gym/tasks/goseek/goseek_full_perception.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def observe(self) -> DataResponse:
(Camera.SEGMENTATION, Compression.OFF, Channels.THREE),
(Camera.DEPTH, Compression.OFF, Channels.THREE),
]
agent_data = self.env.request(DataRequest(metadata=True, cameras=cameras))
return agent_data
return self._data_request(DataRequest(metadata=True, cameras=cameras))


def decode_observations(
Expand Down

0 comments on commit 2a9ad6e

Please sign in to comment.