From 8401554e52f4a7e36b88d9e0f34724f20e1a38a1 Mon Sep 17 00:00:00 2001 From: Nolwen Date: Mon, 16 Dec 2024 14:40:09 +0100 Subject: [PATCH] Fix return hint for EnumerableSpace.get_elements() and simplify implementations In code, the result of `EnumerableSpace.get_elements()` is used for: - `for a in elements` - `if a in elements` - `elements[i]` - `len(elements)` - `random.sample(elements, 2)` So this is actually expected to be `Sequence`. We fix the annotation and simplify the implementation by removing the wrapping in numpy.array() as it is not expected and even sometimes re-wrap in a numpy.array() in the code. --- examples/nocycle_grid_goal_mdp.py | 4 +-- .../domain/scheduling/scheduling_domains.py | 10 +++---- skdecide/core.py | 4 +-- .../hub/domain/graph_domain/GraphDomain.py | 4 +-- skdecide/hub/space/gym/gym.py | 26 +++++++------------ 5 files changed, 21 insertions(+), 27 deletions(-) diff --git a/examples/nocycle_grid_goal_mdp.py b/examples/nocycle_grid_goal_mdp.py index 2dd82ba6dc..d3c2d2867c 100644 --- a/examples/nocycle_grid_goal_mdp.py +++ b/examples/nocycle_grid_goal_mdp.py @@ -4,7 +4,7 @@ import getopt import sys -from collections.abc import Iterable +from collections.abc import Sequence from enum import IntEnum from math import sqrt from typing import NamedTuple, Optional @@ -45,7 +45,7 @@ def __init__(self, state=None, num_cols=0, num_rows=0): self.num_cols = num_cols self.num_rows = num_rows - def get_elements(self) -> Iterable[int]: + def get_elements(self) -> Sequence[int]: if self.state is None: return [a for a in MyActions] else: diff --git a/skdecide/builders/domain/scheduling/scheduling_domains.py b/skdecide/builders/domain/scheduling/scheduling_domains.py index 3819b5c17e..d66c9c7439 100644 --- a/skdecide/builders/domain/scheduling/scheduling_domains.py +++ b/skdecide/builders/domain/scheduling/scheduling_domains.py @@ -5,7 +5,7 @@ from __future__ import annotations import random -from collections.abc import Iterable +from collections.abc import Sequence from enum import Enum from itertools import product from typing import Optional @@ -1219,7 +1219,7 @@ def __init__(self, domain: SchedulingDomain, state: State): self.state = state self.elements = self._get_elements() - def _get_elements(self) -> Iterable[T]: + def _get_elements(self) -> Sequence[T]: choices = [ SchedulingActionEnum.START, SchedulingActionEnum.PAUSE, @@ -1282,7 +1282,7 @@ def _get_elements(self) -> Iterable[T]: ) return list_action - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: return self.elements def sample(self) -> T: @@ -1299,7 +1299,7 @@ def __init__(self, domain: SchedulingDomain, state: State): self.state = state self.elements = self._get_elements() - def _get_elements(self) -> Iterable[T]: + def _get_elements(self) -> Sequence[T]: choices = [ SchedulingActionEnum.START, SchedulingActionEnum.PAUSE, @@ -1382,7 +1382,7 @@ def _get_elements(self) -> Iterable[T]: ) return list_action - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: return self.elements def sample(self) -> T: diff --git a/skdecide/core.py b/skdecide/core.py index bddfa849db..c1cd0f10ea 100644 --- a/skdecide/core.py +++ b/skdecide/core.py @@ -117,7 +117,7 @@ def contains(self, x: T) -> bool: class EnumerableSpace(Space[T]): """A space which elements can be enumerated.""" - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: """Get the elements of this space. # Returns @@ -132,7 +132,7 @@ def contains(self, x: T) -> bool: class EmptySpace(EnumerableSpace[T]): """An (enumerable) empty space.""" - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: return () diff --git a/skdecide/hub/domain/graph_domain/GraphDomain.py b/skdecide/hub/domain/graph_domain/GraphDomain.py index 2087bbecec..5489c55a04 100644 --- a/skdecide/hub/domain/graph_domain/GraphDomain.py +++ b/skdecide/hub/domain/graph_domain/GraphDomain.py @@ -5,7 +5,7 @@ from __future__ import annotations import random -from collections.abc import Iterable +from collections.abc import Sequence from typing import Optional import networkx as nx @@ -44,7 +44,7 @@ def contains(self, x: T) -> bool: def __init__(self, l: list[object]): self.l = l - def get_elements(self) -> Iterable[object]: + def get_elements(self) -> Sequence[object]: return self.l diff --git a/skdecide/hub/space/gym/gym.py b/skdecide/hub/space/gym/gym.py index 2df488c060..5098ea1d76 100644 --- a/skdecide/hub/space/gym/gym.py +++ b/skdecide/hub/space/gym/gym.py @@ -84,13 +84,13 @@ def __init__(self, n, element_class=int): super().__init__(gym_space=gym_spaces.Discrete(n)) self._element_class = element_class - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: """Get the elements of this space. # Returns The elements of this space. """ - return np.array(list(range(self._gym_space.n)), dtype=np.int64) + return range(self._gym_space.n) def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable: return ( @@ -118,16 +118,13 @@ def __init__(self, nvec, element_class=np.ndarray): super().__init__(gym_space=gym_spaces.MultiDiscrete(nvec)) self._element_class = element_class - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: """Get the elements of this space. # Returns The elements of this space. """ - return np.array( - list(itertools.product(*[list(range(n)) for n in self._gym_space.nvec])), - dtype=np.int64, - ) + return tuple(itertools.product(*(range(n) for n in self._gym_space.nvec))) def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable: return ( @@ -155,16 +152,13 @@ def __init__(self, n, element_class=np.ndarray): super().__init__(gym_space=gym_spaces.MultiBinary(n)) self._element_class = element_class - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: """Get the elements of this space. # Returns The elements of this space. """ - return np.array( - list(itertools.product(*[(1, 0) for _ in range(self._gym_space.n)])), - dtype=np.int8, - ) + return tuple(itertools.product(*((1, 0) for _ in range(self._gym_space.n)))) def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable: return ( @@ -311,14 +305,14 @@ def __init__(self, enum_class: EnumMeta) -> None: enum_class: The enumeration class for creating the Gym Discrete space (gym.spaces.Discrete) to wrap. """ self._enum_class = enum_class - self._list_enum = list(enum_class) + self._list_enum = tuple(enum_class) gym_space = gym_spaces.Discrete(len(enum_class)) super().__init__(gym_space) def contains(self, x: T) -> bool: return isinstance(x, self._enum_class) - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: return self._list_enum def sample(self) -> T: @@ -367,7 +361,7 @@ def __init__(self, elements: Iterable[T]) -> None: def contains(self, x: T) -> bool: return x in self._elements - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: return self._elements def sample(self) -> T: @@ -417,7 +411,7 @@ def __init__(self, elements: Iterable[T]) -> None: def contains(self, x: T) -> bool: return x in self._elements - def get_elements(self) -> Iterable[T]: + def get_elements(self) -> Sequence[T]: return self._elements def sample(self) -> T: