diff --git a/hades_extensions/ecs/__init__.pyi b/hades_extensions/ecs/__init__.pyi index e7fe865b..597a455b 100644 --- a/hades_extensions/ecs/__init__.pyi +++ b/hades_extensions/ecs/__init__.pyi @@ -1,7 +1,7 @@ # Builtin from collections.abc import Callable, Iterator from enum import Enum -from typing import Final, SupportsFloat, TypeAlias, TypeVar +from typing import Final, Literal, SupportsFloat, TypeAlias, TypeVar, overload # Define some type vars for the registry _C = TypeVar("_C") @@ -94,9 +94,28 @@ class Registry: ) -> list[int]: ... def get_system(self: Registry, system: type[_S]) -> _S: ... def update(self: Registry, delta_time: float) -> None: ... + @overload def add_callback( self: Registry, - event_type: EventType, + event_type: Literal[EventType.GameObjectCreation], + callback: Callable[[int], None], + ) -> None: ... + @overload + def add_callback( + self: Registry, + event_type: Literal[EventType.GameObjectDeath], + callback: Callable[[int], None], + ) -> None: ... + @overload + def add_callback( + self: Registry, + event_type: Literal[EventType.InventoryUpdate], + callback: Callable[[int], None], + ) -> None: ... + @overload + def add_callback( + self: Registry, + event_type: Literal[EventType.SpriteRemoval], callback: Callable[[int], None], ) -> None: ... diff --git a/src/hades_extensions/include/ecs/registry.hpp b/src/hades_extensions/include/ecs/registry.hpp index 0253f8a2..5494fd4f 100644 --- a/src/hades_extensions/include/ecs/registry.hpp +++ b/src/hades_extensions/include/ecs/registry.hpp @@ -2,6 +2,7 @@ #pragma once // Std headers +#include #ifdef __GNUC__ #include #endif @@ -222,19 +223,28 @@ class Registry { /// Add a callback to the registry to listen for events. /// - /// @param event_type - The type of event to listen for. + /// @tparam E - The type of event to listen for. + /// @tparam Func - The callback functions' signature /// @param callback - The callback to add. - void add_callback(const EventType event_type, const std::function &callback) { - callbacks_[event_type] = callback; + template + void add_callback(Func &&callback) { + listeners_[E].emplace_back([callback = std::forward(callback)](std::any args) { + std::apply(callback, std::any_cast::EventArgs>(args)); + }); } /// Notify all callbacks of an event. /// - /// @param event_type - The type of event to notify callbacks of. - /// @param game_object_id - The game object ID to pass to the callbacks. - void notify_callbacks(const EventType event_type, const GameObjectID game_object_id) { - if (callbacks_.contains(event_type)) { - callbacks_[event_type](game_object_id); + /// @tparam E - The type of event to notify callbacks of. + /// @tparam Args - The types of the arguments to pass to the callbacks. + /// @param args - The arguments to pass to the callbacks. + template + void notify(Args &&...args) { + using ExpectedArgs = typename EventTraits::EventArgs; + static_assert(std::is_same_v...>, ExpectedArgs>); + const ExpectedArgs tuple_args{std::forward(args)...}; + for (const auto &callback : listeners_[E]) { + callback(tuple_args); } } @@ -263,6 +273,6 @@ class Registry { /// The Chipmunk2D space. ChipmunkHandle space_{cpSpaceNew()}; - /// The callbacks registered with the registry to listen for events. - std::unordered_map> callbacks_; + /// The listeners registered for each event type. + std::unordered_map>> listeners_; }; diff --git a/src/hades_extensions/include/ecs/types.hpp b/src/hades_extensions/include/ecs/types.hpp index 731ed99a..f62c20c6 100644 --- a/src/hades_extensions/include/ecs/types.hpp +++ b/src/hades_extensions/include/ecs/types.hpp @@ -32,6 +32,34 @@ enum class EventType : std::uint8_t { SpriteRemoval, }; +/// A helper struct to provide the argument types for each event type. +template +struct EventTraits; + +/// Provides the argument types for the GameObjectCreation event. +template <> +struct EventTraits { + using EventArgs = std::tuple; +}; + +/// Provides the argument types for the GameObjectDeath event. +template <> +struct EventTraits { + using EventArgs = std::tuple; +}; + +/// Provides the argument types for the InventoryUpdate event. +template <> +struct EventTraits { + using EventArgs = std::tuple; +}; + +/// Provides the argument types for the SpriteRemoval event. +template <> +struct EventTraits { + using EventArgs = std::tuple; +}; + /// The base class for all components. struct ComponentBase { /// The copy assignment operator. diff --git a/src/hades_extensions/src/binding.cpp b/src/hades_extensions/src/binding.cpp index 60ff2008..ebee43a1 100644 --- a/src/hades_extensions/src/binding.cpp +++ b/src/hades_extensions/src/binding.cpp @@ -346,11 +346,31 @@ PYBIND11_MODULE(hades_extensions, module) { // NOLINT "Update all systems in the registry.\n\n" "Args:\n" " delta_time: The time interval since the last time the function was called.") - .def("add_callback", &Registry::add_callback, pybind11::arg("event_type"), pybind11::arg("callback"), - "Add a callback to the registry to listen for events.\n\n" - "Args:\n" - " event_type: The type of event to listen for.\n" - " callback: The callback to add."); + .def( + "add_callback", + [](Registry ®istry, const EventType event_type, const pybind11::function &callback) { + switch (event_type) { + case EventType::GameObjectCreation: + registry.add_callback(callback); + break; + case EventType::GameObjectDeath: + registry.add_callback(callback); + break; + case EventType::InventoryUpdate: + registry.add_callback(callback); + break; + case EventType::SpriteRemoval: + registry.add_callback(callback); + break; + default: + throw std::runtime_error("Unsupported event type."); + } + }, + pybind11::arg("event_type"), pybind11::arg("callback"), + "Add a callback to the registry to listen for events.\n\n" + "Args:\n" + " event_type: The type of event to listen for.\n" + " callback: The callback to add."); // Add the stat components pybind11::class_>( diff --git a/src/hades_extensions/src/ecs/registry.cpp b/src/hades_extensions/src/ecs/registry.cpp index f57255d8..411c8e38 100644 --- a/src/hades_extensions/src/ecs/registry.cpp +++ b/src/hades_extensions/src/ecs/registry.cpp @@ -62,7 +62,7 @@ auto Registry::create_game_object(const GameObjectType game_object_type, const c } // Increment the game object ID and return the current game object ID - notify_callbacks(EventType::GameObjectCreation, next_game_object_id_); + notify(next_game_object_id_); next_game_object_id_++; return next_game_object_id_ - 1; } @@ -80,7 +80,7 @@ void Registry::delete_game_object(const GameObjectID game_object_id) { } // Notify the callbacks then delete the game object - notify_callbacks(EventType::GameObjectDeath, game_object_id); + notify(game_object_id); std::erase(game_object_ids_[get_game_object_type(game_object_id)], game_object_id); game_objects_.erase(game_object_id); game_object_types_.erase(game_object_id); diff --git a/src/hades_extensions/src/ecs/systems/inventory.cpp b/src/hades_extensions/src/ecs/systems/inventory.cpp index 924491d2..c884cf05 100644 --- a/src/hades_extensions/src/ecs/systems/inventory.cpp +++ b/src/hades_extensions/src/ecs/systems/inventory.cpp @@ -35,8 +35,8 @@ auto InventorySystem::add_item_to_inventory(const GameObjectID game_object_id, c // Add the item to the inventory and notify the callbacks inventory->items.push_back(item); - get_registry()->notify_callbacks(EventType::InventoryUpdate, game_object_id); - get_registry()->notify_callbacks(EventType::SpriteRemoval, item); + get_registry()->notify(game_object_id); + get_registry()->notify(item); // If the item has a kinematic component, set the collected flag to true to prevent collision detection if (get_registry()->has_component(item, typeid(KinematicComponent))) { @@ -63,8 +63,8 @@ auto InventorySystem::remove_item_from_inventory(const GameObjectID game_object_ // Remove the item from the inventory, delete the game object, and notify the callbacks inventory->items.erase(inventory->items.begin() + index); get_registry()->delete_game_object(item_id); - get_registry()->notify_callbacks(EventType::InventoryUpdate, game_object_id); - get_registry()->notify_callbacks(EventType::SpriteRemoval, item_id); + get_registry()->notify(game_object_id); + get_registry()->notify(item_id); return true; } diff --git a/src/hades_extensions/tests/ecs/systems/test_attacks.cpp b/src/hades_extensions/tests/ecs/systems/test_attacks.cpp index a1c92fd0..b125a5e6 100644 --- a/src/hades_extensions/tests/ecs/systems/test_attacks.cpp +++ b/src/hades_extensions/tests/ecs/systems/test_attacks.cpp @@ -117,7 +117,7 @@ TEST_F(AttackSystemFixture, TestAttackSystemUpdateSteeringMovementZeroDeltaTime) auto game_object_created{-1}; auto game_object_creation_callback{[&](const GameObjectID game_object_id) { game_object_created = game_object_id; }}; create_attack_component({AttackAlgorithm::Ranged}, true); - registry.add_callback(EventType::GameObjectCreation, game_object_creation_callback); + registry.add_callback(game_object_creation_callback); get_attack_system()->update(0); ASSERT_EQ(game_object_created, -1); } @@ -127,7 +127,7 @@ TEST_F(AttackSystemFixture, TestAttackSystemUpdateSteeringMovementNotTarget) { auto game_object_created{-1}; auto game_object_creation_callback{[&](const GameObjectID game_object_id) { game_object_created = game_object_id; }}; create_attack_component({AttackAlgorithm::Ranged}, true); - registry.add_callback(EventType::GameObjectCreation, game_object_creation_callback); + registry.add_callback(game_object_creation_callback); get_attack_system()->update(5); ASSERT_EQ(game_object_created, -1); } @@ -137,7 +137,7 @@ TEST_F(AttackSystemFixture, TestAttackSystemUpdateSteeringMovement) { auto game_object_created{-1}; auto game_object_creation_callback{[&](const GameObjectID game_object_id) { game_object_created = game_object_id; }}; create_attack_component({AttackAlgorithm::Ranged}, true); - registry.add_callback(EventType::GameObjectCreation, game_object_creation_callback); + registry.add_callback(game_object_creation_callback); registry.get_component(8)->movement_state = SteeringMovementState::Target; get_attack_system()->update(5); ASSERT_EQ(game_object_created, 9); @@ -188,7 +188,7 @@ TEST_F(AttackSystemFixture, TestAttackSystemDoAttackRanged) { game_object_created = game_object_id; }}; create_attack_component({AttackAlgorithm::Ranged}); - registry.add_callback(EventType::GameObjectCreation, game_object_creation_callback); + registry.add_callback(game_object_creation_callback); get_attack_system()->update(5); get_attack_system()->do_attack(8, targets); ASSERT_EQ(game_object_created, 9); diff --git a/src/hades_extensions/tests/ecs/systems/test_inventory.cpp b/src/hades_extensions/tests/ecs/systems/test_inventory.cpp index daed82f6..d1e08b27 100644 --- a/src/hades_extensions/tests/ecs/systems/test_inventory.cpp +++ b/src/hades_extensions/tests/ecs/systems/test_inventory.cpp @@ -59,10 +59,10 @@ TEST_F(InventorySystemFixture, TestInventorySystemAddItemToInventoryValid) { // Add the callbacks to the registry auto inventory_update{-1}; auto inventory_update_callback{[&](const GameObjectID game_object_id) { inventory_update = game_object_id; }}; - registry.add_callback(EventType::InventoryUpdate, inventory_update_callback); + registry.add_callback(inventory_update_callback); auto sprite_removal{-1}; auto sprite_removal_callback{[&](const GameObjectID game_object_id) { sprite_removal = game_object_id; }}; - registry.add_callback(EventType::SpriteRemoval, sprite_removal_callback); + registry.add_callback(sprite_removal_callback); // Add the item to the inventory and check the results const auto game_object_id{create_item(GameObjectType::HealthPotion)}; @@ -103,10 +103,10 @@ TEST_F(InventorySystemFixture, TestInventorySystemRemoveItemFromInventoryValid) // Add the callbacks to the registry auto inventory_update{-1}; auto inventory_update_callback{[&](const GameObjectID game_object_id) { inventory_update = game_object_id; }}; - registry.add_callback(EventType::InventoryUpdate, inventory_update_callback); + registry.add_callback(inventory_update_callback); auto sprite_removal{-1}; auto sprite_removal_callback{[&](const GameObjectID game_object_id) { sprite_removal = game_object_id; }}; - registry.add_callback(EventType::SpriteRemoval, sprite_removal_callback); + registry.add_callback(sprite_removal_callback); // Add two items and remove one of them from the inventory and check the results const auto item_id_one{create_item(GameObjectType::HealthPotion)}; diff --git a/src/hades_extensions/tests/ecs/test_registry.cpp b/src/hades_extensions/tests/ecs/test_registry.cpp index 5a68d617..e33ddf7f 100644 --- a/src/hades_extensions/tests/ecs/test_registry.cpp +++ b/src/hades_extensions/tests/ecs/test_registry.cpp @@ -96,7 +96,7 @@ TEST(Tests, TestGridPosToPixelNegativeXYPosition){ TEST_F(RegistryFixture, TestRegistryEmptyGameObject) { // Create the callback for the game object death event int called{-1}; - registry.add_callback(EventType::GameObjectDeath, [&called](const auto event) { called = event; }); + registry.add_callback([&called](const GameObjectID event) { called = event; }); // Test that creating the game object works correctly ASSERT_EQ(registry.create_game_object(GameObjectType::Player, cpvzero, {}), 0); @@ -358,22 +358,22 @@ TEST_F(RegistryFixture, TestRegistryWallBulletCollision) { /// Test that an event is not notified if there are no callbacks added to the registry. TEST_F(RegistryFixture, TestRegistryNotifyCallbacksNoCallbacksAdded) { constexpr bool called{false}; - registry.notify_callbacks(EventType::GameObjectDeath, 0); + registry.notify(0); ASSERT_FALSE(called); } /// Test that an event is not notified if there are no callbacks listening for that event. TEST_F(RegistryFixture, TestRegistryNotifyCallbacksNoCallbacksListening) { auto called{-1}; - registry.add_callback(EventType::GameObjectCreation, [&called](const auto event) { called = event; }); - registry.notify_callbacks(EventType::GameObjectDeath, 0); + registry.add_callback([&called](const auto event) { called = event; }); + registry.notify(0); ASSERT_EQ(called, -1); } /// Test that an event is notified correctly if there is a callback listening for that event. TEST_F(RegistryFixture, TestRegistryNotifyCallbacksListeningCallback) { auto called{-1}; - registry.add_callback(EventType::GameObjectDeath, [&called](const auto event) { called = event; }); - registry.notify_callbacks(EventType::GameObjectDeath, 0); + registry.add_callback([&called](const auto event) { called = event; }); + registry.notify(0); ASSERT_EQ(called, 0); } diff --git a/src/hades_extensions/tests/test_game_engine.cpp b/src/hades_extensions/tests/test_game_engine.cpp index dd11e812..1bb64031 100644 --- a/src/hades_extensions/tests/test_game_engine.cpp +++ b/src/hades_extensions/tests/test_game_engine.cpp @@ -45,7 +45,7 @@ TEST_F(GameEngineFixture, TestGameEngineGenerateEnemy) { auto enemy_created{-1}; auto enemy_creation{[&](const GameObjectID enemy_id) { enemy_created = enemy_id; }}; game_engine.create_game_objects(); - game_engine.get_registry()->add_callback(EventType::GameObjectCreation, enemy_creation); + game_engine.get_registry()->add_callback(enemy_creation); game_engine.generate_enemy(); ASSERT_NE(enemy_created, -1); } @@ -72,7 +72,7 @@ TEST_F(GameEngineFixture, TestGameEngineGenerateEnemyLimit) { } auto enemy_created{-1}; auto enemy_creation{[&](const GameObjectID enemy_id) { enemy_created = enemy_id; }}; - game_engine.get_registry()->add_callback(EventType::GameObjectCreation, enemy_creation); + game_engine.get_registry()->add_callback(enemy_creation); game_engine.generate_enemy(); ASSERT_EQ(enemy_created, -1); }