diff --git a/libraries/Network/src/NetworkEvents.cpp b/libraries/Network/src/NetworkEvents.cpp index 3643734b413..bb02282e9b3 100644 --- a/libraries/Network/src/NetworkEvents.cpp +++ b/libraries/Network/src/NetworkEvents.cpp @@ -134,10 +134,73 @@ void NetworkEvents::checkForEvent() { free(event); } +uint32_t NetworkEvents::findEvent(NetworkEventCb cbEvent, arduino_event_id_t event) { + uint32_t i; + + if (!cbEvent) { + return cbEventList.size(); + } + + for (i = 0; i < cbEventList.size(); i++) { + NetworkEventCbList_t entry = cbEventList[i]; + if (entry.cb == cbEvent && entry.event == event) { + break; + } + } + return i; +} + +template static size_t getStdFunctionAddress(std::function f) { + typedef T(fnType)(U...); + fnType **fnPointer = f.template target(); + if (fnPointer != nullptr) { + return (size_t)*fnPointer; + } + return (size_t)fnPointer; +} + +uint32_t NetworkEvents::findEvent(NetworkEventFuncCb cbEvent, arduino_event_id_t event) { + uint32_t i; + + if (!cbEvent) { + return cbEventList.size(); + } + + for (i = 0; i < cbEventList.size(); i++) { + NetworkEventCbList_t entry = cbEventList[i]; + if (getStdFunctionAddress(entry.fcb) == getStdFunctionAddress(cbEvent) && entry.event == event) { + break; + } + } + return i; +} + +uint32_t NetworkEvents::findEvent(NetworkEventSysCb cbEvent, arduino_event_id_t event) { + uint32_t i; + + if (!cbEvent) { + return cbEventList.size(); + } + + for (i = 0; i < cbEventList.size(); i++) { + NetworkEventCbList_t entry = cbEventList[i]; + if (entry.scb == cbEvent && entry.event == event) { + break; + } + } + return i; +} + network_event_handle_t NetworkEvents::onEvent(NetworkEventCb cbEvent, arduino_event_id_t event) { if (!cbEvent) { return 0; } + + if (findEvent(cbEvent, event) < cbEventList.size()) { + log_w("Attempt to add duplicate event handler!"); + return 0; + } + NetworkEventCbList_t newEventHandler; newEventHandler.cb = cbEvent; newEventHandler.fcb = NULL; @@ -151,6 +214,12 @@ network_event_handle_t NetworkEvents::onEvent(NetworkEventFuncCb cbEvent, arduin if (!cbEvent) { return 0; } + + if (findEvent(cbEvent, event) < cbEventList.size()) { + log_w("Attempt to add duplicate event handler!"); + return 0; + } + NetworkEventCbList_t newEventHandler; newEventHandler.cb = NULL; newEventHandler.fcb = cbEvent; @@ -164,6 +233,12 @@ network_event_handle_t NetworkEvents::onEvent(NetworkEventSysCb cbEvent, arduino if (!cbEvent) { return 0; } + + if (findEvent(cbEvent, event) < cbEventList.size()) { + log_w("Attempt to add duplicate event handler!"); + return 0; + } + NetworkEventCbList_t newEventHandler; newEventHandler.cb = NULL; newEventHandler.fcb = NULL; @@ -177,6 +252,12 @@ network_event_handle_t NetworkEvents::onSysEvent(NetworkEventCb cbEvent, arduino if (!cbEvent) { return 0; } + + if (findEvent(cbEvent, event) < cbEventList.size()) { + log_w("Attempt to add duplicate event handler!"); + return 0; + } + NetworkEventCbList_t newEventHandler; newEventHandler.cb = cbEvent; newEventHandler.fcb = NULL; @@ -190,6 +271,12 @@ network_event_handle_t NetworkEvents::onSysEvent(NetworkEventFuncCb cbEvent, ard if (!cbEvent) { return 0; } + + if (findEvent(cbEvent, event) < cbEventList.size()) { + log_w("Attempt to add duplicate event handler!"); + return 0; + } + NetworkEventCbList_t newEventHandler; newEventHandler.cb = NULL; newEventHandler.fcb = cbEvent; @@ -203,6 +290,12 @@ network_event_handle_t NetworkEvents::onSysEvent(NetworkEventSysCb cbEvent, ardu if (!cbEvent) { return 0; } + + if (findEvent(cbEvent, event) < cbEventList.size()) { + log_w("Attempt to add duplicate event handler!"); + return 0; + } + NetworkEventCbList_t newEventHandler; newEventHandler.cb = NULL; newEventHandler.fcb = NULL; @@ -213,51 +306,51 @@ network_event_handle_t NetworkEvents::onSysEvent(NetworkEventSysCb cbEvent, ardu } void NetworkEvents::removeEvent(NetworkEventCb cbEvent, arduino_event_id_t event) { + uint32_t i; + if (!cbEvent) { return; } - for (uint32_t i = 0; i < cbEventList.size(); i++) { - NetworkEventCbList_t entry = cbEventList[i]; - if (entry.cb == cbEvent && entry.event == event) { - cbEventList.erase(cbEventList.begin() + i); - } + i = findEvent(cbEvent, event); + if (i >= cbEventList.size()) { + log_w("Event handler not found!"); + return; } -} -template static size_t getStdFunctionAddress(std::function f) { - typedef T(fnType)(U...); - fnType **fnPointer = f.template target(); - if (fnPointer != nullptr) { - return (size_t)*fnPointer; - } - return (size_t)fnPointer; + cbEventList.erase(cbEventList.begin() + i); } void NetworkEvents::removeEvent(NetworkEventFuncCb cbEvent, arduino_event_id_t event) { + uint32_t i; + if (!cbEvent) { return; } - for (uint32_t i = 0; i < cbEventList.size(); i++) { - NetworkEventCbList_t entry = cbEventList[i]; - if (getStdFunctionAddress(entry.fcb) == getStdFunctionAddress(cbEvent) && entry.event == event) { - cbEventList.erase(cbEventList.begin() + i); - } + i = findEvent(cbEvent, event); + if (i >= cbEventList.size()) { + log_w("Event handler not found!"); + return; } + + cbEventList.erase(cbEventList.begin() + i); } void NetworkEvents::removeEvent(NetworkEventSysCb cbEvent, arduino_event_id_t event) { + uint32_t i; + if (!cbEvent) { return; } - for (uint32_t i = 0; i < cbEventList.size(); i++) { - NetworkEventCbList_t entry = cbEventList[i]; - if (entry.scb == cbEvent && entry.event == event) { - cbEventList.erase(cbEventList.begin() + i); - } + i = findEvent(cbEvent, event); + if (i >= cbEventList.size()) { + log_w("Event handler not found!"); + return; } + + cbEventList.erase(cbEventList.begin() + i); } void NetworkEvents::removeEvent(network_event_handle_t id) { @@ -265,8 +358,10 @@ void NetworkEvents::removeEvent(network_event_handle_t id) { NetworkEventCbList_t entry = cbEventList[i]; if (entry.id == id) { cbEventList.erase(cbEventList.begin() + i); + return; } } + log_w("Event handler not found!"); } int NetworkEvents::setStatusBits(int bits) { diff --git a/libraries/Network/src/NetworkEvents.h b/libraries/Network/src/NetworkEvents.h index a68fde59572..ac324d19841 100644 --- a/libraries/Network/src/NetworkEvents.h +++ b/libraries/Network/src/NetworkEvents.h @@ -155,6 +155,9 @@ class NetworkEvents { protected: bool initNetworkEvents(); + uint32_t findEvent(NetworkEventCb cbEvent, arduino_event_id_t event); + uint32_t findEvent(NetworkEventFuncCb cbEvent, arduino_event_id_t event); + uint32_t findEvent(NetworkEventSysCb cbEvent, arduino_event_id_t event); network_event_handle_t onSysEvent(NetworkEventCb cbEvent, arduino_event_id_t event = ARDUINO_EVENT_MAX); network_event_handle_t onSysEvent(NetworkEventFuncCb cbEvent, arduino_event_id_t event = ARDUINO_EVENT_MAX); network_event_handle_t onSysEvent(NetworkEventSysCb cbEvent, arduino_event_id_t event = ARDUINO_EVENT_MAX);