Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

changed to async #398

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#pragma once

#include <moveit_studio_behavior_interface/service_client_behavior_base.hpp>

#include <moveit_studio_behavior_interface/async_behavior_base.hpp>
#include <moveit_pro_ml/onnx_sam2.hpp>
#include <sensor_msgs/msg/image.hpp>
#include <fmt/format.h>

namespace custom_behaviors
{
/**
* @brief Segment an image using the SAM 2 model
*/
class SAM2Segmentation : public moveit_studio::behaviors::SharedResourcesNode<BT::SyncActionNode>
class SAM2Segmentation : public moveit_studio::behaviors::AsyncBehaviorBase
{
public:
/**
Expand All @@ -36,11 +38,9 @@ class SAM2Segmentation : public moveit_studio::behaviors::SharedResourcesNode<BT
*/
static BT::KeyValueVector metadata();

/**
* @brief Implementation of BT::SyncActionNode::tick() for StretchMtc.
* @details This function is where the Behavior performs its work when the behavior tree is being run. Since StretchMtc is derived from BT::SyncActionNode, it is very important that its tick() function always finishes very quickly. If tick() blocks before returning, it will block execution of the entire behavior tree, which may have undesirable consequences for other Behaviors that require a fast update rate to work correctly.
*/
BT::NodeStatus tick() override;
protected:
tl::expected<bool, std::string> doWork() override;


private:
void set_onnx_from_ros_image(const sensor_msgs::msg::Image& image_msg);
Expand All @@ -49,6 +49,15 @@ class SAM2Segmentation : public moveit_studio::behaviors::SharedResourcesNode<BT
std::shared_ptr<moveit_pro_ml::SAM2> sam2_;
moveit_pro_ml::ONNXImage onnx_image_;
sensor_msgs::msg::Image image_;

/** @brief Classes derived from AsyncBehaviorBase must implement getFuture() so that it returns a shared_future class member */
std::shared_future<tl::expected<bool, std::string>>& getFuture() override
{
return future_;
}

/** @brief Classes derived from AsyncBehaviorBase must have this shared_future as a class member */
std::shared_future<tl::expected<bool, std::string>> future_;

};
} // namespace sam2_segmentation
22 changes: 11 additions & 11 deletions src/picknik_ur_site_config/behaviors/src/sam2_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace custom_behaviors
{
SAM2Segmentation::SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config,
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources)
: SharedResourcesNode<BT::SyncActionNode>(name, config, shared_resources)
: moveit_studio::behaviors::AsyncBehaviorBase(name, config, shared_resources)
{
std::filesystem::path package_path = ament_index_cpp::get_package_share_directory("picknik_ur_site_config");
const std::filesystem::path encoder_onnx_file = package_path / "models" / "sam2_hiera_large_encoder.onnx";
Expand Down Expand Up @@ -79,7 +79,7 @@ namespace custom_behaviors
}


BT::NodeStatus SAM2Segmentation::tick()
tl::expected<bool, std::string> SAM2Segmentation::doWork()
{
const auto ports = moveit_studio::behaviors::getRequiredInputs(getInput<sensor_msgs::msg::Image>(kPortImage),
getInput<std::vector<
Expand All @@ -88,15 +88,15 @@ namespace custom_behaviors
// Check that all required input data ports were set.
if (!ports.has_value())
{
spdlog::error("Failed to get required values from input data ports:\n{}", ports.error());
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Failed to get required values from input data ports:\n{}", ports.error());
return tl::make_unexpected(error_message);
}
const auto& [image_msg, points_2d] = ports.value();

if (image_msg.encoding != "rgb8" && image_msg.encoding != "rgba8")
{
spdlog::error("Invalid image message format. Expected (rgb8, rgba8) got :\n{}", image_msg.encoding);
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Invalid image message format. Expected (rgb8, rgba8) got :\n{}", image_msg.encoding);
return tl::make_unexpected(error_message);
}

// create ONNX formatted image tensor from ROS image
Expand All @@ -111,8 +111,8 @@ namespace custom_behaviors

if (image_msg.encoding != "rgb8")
{
spdlog::error("Invalid image message format. Expected `rgb8` got :\n{}", image_msg.encoding);
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Invalid image message format. Expected `rgb8` got :\n{}", image_msg.encoding);
return tl::make_unexpected(error_message);
}
try
{
Expand All @@ -129,11 +129,11 @@ namespace custom_behaviors
}
catch (std::invalid_argument& e)
{
spdlog::error("Invalid argument: {}", e.what());
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Invalid argument: {}", e.what());
return tl::make_unexpected(error_message);
}

return BT::NodeStatus::SUCCESS;
return true;
}

BT::KeyValueVector SAM2Segmentation::metadata()
Expand Down
Loading