Skip to content

Commit

Permalink
fix(nyz): fix gym hybrid reward dtype bug (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed May 22, 2023
1 parent 4023c59 commit 164cb1a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def distance(self) -> float:

@staticmethod
def get_distance(x1: float, y1: float, x2: float, y2: float) -> float:
return np.sqrt(((x1 - x2) ** 2) + ((y1 - y2) ** 2))
return np.sqrt(((x1 - x2) ** 2) + ((y1 - y2) ** 2)).item()

def render(self, mode='human'):
screen_width = 400
Expand Down Expand Up @@ -397,7 +397,7 @@ def distance(self) -> float:

@staticmethod
def get_distance(x1: float, y1: float, x2: float, y2: float) -> float:
return np.sqrt(((x1 - x2) ** 2) + ((y1 - y2) ** 2))
return np.sqrt(((x1 - x2) ** 2) + ((y1 - y2) ** 2)).item()

def close(self):
if self.viewer:
Expand Down
24 changes: 12 additions & 12 deletions dizoo/gym_hybrid/envs/gym_hybrid_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,6 @@ def step(self, action: Dict) -> BaseEnvTimestep:
if self._save_replay:
self._frames.append(self._env.render(mode='rgb_array'))
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
if done:
info['eval_episode_return'] = self._eval_episode_return
if self._save_replay:
if self._env_id == 'HardMove-v0':
self._env_id = f'hardmove_n{self._cfg.num_actuators}'
path = os.path.join(
self._replay_path, '{}_episode_{}.gif'.format(self._env_id, self._save_replay_count)
)
self.display_frames_as_gif(self._frames, path)
self._frames = []
self._save_replay_count += 1

obs = to_ndarray(obs)
if isinstance(obs, list): # corner case
Expand All @@ -114,6 +102,18 @@ def step(self, action: Dict) -> BaseEnvTimestep:
if isinstance(rew, list):
rew = rew[0]
assert isinstance(rew, np.ndarray) and rew.shape == (1, )
self._eval_episode_return += rew.item()
if done:
info['eval_episode_return'] = self._eval_episode_return
if self._save_replay:
if self._env_id == 'HardMove-v0':
self._env_id = f'hardmove_n{self._cfg.num_actuators}'
path = os.path.join(
self._replay_path, '{}_episode_{}.gif'.format(self._env_id, self._save_replay_count)
)
self.display_frames_as_gif(self._frames, path)
self._frames = []
self._save_replay_count += 1
info['action_args_mask'] = np.array([[1, 0], [0, 1], [0, 0]])
return BaseEnvTimestep(obs, rew, done, info)

Expand Down

0 comments on commit 164cb1a

Please sign in to comment.