From e3c95ffe2117f78620b7ce7e8bdb368157500a74 Mon Sep 17 00:00:00 2001 From: hiha3456 <744762298@qq.com> Date: Fri, 10 Jun 2022 16:16:18 +0800 Subject: [PATCH] make old test runnable --- .../distar/envs/tests/test_distar_env_data.py | 13 +++++----- .../envs/tests/test_distar_env_time_space.py | 25 ++++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/dizoo/distar/envs/tests/test_distar_env_data.py b/dizoo/distar/envs/tests/test_distar_env_data.py index 38f1dc83a6..85caa88973 100644 --- a/dizoo/distar/envs/tests/test_distar_env_data.py +++ b/dizoo/distar/envs/tests/test_distar_env_data.py @@ -8,6 +8,9 @@ import random import time +from dizoo.distar.envs import DIStarEnv +import traceback + class TestDIstarEnv: def __init__(self): @@ -16,8 +19,6 @@ def __init__(self): self._whole_cfg.env.map_name = 'KingsCove' def _inference_loop(self, job={}): - from dizoo.distar.envs import DIStarEnv - import traceback torch.set_num_threads(1) @@ -26,15 +27,15 @@ def _inference_loop(self, job={}): with torch.no_grad(): for _ in range(5): try: - observations, game_info, map_name = self._env.reset() + observations = self._env.reset() for iter in range(1000): # one episode loop # agent step actions = self._env.random_action(observations) # env step - next_observations, reward, done = self._env.step(actions) - if not done: - observations = next_observations + timestep = self._env.step(actions) + if not timestep.done: + observations = timestep.obs else: break diff --git a/dizoo/distar/envs/tests/test_distar_env_time_space.py b/dizoo/distar/envs/tests/test_distar_env_time_space.py index 3cc3612e45..6d99173aa5 100644 --- a/dizoo/distar/envs/tests/test_distar_env_time_space.py +++ b/dizoo/distar/envs/tests/test_distar_env_time_space.py @@ -9,6 +9,9 @@ import time import sys +from dizoo.distar.envs import DIStarEnv +import traceback + class TestDIstarEnv: def __init__(self): @@ -20,8 +23,6 @@ def __init__(self): self._total_space = 0 def _inference_loop(self, job={}): - from distar_env import DIStarEnv - import traceback torch.set_num_threads(1) @@ -30,32 +31,32 @@ def _inference_loop(self, job={}): with torch.no_grad(): for _ in range(5): try: - observations, game_info, map_name = self._env.reset() + observations = self._env.reset() for iter in range(1000): # one episode loop # agent step actions = self._env.random_action(observations) # env step before_step_time = time.time() - next_observations, reward, done = self._env.step(actions) + timestep = self._env.step(actions) after_step_time = time.time() self._total_time += after_step_time - before_step_time self._total_iters += 1 - self._total_space += sys.getsizeof((actions,observations,next_observations,reward,done)) + self._total_space += sys.getsizeof((actions,observations,timestep.obs,timestep.reward,timestep.done)) print('observations: ', sys.getsizeof(observations), ' Byte') print('actions: ', sys.getsizeof(actions), ' Byte') - print('reward: ', sys.getsizeof(reward), ' Byte') - print('done: ', sys.getsizeof(done), ' Byte') - print('total: ', sys.getsizeof((actions,observations,next_observations,reward,done)),' Byte') + print('reward: ', sys.getsizeof(timestep.reward), ' Byte') + print('done: ', sys.getsizeof(timestep.done), ' Byte') + print('total: ', sys.getsizeof((actions,observations,timestep.obs,timestep.reward,timestep.done)),' Byte') print(type(observations)) # dict - print(type(reward)) # list - print(type(done)) # bool + print(type(timestep.reward)) # list + print(type(timestep.done)) # bool print(type(actions)) # dict - if not done: - observations = next_observations + if not timestep.done: + observations = timestep.obs else: break