diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a42695..4d18c7e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ find_package(phnx_msgs REQUIRED) find_package(OpenCV 4.2.0 REQUIRED) # Add source for node executable (link non-ros dependencies here) -add_executable(obj_tracker src/ObjTrackerNode.cpp src/ObjTrackerNode_node.cpp src/Hungarian.cpp) +add_executable(obj_tracker src/ObjTrackerNode.cpp src/ObjTrackerNode_node.cpp src/Hungarian.cpp src/tracker.cpp) target_include_directories(obj_tracker PUBLIC $ $) @@ -28,7 +28,7 @@ set(dependencies nav_msgs phnx_msgs OpenCV - ) +) # Link ros dependencies ament_target_dependencies( @@ -57,7 +57,8 @@ if (BUILD_TESTING) tests/unit.cpp # Remember to add node source files src/ObjTrackerNode_node.cpp - ) + src/tracker.cpp + ) ament_target_dependencies(${PROJECT_NAME}-test ${dependencies}) target_include_directories(${PROJECT_NAME}-test PUBLIC $ diff --git a/include/obj_tracker/tracker.hpp b/include/obj_tracker/tracker.hpp new file mode 100644 index 0000000..a6b89fd --- /dev/null +++ b/include/obj_tracker/tracker.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "cstdint" +#include "opencv2/video/tracking.hpp" +#include "optional" + +class Tracker { + uint64_t id; + cv::KalmanFilter filter{}; + std::optional last_stamp; + uint8_t missed_frames = 0; + + /// Updates the time between measurements + void update_dt(double stamp); + +public: + Tracker(uint64_t id, const cv::Point3f& inital_Point); + + /// Predicts the next location of the track + cv::Mat predict(double stamp); + + /// Corrects the filter + cv::Mat correct(const cv::Point3f& point); +}; \ No newline at end of file diff --git a/src/ObjTrackerNode_node.cpp b/src/ObjTrackerNode_node.cpp index 9040927..f4490e0 100644 --- a/src/ObjTrackerNode_node.cpp +++ b/src/ObjTrackerNode_node.cpp @@ -1,8 +1,23 @@ #include "obj_tracker/ObjTrackerNode_node.hpp" +#include "obj_tracker/tracker.hpp" + // For _1 using namespace std::placeholders; +cv::Point3f pose_to_cv(const geometry_msgs::msg::Pose& point) { + return cv::Point3f{(float)point.position.x, (float)point.position.y, (float)point.position.z}; +} + +geometry_msgs::msg::Pose mat_to_pose(const cv::Mat& state) { + geometry_msgs::msg::Pose pose{}; + pose.position.x = state.at(0); + pose.position.y = state.at(1); + pose.position.z = state.at(2); + + return pose; +} + ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options) : Node("ObjTrackerNode", options) { // Pub Sub this->pose_sub = this->create_subscription( @@ -10,4 +25,24 @@ ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options) : Node("ObjTr this->pose_pub = this->create_publisher("/tracks", 1); } -void ObjTrackerNode::pose_cb(geometry_msgs::msg::PoseArray::SharedPtr msg) {} +void ObjTrackerNode::pose_cb(geometry_msgs::msg::PoseArray::SharedPtr msg) { + if (msg->poses.empty()) return; + + static Tracker tracker{0, pose_to_cv(msg->poses[0])}; + + // Predict + rclcpp::Time t = msg->header.stamp; + tracker.predict(t.seconds()); + + // Correct + auto state = tracker.correct(pose_to_cv(msg->poses[0])); + + // Create filtered poses to publish + auto pose = mat_to_pose(state); + geometry_msgs::msg::PoseArray filtered_arr{}; + filtered_arr.poses.push_back(pose); + + filtered_arr.header = msg->header; + + this->pose_pub->publish(filtered_arr); +} diff --git a/src/tracker.cpp b/src/tracker.cpp new file mode 100644 index 0000000..e118107 --- /dev/null +++ b/src/tracker.cpp @@ -0,0 +1,72 @@ +#include "obj_tracker/tracker.hpp" + +Tracker::Tracker(uint64_t id, const cv::Point3f& inital_Point) { + this->id = id; + // state: (x,y,z,vx,vy,vz) measure: (x,y,z) + this->filter.init(6, 3); + + // clang-format off + + // Maps 3d points to the state + cv::Mat_ measure = (cv::Mat_(3, 6) << + 1, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0); + + // Predicts the next location of tracks + // This assumes linear motion, and velocities are estimated + cv::Mat_ trans = (cv::Mat_(6, 6) << + 1, 0, 0, 1, 0, 0, // x + 0, 1, 0, 0, 1, 0, // y + 0, 0, 1, 0, 0, 1, // z + 0, 0, 0, 1, 0, 0, // vx + 0, 0, 0, 0, 1, 0, // vy + 0, 0, 0, 0, 0, 1); // vz + + // Prediction covariance TODO tune + cv::Mat_ proc_noise = cv::Mat::eye(6, 6, CV_32F) * 0.01; + + cv::Mat noise_pre = cv::Mat::eye(6, 6, CV_32F); + cv::Mat state_pre = cv::Mat::zeros(6, 1, CV_32F); + cv::Mat state_post = cv::Mat::zeros(6, 1, CV_32F); + + // clang-format on + + // Set initial values + state_pre.at(0, 0) = inital_Point.x; + state_pre.at(0, 1) = inital_Point.y; + state_pre.at(0, 2) = inital_Point.z; + + this->filter.measurementMatrix = measure; + this->filter.transitionMatrix = trans; + this->filter.processNoiseCov = proc_noise; + this->filter.statePost = state_post; + this->filter.statePre = state_pre; + this->filter.errorCovPre = noise_pre; + cv::setIdentity(filter.measurementNoiseCov, cv::Scalar(1e-1)); +} + +void Tracker::update_dt(double stamp) { + float dt = 0.01; + if (!last_stamp) { + last_stamp = stamp; + } else { + dt = (float)(stamp - *last_stamp); + last_stamp = stamp; + } + + // Update dt for each x, y, z + this->filter.transitionMatrix.at(0, 3) = dt; + this->filter.transitionMatrix.at(1, 4) = dt; + this->filter.transitionMatrix.at(2, 5) = dt; +} + +cv::Mat Tracker::predict(double stamp) { + this->update_dt(stamp); + + return this->filter.predict(); +} + +cv::Mat Tracker::correct(const cv::Point3f& point) { + return this->filter.correct((cv::Mat_(3, 1) << point.x, point.y, point.z)); +}