Skip to content

Commit

Permalink
Various fixes w/ help from mypy's type checking
Browse files Browse the repository at this point in the history
I made `_get_matching_entity_no_optional` as a workaround to enable
`get_matching_entities` to construct its entity views without having to check
for `None` every time-- I'm not even sure *how* it would do that, considering
the structure of the dict comprehension. So that means `get_matching_entity`
no longer takes `_checked` as a keyword argument because the code for checked
entities has been delegated to `_get_matching_entity_no_optional`.

I also made a bunch of other changes:

- Various type annotation changes to accomodate the complaints of the type checker
(mypy in this case)
- Changed `Aspect.__eq__`s type annotation and implementation-- Turns out that I
should return `NotImplemented` if the object doesn't support operations (==, >, etc.)
with its input. BUT mypy complains when it sees that `NotImplemented` is able to be
returned, rather than a value of type `bool`. So I made a hack to trick mypy
into thinking that the operator (method) doesn't return `NotImplemented`, but
actually does so at runtime-- though it seems that the mypy team recognises this
and are working on a fix/whatever.

Issue Reference:
- python/mypy#4534
  • Loading branch information
Hoboneer committed Feb 16, 2018
1 parent f8ecf8d commit be9ae61
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 45 deletions.
28 changes: 18 additions & 10 deletions test_pygame_space_shooter/engine/core/ecs/aspect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, FrozenSet, Set, Generator
from typing import Iterable, FrozenSet, Set, Generator, TYPE_CHECKING
from itertools import chain

from .types import ComponentName
Expand All @@ -23,7 +23,7 @@ def optional(self) -> FrozenSet[ComponentName]:
return self._optional_components

@property
def either_or(self) -> FrozenSet[ComponentName]:
def either_or(self) -> FrozenSet[FrozenSet[ComponentName]]:
return self._xor_components

