Skip to content

Commit

Permalink
Updated GenericSubscription to AnySubscriptionCallback (ros2#1928)
Browse files Browse the repository at this point in the history
* added rclcpp::SerializedMessage support for AnySubscriptionCallback

Signed-off-by: Joshua Hampp <[email protected]>
Signed-off-by: Joshua Hampp <[email protected]>

* using AnySubscription callback for generic subscriptiion

Signed-off-by: Joshua Hampp <[email protected]>
Signed-off-by: Joshua Hampp <[email protected]>

* updated tests

Signed-off-by: Joshua Hampp <[email protected]>
Signed-off-by: Joshua Hampp <[email protected]>

* Remove comment

Signed-off-by: Joshua Hampp <[email protected]>

---------

Signed-off-by: Joshua Hampp <[email protected]>
Signed-off-by: Joshua Hampp <[email protected]>
Co-authored-by: Joshua Hampp <[email protected]>
Co-authored-by: Jacob Perron <[email protected]>
Signed-off-by: Oren Bell <[email protected]>
  • Loading branch information
3 people authored and nightduck committed Jan 25, 2024
1 parent 7afe91c commit 8945689
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 21 deletions.
34 changes: 29 additions & 5 deletions rclcpp/include/rclcpp/any_subscription_callback.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "rclcpp/detail/subscription_callback_type_helper.hpp"
#include "rclcpp/function_traits.hpp"
#include "rclcpp/message_info.hpp"
#include "rclcpp/serialization.hpp"
#include "rclcpp/serialized_message.hpp"
#include "rclcpp/type_adapter.hpp"

Expand Down Expand Up @@ -158,13 +159,14 @@ struct AnySubscriptionCallbackPossibleTypes
template<
typename MessageT,
typename AllocatorT,
bool is_adapted_type = rclcpp::TypeAdapter<MessageT>::is_specialized::value
bool is_adapted_type = rclcpp::TypeAdapter<MessageT>::is_specialized::value,
bool is_serialized_type = serialization_traits::is_serialized_message_class<MessageT>::value
>
struct AnySubscriptionCallbackHelper;

/// Specialization for when MessageT is not a TypeAdapter.
template<typename MessageT, typename AllocatorT>
struct AnySubscriptionCallbackHelper<MessageT, AllocatorT, false>
struct AnySubscriptionCallbackHelper<MessageT, AllocatorT, false, false>
{
using CallbackTypes = AnySubscriptionCallbackPossibleTypes<MessageT, AllocatorT>;

Expand Down Expand Up @@ -194,7 +196,7 @@ struct AnySubscriptionCallbackHelper<MessageT, AllocatorT, false>

/// Specialization for when MessageT is a TypeAdapter.
template<typename MessageT, typename AllocatorT>
struct AnySubscriptionCallbackHelper<MessageT, AllocatorT, true>
struct AnySubscriptionCallbackHelper<MessageT, AllocatorT, true, false>
{
using CallbackTypes = AnySubscriptionCallbackPossibleTypes<MessageT, AllocatorT>;

Expand Down Expand Up @@ -232,6 +234,26 @@ struct AnySubscriptionCallbackHelper<MessageT, AllocatorT, true>
>;
};

/// Specialization for when MessageT is a SerializedMessage to exclude duplicated declarations.
template<typename MessageT, typename AllocatorT>
struct AnySubscriptionCallbackHelper<MessageT, AllocatorT, false, true>
{
using CallbackTypes = AnySubscriptionCallbackPossibleTypes<MessageT, AllocatorT>;

using variant_type = std::variant<
typename CallbackTypes::ConstRefSerializedMessageCallback,
typename CallbackTypes::ConstRefSerializedMessageWithInfoCallback,
typename CallbackTypes::UniquePtrSerializedMessageCallback,
typename CallbackTypes::UniquePtrSerializedMessageWithInfoCallback,
typename CallbackTypes::SharedConstPtrSerializedMessageCallback,
typename CallbackTypes::SharedConstPtrSerializedMessageWithInfoCallback,
typename CallbackTypes::ConstRefSharedConstPtrSerializedMessageCallback,
typename CallbackTypes::ConstRefSharedConstPtrSerializedMessageWithInfoCallback,
typename CallbackTypes::SharedPtrSerializedMessageCallback,
typename CallbackTypes::SharedPtrSerializedMessageWithInfoCallback
>;
};

} // namespace detail

template<
Expand Down Expand Up @@ -487,7 +509,9 @@ class AnySubscriptionCallback
}

