Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/async-sam2' into pr-add-sam2-beh…
Browse files Browse the repository at this point in the history
…avior
  • Loading branch information
pac48 committed Nov 15, 2024
2 parents 23dc1f1 + 5890d0b commit 99dddf6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#pragma once

#include <moveit_studio_behavior_interface/service_client_behavior_base.hpp>

#include <moveit_studio_behavior_interface/async_behavior_base.hpp>
#include <moveit_pro_ml/onnx_sam2.hpp>
#include <sensor_msgs/msg/image.hpp>
#include <fmt/format.h>

namespace custom_behaviors
{
/**
* @brief Segment an image using the SAM 2 model
*/
class SAM2Segmentation : public moveit_studio::behaviors::SharedResourcesNode<BT::SyncActionNode>
class SAM2Segmentation : public moveit_studio::behaviors::AsyncBehaviorBase
{
public:
/**
Expand All @@ -36,11 +38,9 @@ class SAM2Segmentation : public moveit_studio::behaviors::SharedResourcesNode<BT
*/
static BT::KeyValueVector metadata();

/**
* @brief Implementation of BT::SyncActionNode::tick() for StretchMtc.
* @details This function is where the Behavior performs its work when the behavior tree is being run. Since StretchMtc is derived from BT::SyncActionNode, it is very important that its tick() function always finishes very quickly. If tick() blocks before returning, it will block execution of the entire behavior tree, which may have undesirable consequences for other Behaviors that require a fast update rate to work correctly.
*/
BT::NodeStatus tick() override;
protected:
tl::expected<bool, std::string> doWork() override;


private:
void set_onnx_from_ros_image(const sensor_msgs::msg::Image& image_msg);
Expand All @@ -49,6 +49,15 @@ class SAM2Segmentation : public moveit_studio::behaviors::SharedResourcesNode<BT
std::shared_ptr<moveit_pro_ml::SAM2> sam2_;
moveit_pro_ml::ONNXImage onnx_image_;
sensor_msgs::msg::Image image_;

/** @brief Classes derived from AsyncBehaviorBase must implement getFuture() so that it returns a shared_future class member */
std::shared_future<tl::expected<bool, std::string>>& getFuture() override
{
return future_;
}

/** @brief Classes derived from AsyncBehaviorBase must have this shared_future as a class member */
std::shared_future<tl::expected<bool, std::string>> future_;

};
} // namespace sam2_segmentation
22 changes: 11 additions & 11 deletions src/picknik_ur_site_config/behaviors/src/sam2_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace custom_behaviors
{
SAM2Segmentation::SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config,
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources)
: SharedResourcesNode<BT::SyncActionNode>(name, config, shared_resources)
: moveit_studio::behaviors::AsyncBehaviorBase(name, config, shared_resources)
{
std::filesystem::path package_path = ament_index_cpp::get_package_share_directory("picknik_ur_site_config");
const std::filesystem::path encoder_onnx_file = package_path / "models" / "sam2_hiera_large_encoder.onnx";
Expand Down Expand Up @@ -77,7 +77,7 @@ namespace custom_behaviors
}


BT::NodeStatus SAM2Segmentation::tick()
tl::expected<bool, std::string> SAM2Segmentation::doWork()
{
const auto ports = moveit_studio::behaviors::getRequiredInputs(getInput<sensor_msgs::msg::Image>(kPortImage),
getInput<std::vector<
Expand All @@ -86,15 +86,15 @@ namespace custom_behaviors
// Check that all required input data ports were set.
if (!ports.has_value())
{
spdlog::error("Failed to get required values from input data ports:\n{}", ports.error());
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Failed to get required values from input data ports:\n{}", ports.error());
return tl::make_unexpected(error_message);
}
const auto& [image_msg, points_2d] = ports.value();

if (image_msg.encoding != "rgb8")
{
spdlog::error("Invalid image message format. Expected `rgb8` got :\n{}", image_msg.encoding);
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Invalid image message format. Expected `rgb8` got :\n{}", image_msg.encoding);
return tl::make_unexpected(error_message);
}

// create ONNX formatted image tensor from ROS image
Expand All @@ -109,8 +109,8 @@ namespace custom_behaviors

if (image_msg.encoding != "rgb8")
{
spdlog::error("Invalid image message format. Expected `rgb8` got :\n{}", image_msg.encoding);
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Invalid image message format. Expected `rgb8` got :\n{}", image_msg.encoding);
return tl::make_unexpected(error_message);
}
try
{
Expand All @@ -127,11 +127,11 @@ namespace custom_behaviors
}
catch (std::invalid_argument& e)
{
spdlog::error("Invalid argument: {}", e.what());
return BT::NodeStatus::FAILURE;
auto error_message = fmt::format("Invalid argument: {}", e.what());
return tl::make_unexpected(error_message);
}

return BT::NodeStatus::SUCCESS;
return true;
}

BT::KeyValueVector SAM2Segmentation::metadata()
Expand Down

0 comments on commit 99dddf6

Please sign in to comment.