From 954d496460ec2d0afa5bd81c10194274310f6584 Mon Sep 17 00:00:00 2001 From: zhanpenghe Date: Fri, 8 Jun 2018 12:01:27 -0700 Subject: [PATCH] fix arg name --- rllab/envs/normalized_gym_env.py | 6 +++--- tests/test_normalized_gym.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rllab/envs/normalized_gym_env.py b/rllab/envs/normalized_gym_env.py index 1fd203554..e42ea8651 100644 --- a/rllab/envs/normalized_gym_env.py +++ b/rllab/envs/normalized_gym_env.py @@ -52,7 +52,7 @@ def __init__( scale_reward=1., normalize_obs=False, normalize_reward=False, - flatten=True, + flatten_obs=True, obs_alpha=0.001, reward_alpha=0.001, ): @@ -61,7 +61,7 @@ def __init__( self._scale_reward = scale_reward self._normalize_obs = normalize_obs self._normalize_reward = normalize_reward - self._flatten = flatten + self._flatten_obs = flatten_obs self._obs_alpha = obs_alpha flat_obs_dim = gym_space_flatten_dim(env.observation_space) @@ -92,7 +92,7 @@ def _apply_normalize_obs(self, obs): self._update_obs_estimate(obs) normalized_obs = (gym_space_flatten(self.env.observation_space, obs) - self._obs_mean) / (np.sqrt(self._obs_var) + 1e-8) - if not self._flatten: + if not self._flatten_obs: normalized_obs = gym_space_unflatten(self.env.observation_space, normalized_obs) return normalized_obs diff --git a/tests/test_normalized_gym.py b/tests/test_normalized_gym.py index 54b8b416c..397807afe 100644 --- a/tests/test_normalized_gym.py +++ b/tests/test_normalized_gym.py @@ -9,7 +9,7 @@ def test_flatten(): gym.make('Pendulum-v0'), normalize_reward=True, normalize_obs=True, - flatten=True) + flatten_obs=True) for i in range(100): env.reset() for e in range(100): @@ -28,7 +28,7 @@ def test_unflatten(): gym.make('Blackjack-v0'), normalize_reward=True, normalize_obs=True, - flatten=False) + flatten_obs=False) for i in range(100): env.reset() for e in range(100):