diff --git a/pl_bolts/models/rl/common/gym_wrappers.py b/pl_bolts/models/rl/common/gym_wrappers.py index 37f8ee5b50..cbf95bcbdc 100644 --- a/pl_bolts/models/rl/common/gym_wrappers.py +++ b/pl_bolts/models/rl/common/gym_wrappers.py @@ -4,21 +4,28 @@ """ import collections -import gym -import gym.spaces import numpy as np import torch -from pl_bolts.utils import _OPENCV_AVAILABLE +from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg +if _GYM_AVAILABLE: + import gym.spaces + from gym import ObservationWrapper, Wrapper + from gym import make as gym_make +else: # pragma: no-cover + warn_missing_pkg('gym') + Wrapper = object + ObservationWrapper = object + if _OPENCV_AVAILABLE: import cv2 else: warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover -class ToTensor(gym.Wrapper): +class ToTensor(Wrapper): """For environments where the user need to press FIRE for the game to start.""" def __init__(self, env=None): @@ -34,7 +41,7 @@ def reset(self): return torch.tensor(self.env.reset()) -class FireResetEnv(gym.Wrapper): +class FireResetEnv(Wrapper): """For environments where the user need to press FIRE for the game to start.""" def __init__(self, env=None): @@ -58,7 +65,7 @@ def reset(self): return obs -class MaxAndSkipEnv(gym.Wrapper): +class MaxAndSkipEnv(Wrapper): """Return only every `skip`-th frame""" def __init__(self, env=None, skip=4): @@ -88,7 +95,7 @@ def reset(self): return obs -class ProcessFrame84(gym.ObservationWrapper): +class ProcessFrame84(ObservationWrapper): """preprocessing images from env""" def __init__(self, env=None): @@ -121,7 +128,7 @@ def process(frame): return x_t.astype(np.uint8) -class ImageToPyTorch(gym.ObservationWrapper): +class ImageToPyTorch(ObservationWrapper): """converts image to pytorch format""" def __init__(self, env): @@ -142,7 +149,7 @@ def observation(observation): return np.moveaxis(observation, 2, 0) -class ScaledFloatFrame(gym.ObservationWrapper): +class ScaledFloatFrame(ObservationWrapper): """scales the pixels""" @staticmethod @@ -150,7 +157,7 @@ def observation(obs): return np.array(obs).astype(np.float32) / 255.0 -class BufferWrapper(gym.ObservationWrapper): +class BufferWrapper(ObservationWrapper): """"Wrapper for image stacking""" def __init__(self, env, n_steps, dtype=np.float32): @@ -176,7 +183,7 @@ def observation(self, observation): return self.buffer -class DataAugmentation(gym.ObservationWrapper): +class DataAugmentation(ObservationWrapper): """ Carries out basic data augmentation on the env observations - ToTensor @@ -197,7 +204,7 @@ def observation(self, obs): def make_environment(env_name): """Convert environment with wrappers""" - env = gym.make(env_name) + env = gym_make(env_name) env = MaxAndSkipEnv(env) env = FireResetEnv(env) env = ProcessFrame84(env) diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 0635ead24c..6f5e958a36 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -18,15 +18,17 @@ from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.losses.rl import dqn_loss from pl_bolts.models.rl.common.agents import ValueAgent +from pl_bolts.models.rl.common.gym_wrappers import make_environment from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import CNN from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: - from pl_bolts.models.rl.common.gym_wrappers import gym, make_environment + from gym import Env else: warn_missing_pkg('gym') # pragma: no-cover + Env = object class DQN(pl.LightningModule): @@ -336,7 +338,7 @@ def test_dataloader(self) -> DataLoader: return self._dataloader() @staticmethod - def make_environment(env_name: str, seed: Optional[int] = None) -> gym.Env: + def make_environment(env_name: str, seed: Optional[int] = None) -> Env: """ Initialise gym environment