Skip to content

Commit

Permalink
Fix return hint for EnumerableSpace.get_elements() and simplify imple…
Browse files Browse the repository at this point in the history
…mentations

In code, the result of `EnumerableSpace.get_elements()` is used for:
- `for a in elements`
- `if a in elements`
- `elements[i]`
- `len(elements)`
- `random.sample(elements, 2)`

So this is actually expected to be `Sequence`.
We fix the annotation and simplify the implementation by removing the
wrapping in numpy.array() as it is not expected and even sometimes
re-wrap in a numpy.array() in the code.
  • Loading branch information
nhuet authored and fteicht committed Jan 10, 2025
1 parent dcf2120 commit 8401554
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 27 deletions.
4 changes: 2 additions & 2 deletions examples/nocycle_grid_goal_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import getopt
import sys
from collections.abc import Iterable
from collections.abc import Sequence
from enum import IntEnum
from math import sqrt
from typing import NamedTuple, Optional
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, state=None, num_cols=0, num_rows=0):
self.num_cols = num_cols
self.num_rows = num_rows

def get_elements(self) -> Iterable[int]:
def get_elements(self) -> Sequence[int]:
if self.state is None:
return [a for a in MyActions]
else:
Expand Down
10 changes: 5 additions & 5 deletions skdecide/builders/domain/scheduling/scheduling_domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import random
from collections.abc import Iterable
from collections.abc import Sequence
from enum import Enum
from itertools import product
from typing import Optional
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def __init__(self, domain: SchedulingDomain, state: State):
self.state = state
self.elements = self._get_elements()

def _get_elements(self) -> Iterable[T]:
def _get_elements(self) -> Sequence[T]:
choices = [
SchedulingActionEnum.START,
SchedulingActionEnum.PAUSE,
Expand Down Expand Up @@ -1282,7 +1282,7 @@ def _get_elements(self) -> Iterable[T]:
)
return list_action

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self.elements

def sample(self) -> T:
Expand All @@ -1299,7 +1299,7 @@ def __init__(self, domain: SchedulingDomain, state: State):
self.state = state
self.elements = self._get_elements()

def _get_elements(self) -> Iterable[T]:
def _get_elements(self) -> Sequence[T]:
choices = [
SchedulingActionEnum.START,
SchedulingActionEnum.PAUSE,
Expand Down Expand Up @@ -1382,7 +1382,7 @@ def _get_elements(self) -> Iterable[T]:
)
return list_action

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self.elements

def sample(self) -> T:
Expand Down
4 changes: 2 additions & 2 deletions skdecide/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def contains(self, x: T) -> bool:
class EnumerableSpace(Space[T]):
"""A space which elements can be enumerated."""

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.
# Returns
Expand All @@ -132,7 +132,7 @@ def contains(self, x: T) -> bool:
class EmptySpace(EnumerableSpace[T]):
"""An (enumerable) empty space."""

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return ()


Expand Down
4 changes: 2 additions & 2 deletions skdecide/hub/domain/graph_domain/GraphDomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import random
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Optional

import networkx as nx
Expand Down Expand Up @@ -44,7 +44,7 @@ def contains(self, x: T) -> bool:
def __init__(self, l: list[object]):
self.l = l

def get_elements(self) -> Iterable[object]:
def get_elements(self) -> Sequence[object]:
return self.l


Expand Down
26 changes: 10 additions & 16 deletions skdecide/hub/space/gym/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def __init__(self, n, element_class=int):
super().__init__(gym_space=gym_spaces.Discrete(n))
self._element_class = element_class

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.
# Returns
The elements of this space.
"""
return np.array(list(range(self._gym_space.n)), dtype=np.int64)
return range(self._gym_space.n)

def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable:
return (
Expand Down Expand Up @@ -118,16 +118,13 @@ def __init__(self, nvec, element_class=np.ndarray):
super().__init__(gym_space=gym_spaces.MultiDiscrete(nvec))
self._element_class = element_class

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.
# Returns
The elements of this space.
"""
return np.array(
list(itertools.product(*[list(range(n)) for n in self._gym_space.nvec])),
dtype=np.int64,
)
return tuple(itertools.product(*(range(n) for n in self._gym_space.nvec)))

def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable:
return (
Expand Down Expand Up @@ -155,16 +152,13 @@ def __init__(self, n, element_class=np.ndarray):
super().__init__(gym_space=gym_spaces.MultiBinary(n))
self._element_class = element_class

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
"""Get the elements of this space.
# Returns
The elements of this space.
"""
return np.array(
list(itertools.product(*[(1, 0) for _ in range(self._gym_space.n)])),
dtype=np.int8,
)
return tuple(itertools.product(*((1, 0) for _ in range(self._gym_space.n))))

def to_unwrapped(self, sample_n: Iterable[T]) -> Iterable:
return (
Expand Down Expand Up @@ -311,14 +305,14 @@ def __init__(self, enum_class: EnumMeta) -> None:
enum_class: The enumeration class for creating the Gym Discrete space (gym.spaces.Discrete) to wrap.
"""
self._enum_class = enum_class
self._list_enum = list(enum_class)
self._list_enum = tuple(enum_class)
gym_space = gym_spaces.Discrete(len(enum_class))
super().__init__(gym_space)

def contains(self, x: T) -> bool:
return isinstance(x, self._enum_class)

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self._list_enum

def sample(self) -> T:
Expand Down Expand Up @@ -367,7 +361,7 @@ def __init__(self, elements: Iterable[T]) -> None:
def contains(self, x: T) -> bool:
return x in self._elements

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self._elements

def sample(self) -> T:
Expand Down Expand Up @@ -417,7 +411,7 @@ def __init__(self, elements: Iterable[T]) -> None:
def contains(self, x: T) -> bool:
return x in self._elements

def get_elements(self) -> Iterable[T]:
def get_elements(self) -> Sequence[T]:
return self._elements

def sample(self) -> T:
Expand Down

0 comments on commit 8401554

Please sign in to comment.