diff --git a/test_pygame_space_shooter/engine/core/ecs/aspect.py b/test_pygame_space_shooter/engine/core/ecs/aspect.py index a3eff38..8664f2f 100644 --- a/test_pygame_space_shooter/engine/core/ecs/aspect.py +++ b/test_pygame_space_shooter/engine/core/ecs/aspect.py @@ -1,4 +1,4 @@ -from typing import Iterable, FrozenSet, Set, Generator +from typing import Iterable, FrozenSet, Set, Generator, TYPE_CHECKING from itertools import chain from .types import ComponentName @@ -23,7 +23,7 @@ def optional(self) -> FrozenSet[ComponentName]: return self._optional_components @property - def either_or(self) -> FrozenSet[ComponentName]: + def either_or(self) -> FrozenSet[FrozenSet[ComponentName]]: return self._xor_components @property @@ -55,15 +55,23 @@ def xor(self, component_names: Set[ComponentName]) -> Generator[ComponentName, N def __hash__(self) -> int: return hash(self._and_components | self._optional_components | self._xor_components) - def __eq__(self, other: "Aspect") -> bool: - try: - this_intersects = self._and_components | self._optional_components | self._xor_components - other_intersects = other._and_components | other._optional_components | other._xor_components + def __eq__(self, other: object) -> bool: + # I'd really like to enable ducktyping here though... why, mypy + if isinstance(other, Aspect): + this_intersects = self.mandatory | self.optional | self.either_or + other_intersects = other.mandatory | other.optional | other.either_or return this_intersects == other_intersects - except AttributeError: - raise TypeError( - " does not support equality operation (==) with classes other than ." - "Got instead.".format(other.__class__.__name__)) + else: + # A dirty hack to trick mypy into thinking this method is OK and complies + # with the type signature-- even though `NotImplemented` should be OK here. + # Though it looks like they're dealing with it. Reference: + # https://github.com/python/mypy/issues/4534 + if TYPE_CHECKING: + return False + else: + # This is returned at runtime + return NotImplemented + def __repr__(self) -> str: return "".format(len(self.mandatory), len(self.optional), len(self.either_or)) diff --git a/test_pygame_space_shooter/engine/core/ecs/entity_manager.py b/test_pygame_space_shooter/engine/core/ecs/entity_manager.py index 4105d69..bee3a3b 100644 --- a/test_pygame_space_shooter/engine/core/ecs/entity_manager.py +++ b/test_pygame_space_shooter/engine/core/ecs/entity_manager.py @@ -1,8 +1,15 @@ -from typing import Set, Dict, Iterable, Mapping, Any, Optional, Type, List, Union, cast +from typing import ( + Set, Dict, List, + Iterable, Iterator, Mapping, Callable, AbstractSet, + Any, Optional, Type, Union, + cast +) from mypy_extensions import TypedDict from .events import EntityManagerEvent, RemoveEntity, EntityAdded -from .types import (Entity, EntityID, EntityManagerEventID, - ComponentName, ComponentObject, NewComponentInfo) +from .types import ( + Entity, EntityID, EntityManagerEventID, + ComponentName, ComponentObject, NewComponentInfo +) from .aspect import Aspect @@ -61,7 +68,7 @@ def registered_components(self) -> Set[ComponentName]: def live_entities(self) -> Set[EntityID]: return set(self.entities.keys()) - def _is_component_names_valid(self, component_names: Set[ComponentName]) -> bool: + def _is_component_names_valid(self, component_names: AbstractSet[ComponentName]) -> bool: try: return component_names <= self.registered_components except TypeError: @@ -83,7 +90,7 @@ def add_component_to_entity(self, entity_id: EntityID, component_name: Component self.components[component_name][entity_id] = component_cls(*args, **kwargs) def create_entity(self, components: Dict[ComponentName, Union[NewComponentInfo, ComponentObject]], instantiated: bool = False) -> EntityID: - if not self._is_component_names_valid(set(components.keys())) + if not self._is_component_names_valid(set(components.keys())): raise InvalidComponentNameError("Failed to add new entity to EntityManager because `components` has keys of non-existent components`") current_entity_id = self._next_entity_id @@ -120,7 +127,7 @@ def create_entity(self, components: Dict[ComponentName, Union[NewComponentInfo, return current_entity_id def _group_entity_pattern_match(self, pattern: Aspect) -> Set[EntityID]: - filter_func = (lambda c, p: self._is_component_names_valid(c | p) and c >= p) + filter_func = (lambda c, p: self._is_component_names_valid(c | p) and c >= p) # type: Callable[[AbstractSet[ComponentName], AbstractSet[ComponentName]], bool] matching_mandatory_entities = {entity_id for entity_id, component_names in self.entities.items() if filter_func(component_names, pattern.mandatory)} # type: Set[EntityID] matching_xor_entities = {entity_id for entity_id, component_names in self.entities.items() if all(pattern._xor(component_names))} # type: Set[EntityID] @@ -132,46 +139,44 @@ def get_matching_entities(self, pattern: Aspect) -> Dict[EntityID, Entity]: if not self._is_component_names_valid(pattern.all): raise InvalidComponentNameError("Failed to get matching entities because `pattern` has keys of non-existent components") - entity_views = {} # type: Dict[EntityID, Entity] - matching_entities = self._group_entity_pattern_match(pattern) - for entity_id in matching_entities: - entity_views[entity_id] = get_matching_entity(entity_id, pattern, _checked=True) - entity_views = {entity for entity in for get_matching_entity} # type: Dict[EntityID, Entity] + + entity_views = {entity_id:self._get_matching_entity_no_optional(entity_id, pattern) for entity_id in matching_entities} # type: Dict[EntityID, Entity] return entity_views - def get_matching_entity(self, entity_id: EntityID, pattern: Aspect, _checked: bool = False) -> Optional[Entity]: + def get_matching_entity(self, entity_id: EntityID, pattern: Aspect) -> Optional[Entity]: # Used to get a *specific* entity's *specific* components - if not _checked: - if not self._is_component_names_valid(pattern.all): - raise InvalidComponentNameError("Failed to get matching entity because `pattern` has keys of non-existent components") - - try: - entity_component_names = self.entities[entity_id] # type: Set[ComponentName] - except KeyError: - raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist".format(entity_id)) - else: - if pattern.is_matched(entity_component_names): - # try replace this one below with 1 or 2 set operations (if too slow) - # This does `s1 | s2` because we already know that it has the mandatory components, but any optionals are OK - entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor()} # type: Entity + if not self._is_component_names_valid(pattern.all): + raise InvalidComponentNameError("Failed to get matching entity because `pattern` has keys of non-existent components") - return entity_view - else: - return None + try: + entity_component_names = self.entities[entity_id] # type: Set[ComponentName] + except KeyError: + raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist".format(entity_id)) else: - # This branch should only be executed when called by `get_matching_entities`-- which lets it assume that the entity matches the pattern - try: - entity_component_names = self.entities[entity_id] # type: Set[ComponentName] - except KeyError: - raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist".format(entity_id)) - else: + if pattern.is_matched(entity_component_names): # try replace this one below with 1 or 2 set operations (if too slow) # This does `s1 | s2` because we already know that it has the mandatory components, but any optionals are OK - entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor()} # type: Entity + entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor(entity_component_names)} # type: Entity return entity_view + else: + return None + + def _get_matching_entity_no_optional(self, entity_id: EntityID, pattern: Aspect) -> Entity: + # Used by `get_matching_entities` to avoid a type check error when constructing entity view + try: + entity_component_names = self.entities[entity_id] + except KeyError: + # Raise `RuntimeError` instead, to indicate that it shouldn't have happened? + raise InvalidEntityIDError("Failed to get matching entity because `entity_id` is {}, which does not exist (You should not have gotten here!)".format(entity_id)) + else: + # try replace this one below with 1 or 2 set operations (if too slow) + # This does `s1 | s2` because we already know that it has the mandatory components, but any optionals are OK + entity_view = {name:(self.components[name][entity_id]) for name in entity_component_names if name in pattern.mandatory | pattern.optional or name in pattern.xor(entity_component_names)} + + return entity_view def get_entity(self, entity_id: EntityID) -> Entity: try: