Skip to content

Commit

Permalink
box.contains check dtype and promote non-ndarrays (#2374)
Browse files Browse the repository at this point in the history
* box.contains check dtype and promote non-ndarrays

Closes: #2357 and #2298

Instead of only casting list to ndarray, cast any class to ndarray (if possible) and emit a warning when casting. Also, check if the dtype of the input matches the dtype of the space.

* use import warnings

* blackify

* changs from code review

* fix wrapped space

Co-authored-by: Tristan Deleu <[email protected]>

* fix box bondaries

Co-authored-by: Tristan Deleu <[email protected]>

* TEST: add regression test.

* STY: black

Co-authored-by: Tristan Deleu <[email protected]>
  • Loading branch information
FirefoxMetzger and tristandeleu authored Sep 1, 2021
1 parent 0b07221 commit 7573c57
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
6 changes: 3 additions & 3 deletions gym/envs/toy_text/kellycoinflip.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def step(self, action):
return self._get_obs(), reward, done, {}

def _get_obs(self):
return np.array([self.wealth]), self.rounds
return np.array([self.wealth], dtype=np.float32), self.rounds

def reset(self):
self.rounds = self.max_rounds
Expand Down Expand Up @@ -236,11 +236,11 @@ def step(self, action):

def _get_obs(self):
return (
np.array([float(self.wealth)]),
np.array([float(self.wealth)], dtype=np.float32),
self.rounds_elapsed,
self.wins,
self.losses,
np.array([float(self.max_ever_wealth)]),
np.array([float(self.max_ever_wealth)], dtype=np.float32),
)

def reset(self):
Expand Down
12 changes: 9 additions & 3 deletions gym/spaces/box.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import warnings

from .space import Space
from gym import logger
Expand Down Expand Up @@ -138,10 +139,15 @@ def sample(self):
return sample.astype(self.dtype)

def contains(self, x):
if isinstance(x, list):
x = np.array(x) # Promote list to array for contains check
if not isinstance(x, np.ndarray):
warnings.warn("Casting input x to numpy array.")
x = np.asarray(x, dtype=self.dtype)

return (
x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
np.can_cast(x.dtype, self.dtype)
and x.shape == self.shape
and np.any(x >= self.low)
and np.any(x <= self.high)
)

def to_jsonable(self, sample_n):
Expand Down
15 changes: 15 additions & 0 deletions gym/spaces/tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,18 @@ def test_class_inequality(spaces):
def test_bad_space_calls(space_fn):
with pytest.raises(AssertionError):
space_fn()


def test_box_dtype_check():
# Related Issues:
# https://github.com/openai/gym/issues/2357
# https://github.com/openai/gym/issues/2298

space = Box(0, 2, tuple(), dtype=np.float32)

# casting will match the correct type
assert space.contains(0.5)

# float64 is not in float32 space
assert not space.contains(np.array(0.5))
assert not space.contains(np.array(1))
6 changes: 4 additions & 2 deletions gym/wrappers/test_flatten_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def test_flatten_observation(env_id):
space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32)
wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64)
elif env_id == "KellyCoinflip-v0":
space = spaces.Tuple(
(spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1))
)
wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32)
low = np.zeros((302,), dtype=np.float64)
high = np.array([250.0] + [1.0] * 301, dtype=np.float64)
wrapped_space = spaces.Box(low, high, [1 + (300 + 1)], dtype=np.float64)

assert space.contains(obs)
assert wrapped_space.contains(wrapped_obs)

0 comments on commit 7573c57

Please sign in to comment.