Skip to content

Commit

Permalink
WIP MOT
Browse files Browse the repository at this point in the history
  • Loading branch information
andyblarblar committed Nov 10, 2023
1 parent 148af06 commit 5df9abc
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 24 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ find_package(phnx_msgs REQUIRED)
find_package(OpenCV 4.2.0 REQUIRED)

# Add source for node executable (link non-ros dependencies here)
add_executable(obj_tracker src/ObjTrackerNode.cpp src/ObjTrackerNode_node.cpp src/Hungarian.cpp src/tracker.cpp)
add_executable(obj_tracker src/ObjTrackerNode.cpp src/ObjTrackerNode_node.cpp src/Hungarian.cpp src/tracker.cpp src/mot.cpp)
target_include_directories(obj_tracker PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>)
Expand Down Expand Up @@ -58,6 +58,8 @@ if (BUILD_TESTING)
# Remember to add node source files
src/ObjTrackerNode_node.cpp
src/tracker.cpp
src/Hungarian.cpp
src/mot.cpp
)
ament_target_dependencies(${PROJECT_NAME}-test ${dependencies})
target_include_directories(${PROJECT_NAME}-test PUBLIC
Expand Down
6 changes: 5 additions & 1 deletion include/obj_tracker/ObjTrackerNode_node.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include "rclcpp/rclcpp.hpp"
#include "geometry_msgs/msg/pose_array.hpp"
#include "mot.hpp"
#include "rclcpp/rclcpp.hpp"

class ObjTrackerNode : public rclcpp::Node {
private:
Expand All @@ -11,6 +12,9 @@ class ObjTrackerNode : public rclcpp::Node {
/// Poses to be filtered
rclcpp::Subscription<geometry_msgs::msg::PoseArray>::SharedPtr pose_sub;

/// Multiple object tracker
MOT mot;

public:
ObjTrackerNode(const rclcpp::NodeOptions& options);

Expand Down
63 changes: 63 additions & 0 deletions include/obj_tracker/mot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include "Hungarian.h"
#include "map"
#include "stack"
#include "tracker.hpp"

#include "rclcpp/rclcpp.hpp"

/// Data structure to manage allocating tracker ids
class TrackerManager {
/// Stack of ids to recycle
std::stack<uint64_t> free_ids;
/// Contiguous map of trackers. Holes in ids are tracked by free_ids, but these holes are erased in the map.
/// This allows us to iterate over the trackers easily.
std::map<uint64_t, Tracker> trackers;
/// Max allowed missed frames before deallocation
uint64_t max_missed_frames;

public:
explicit TrackerManager(uint64_t max_missed_frames);

/// Allocates a new tracker
void add_tracker(cv::Point3f inital_point);

/// Predicts the state of all tracks at some timestamp
std::map<uint64_t, cv::Mat> predict_all(double stamp);

/// Removes all trackers over the max frame count
void prune();

/// Corrects the given tracker with a measurement
void correct_tracker(uint64_t id, const cv::Point3f& measurement);

/// Returns all (id, state) of the trackers
[[nodiscard]] std::vector<std::pair<uint64_t, cv::Point3f>> get_all_states() const;
};

/// Multi object tracking implementation
class MOT {
TrackerManager tracks;
HungarianAlgorithm solver{};
/// Largest distance between detection and track for it to be considered a match
double max_cost;

//TODO remove
rclcpp::Node& node;

public:
explicit MOT(uint64_t max_missed_frames, double max_cost, rclcpp::Node& node);

/// Continuously filters detections over time via tracking. Returns (id, state).
std::vector<std::pair<uint64_t, cv::Point3f>> filter(const std::vector<cv::Point3f>& detections, double stamp);
};

template <typename TK, typename TV>
std::vector<TK> extract_keys(std::map<TK, TV> const& input_map) {
std::vector<TK> retval;
for (auto const& element : input_map) {
retval.push_back(element.first);
}
return retval;
}
7 changes: 7 additions & 0 deletions include/obj_tracker/tracker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
class Tracker {
uint64_t id;
cv::KalmanFilter filter{};
/// Last timestamp predicted at
std::optional<double> last_stamp;
/// Number of frames this tracker has predicted with no correction
uint8_t missed_frames = 0;

/// Updates the time between measurements
Expand All @@ -21,4 +23,9 @@ class Tracker {

/// Corrects the filter
cv::Mat correct(const cv::Point3f& point);

/// Returns the current state matrix
[[nodiscard]] cv::Mat get_state() const;

[[nodiscard]] uint64_t get_missed_frames() const;
};
40 changes: 21 additions & 19 deletions src/ObjTrackerNode_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ cv::Point3f pose_to_cv(const geometry_msgs::msg::Pose& point) {
return cv::Point3f{(float)point.position.x, (float)point.position.y, (float)point.position.z};
}

geometry_msgs::msg::Pose mat_to_pose(const cv::Mat& state) {
geometry_msgs::msg::Pose pose{};
pose.position.x = state.at<float>(0);
pose.position.y = state.at<float>(1);
pose.position.z = state.at<float>(2);

return pose;
geometry_msgs::msg::Pose cv_to_pose(const cv::Point3f& point) {
geometry_msgs::msg::Pose p{};
p.position.x = point.x;
p.position.y = point.y;
p.position.z = point.z;
return p;
}

ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options) : Node("ObjTrackerNode", options) {
ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options)
: Node("ObjTrackerNode", options), mot(10, 0.02, *this) {
// Pub Sub
this->pose_sub = this->create_subscription<geometry_msgs::msg::PoseArray>(
"/object_poses", 10, std::bind(&ObjTrackerNode::pose_cb, this, _1));
Expand All @@ -28,21 +28,23 @@ ObjTrackerNode::ObjTrackerNode(const rclcpp::NodeOptions& options) : Node("ObjTr
void ObjTrackerNode::pose_cb(geometry_msgs::msg::PoseArray::SharedPtr msg) {
if (msg->poses.empty()) return;

static Tracker tracker{0, pose_to_cv(msg->poses[0])};

// Predict
rclcpp::Time t = msg->header.stamp;
tracker.predict(t.seconds());
// Convert to opencv types
rclcpp::Time stamp = msg->header.stamp;
std::vector<cv::Point3f> detections;
for (auto& pose : msg->poses) {
detections.push_back(pose_to_cv(pose));
}

// Correct
auto state = tracker.correct(pose_to_cv(msg->poses[0]));
// Filter poses
auto filtered = this->mot.filter(detections, stamp.seconds());

// Create filtered poses to publish
auto pose = mat_to_pose(state);
// Convert from cv to ros
geometry_msgs::msg::PoseArray filtered_arr{};
filtered_arr.poses.push_back(pose);
for (auto& [id, point] : filtered) {
RCLCPP_INFO(this->get_logger(), "id: %lu pose: %f %f %f", id, point.x, point.y, point.z);
filtered_arr.poses.push_back(cv_to_pose(point));
}

filtered_arr.header = msg->header;

this->pose_pub->publish(filtered_arr);
}
145 changes: 145 additions & 0 deletions src/mot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#include "obj_tracker/mot.hpp"

TrackerManager::TrackerManager(uint64_t max_missed_frames) : max_missed_frames(max_missed_frames) {}

void TrackerManager::add_tracker(cv::Point3f inital_point) {
uint64_t id;

// If no ids to recycle, then allocate next highest id
if (this->free_ids.empty()) {
id = this->trackers.size();
} else {
id = this->free_ids.top();
this->free_ids.pop();
}

this->trackers.try_emplace(id, id, inital_point);
}

std::map<uint64_t, cv::Mat> TrackerManager::predict_all(double stamp) {
std::map<uint64_t, cv::Mat> pred;

for (auto& [id, track] : this->trackers) {
auto p = track.predict(stamp);

pred.try_emplace(id, p);
}

return pred;
}

void TrackerManager::prune() {
std::vector<uint64_t> to_remove;

for (auto& [id, tracker] : this->trackers) {
if (tracker.get_missed_frames() > this->max_missed_frames) {
// Deallocate later to avoid breaking iter
to_remove.push_back(id);
}
}

for (auto id : to_remove) {
this->trackers.erase(id);
this->free_ids.push(id);
}
}
void TrackerManager::correct_tracker(uint64_t id, const cv::Point3f& measurement) {
this->trackers.at(id).correct(measurement);
}
std::vector<std::pair<uint64_t, cv::Point3f>> TrackerManager::get_all_states() const {
std::vector<std::pair<uint64_t, cv::Point3f>> states;

for (auto& [id, tracker] : this->trackers) {
auto state = tracker.get_state();
cv::Point3f state_v{state.at<float>(0), state.at<float>(1), state.at<float>(2)};

states.emplace_back(id, state_v);
}

return states;
}

MOT::MOT(uint64_t max_missed_frames, double max_cost, rclcpp::Node& node)
: tracks(max_missed_frames), max_cost(max_cost), node(node) {}

std::vector<std::pair<uint64_t, cv::Point3f>> MOT::filter(const std::vector<cv::Point3f>& detections, double stamp) {
// Forward predict all tracks
auto pred = this->tracks.predict_all(stamp);

// TODO if pred and detec are diff sizes, we may need to pad

// Build C_ij cost matrix as dist between each track and detection
std::vector<std::vector<double>> cost{};
for (auto& [i, state] : pred) {
// Extract point from state
cv::Point3f state_v{state.at<float>(0), state.at<float>(1), state.at<float>(2)};

std::vector<double> track_to_detect;
for (auto& det : detections) {
auto dist = cv::norm(state_v - det);
track_to_detect.push_back(dist);
}

cost.push_back(track_to_detect);
}

// Solve the assignment problem between tracks and detections
std::vector<int> assign{};
if (!pred.empty()) {
this->solver.Solve(cost, assign);
}

// Find all missed detections
std::vector<int> skip_set{}; //TODO this definitely doesnt work
for (uint64_t i = 0; i < assign.size(); ++i) {
auto paired_detect = assign[i];
auto paired_cost = cost[i][paired_detect];

if (paired_cost > this->max_cost) {
RCLCPP_INFO(node.get_logger(), "Skipping: %lu", i);
skip_set.push_back((int)i);
}
}

// Map index of assignment vector to actual id
auto assign_to_id = extract_keys(pred);

// Correct trackers with assigned detections
for (uint64_t i = 0; i < assign.size(); ++i) {
// Skip if tracker missed a detection
if (std::find(skip_set.begin(), skip_set.end(), i) != skip_set.end()) {
continue;
}

RCLCPP_INFO(node.get_logger(), "Correcting: %lu", i);
this->tracks.correct_tracker(assign_to_id[i], detections[assign[i]]);
}

// Allocate new tracks for unmatched detections
for (uint64_t j = 0; j < detections.size(); ++j) {
bool new_track = true;

// Try to find detection in assignment - skip
for (uint64_t i = 0; i < assign.size(); ++i) {
if (std::find(skip_set.begin(), skip_set.end(), i) != skip_set.end()) {
continue;
}

if ((uint64_t)assign[i] == j) {
new_track = false;
break;
}
}

if (new_track) {
RCLCPP_INFO(node.get_logger(), "Adding new track!");
this->tracks.add_tracker(detections[j]);
}
}

// Remove invalidated trackers
this->tracks.prune();

// Output states of all tracks
return this->tracks.get_all_states();
}
15 changes: 12 additions & 3 deletions src/tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ Tracker::Tracker(uint64_t id, const cv::Point3f& inital_Point) {
// clang-format on

// Set initial values
state_pre.at<float>(0, 0) = inital_Point.x;
state_pre.at<float>(0, 1) = inital_Point.y;
state_pre.at<float>(0, 2) = inital_Point.z;
state_pre.at<float>(0) = inital_Point.x;
state_pre.at<float>(1) = inital_Point.y;
state_pre.at<float>(2) = inital_Point.z;

state_post.at<float>(0) = inital_Point.x;
state_post.at<float>(1) = inital_Point.y;
state_post.at<float>(2) = inital_Point.z;

this->filter.measurementMatrix = measure;
this->filter.transitionMatrix = trans;
Expand Down Expand Up @@ -63,10 +67,15 @@ void Tracker::update_dt(double stamp) {

cv::Mat Tracker::predict(double stamp) {
this->update_dt(stamp);
this->missed_frames++;

return this->filter.predict();
}

cv::Mat Tracker::correct(const cv::Point3f& point) {
this->missed_frames = 0;
return this->filter.correct((cv::Mat_<float>(3, 1) << point.x, point.y, point.z));
}
cv::Mat Tracker::get_state() const { return filter.statePost; }

uint64_t Tracker::get_missed_frames() const { return this->missed_frames; }

0 comments on commit 5df9abc

Please sign in to comment.