@property
Expand Down Expand Up @@ -55,15 +55,23 @@ def xor(self, component_names: Set[ComponentName]) -> Generator[ComponentName, N
def __hash__(self) -> int:
return hash(self._and_components | self._optional_components | self._xor_components)

def __eq__(self, other: "Aspect") -> bool:
try:
this_intersects = self._and_components | self._optional_components | self._xor_components
other_intersects = other._and_components | other._optional_components | other._xor_components
def __eq__(self, other: object) -> bool:
# I'd really like to enable ducktyping here though... why, mypy
if isinstance(other, Aspect):
this_intersects = self.mandatory | self.optional | self.either_or
other_intersects = other.mandatory | other.optional | other.either_or
return this_intersects == other_intersects
except AttributeError:
raise TypeError(
"<class Aspect> does not support equality operation (==) with classes other than <class Aspect>."
"Got <class {}> instead.".format(other.__class__.__name__))
else:
# A dirty hack to trick mypy into thinking this method is OK and complies
# with the type signature-- even though `NotImplemented` should be OK here.
# Though it looks like they're dealing with it. Reference:
# https://github.com/python/mypy/issues/4534
if TYPE_CHECKING:
return False
else:
# This is returned at runtime
return NotImplemented


def __repr__(self) -> str:
return "<class Aspect ({}|{}|{})>".format(len(self.mandatory), len(self.optional), len(self.either_or))
Expand Down
75 changes: 40 additions & 35 deletions test_pygame_space_shooter/engine/core/ecs/entity_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Set, Dict, Iterable, Mapping, Any, Optional, Type, List, Union, cast
from typing import (
Set, Dict, List,
Iterable, Iterator, Mapping, Callable, AbstractSet,
Any, Optional, Type, Union,
cast
)
from mypy_extensions import TypedDict
from .events import EntityManagerEvent, RemoveEntity, EntityAdded
from .types import (Entity, EntityID, EntityManagerEventID,
ComponentName, ComponentObject, NewComponentInfo)
from .types import (
Entity, EntityID, EntityManagerEventID,
ComponentName, ComponentObject, NewComponentInfo
)
from .aspect import Aspect


Expand Down Expand Up @@ -61,7 +68,7 @@ def registered_components(self) -> Set[ComponentName]:
def live_entities(self) -> Set[EntityID]:
return set(self.entities.keys())

def _is_component_names_valid(self, component_names: Set[ComponentName]) -> bool:
def _is_component_names_valid(self, component_names: AbstractSet[ComponentName]) -> bool:
try:
return component_names <= self.registered_components
except TypeError:
Expand All @@ -83,7 +90,7 @@ def add_component_to_entity(self, entity_id: EntityID, component_name: Component
self.components[component_name][entity_id] = component_cls(*args, **kwargs)

def create_entity(self, components: Dict[ComponentName, Union[NewComponentInfo, ComponentObject]], instantiated: bool = False) -> EntityID:
if not self._is_component_names_valid(set(components.keys()))
if not self._is_component_names_valid(set(components.keys())):
raise InvalidComponentNameError("Failed to add new entity to EntityManager because `components` has keys of non-existent components`")

current_entity_id = self._next_entity_id
Expand Down Expand Up @@ -120,7 +127,7 @@ def create_entity(self, components: Dict[ComponentName, Union[NewComponentInfo,
return current_entity_id

def _group_entity_pattern_match(self, pattern: Aspect) -> Set[EntityID]:
filter_func = (lambda c, p: self._is_component_names_valid(c | p) and c >= p)
filter_func = (lambda c, p: self._is_component_names_valid(c | p) and c >= p) # type: Callable[[AbstractSet[ComponentName], AbstractSet[ComponentName]], bool]

matching_mandatory_entities = {entity_id for entity_id, component_names in self.entities.items() if filter_func(component_names, pattern.mandatory)} # type: Set[EntityID]
matching_xor_entities = {entity_id for entity_id, component_names in self.entities.items() if all(pattern._xor(component_names))} # type: Set[EntityID]
Expand All @@ -132,46 +139,44 @@ def get_matching_entities(self, pattern: Aspect) -> Dict[EntityID, Entity]:
if not self._is_component_names_valid(pattern.all):
raise InvalidComponentNameError("Failed to get matching entities because `pattern` has keys of non-existent components")

entity_views = {} # type: Dict[EntityID, Entity]

matching_entities = self._group_entity_pattern_match(pattern)
for entity_id in matching_entities:
entity_views[entity_id] = get_matching_entity(entity_id, pattern, _checked=True)
entity_views = {entity for entity in for get_matching_entity} # type: Dict[EntityID, Entity]

entity_views = {entity_id:self._get_matching_entity_no_optional(entity_id, pattern) for entity_id in matching_entities} # type: Dict[EntityID, Entity]

return entity_views

def get_matching_entity(self, entity_id: EntityID, pattern: Aspect, _checked: bool = False) -> Optional[Entity]:
def get_matching_entity(self, entity_id: EntityID, pattern: Aspect) -> Optional[Entity]:
# Used to get a *specific* entity's *specific* components
if not _checked:
if not self._is_component_names_valid(pattern.all):
raise InvalidComponentNameError("Failed to get matching entity because `pattern` has keys of non-existent components")

try:
entity_component_names = self.entities[entity_id] # type: Set[ComponentName]
except KeyError:
raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist".format(entity_id))
else:
if pattern.is_matched(entity_component_names):
# try replace this one below with 1 or 2 set operations (if too slow)
# This does `s1 | s2` because we already know that it has the mandatory components, but any optionals are OK
entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor()} # type: Entity
if not self._is_component_names_valid(pattern.all):
raise InvalidComponentNameError("Failed to get matching entity because `pattern` has keys of non-existent components")

return entity_view
else:
return None
try:
entity_component_names = self.entities[entity_id] # type: Set[ComponentName]
except KeyError:
raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist".format(entity_id))
else:
# This branch should only be executed when called by `get_matching_entities`-- which lets it assume that the entity matches the pattern
try:
entity_component_names = self.entities[entity_id] # type: Set[ComponentName]
except KeyError:
raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist".format(entity_id))
else:
if pattern.is_matched(entity_component_names):
# try replace this one below with 1 or 2 set operations (if too slow)
# This does `s1 | s2` because we already know that it has the mandatory components, but any optionals are OK
entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor()} # type: Entity
entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor(entity_component_names)} # type: Entity

return entity_view
else:
return None

def _get_matching_entity_no_optional(self, entity_id: EntityID, pattern: Aspect) -> Entity:
# Used by `get_matching_entities` to avoid a type check error when constructing entity view
try:
entity_component_names = self.entities[entity_id]
except KeyError:
# Raise `RuntimeError` instead, to indicate that it shouldn't have happened?
raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist (You should not have gotten here!)".format(entity_id))
else:
# try replace this one below with 1 or 2 set operations (if too slow)
# This does `s1 | s2` because we already know that it has the mandatory components, but any optionals are OK
entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor(entity_component_names)}

return entity_view

def get_entity(self, entity_id: EntityID) -> Entity:
try:
Expand Down

0 comments on commit be9ae61

Please sign in to comment.