diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d18c7e..432c34a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ find_package(phnx_msgs REQUIRED) find_package(OpenCV 4.2.0 REQUIRED) # Add source for node executable (link non-ros dependencies here) -add_executable(obj_tracker src/ObjTrackerNode.cpp src/ObjTrackerNode_node.cpp src/Hungarian.cpp src/tracker.cpp) +add_executable(obj_tracker src/ObjTrackerNode.cpp src/ObjTrackerNode_node.cpp src/Hungarian.cpp src/tracker.cpp src/mot.cpp) target_include_directories(obj_tracker PUBLIC $ $) @@ -58,6 +58,8 @@ if (BUILD_TESTING) # Remember to add node source files src/ObjTrackerNode_node.cpp src/tracker.cpp + src/Hungarian.cpp + src/mot.cpp ) ament_target_dependencies(${PROJECT_NAME}-test ${dependencies}) target_include_directories(${PROJECT_NAME}-test PUBLIC diff --git a/include/obj_tracker/ObjTrackerNode_node.hpp b/include/obj_tracker/ObjTrackerNode_node.hpp index 64a5a1a..d5ffd84 100644 --- a/include/obj_tracker/ObjTrackerNode_node.hpp +++ b/include/obj_tracker/ObjTrackerNode_node.hpp @@ -1,7 +1,8 @@ #pragma once -#include "rclcpp/rclcpp.hpp" #include "geometry_msgs/msg/pose_array.hpp" +#include "mot.hpp" +#include "rclcpp/rclcpp.hpp" class ObjTrackerNode : public rclcpp::Node { private: @@ -11,6 +12,9 @@ class ObjTrackerNode : public rclcpp::Node { /// Poses to be filtered rclcpp::Subscription::SharedPtr pose_sub; + /// Multiple object tracker + MOT mot; + public: ObjTrackerNode(const rclcpp::NodeOptions& options); diff --git a/include/obj_tracker/mot.hpp b/include/obj_tracker/mot.hpp new file mode 100644 index 0000000..b2e93d2 --- /dev/null +++ b/include/obj_tracker/mot.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include "Hungarian.h" +#include "map" +#include "stack" +#include "tracker.hpp" + +#include "rclcpp/rclcpp.hpp" + +/// Data structure to manage allocating tracker ids +class TrackerManager { + /// Stack of ids to recycle + std::stack free_ids; + /// Contiguous map of trackers. Holes in ids are tracked by free_ids, but these holes are erased in the map. + /// This allows us to iterate over the trackers easily. + std::map trackers; + /// Max allowed missed frames before deallocation + uint64_t max_missed_frames; + +public: + explicit TrackerManager(uint64_t max_missed_frames); + + /// Allocates a new tracker + void add_tracker(cv::Point3f inital_point); + + /// Predicts the state of all tracks at some timestamp + std::map predict_all(double stamp); + + /// Removes all trackers over the max frame count + void prune(); + + /// Corrects the given tracker with a measurement + void correct_tracker(uint64_t id, const cv::Point3f& measurement); + + /// Returns all (id, state) of the trackers + [[nodiscard]] std::vector> get_all_states() const; +}; + +/// Multi object tracking implementation +class MOT { + TrackerManager tracks; + HungarianAlgorithm solver{}; + /// Largest distance between detection and track for it to be considered a match + double max_cost; + + //TODO remove + rclcpp::Node& node; + +public: + explicit MOT(uint64_t max_missed_frames, double max_cost, rclcpp::Node& node); + + /// Continuously filters detections over time via tracking. Returns (id, state). + std::vector> filter(const std::vector& detections, double stamp); +}; + +template +std::vector extract_keys(std::map const& input_map) { + std::vector retval; + for (auto const& element : input_map) { + retval.push_back(element.first); + } + return retval; +} \ No newline at end of file diff --git a/include/obj_tracker/tracker.hpp b/include/obj_tracker/tracker.hpp index a6b89fd..c5b59b9 100644 --- a/include/obj_tracker/tracker.hpp +++ b/include/obj_tracker/tracker.hpp @@ -7,7 +7,9 @@ class Tracker { uint64_t id; cv::KalmanFilter filter{}; + /// Last timestamp predicted at std::optional last_stamp; + /// Number of frames this tracker has predicted with no correction uint8_t missed_frames = 0; /// Updates the time between measurements @@ -21,4 +23,9 @@ class Tracker { /// Corrects the filter cv::Mat correct(const cv::Point3f& point); + + /// Returns the current state matrix + [[nodiscard]] cv::Mat get_state() const; + + [[nodiscard]] uint64_t get_missed_frames() const; }; \ No newline at end of file diff --git a/src/ObjTrackerNode_node.cpp b/src/ObjTrackerNode_node.cpp index f4490e0..2067214 100644 --- a/src/ObjTrackerNode_node.cpp +++ b/src/ObjTrackerNode_node.cpp @@ -9,16 +9,16 @@ cv::Point3f pose_to_cv(const geometry_msgs::msg::Pose& point) { return cv::Point3f{(float)point.position.x, (float)point.position.y, (float)point.position.z}; } -geometry_msgs::msg::Pose mat_to_pose(const cv::Mat& state) { - geometry_msgs::msg::Pose pose{}; - pose.position.x = state.at(0); - pose.position.y = state.at(1); - pose.position.z = state.at(2); - - return pose; +geometry_msgs::msg::Pose cv_to_pose(const cv::Point3f& point) { + geometry_msgs::msg::Pose p{}; + p.position.x = point.x; + p.position.y = point.y; + p.position.z = point.z; + return p; } -ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options) : Node("ObjTrackerNode", options) { +ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options) + : Node("ObjTrackerNode", options), mot(10, 0.02, *this) { // Pub Sub this->pose_sub = this->create_subscription( "/object_poses", 10, std::bind(&ObjTrackerNode::pose_cb, this, _1)); @@ -28,21 +28,23 @@ ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options) : Node("ObjTr void ObjTrackerNode::pose_cb(geometry_msgs::msg::PoseArray::SharedPtr msg) { if (msg->poses.empty()) return; - static Tracker tracker{0, pose_to_cv(msg->poses[0])}; - - // Predict - rclcpp::Time t = msg->header.stamp; - tracker.predict(t.seconds()); + // Convert to opencv types + rclcpp::Time stamp = msg->header.stamp; + std::vector detections; + for (auto& pose : msg->poses) { + detections.push_back(pose_to_cv(pose)); + } - // Correct - auto state = tracker.correct(pose_to_cv(msg->poses[0])); + // Filter poses + auto filtered = this->mot.filter(detections, stamp.seconds()); - // Create filtered poses to publish - auto pose = mat_to_pose(state); + // Convert from cv to ros geometry_msgs::msg::PoseArray filtered_arr{}; - filtered_arr.poses.push_back(pose); + for (auto& [id, point] : filtered) { + RCLCPP_INFO(this->get_logger(), "id: %lu pose: %f %f %f", id, point.x, point.y, point.z); + filtered_arr.poses.push_back(cv_to_pose(point)); + } filtered_arr.header = msg->header; - this->pose_pub->publish(filtered_arr); } diff --git a/src/mot.cpp b/src/mot.cpp new file mode 100644 index 0000000..fd010e7 --- /dev/null +++ b/src/mot.cpp @@ -0,0 +1,145 @@ +#include "obj_tracker/mot.hpp" + +TrackerManager::TrackerManager(uint64_t max_missed_frames) : max_missed_frames(max_missed_frames) {} + +void TrackerManager::add_tracker(cv::Point3f inital_point) { + uint64_t id; + + // If no ids to recycle, then allocate next highest id + if (this->free_ids.empty()) { + id = this->trackers.size(); + } else { + id = this->free_ids.top(); + this->free_ids.pop(); + } + + this->trackers.try_emplace(id, id, inital_point); +} + +std::map TrackerManager::predict_all(double stamp) { + std::map pred; + + for (auto& [id, track] : this->trackers) { + auto p = track.predict(stamp); + + pred.try_emplace(id, p); + } + + return pred; +} + +void TrackerManager::prune() { + std::vector to_remove; + + for (auto& [id, tracker] : this->trackers) { + if (tracker.get_missed_frames() > this->max_missed_frames) { + // Deallocate later to avoid breaking iter + to_remove.push_back(id); + } + } + + for (auto id : to_remove) { + this->trackers.erase(id); + this->free_ids.push(id); + } +} +void TrackerManager::correct_tracker(uint64_t id, const cv::Point3f& measurement) { + this->trackers.at(id).correct(measurement); +} +std::vector> TrackerManager::get_all_states() const { + std::vector> states; + + for (auto& [id, tracker] : this->trackers) { + auto state = tracker.get_state(); + cv::Point3f state_v{state.at(0), state.at(1), state.at(2)}; + + states.emplace_back(id, state_v); + } + + return states; +} + +MOT::MOT(uint64_t max_missed_frames, double max_cost, rclcpp::Node& node) + : tracks(max_missed_frames), max_cost(max_cost), node(node) {} + +std::vector> MOT::filter(const std::vector& detections, double stamp) { + // Forward predict all tracks + auto pred = this->tracks.predict_all(stamp); + + // TODO if pred and detec are diff sizes, we may need to pad + + // Build C_ij cost matrix as dist between each track and detection + std::vector> cost{}; + for (auto& [i, state] : pred) { + // Extract point from state + cv::Point3f state_v{state.at(0), state.at(1), state.at(2)}; + + std::vector track_to_detect; + for (auto& det : detections) { + auto dist = cv::norm(state_v - det); + track_to_detect.push_back(dist); + } + + cost.push_back(track_to_detect); + } + + // Solve the assignment problem between tracks and detections + std::vector assign{}; + if (!pred.empty()) { + this->solver.Solve(cost, assign); + } + + // Find all missed detections + std::vector skip_set{}; //TODO this definitely doesnt work + for (uint64_t i = 0; i < assign.size(); ++i) { + auto paired_detect = assign[i]; + auto paired_cost = cost[i][paired_detect]; + + if (paired_cost > this->max_cost) { + RCLCPP_INFO(node.get_logger(), "Skipping: %lu", i); + skip_set.push_back((int)i); + } + } + + // Map index of assignment vector to actual id + auto assign_to_id = extract_keys(pred); + + // Correct trackers with assigned detections + for (uint64_t i = 0; i < assign.size(); ++i) { + // Skip if tracker missed a detection + if (std::find(skip_set.begin(), skip_set.end(), i) != skip_set.end()) { + continue; + } + + RCLCPP_INFO(node.get_logger(), "Correcting: %lu", i); + this->tracks.correct_tracker(assign_to_id[i], detections[assign[i]]); + } + + // Allocate new tracks for unmatched detections + for (uint64_t j = 0; j < detections.size(); ++j) { + bool new_track = true; + + // Try to find detection in assignment - skip + for (uint64_t i = 0; i < assign.size(); ++i) { + if (std::find(skip_set.begin(), skip_set.end(), i) != skip_set.end()) { + continue; + } + + if ((uint64_t)assign[i] == j) { + new_track = false; + break; + } + } + + if (new_track) { + RCLCPP_INFO(node.get_logger(), "Adding new track!"); + this->tracks.add_tracker(detections[j]); + } + } + + // Remove invalidated trackers + this->tracks.prune(); + + // Output states of all tracks + return this->tracks.get_all_states(); +} diff --git a/src/tracker.cpp b/src/tracker.cpp index e118107..9e0edaf 100644 --- a/src/tracker.cpp +++ b/src/tracker.cpp @@ -33,9 +33,13 @@ Tracker::Tracker(uint64_t id, const cv::Point3f& inital_Point) { // clang-format on // Set initial values - state_pre.at(0, 0) = inital_Point.x; - state_pre.at(0, 1) = inital_Point.y; - state_pre.at(0, 2) = inital_Point.z; + state_pre.at(0) = inital_Point.x; + state_pre.at(1) = inital_Point.y; + state_pre.at(2) = inital_Point.z; + + state_post.at(0) = inital_Point.x; + state_post.at(1) = inital_Point.y; + state_post.at(2) = inital_Point.z; this->filter.measurementMatrix = measure; this->filter.transitionMatrix = trans; @@ -63,10 +67,15 @@ void Tracker::update_dt(double stamp) { cv::Mat Tracker::predict(double stamp) { this->update_dt(stamp); + this->missed_frames++; return this->filter.predict(); } cv::Mat Tracker::correct(const cv::Point3f& point) { + this->missed_frames = 0; return this->filter.correct((cv::Mat_(3, 1) << point.x, point.y, point.z)); } +cv::Mat Tracker::get_state() const { return filter.statePost; } + +uint64_t Tracker::get_missed_frames() const { return this->missed_frames; }