diff --git a/gym/envs/toy_text/kellycoinflip.py b/gym/envs/toy_text/kellycoinflip.py index 1a47c0ae67a..4f305abe9b6 100644 --- a/gym/envs/toy_text/kellycoinflip.py +++ b/gym/envs/toy_text/kellycoinflip.py @@ -79,7 +79,7 @@ def step(self, action): return self._get_obs(), reward, done, {} def _get_obs(self): - return np.array([self.wealth]), self.rounds + return np.array([self.wealth], dtype=np.float32), self.rounds def reset(self): self.rounds = self.max_rounds @@ -236,11 +236,11 @@ def step(self, action): def _get_obs(self): return ( - np.array([float(self.wealth)]), + np.array([float(self.wealth)], dtype=np.float32), self.rounds_elapsed, self.wins, self.losses, - np.array([float(self.max_ever_wealth)]), + np.array([float(self.max_ever_wealth)], dtype=np.float32), ) def reset(self): diff --git a/gym/spaces/box.py b/gym/spaces/box.py index 32b5c937171..7051dd9a549 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -1,4 +1,5 @@ import numpy as np +import warnings from .space import Space from gym import logger @@ -138,10 +139,15 @@ def sample(self): return sample.astype(self.dtype) def contains(self, x): - if isinstance(x, list): - x = np.array(x) # Promote list to array for contains check + if not isinstance(x, np.ndarray): + warnings.warn("Casting input x to numpy array.") + x = np.asarray(x, dtype=self.dtype) + return ( - x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high) + np.can_cast(x.dtype, self.dtype) + and x.shape == self.shape + and np.any(x >= self.low) + and np.any(x <= self.high) ) def to_jsonable(self, sample_n): diff --git a/gym/spaces/tests/test_spaces.py b/gym/spaces/tests/test_spaces.py index 9dd4c6b8965..2694f309102 100644 --- a/gym/spaces/tests/test_spaces.py +++ b/gym/spaces/tests/test_spaces.py @@ -178,3 +178,18 @@ def test_class_inequality(spaces): def test_bad_space_calls(space_fn): with pytest.raises(AssertionError): space_fn() + + +def test_box_dtype_check(): + # Related Issues: + # https://github.com/openai/gym/issues/2357 + # https://github.com/openai/gym/issues/2298 + + space = Box(0, 2, tuple(), dtype=np.float32) + + # casting will match the correct type + assert space.contains(0.5) + + # float64 is not in float32 space + assert not space.contains(np.array(0.5)) + assert not space.contains(np.array(1)) diff --git a/gym/wrappers/test_flatten_observation.py b/gym/wrappers/test_flatten_observation.py index f190081078b..967aa8f721d 100644 --- a/gym/wrappers/test_flatten_observation.py +++ b/gym/wrappers/test_flatten_observation.py @@ -19,12 +19,14 @@ def test_flatten_observation(env_id): space = spaces.Tuple( (spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)) ) - wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32) + wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64) elif env_id == "KellyCoinflip-v0": space = spaces.Tuple( (spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1)) ) - wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32) + low = np.zeros((302,), dtype=np.float64) + high = np.array([250.0] + [1.0] * 301, dtype=np.float64) + wrapped_space = spaces.Box(low, high, [1 + (300 + 1)], dtype=np.float64) assert space.contains(obs) assert wrapped_space.contains(wrapped_obs)