Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.1.4 snapshot #4

Merged
merged 5 commits into from
Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

setup(
name="tesse_gym",
version="0.1.3",
version="0.1.4",
description="TESSE OpenAI Gym python interface",
packages=find_packages("src"),
# tell setuptools that all packages will be under the 'src' directory
Expand Down
20 changes: 15 additions & 5 deletions src/tesse_gym/core/continuous_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from tesse.msgs import *
from tesse.utils import UdpListener


# gains 1: 150, 35, 1.6, 0.27
# gains 2: 200, 35, 1.6, 0.27

Expand Down Expand Up @@ -211,13 +210,12 @@ def transform(
data = self.get_data()
self.set_goal(data, translate_x, translate_z, rotate_y)

# track movement across steps to find collisions
last_z_err, last_z_rate_err = 0, 0
collision_count = 0
n_steps = 0

# Apply controls until at goal point, a collision occurs, or max steps reached
i = 0
while not self.at_goal(data) and i < self.max_steps:
while not self.at_goal(data) and n_steps < self.max_steps:
force_z, z_error = self.control(data)
data = self.get_data()

Expand All @@ -228,6 +226,7 @@ def transform(
break

last_z_err = z_error
n_steps += 1

self.set_goal(data)

Expand Down Expand Up @@ -264,7 +263,7 @@ def get_data(self) -> AgentState:
if self.last_metadata is None:
response = self.env.request(MetadataRequest()).metadata
else:
response = self.last_metadata
response = self.get_broadcast_metadata()
return parse_metadata(response)

def set_goal(
Expand Down Expand Up @@ -367,6 +366,17 @@ def control(self, data: AgentState) -> Tuple[float, float]:
self.env.send(StepWithForce(force_z, torque_y, force_x))
return force_z, z_error_body

def get_current_time(self) -> float:
""" Get current sim time. """
if self.last_metadata is None:
raise ValueError("Cannot get TESSE time, metadata is `NoneType`")
else:
return float(ET.fromstring(self.last_metadata).find("time").text)

def get_broadcast_metadata(self) -> str:
""" Get metadata provided by TESSE UDP broadcasts. """
return self.last_metadata

def close(self):
""" Called upon destruction, join UDP listener. """
self.udp_listener.join()
44 changes: 39 additions & 5 deletions src/tesse_gym/core/tesse_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from tesse.msgs import *

from .continuous_control import ContinuousController
from .utils import NetworkConfig, get_network_config, set_all_camera_params
from .utils import (
NetworkConfig,
TesseConnectionError,
get_network_config,
response_nonetype_check,
set_all_camera_params,
)


class TesseGym(GymEnv):
Expand Down Expand Up @@ -210,7 +216,7 @@ def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
def observe(self) -> DataResponse:
""" Observe state. """
cameras = [(Camera.RGB_LEFT, Compression.OFF, Channels.THREE)]
return self.env.request(DataRequest(metadata=True, cameras=cameras))
return self._data_request(DataRequest(metadata=True, cameras=cameras))

def reset(
self, scene_id: Optional[int] = None, random_seed: Optional[int] = None
Expand Down Expand Up @@ -338,11 +344,39 @@ def get_pose(self) -> np.ndarray:
"""
return self.relative_pose

def _data_request(self, request_type: DataRequest, n_attempts: int = 20):
""" Make a data request while handling potential network limitations.

If during the request, a `TesseConnectionError` is throw, this assumes
there is a spurrious bandwidth issue and re-requests `n_attempts` times.
If, after `n_attempts`, data cannot be recieved, a `TesseConnectionError`
is thrown.

Args:
request_type (DataRequest): Data request type.
n_attempts (int): Number of times to request data from TESSE.
Default is 20.

Returns:
DataResponse: Response from TESSE.

Raises:
TesseConnectionError: Raised if data cannot be read from TESSE.
"""
for _ in range(n_attempts):
try:
return response_nonetype_check(self.env.request(request_type))
except TesseConnectionError:
pass

raise TesseConnectionError()

def _init_pose(self):
""" Initialize agent's starting pose """
metadata = self.env.request(MetadataRequest()).metadata
position = self._get_agent_position(metadata)
rotation = self._get_agent_rotation(metadata)
metadata_response = self._data_request(MetadataRequest())

position = self._get_agent_position(metadata_response.metadata)
rotation = self._get_agent_rotation(metadata_response.metadata)

# initialize position in in agent frame
initial_yaw = rotation[2]
Expand Down
41 changes: 40 additions & 1 deletion src/tesse_gym/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
###################################################################################################

from collections import namedtuple
from typing import Union

from tesse.msgs import Camera, SetCameraParametersRequest, SetCameraPositionRequest
from tesse.msgs import (
Camera,
DataResponse,
SetCameraParametersRequest,
SetCameraPositionRequest,
)

NetworkConfig = namedtuple(
"NetworkConfig",
Expand Down Expand Up @@ -174,3 +180,36 @@ def _adjust_camera_position(tesse_gym, camera, x=-0.05, y=0, z=0):
z (int): z position.
"""
tesse_gym.env.request(SetCameraPositionRequest(camera=camera, x=x, y=y, z=z))


def response_nonetype_check(obs: Union[DataResponse, None]) -> DataResponse:
""" Check that data from the sim is not `NoneType`.

`obs` being `NoneType` indicates that data could
not be read from TESSE. Raise an exception if this
is the case.

Args:
obs (Union[DataResponse, None]): Response from the simulator.

Returns:
DataResponse: `obs` if `obs` is not `None`.

Raises:
TesseConnectionError
"""
if obs is None:
raise TesseConnectionError()
return obs


class TesseConnectionError(Exception):
def __init__(self):
""" Indicates data cannot be read from TESSE. """
self.message = (
"Cannot receive data from the simulator. "
"The connection is blocked or the simulator is not running. "
)

def __str__(self):
return self.message
3 changes: 1 addition & 2 deletions src/tesse_gym/tasks/goseek/goseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def observe(self) -> DataResponse:
Returns:
DataResponse: The `DataResponse` object. """
cameras = [(Camera.RGB_LEFT, Compression.OFF, Channels.THREE)]
agent_data = self.env.request(DataRequest(metadata=True, cameras=cameras))
return agent_data
return self._data_request(DataRequest(metadata=True, cameras=cameras))

def reset(
self, scene_id: Optional[int] = None, random_seed: Optional[int] = None
Expand Down
12 changes: 9 additions & 3 deletions src/tesse_gym/tasks/goseek/goseek_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]:
"""
results = {}
for episode in range(len(self.scenes)):
print(f"Evaluation episode on episode {episode}, scene {self.scenes[episode]}")
print(
f"Evaluation episode on episode {episode}, scene {self.scenes[episode]}"
)
n_found_targets = 0
n_predictions = 0
n_successful_predictions = 0
Expand All @@ -91,7 +93,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]:

self.env.n_targets = self.n_targets[episode]
agent.reset()
obs = self.env.reset(scene_id=self.scenes[episode], random_seed=self.random_seeds[episode])
obs = self.env.reset(
scene_id=self.scenes[episode], random_seed=self.random_seeds[episode]
)

for step in tqdm.tqdm(range(self.episode_length[episode])):
action = agent.act(obs)
Expand All @@ -106,7 +110,9 @@ def evaluate(self, agent: Agent) -> Dict[str, Dict[str, float]]:
if done:
break

precision = 1 if n_predictions == 0 else n_successful_predictions / n_predictions
precision = (
1 if n_predictions == 0 else n_successful_predictions / n_predictions
)
recall = n_found_targets / self.env.n_targets
results[str(episode)] = {
"found_targets": n_found_targets,
Expand Down
3 changes: 1 addition & 2 deletions src/tesse_gym/tasks/goseek/goseek_full_perception.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def observe(self) -> DataResponse:
(Camera.SEGMENTATION, Compression.OFF, Channels.THREE),
(Camera.DEPTH, Compression.OFF, Channels.THREE),
]
agent_data = self.env.request(DataRequest(metadata=True, cameras=cameras))
return agent_data
return self._data_request(DataRequest(metadata=True, cameras=cameras))


def decode_observations(
Expand Down