// Dispatch when input is a ros message and the output could be anything.
void
template<typename TMsg = ROSMessageType>
typename std::enable_if<!serialization_traits::is_serialized_message_class<TMsg>::value,
void>::type
dispatch(
std::shared_ptr<ROSMessageType> message,
const rclcpp::MessageInfo & message_info)
Expand Down Expand Up @@ -589,7 +613,7 @@ class AnySubscriptionCallback
// Dispatch when input is a serialized message and the output could be anything.
void
dispatch(
std::shared_ptr<rclcpp::SerializedMessage> serialized_message,
std::shared_ptr<const rclcpp::SerializedMessage> serialized_message,
const rclcpp::MessageInfo & message_info)
{
TRACETOOLS_TRACEPOINT(callback_start, static_cast<const void *>(this), false);
Expand Down
15 changes: 12 additions & 3 deletions rclcpp/include/rclcpp/create_generic_subscription.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ namespace rclcpp
* Not all publisher options are currently respected, the only relevant options for this
* publisher are `event_callbacks`, `use_default_callbacks`, and `%callback_group`.
*/
template<typename AllocatorT = std::allocator<void>>
template<
typename CallbackT,
typename AllocatorT = std::allocator<void>>
std::shared_ptr<GenericSubscription> create_generic_subscription(
rclcpp::node_interfaces::NodeTopicsInterface::SharedPtr topics_interface,
const std::string & topic_name,
const std::string & topic_type,
const rclcpp::QoS & qos,
std::function<void(std::shared_ptr<rclcpp::SerializedMessage>)> callback,
CallbackT && callback,
const rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> & options = (
rclcpp::SubscriptionOptionsWithAllocator<AllocatorT>()
)
Expand All @@ -60,13 +62,20 @@ std::shared_ptr<GenericSubscription> create_generic_subscription(
auto ts_lib = rclcpp::get_typesupport_library(
topic_type, "rosidl_typesupport_cpp");

auto allocator = options.get_allocator();

using rclcpp::AnySubscriptionCallback;
AnySubscriptionCallback<rclcpp::SerializedMessage, AllocatorT>
any_subscription_callback(*allocator);
any_subscription_callback.set(std::forward<CallbackT>(callback));

auto subscription = std::make_shared<GenericSubscription>(
topics_interface->get_node_base_interface(),
std::move(ts_lib),
topic_name,
topic_type,
qos,
callback,
any_subscription_callback,
options);

topics_interface->add_subscription(subscription, options.callback_group);
Expand Down
13 changes: 9 additions & 4 deletions rclcpp/include/rclcpp/generic_subscription.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ class GenericSubscription : public rclcpp::SubscriptionBase
const std::string & topic_name,
const std::string & topic_type,
const rclcpp::QoS & qos,
// TODO(nnmm): Add variant for callback with message info. See issue #1604.
std::function<void(std::shared_ptr<rclcpp::SerializedMessage>)> callback,
AnySubscriptionCallback<rclcpp::SerializedMessage, AllocatorT> callback,
const rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> & options)
: SubscriptionBase(
node_base,
Expand All @@ -85,7 +84,11 @@ class GenericSubscription : public rclcpp::SubscriptionBase
options.event_callbacks,
options.use_default_callbacks,
DeliveredMessageKind::SERIALIZED_MESSAGE),
callback_(callback),
callback_([callback](
std::shared_ptr<const rclcpp::SerializedMessage> serialized_message,
const rclcpp::MessageInfo & message_info) mutable {
callback.dispatch(serialized_message, message_info);
}),
ts_lib_(ts_lib)
{}

Expand Down Expand Up @@ -151,7 +154,9 @@ class GenericSubscription : public rclcpp::SubscriptionBase
private:
RCLCPP_DISABLE_COPY(GenericSubscription)

std::function<void(std::shared_ptr<rclcpp::SerializedMessage>)> callback_;
std::function<void(
std::shared_ptr<const rclcpp::SerializedMessage>,
const rclcpp::MessageInfo)> callback_;
// The type support library should stay loaded, so it is stored in the GenericSubscription
std::shared_ptr<rcpputils::SharedLibrary> ts_lib_;
};
Expand Down
6 changes: 4 additions & 2 deletions rclcpp/include/rclcpp/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,14 @@ class Node : public std::enable_shared_from_this<Node>
* `%callback_group`.
* \return Shared pointer to the created generic subscription.
*/
template<typename AllocatorT = std::allocator<void>>
template<
typename CallbackT,
typename AllocatorT = std::allocator<void>>
std::shared_ptr<rclcpp::GenericSubscription> create_generic_subscription(
const std::string & topic_name,
const std::string & topic_type,
const rclcpp::QoS & qos,
std::function<void(std::shared_ptr<rclcpp::SerializedMessage>)> callback,
CallbackT && callback,
const rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> & options = (
rclcpp::SubscriptionOptionsWithAllocator<AllocatorT>()
)
Expand Down
6 changes: 3 additions & 3 deletions rclcpp/include/rclcpp/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,21 @@ Node::create_generic_publisher(
);
}

template<typename AllocatorT>
template<typename CallbackT, typename AllocatorT>
std::shared_ptr<rclcpp::GenericSubscription>
Node::create_generic_subscription(
const std::string & topic_name,
const std::string & topic_type,
const rclcpp::QoS & qos,
std::function<void(std::shared_ptr<rclcpp::SerializedMessage>)> callback,
CallbackT && callback,
const rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> & options)
{
return rclcpp::create_generic_subscription(
node_topics_,
extend_name_with_sub_namespace(topic_name, this->get_sub_namespace()),
topic_type,
qos,
std::move(callback),
std::forward<CallbackT>(callback),
options
);
}
Expand Down
4 changes: 2 additions & 2 deletions rclcpp/src/rclcpp/generic_subscription.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ GenericSubscription::handle_message(
void
GenericSubscription::handle_serialized_message(
const std::shared_ptr<rclcpp::SerializedMessage> & message,
const rclcpp::MessageInfo &)
const rclcpp::MessageInfo & message_info)
{
callback_(message);
callback_(message, message_info);
}

void
Expand Down
50 changes: 48 additions & 2 deletions rclcpp/test/rclcpp/test_generic_pubsub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class RclcppGenericNodeFixture : public Test
size_t counter = 0;
auto subscription = node_->create_generic_subscription(
topic_name, type, rclcpp::QoS(1),
[&counter, &messages, this](std::shared_ptr<rclcpp::SerializedMessage> message) {
[&counter, &messages, this](const std::shared_ptr<const rclcpp::SerializedMessage> message) {
T2 deserialized_message;
rclcpp::Serialization<T2> serializer;
serializer.deserialize_message(message.get(), &deserialized_message);
Expand Down Expand Up @@ -236,7 +236,7 @@ TEST_F(RclcppGenericNodeFixture, generic_subscription_uses_qos)
auto publisher = node_->create_publisher<test_msgs::msg::Strings>(topic_name, qos);
auto subscription = node_->create_generic_subscription(
topic_name, topic_type, qos,
[](std::shared_ptr<rclcpp::SerializedMessage>/* message */) {});
[](std::shared_ptr<const rclcpp::SerializedMessage>/* message */) {});
auto connected = [publisher, subscription]() -> bool {
return publisher->get_subscription_count() && subscription->get_publisher_count();
};
Expand All @@ -263,3 +263,49 @@ TEST_F(RclcppGenericNodeFixture, generic_publisher_uses_qos)
// It normally takes < 20ms, 5s chosen as "a very long time"
ASSERT_TRUE(wait_for(connected, 5s));
}

TEST_F(RclcppGenericNodeFixture, generic_subscription_different_callbacks)
{
using namespace std::chrono_literals;
std::string topic_name = "string_topic";
std::string topic_type = "test_msgs/msg/Strings";
rclcpp::QoS qos = rclcpp::QoS(1);

auto publisher = node_->create_publisher<test_msgs::msg::Strings>(topic_name, qos);

// Test shared_ptr for const messages
{
auto subscription = node_->create_generic_subscription(
topic_name, topic_type, qos,
[](const std::shared_ptr<const rclcpp::SerializedMessage>/* message */) {});
auto connected = [publisher, subscription]() -> bool {
return publisher->get_subscription_count() && subscription->get_publisher_count();
};
// It normally takes < 20ms, 5s chosen as "a very long time"
ASSERT_TRUE(wait_for(connected, 5s));
}

// Test unique_ptr
{
auto subscription = node_->create_generic_subscription(
topic_name, topic_type, qos,
[](std::unique_ptr<rclcpp::SerializedMessage>/* message */) {});
auto connected = [publisher, subscription]() -> bool {
return publisher->get_subscription_count() && subscription->get_publisher_count();
};
// It normally takes < 20ms, 5s chosen as "a very long time"
ASSERT_TRUE(wait_for(connected, 5s));
}

// Test message callback
{
auto subscription = node_->create_generic_subscription(
topic_name, topic_type, qos,
[](rclcpp::SerializedMessage /* message */) {});
auto connected = [publisher, subscription]() -> bool {
return publisher->get_subscription_count() && subscription->get_publisher_count();
};
// It normally takes < 20ms, 5s chosen as "a very long time"
ASSERT_TRUE(wait_for(connected, 5s));
}
}

0 comments on commit 8945689

Please sign in to comment.