diff --git a/configs/mmpose/pose-detection_sdk_static.py b/configs/mmpose/pose-detection_sdk_static.py new file mode 100644 index 0000000000..b93c858044 --- /dev/null +++ b/configs/mmpose/pose-detection_sdk_static.py @@ -0,0 +1,14 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/sdk.py'] + +codebase_config = dict(model_type='sdk') + +backend_config = dict(pipeline=[ + dict(type='LoadImageFromFile', channel_order='bgr'), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]) +]) diff --git a/csrc/apis/c/CMakeLists.txt b/csrc/apis/c/CMakeLists.txt index f1809995bb..5709e0c57a 100644 --- a/csrc/apis/c/CMakeLists.txt +++ b/csrc/apis/c/CMakeLists.txt @@ -5,7 +5,7 @@ project(capis) include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake) if ("all" IN_LIST MMDEPLOY_CODEBASES) - set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;restorer;model") + set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;pose_detector;restorer;model") else () set(TASK_LIST "model") if ("mmcls" IN_LIST MMDEPLOY_CODEBASES) @@ -24,6 +24,9 @@ else () list(APPEND TASK_LIST "text_detector") list(APPEND TASK_LIST "text_recognizer") endif () + if ("mmpose" IN_LIST MMDEPLOY_CODEBASES) + list(APPEND TASK_LIST "pose_detector") + endif () endif () foreach (TASK ${TASK_LIST}) diff --git a/csrc/apis/c/pose_detector.cpp b/csrc/apis/c/pose_detector.cpp new file mode 100644 index 0000000000..6c5ef426ef --- /dev/null +++ b/csrc/apis/c/pose_detector.cpp @@ -0,0 +1,190 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "pose_detector.h" + +#include + +#include "codebase/mmpose/mmpose.h" +#include "core/device.h" +#include "core/graph.h" +#include "core/mat.h" +#include "core/tensor.h" +#include "core/utils/formatter.h" +#include "handle.h" + +using namespace std; +using namespace mmdeploy; + +namespace { + +const Value& config_template() { + // clang-format off + static Value v { + { + "pipeline", { + {"input", {"img_with_boxes"}}, + {"output", {"key_points_unflat"}}, + { + "tasks", { + { + {"name", "flatten"}, + {"type", "Flatten"}, + {"input", {"img_with_boxes"}}, + {"output", {"patch_flat", "patch_index"}}, + }, + { + {"name", "pose-detector"}, + {"type", "Inference"}, + {"params", {{"model", "TBD"},{"batch_size", 1}}}, + {"input", {"patch_flat"}}, + {"output", {"key_points"}} + }, + { + {"name", "unflatten"}, + {"type", "Unflatten"}, + {"input", {"key_points", "patch_index"}}, + {"output", {"key_points_unflat"}}, + } + } + } + } + } + }; + // clang-format on + return v; +} + +template +int mmdeploy_pose_detector_create_impl(ModelType&& m, const char* device_name, int device_id, + mm_handle_t* handle) { + try { + auto value = config_template(); + value["pipeline"]["tasks"][1]["params"]["model"] = std::forward(m); + + auto pose_estimator = std::make_unique(device_name, device_id, std::move(value)); + + *handle = pose_estimator.release(); + return MM_SUCCESS; + + } catch (const std::exception& e) { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } catch (...) { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MM_E_FAIL; +} + +} // namespace + +int mmdeploy_pose_detector_create(mm_model_t model, const char* device_name, int device_id, + mm_handle_t* handle) { + return mmdeploy_pose_detector_create_impl(*static_cast(model), device_name, device_id, + handle); +} + +int mmdeploy_pose_detector_create_by_path(const char* model_path, const char* device_name, + int device_id, mm_handle_t* handle) { + return mmdeploy_pose_detector_create_impl(model_path, device_name, device_id, handle); +} + +int mmdeploy_pose_detector_apply(mm_handle_t handle, const mm_mat_t* mats, int mat_count, + mm_pose_detect_t** results) { + return mmdeploy_pose_detector_apply_bbox(handle, mats, mat_count, nullptr, nullptr, results); +} + +int mmdeploy_pose_detector_apply_bbox(mm_handle_t handle, const mm_mat_t* mats, int mat_count, + const mm_rect_t* bboxes, const int* bbox_count, + mm_pose_detect_t** results) { + if (handle == nullptr || mats == nullptr || mat_count == 0 || results == nullptr) { + return MM_E_INVALID_ARG; + } + + try { + auto pose_detector = static_cast(handle); + Value input{Value::kArray}; + auto result_count = 0; + for (int i = 0; i < mat_count; ++i) { + mmdeploy::Mat _mat{mats[i].height, mats[i].width, PixelFormat(mats[i].format), + DataType(mats->type), mats[i].data, Device{"cpu"}}; + + Value img_with_boxes; + if (bboxes && bbox_count) { + for (int j = 0; j < bbox_count[i]; ++j) { + Value obj; + obj["ori_img"] = _mat; + float width = bboxes[j].right - bboxes[j].left + 1; + float height = bboxes[j].bottom - bboxes[j].top + 1; + obj["box"] = {bboxes[j].left, bboxes[j].top, width, height, 1.0}; + obj["rotation"] = 0.f; + img_with_boxes.push_back(obj); + } + bboxes += bbox_count[i]; + result_count += bbox_count[i]; + } else { + // inference whole image + Value obj; + obj["ori_img"] = _mat; + obj["box"] = {0, 0, _mat.width(), _mat.height(), 1.0}; + obj["rotation"] = 0.f; + img_with_boxes.push_back(obj); + result_count += 1; + } + input.front().push_back(img_with_boxes); + } + + auto output = pose_detector->Run(std::move(input)).value().front(); + + auto pose_outputs = from_value>>(output); + + std::vector counts; + if (bboxes && bbox_count) { + counts = std::vector(bbox_count, bbox_count + mat_count); + } else { + counts.resize(mat_count, 1); + } + std::vector offsets{0}; + std::partial_sum(begin(counts), end(counts), back_inserter(offsets)); + + auto deleter = [&](mm_pose_detect_t* p) { + mmdeploy_pose_detector_release_result(p, offsets.back()); + }; + + std::unique_ptr _results( + new mm_pose_detect_t[result_count]{}, deleter); + + for (int i = 0; i < mat_count; ++i) { + auto& pose_output = pose_outputs[i]; + for (int j = 0; j < pose_output.size(); ++j) { + auto& res = _results[offsets[i] + j]; + auto& box_result = pose_output[j]; + int sz = box_result.key_points.size(); + + res.point = new mm_pointf_t[sz]; + res.score = new float[sz]; + res.length = sz; + for (int k = 0; k < sz; k++) { + res.point[k].x = box_result.key_points[k].bbox[0]; + res.point[k].y = box_result.key_points[k].bbox[1]; + res.score[k] = box_result.key_points[k].score; + } + } + } + *results = _results.release(); + return MM_SUCCESS; + + } catch (const std::exception& e) { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } catch (...) { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MM_E_FAIL; +} + +void mmdeploy_pose_detector_release_result(mm_pose_detect_t* results, int count) { + for (int i = 0; i < count; ++i) { + delete[] results[i].point; + delete[] results[i].score; + } + delete[] results; +} +void mmdeploy_pose_detector_destroy(mm_handle_t handle) { delete static_cast(handle); } diff --git a/csrc/apis/c/pose_detector.h b/csrc/apis/c/pose_detector.h new file mode 100644 index 0000000000..16e3e23d26 --- /dev/null +++ b/csrc/apis/c/pose_detector.h @@ -0,0 +1,97 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +/** + * @file pose_detector.h + * @brief Interface to MMPose task + */ + +#ifndef MMDEPLOY_SRC_APIS_C_POSE_DETECTOR_H_ +#define MMDEPLOY_SRC_APIS_C_POSE_DETECTOR_H_ + +#include "common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct mm_pose_detect_t { + mm_pointf_t* point; ///< keypoint + float* score; ///< keypoint score + int length; ///< number of keypoint +} mm_pose_detect_t; + +/** + * @brief Create a pose detector instance + * @param[in] model an instance of mmpose model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] handle handle of the created pose detector, which must be destroyed + * by \ref mmdeploy_pose_detector_destroy + * @return status code of the operation + */ +MMDEPLOY_API int mmdeploy_pose_detector_create(mm_model_t model, const char* device_name, + int device_id, mm_handle_t* handle); + +/** + * @brief Create a pose detector instance + * @param[in] model_path path to pose detection model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] handle handle of the created pose detector, which must be destroyed + * by \ref mmdeploy_pose_detector_destroy + * @return status code of the operation + */ +MMDEPLOY_API int mmdeploy_pose_detector_create_by_path(const char* model_path, + const char* device_name, int device_id, + mm_handle_t* handle); + +/** + * @brief Apply pose detector to a batch of images with full image roi + * @param[in] handle pose detector's handle created by \ref + * mmdeploy_pose_detector_create_by_path + * @param[in] images a batch of images + * @param[in] count number of images in the batch + * @param[out] results a linear buffer contains the pose result, must be release + * by \ref mmdeploy_pose_detector_release_result + * @return status code of the operation + */ +MMDEPLOY_API int mmdeploy_pose_detector_apply(mm_handle_t handle, const mm_mat_t* mats, + int mat_count, mm_pose_detect_t** results); + +/** + * @brief Apply pose detector to a batch of images supplied with bboxes(roi) + * @param[in] handle pose detector's handle created by \ref + * mmdeploy_pose_detector_create_by_path + * @param[in] images a batch of images + * @param[in] image_count number of images in the batch + * @param[in] bboxes bounding boxes(roi) detected by mmdet + * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images + * @param[out] results a linear buffer contains the pose result, which has the same length as \p + * bboxes, must be release by \ref mmdeploy_pose_detector_release_result + * @return status code of the operation + */ +MMDEPLOY_API int mmdeploy_pose_detector_apply_bbox(mm_handle_t handle, const mm_mat_t* mats, + int mat_count, const mm_rect_t* bboxes, + const int* bbox_count, + mm_pose_detect_t** results); + +/** @brief Release result buffer returned by \ref mmdeploy_pose_detector_apply or \ref + * mmdeploy_pose_detector_apply_bbox + * @param[in] results result buffer by pose detector + * @param[in] count length of \p result + */ +MMDEPLOY_API void mmdeploy_pose_detector_release_result(mm_pose_detect_t* results, int count); + +/** + * @brief destroy pose_detector + * @param[in] handle handle of pose_detector created by \ref + * mmdeploy_pose_detector_create_by_path or \ref mmdeploy_pose_detector_create + */ +MMDEPLOY_API void mmdeploy_pose_detector_destroy(mm_handle_t handle); + +#ifdef __cplusplus +} +#endif + +#endif // MMDEPLOY_SRC_APIS_C_POSE_DETECTOR_H_ diff --git a/csrc/apis/python/CMakeLists.txt b/csrc/apis/python/CMakeLists.txt index 4421995733..ce86ed2796 100644 --- a/csrc/apis/python/CMakeLists.txt +++ b/csrc/apis/python/CMakeLists.txt @@ -24,6 +24,7 @@ mmdeploy_python_add_module(segmentor) mmdeploy_python_add_module(text_detector) mmdeploy_python_add_module(text_recognizer) mmdeploy_python_add_module(restorer) +mmdeploy_python_add_module(pose_detector) pybind11_add_module(${PROJECT_NAME} ${MMDEPLOY_PYTHON_SRCS}) diff --git a/csrc/apis/python/pose_detector.cpp b/csrc/apis/python/pose_detector.cpp new file mode 100644 index 0000000000..36e024f1a1 --- /dev/null +++ b/csrc/apis/python/pose_detector.cpp @@ -0,0 +1,83 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "pose_detector.h" + +#include "common.h" +#include "core/logger.h" + +namespace mmdeploy { + +class PyPoseDedector { + public: + PyPoseDedector(const char *model_path, const char *device_name, int device_id) { + MMDEPLOY_INFO("{}, {}, {}", model_path, device_name, device_id); + auto status = + mmdeploy_pose_detector_create_by_path(model_path, device_name, device_id, &handle_); + if (status != MM_SUCCESS) { + throw std::runtime_error("failed to create pose_detedtor"); + } + } + py::list Apply(const std::vector &imgs, const std::vector> &_boxes) { + std::vector mats; + std::vector boxes; + mats.reserve(imgs.size()); + for (const auto &img : imgs) { + auto mat = GetMat(img); + mats.push_back(mat); + } + for (const auto &_box : _boxes) { + mm_rect_t box = {_box[0], _box[1], _box[2], _box[3]}; + boxes.push_back(box); + } + mm_pose_detect_t *detection{}; + int num_box = boxes.size(); + auto status = mmdeploy_pose_detector_apply_bbox(handle_, mats.data(), (int)mats.size(), + boxes.data(), &num_box, &detection); + if (status != MM_SUCCESS) { + throw std::runtime_error("failed to apply pose_detector, code: " + std::to_string(status)); + } + auto output = py::list{}; + auto result = detection; + for (int i = 0; i < mats.size(); i++) { + int n_point = result->length; + auto pred = py::array_t({1, n_point, 3}); + auto dst = pred.mutable_data(); + for (int j = 0; j < n_point; j++) { + dst[0] = result->point[j].x; + dst[1] = result->point[j].y; + dst[2] = result->score[j]; + dst += 3; + } + output.append(std::move(pred)); + result++; + } + mmdeploy_pose_detector_release_result(detection, (int)mats.size()); + return output; + } + ~PyPoseDedector() { + mmdeploy_pose_detector_destroy(handle_); + handle_ = {}; + } + + private: + mm_handle_t handle_{}; +}; + +static void register_python_pose_detector(py::module &m) { + py::class_(m, "PoseDetector") + .def(py::init([](const char *model_path, const char *device_name, int device_id) { + return std::make_unique(model_path, device_name, device_id); + })) + .def("__call__", &PyPoseDedector::Apply); +} + +class PythonPoseDetectorRegisterer { + public: + PythonPoseDetectorRegisterer() { + gPythonBindings().emplace("pose_detector", register_python_pose_detector); + } +}; + +static PythonPoseDetectorRegisterer python_pose_detector_registerer; + +} // namespace mmdeploy diff --git a/csrc/codebase/CMakeLists.txt b/csrc/codebase/CMakeLists.txt index 9ef6490a8c..a0b98594aa 100644 --- a/csrc/codebase/CMakeLists.txt +++ b/csrc/codebase/CMakeLists.txt @@ -9,6 +9,7 @@ if ("all" IN_LIST MMDEPLOY_CODEBASES) list(APPEND CODEBASES "mmseg") list(APPEND CODEBASES "mmocr") list(APPEND CODEBASES "mmedit") + list(APPEND CODEBASES "mmpose") else () set(CODEBASES ${MMDEPLOY_CODEBASES}) endif () diff --git a/csrc/codebase/mmpose/CMakeLists.txt b/csrc/codebase/mmpose/CMakeLists.txt new file mode 100644 index 0000000000..6d4c7dd562 --- /dev/null +++ b/csrc/codebase/mmpose/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +cmake_minimum_required(VERSION 3.14) +project(mmdeploy_mmpose) + +include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake) +include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake) + +file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") +mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") +target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils) +add_library(mmdeploy::mmpose ALIAS ${PROJECT_NAME}) diff --git a/csrc/codebase/mmpose/keypoints_from_heatmap.cpp b/csrc/codebase/mmpose/keypoints_from_heatmap.cpp new file mode 100644 index 0000000000..72c6a3cf07 --- /dev/null +++ b/csrc/codebase/mmpose/keypoints_from_heatmap.cpp @@ -0,0 +1,390 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include +#include + +#include "core/device.h" +#include "core/registry.h" +#include "core/serialization.h" +#include "core/tensor.h" +#include "core/utils/device_utils.h" +#include "core/utils/formatter.h" +#include "core/value.h" +#include "experimental/module_adapter.h" +#include "mmpose.h" +#include "opencv_utils.h" + +namespace mmdeploy::mmpose { + +using std::string; +using std::vector; + +template +struct _LoopBody : public cv::ParallelLoopBody { + F f_; + _LoopBody(F f) : f_(std::move(f)) {} + void operator()(const cv::Range& range) const override { f_(range); } +}; + +std::string to_lower(const std::string& s) { + std::string t = s; + std::transform(t.begin(), t.end(), t.begin(), [](unsigned char c) { return std::tolower(c); }); + return t; +} + +class TopdownHeatmapBaseHeadDecode : public MMPose { + public: + explicit TopdownHeatmapBaseHeadDecode(const Value& config) : MMPose(config) { + if (config.contains("params")) { + auto& params = config["params"]; + flip_test_ = params.value("flip_test", flip_test_); + use_udp_ = params.value("use_udp", use_udp_); + target_type_ = params.value("target_type", target_type_); + valid_radius_factor_ = params.value("valid_radius_factor", valid_radius_factor_); + unbiased_decoding_ = params.value("unbiased_decoding", unbiased_decoding_); + post_process_ = params.value("post_process", post_process_); + shift_heatmap_ = params.value("shift_heatmap", shift_heatmap_); + modulate_kernel_ = params.value("modulate_kernel", modulate_kernel_); + } + } + + Result operator()(const Value& _data, const Value& _prob) { + MMDEPLOY_DEBUG("preprocess_result: {}", _data); + MMDEPLOY_DEBUG("inference_result: {}", _prob); + + Device cpu_device{"cpu"}; + OUTCOME_TRY(auto heatmap, + MakeAvailableOnDevice(_prob["output"].get(), cpu_device, stream())); + OUTCOME_TRY(stream().Wait()); + if (!(heatmap.shape().size() == 4 && heatmap.data_type() == DataType::kFLOAT)) { + MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", heatmap.shape(), + (int)heatmap.data_type()); + return Status(eNotSupported); + } + + auto& img_metas = _data["img_metas"]; + + vector center; + vector scale; + from_value(img_metas["center"], center); + from_value(img_metas["scale"], scale); + Tensor pred = + keypoints_from_heatmap(heatmap, center, scale, unbiased_decoding_, post_process_, + modulate_kernel_, valid_radius_factor_, use_udp_, target_type_); + + return GetOutput(pred); + } + + Value GetOutput(Tensor& pred) { + PoseDetectorOutput output; + int K = pred.shape(1); + float* data = pred.data(); + for (int i = 0; i < K; i++) { + float x = *(data + 0); + float y = *(data + 1); + float s = *(data + 2); + output.key_points.push_back({{x, y}, s}); + data += 3; + } + return to_value(std::move(output)); + } + + Tensor keypoints_from_heatmap(const Tensor& _heatmap, const vector& center, + const vector& scale, bool unbiased_decoding, + const string& post_process, int modulate_kernel, + float valid_radius_factor, bool use_udp, + const string& target_type) { + Tensor heatmap(_heatmap.desc()); + heatmap.CopyFrom(_heatmap, stream()).value(); + stream().Wait().value(); + + int K = heatmap.shape(1); + int H = heatmap.shape(2); + int W = heatmap.shape(3); + + if (post_process == "megvii") { + heatmap = gaussian_blur(heatmap, modulate_kernel); + } + + Tensor pred; + + if (use_udp) { + if (to_lower(target_type) == to_lower(string("GaussianHeatMap"))) { + pred = get_max_pred(heatmap); + post_dark_udp(pred, heatmap, modulate_kernel); + } else if (to_lower(target_type) == to_lower(string("CombinedTarget"))) { + // output channel = 3 * channel_cfg['num_output_channels'] + assert(K % 3 == 0); + cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) { + for (int i = r.start; i < r.end; i++) { + int kt = (i % 3 == 0) ? 2 * modulate_kernel + 1 : modulate_kernel; + float* data = heatmap.data() + i * H * W; + cv::Mat work = cv::Mat(H, W, CV_32FC(1), data); + cv::GaussianBlur(work, work, {kt, kt}, 0); // inplace + } + }}); + float valid_radius = valid_radius_factor_ * H; + TensorDesc desc = {Device{"cpu"}, DataType::kFLOAT, {1, K / 3, H, W}}; + Tensor offset_x(desc); + Tensor offset_y(desc); + Tensor heatmap_(desc); + { + // split heatmap + float* src = heatmap.data(); + float* dst0 = heatmap_.data(); + float* dst1 = offset_x.data(); + float* dst2 = offset_y.data(); + for (int i = 0; i < K / 3; i++) { + std::copy_n(src, H * W, dst0); + std::transform(src + H * W, src + 2 * H * W, dst1, + [=](float& x) { return x * valid_radius; }); + std::transform(src + 2 * H * W, src + 3 * H * W, dst2, + [=](float& x) { return x * valid_radius; }); + src += 3 * H * W; + dst0 += H * W; + dst1 += H * W; + dst2 += H * W; + } + } + pred = get_max_pred(heatmap_); + for (int i = 0; i < K / 3; i++) { + float* data = pred.data() + i * 3; + int index = *(data + 0) + *(data + 1) * W + H * W * i; + float* offx = offset_x.data() + index; + float* offy = offset_y.data() + index; + *(data + 0) += *offx; + *(data + 1) += *offy; + } + } + } else { + pred = get_max_pred(heatmap); + if (post_process == "unbiased") { + heatmap = gaussian_blur(heatmap, modulate_kernel); + float* data = heatmap.data(); + std::for_each(data, data + K * H * W, [](float& v) { + double _v = std::max((double)v, 1e-10); + v = std::log(_v); + }); + for (int i = 0; i < K; i++) { + taylor(heatmap, pred, i); + } + + } else if (post_process != "null") { + for (int i = 0; i < K; i++) { + float* data = heatmap.data() + i * W * H; + auto _data = [&](int y, int x) { return *(data + y * W + x); }; + int px = *(pred.data() + i * 3 + 0); + int py = *(pred.data() + i * 3 + 1); + if (1 < px && px < W - 1 && 1 < py && py < H - 1) { + float v1 = _data(py, px + 1) - _data(py, px - 1); + float v2 = _data(py + 1, px) - _data(py - 1, px); + *(pred.data() + i * 3 + 0) += (v1 > 0) ? 0.25 : ((v1 < 0) ? -0.25 : 0); + *(pred.data() + i * 3 + 1) += (v2 > 0) ? 0.25 : ((v2 < 0) ? -0.25 : 0); + if (post_process_ == "megvii") { + *(pred.data() + i * 3 + 0) += 0.5; + *(pred.data() + i * 3 + 1) += 0.5; + } + } + } + } + } + + K = pred.shape(1); // changed if target_type is CombinedTarget + + // Transform back to the image + for (int i = 0; i < K; i++) { + transform_pred(pred, i, center, scale, {W, H}, use_udp); + } + + if (post_process_ == "megvii") { + for (int i = 0; i < K; i++) { + float* data = pred.data() + i * 3 + 2; + *data = *data / 255.0 + 0.5; + } + } + + return pred; + } + + void post_dark_udp(Tensor& pred, Tensor& heatmap, int kernel) { + int K = heatmap.shape(1); + int H = heatmap.shape(2); + int W = heatmap.shape(3); + cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) { + for (int i = r.start; i < r.end; i++) { + float* data = heatmap.data() + i * H * W; + cv::Mat work = cv::Mat(H, W, CV_32FC(1), data); + cv::GaussianBlur(work, work, {kernel, kernel}, 0); // inplace + } + }}); + std::for_each(heatmap.data(), heatmap.data() + K * H * W, [](float& x) { + x = std::max(0.001f, std::min(50.f, x)); + x = std::log(x); + }); + auto _heatmap_data = [&](int index, int c) -> float { + int y = index / (W + 2); + int x = index % (W + 2); + y = std::max(0, y - 1); + x = std::max(0, x - 1); + return *(heatmap.data() + c * H * W + y * W + x); + }; + for (int i = 0; i < K; i++) { + float* data = pred.data() + i * 3; + int index = *(data + 0) + 1 + (*(data + 1) + 1) * (W + 2); + float i_ = _heatmap_data(index, i); + float ix1 = _heatmap_data(index + 1, i); + float iy1 = _heatmap_data(index + W + 2, i); + float ix1y1 = _heatmap_data(index + W + 3, i); + float ix1_y1_ = _heatmap_data(index - W - 3, i); + float ix1_ = _heatmap_data(index - 1, i); + float iy1_ = _heatmap_data(index - 2 - W, i); + float dx = 0.5 * (ix1 - ix1_); + float dy = 0.5 * (iy1 - iy1_); + float dxx = ix1 - 2 * i_ + ix1_; + float dyy = iy1 - 2 * i_ + iy1_; + float dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_); + vector _data0 = {dx, dy}; + vector _data1 = {dxx, dxy, dxy, dyy}; + cv::Mat derivative = cv::Mat(2, 1, CV_32FC1, _data0.data()); + cv::Mat hessian = cv::Mat(2, 2, CV_32FC1, _data1.data()); + cv::Mat hessianinv = hessian.inv(); + cv::Mat offset = -hessianinv * derivative; + *(data + 0) += offset.at(0, 0); + *(data + 1) += offset.at(1, 0); + } + } + + void transform_pred(Tensor& pred, int k, const vector& center, const vector& _scale, + const vector& output_size, bool use_udp = false) { + auto scale = _scale; + scale[0] *= 200; + scale[1] *= 200; + + float scale_x, scale_y; + if (use_udp) { + scale_x = scale[0] / (output_size[0] - 1.0); + scale_y = scale[1] / (output_size[1] - 1.0); + } else { + scale_x = scale[0] / output_size[0]; + scale_y = scale[1] / output_size[1]; + } + + float* data = pred.data() + k * 3; + *(data + 0) = *(data + 0) * scale_x + center[0] - scale[0] * 0.5; + *(data + 1) = *(data + 1) * scale_y + center[1] - scale[1] * 0.5; + } + + void taylor(const Tensor& heatmap, Tensor& pred, int k) { + int K = heatmap.shape(1); + int H = heatmap.shape(2); + int W = heatmap.shape(3); + int px = *(pred.data() + k * 3 + 0); + int py = *(pred.data() + k * 3 + 1); + if (1 < px && px < W - 2 && 1 < py && py < H - 2) { + float* data = const_cast(heatmap.data() + k * H * W); + auto get_data = [&](int r, int c) { return *(data + r * W + c); }; + float dx = 0.5 * (get_data(py, px + 1) - get_data(py, px - 1)); + float dy = 0.5 * (get_data(py + 1, px) - get_data(py - 1, px)); + float dxx = 0.25 * (get_data(py, px + 2) - 2 * get_data(py, px) + get_data(py, px - 2)); + float dxy = 0.25 * (get_data(py + 1, px + 1) - get_data(py - 1, px + 1) - + get_data(py + 1, px - 1) + get_data(py - 1, px - 1)); + float dyy = 0.25 * (get_data(py + 2, px) - 2 * get_data(py, px) + get_data(py - 2, px)); + + vector _data0 = {dx, dy}; + vector _data1 = {dxx, dxy, dxy, dyy}; + cv::Mat derivative = cv::Mat(2, 1, CV_32FC1, _data0.data()); + cv::Mat hessian = cv::Mat(2, 2, CV_32FC1, _data1.data()); + if (std::fabs(dxx * dyy - dxy * dxy) > 1e-6) { + cv::Mat hessianinv = hessian.inv(); + cv::Mat offset = -hessianinv * derivative; + *(pred.data() + k * 3 + 0) += offset.at(0, 0); + *(pred.data() + k * 3 + 1) += offset.at(1, 0); + } + } + } + + Tensor gaussian_blur(const Tensor& _heatmap, int kernel) { + assert(kernel % 2 == 1); + + auto desc = _heatmap.desc(); + Tensor heatmap(desc); + + int K = _heatmap.shape(1); + int H = _heatmap.shape(2); + int W = _heatmap.shape(3); + int num_points = H * W; + + int border = (kernel - 1) / 2; + + for (int i = 0; i < K; i++) { + int offset = i * H * W; + float* data = const_cast(_heatmap.data()) + offset; + float origin_max = *std::max_element(data, data + num_points); + cv::Mat work = cv::Mat(H + 2 * border, W + 2 * border, CV_32FC1, cv::Scalar{}); + cv::Mat curr = cv::Mat(H, W, CV_32FC1, data); + cv::Rect roi = {border, border, W, H}; + curr.copyTo(work(roi)); + cv::GaussianBlur(work, work, {kernel, kernel}, 0); + cv::Mat valid = work(roi).clone(); + float cur_max = *std::max_element((float*)valid.data, (float*)valid.data + num_points); + float* dst = heatmap.data() + offset; + std::transform((float*)valid.data, (float*)valid.data + num_points, dst, + [&](float v) { return v * origin_max / cur_max; }); + } + return heatmap; + } + + Tensor get_max_pred(const Tensor& heatmap) { + int K = heatmap.shape(1); + int H = heatmap.shape(2); + int W = heatmap.shape(3); + int num_points = H * W; + TensorDesc pred_desc = {Device{"cpu"}, DataType::kFLOAT, {1, K, 3}}; + Tensor pred(pred_desc); + + cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) { + for (int i = r.start; i < r.end; i++) { + float* src_data = const_cast(heatmap.data()) + i * H * W; + cv::Mat mat = cv::Mat(H, W, CV_32FC1, src_data); + double min_val, max_val; + cv::Point min_loc, max_loc; + cv::minMaxLoc(mat, &min_val, &max_val, &min_loc, &max_loc); + float* dst_data = pred.data() + i * 3; + *(dst_data + 0) = -1; + *(dst_data + 1) = -1; + *(dst_data + 2) = max_val; + if (max_val > 0.0) { + *(dst_data + 0) = max_loc.x; + *(dst_data + 1) = max_loc.y; + } + } + }}); + + return pred; + } + + private: + bool flip_test_{true}; + bool shift_heatmap_{true}; + string post_process_ = {"default"}; + int modulate_kernel_{11}; + bool unbiased_decoding_{false}; + float valid_radius_factor_{0.0546875f}; + bool use_udp_{false}; + string target_type_{"GaussianHeatmap"}; +}; + +REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapBaseHeadDecode); + +// decode process is same +using TopdownHeatmapSimpleHeadDecode = TopdownHeatmapBaseHeadDecode; +REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapSimpleHeadDecode); +using TopdownHeatmapMultiStageHeadDecode = TopdownHeatmapBaseHeadDecode; +REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapMultiStageHeadDecode); +using ViPNASHeatmapSimpleHeadDecode = TopdownHeatmapBaseHeadDecode; +REGISTER_CODEBASE_COMPONENT(MMPose, ViPNASHeatmapSimpleHeadDecode); +using TopdownHeatmapMSMUHeadDecode = TopdownHeatmapBaseHeadDecode; +REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapMSMUHeadDecode); + +} // namespace mmdeploy::mmpose diff --git a/csrc/codebase/mmpose/keypoints_from_regression.cpp b/csrc/codebase/mmpose/keypoints_from_regression.cpp new file mode 100644 index 0000000000..a484b670e8 --- /dev/null +++ b/csrc/codebase/mmpose/keypoints_from_regression.cpp @@ -0,0 +1,115 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include "core/device.h" +#include "core/registry.h" +#include "core/serialization.h" +#include "core/tensor.h" +#include "core/utils/device_utils.h" +#include "core/utils/formatter.h" +#include "core/value.h" +#include "experimental/module_adapter.h" +#include "mmpose.h" +#include "opencv_utils.h" + +namespace mmdeploy::mmpose { + +using std::string; +using std::vector; + +class DeepposeRegressionHeadDecode : public MMPose { + public: + explicit DeepposeRegressionHeadDecode(const Value& config) : MMPose(config) {} + + Result operator()(const Value& _data, const Value& _prob) { + MMDEPLOY_DEBUG("preprocess_result: {}", _data); + MMDEPLOY_DEBUG("inference_result: {}", _prob); + + Device cpu_device{"cpu"}; + OUTCOME_TRY(auto output, + MakeAvailableOnDevice(_prob["output"].get(), cpu_device, stream())); + OUTCOME_TRY(stream().Wait()); + if (!(output.shape().size() == 3 && output.data_type() == DataType::kFLOAT)) { + MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(), + (int)output.data_type()); + return Status(eNotSupported); + } + + auto& img_metas = _data["img_metas"]; + + vector center; + vector scale; + from_value(img_metas["center"], center); + from_value(img_metas["scale"], scale); + vector img_size = {img_metas["img_shape"][2].get(), + img_metas["img_shape"][1].get()}; + + Tensor pred = keypoints_from_regression(output, center, scale, img_size); + + return GetOutput(pred); + } + + Value GetOutput(Tensor& pred) { + PoseDetectorOutput output; + int K = pred.shape(1); + float* data = pred.data(); + for (int i = 0; i < K; i++) { + float x = *(data + 0); + float y = *(data + 1); + float s = *(data + 2); + output.key_points.push_back({{x, y}, s}); + data += 3; + } + return to_value(std::move(output)); + } + + Tensor keypoints_from_regression(const Tensor& output, const vector& center, + const vector& scale, const vector& img_size) { + int K = output.shape(1); + TensorDesc pred_desc = {Device{"cpu"}, DataType::kFLOAT, {1, K, 3}}; + Tensor pred(pred_desc); + + float* src = const_cast(output.data()); + float* dst = pred.data(); + for (int i = 0; i < K; i++) { + *(dst + 0) = *(src + 0) * img_size[0]; + *(dst + 1) = *(src + 1) * img_size[1]; + *(dst + 2) = 1.f; + src += 2; + dst += 3; + } + + // Transform back to the image + for (int i = 0; i < K; i++) { + transform_pred(pred, i, center, scale, img_size, false); + } + + return pred; + } + + void transform_pred(Tensor& pred, int k, const vector& center, const vector& _scale, + const vector& output_size, bool use_udp = false) { + auto scale = _scale; + scale[0] *= 200; + scale[1] *= 200; + + float scale_x, scale_y; + if (use_udp) { + scale_x = scale[0] / (output_size[0] - 1.0); + scale_y = scale[1] / (output_size[1] - 1.0); + } else { + scale_x = scale[0] / output_size[0]; + scale_y = scale[1] / output_size[1]; + } + + float* data = pred.data() + k * 3; + *(data + 0) = *(data + 0) * scale_x + center[0] - scale[0] * 0.5; + *(data + 1) = *(data + 1) * scale_y + center[1] - scale[1] * 0.5; + } +}; + +REGISTER_CODEBASE_COMPONENT(MMPose, DeepposeRegressionHeadDecode); + +} // namespace mmdeploy::mmpose diff --git a/csrc/codebase/mmpose/mmpose.cpp b/csrc/codebase/mmpose/mmpose.cpp new file mode 100644 index 0000000000..7d5e048b11 --- /dev/null +++ b/csrc/codebase/mmpose/mmpose.cpp @@ -0,0 +1,15 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "codebase/mmpose/mmpose.h" + +using namespace std; + +namespace mmdeploy { +namespace mmpose { + +REGISTER_CODEBASE(MMPose); + +} + +MMDEPLOY_DEFINE_REGISTRY(mmpose::MMPose); +} // namespace mmdeploy diff --git a/csrc/codebase/mmpose/mmpose.h b/csrc/codebase/mmpose/mmpose.h new file mode 100644 index 0000000000..ed66f53a8e --- /dev/null +++ b/csrc/codebase/mmpose/mmpose.h @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_MMPOSE_H +#define MMDEPLOY_MMPOSE_H + +#include "codebase/common.h" +#include "core/device.h" +#include "core/module.h" + +namespace mmdeploy { +namespace mmpose { + +struct PoseDetectorOutput { + struct KeyPoint { + std::array bbox; // x, y + float score; + MMDEPLOY_ARCHIVE_MEMBERS(bbox, score); + }; + std::vector key_points; + MMDEPLOY_ARCHIVE_MEMBERS(key_points); +}; + +DECLARE_CODEBASE(MMPose, mmpose); + +} // namespace mmpose + +MMDEPLOY_DECLARE_REGISTRY(mmpose::MMPose); +} // namespace mmdeploy + +#endif // MMDEPLOY_MMPOSE_H diff --git a/csrc/codebase/mmpose/topdown_affine.cpp b/csrc/codebase/mmpose/topdown_affine.cpp new file mode 100644 index 0000000000..e3effd0e21 --- /dev/null +++ b/csrc/codebase/mmpose/topdown_affine.cpp @@ -0,0 +1,191 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "archive/json_archive.h" +#include "archive/value_archive.h" +#include "core/registry.h" +#include "core/tensor.h" +#include "core/utils/device_utils.h" +#include "core/utils/formatter.h" +#include "opencv2/imgproc.hpp" +#include "opencv_utils.h" +#include "preprocess/transform/resize.h" +#include "preprocess/transform/transform.h" + +using namespace std; + +namespace mmdeploy { + +cv::Point2f operator*(cv::Point2f a, cv::Point2f b) { + cv::Point2f c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +class TopDownAffineImpl : public Module { + public: + explicit TopDownAffineImpl(const Value& args) noexcept { + use_udp_ = args.value("use_udp", use_udp_); + backend_ = args.contains("backend") && args["backend"].is_string() + ? args["backend"].get() + : backend_; + stream_ = args["context"]["stream"].get(); + assert(args.contains("image_size")); + from_value(args["image_size"], image_size_); + } + + ~TopDownAffineImpl() override = default; + + Result Process(const Value& input) override { + MMDEPLOY_DEBUG("top_down_affine input: {}", input); + + Device host{"cpu"}; + auto _img = input["img"].get(); + OUTCOME_TRY(auto img, MakeAvailableOnDevice(_img, host, stream_)); + stream_.Wait().value(); + auto src = cpu::Tensor2CVMat(img); + + // prepare data + vector box; + from_value(input["box"], box); + vector c; // center + vector s; // scale + Box2cs(box, c, s); + auto r = input["rotation"].get(); + + cv::Mat dst; + if (use_udp_) { + cv::Mat trans = + GetWarpMatrix(r, {c[0] * 2.f, c[1] * 2.f}, {image_size_[0] - 1.f, image_size_[1] - 1.f}, + {s[0] * 200.f, s[1] * 200.f}); + + cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR); + } else { + cv::Mat trans = + GetAffineTransform({c[0], c[1]}, {s[0], s[1]}, r, {image_size_[0], image_size_[1]}); + cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR); + } + + Value output = input; + output["img"] = cpu::CVMat2Tensor(dst); + output["img_shape"] = {1, image_size_[1], image_size_[0], dst.channels()}; + output["center"] = to_value(c); + output["scale"] = to_value(s); + MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2)); + return output; + } + + void Box2cs(vector& box, vector& center, vector& scale) { + float x = box[0]; + float y = box[1]; + float w = box[2]; + float h = box[3]; + float aspect_ratio = image_size_[0] * 1.0 / image_size_[1]; + center.push_back(x + w * 0.5); + center.push_back(y + h * 0.5); + if (w > aspect_ratio * h) { + h = w * 1.0 / aspect_ratio; + } else if (w < aspect_ratio * h) { + w = h * aspect_ratio; + } + scale.push_back(w / 200 * 1.25); + scale.push_back(h / 200 * 1.25); + } + + cv::Mat GetWarpMatrix(float theta, cv::Size2f size_input, cv::Size2f size_dst, + cv::Size2f size_target) { + theta = theta * 3.1415926 / 180; + float scale_x = size_dst.width / size_target.width; + float scale_y = size_dst.height / size_target.height; + cv::Mat matrix = cv::Mat(2, 3, CV_32FC1); + matrix.at(0, 0) = std::cos(theta) * scale_x; + matrix.at(0, 1) = -std::sin(theta) * scale_x; + matrix.at(0, 2) = + scale_x * (-0.5f * size_input.width * std::cos(theta) + + 0.5f * size_input.height * std::sin(theta) + 0.5f * size_target.width); + matrix.at(1, 0) = std::sin(theta) * scale_y; + matrix.at(1, 1) = std::cos(theta) * scale_y; + matrix.at(1, 2) = + scale_y * (-0.5f * size_input.width * std::sin(theta) - + 0.5f * size_input.height * std::cos(theta) + 0.5f * size_target.height); + return matrix; + } + + cv::Mat GetAffineTransform(cv::Point2f center, cv::Point2f scale, float rot, cv::Size output_size, + cv::Point2f shift = {0.f, 0.f}, bool inv = false) { + cv::Point2f scale_tmp = scale * 200; + float src_w = scale_tmp.x; + int dst_w = output_size.width; + int dst_h = output_size.height; + float rot_rad = 3.1415926 * rot / 180; + cv::Point2f src_dir = rotate_point({0.f, src_w * -0.5f}, rot_rad); + cv::Point2f dst_dir = {0.f, dst_w * -0.5f}; + + cv::Point2f src_points[3]; + src_points[0] = center + scale_tmp * shift; + src_points[1] = center + src_dir + scale_tmp * shift; + src_points[2] = Get3rdPoint(src_points[0], src_points[1]); + + cv::Point2f dst_points[3]; + dst_points[0] = {dst_w * 0.5f, dst_h * 0.5f}; + dst_points[1] = dst_dir + cv::Point2f(dst_w * 0.5f, dst_h * 0.5f); + dst_points[2] = Get3rdPoint(dst_points[0], dst_points[1]); + + cv::Mat trans = inv ? cv::getAffineTransform(dst_points, src_points) + : cv::getAffineTransform(src_points, dst_points); + return trans; + } + + cv::Point2f rotate_point(cv::Point2f pt, float angle_rad) { + float sn = std::sin(angle_rad); + float cs = std::cos(angle_rad); + float new_x = pt.x * cs - pt.y * sn; + float new_y = pt.x * sn + pt.y * cs; + return {new_x, new_y}; + } + + cv::Point2f Get3rdPoint(cv::Point2f a, cv::Point2f b) { + cv::Point2f direction = a - b; + cv::Point2f third_pt = b + cv::Point2f(-direction.y, direction.x); + return third_pt; + } + + protected: + bool use_udp_{false}; + vector image_size_; + std::string backend_; + Stream stream_; +}; + +class TopDownAffineImplCreator : public Creator { + public: + const char* GetName() const override { return "cpu"; } + int GetVersion() const override { return 1; } + ReturnType Create(const Value& args) override { + return std::make_unique(args); + } +}; + +MMDEPLOY_DEFINE_REGISTRY(TopDownAffineImpl); + +REGISTER_MODULE(TopDownAffineImpl, TopDownAffineImplCreator); + +class TopDownAffine : public Transform { + public: + explicit TopDownAffine(const Value& args) : Transform(args) { + impl_ = Instantiate("TopDownAffine", args); + } + ~TopDownAffine() override = default; + + Result Process(const Value& input) override { return impl_->Process(input); } + + private: + std::unique_ptr impl_; + static const std::string name_; +}; + +DECLARE_AND_REGISTER_MODULE(Transform, TopDownAffine, 1); + +} // namespace mmdeploy diff --git a/demo/csrc/CMakeLists.txt b/demo/csrc/CMakeLists.txt index 3e1bdcc6fb..71d49f3199 100644 --- a/demo/csrc/CMakeLists.txt +++ b/demo/csrc/CMakeLists.txt @@ -20,4 +20,5 @@ add_example(image_classification) add_example(object_detection) add_example(image_restorer) add_example(image_segmentation) +add_example(pose_detection) add_example(ocr) diff --git a/demo/csrc/pose_detection.cpp b/demo/csrc/pose_detection.cpp new file mode 100644 index 0000000000..14fa9c7391 --- /dev/null +++ b/demo/csrc/pose_detection.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include +#include + +#include "pose_detector.h" + +int main(int argc, char *argv[]) { + if (argc != 4) { + fprintf(stderr, "usage:\n pose_detection device_name model_path image_path\n"); + return 1; + } + auto device_name = argv[1]; + auto model_path = argv[2]; + auto image_path = argv[3]; + cv::Mat img = cv::imread(image_path); + if (!img.data) { + fprintf(stderr, "failed to load image: %s\n", image_path); + return 1; + } + + mm_handle_t pose_estimator{}; + int status{}; + status = mmdeploy_pose_detector_create_by_path(model_path, device_name, 0, &pose_estimator); + if (status != MM_SUCCESS) { + fprintf(stderr, "failed to create pose_estimator, code: %d\n", (int)status); + return 1; + } + + mm_mat_t mat{img.data, img.rows, img.cols, 3, MM_BGR, MM_INT8}; + + mm_pose_detect_t *res{}; + int *res_count{}; + status = mmdeploy_pose_detector_apply(pose_estimator, &mat, 1, &res, &res_count); + if (status != MM_SUCCESS) { + fprintf(stderr, "failed to apply pose estimator, code: %d\n", (int)status); + return 1; + } + + for (int i = 0; i < res->length; i++) { + cv::circle(img, {(int)res->point[i].x, (int)res->point[i].y}, 1, {0, 255, 0}, 2); + } + cv::imwrite("output_pose.png", img); + + mmdeploy_pose_detector_release_result(res, 1); + mmdeploy_pose_detector_destroy(pose_estimator); + + return 0; +} diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 4dee279a8c..0405523400 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import logging import os from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -12,7 +13,60 @@ from mmdeploy.codebase.base import BaseTask from mmdeploy.codebase.mmpose.deploy.mmpose import MMPOSE_TASK -from mmdeploy.utils import Task +from mmdeploy.utils import Task, get_input_shape + + +def process_model_config( + model_cfg: mmcv.Config, + imgs: Union[Sequence[str], Sequence[np.ndarray]], + input_shape: Optional[Sequence[int]] = None, +): + """Process the model config. + + Args: + model_cfg (mmcv.Config): The model config. + imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted + data type are List[str], List[np.ndarray]. + input_shape (list[int]): A list of two integer in (width, height) + format specifying input shape. Default: None. + + Returns: + mmcv.Config: the model config after processing. + """ + cfg = copy.deepcopy(model_cfg) + test_pipeline = cfg.data.test.pipeline + sdk_pipeline = [] + color_type = 'color' + channel_order = 'rgb' + + idx = 0 + while idx < len(test_pipeline): + trans = test_pipeline[idx] + if trans.type == 'ToTensor': + assert idx + 1 < len(test_pipeline) and \ + test_pipeline[idx + 1].type == 'NormalizeTensor' + trans = test_pipeline[idx + 1] + trans.type = 'Normalize' + trans['to_rgb'] = (channel_order == 'rgb') + trans['mean'] = [x * 255 for x in trans['mean']] + trans['std'] = [x * 255 for x in trans['std']] + sdk_pipeline.append(trans) + sdk_pipeline.append({'type': 'ImageToTensor', 'keys': ['img']}) + idx = idx + 2 + continue + + if trans.type == 'LoadImageFromFile': + if 'color_type' in trans: + color_type = trans['color_type'] # NOQA + if 'channel_order' in trans: + channel_order = trans['channel_order'] + if trans.type == 'TopDownAffine': + trans['image_size'] = input_shape + + sdk_pipeline.append(trans) + idx = idx + 1 + cfg.data.test.pipeline = sdk_pipeline + return cfg @MMPOSE_TASK.register_module(Task.POSE_DETECTION.value) @@ -130,7 +184,7 @@ def create_input(self, 'rotation': 0, 'ann_info': { - 'image_size': image_size, + 'image_size': np.array(image_size), 'num_joints': cfg.data_cfg['num_joints'], 'flip_pairs': flip_pairs } @@ -257,12 +311,24 @@ def get_partition_cfg(partition_type: str, **kwargs) -> Dict: raise NotImplementedError('Not supported yet.') def get_preprocess(self) -> Dict: - """Get the preprocess information for SDK.""" - raise NotImplementedError('Not supported yet.') + """Get the preprocess information for SDK. + + Return: + dict: Composed of the preprocess information. + """ + input_shape = get_input_shape(self.deploy_cfg) + model_cfg = process_model_config(self.model_cfg, [''], input_shape) + preprocess = model_cfg.data.test.pipeline + return preprocess def get_postprocess(self) -> Dict: """Get the postprocess information for SDK.""" - raise NotImplementedError('Not supported yet.') + postprocess = {'type': 'UNKNOWN'} + if self.model_cfg.model.type == 'TopDown': + postprocess[ + 'type'] = self.model_cfg.model.keypoint_head.type + 'Decode' + postprocess.update(self.model_cfg.model.test_cfg) + return postprocess @staticmethod def get_tensor_from_input(input_data: Dict[str, Any], diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index e54a2f9494..1844c5cc10 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -4,11 +4,22 @@ import mmcv import numpy as np import torch +from mmcv.utils import Registry from mmdeploy.codebase.base import BaseBackendModel -from mmdeploy.utils import Backend, get_backend, load_config +from mmdeploy.utils import (Backend, get_backend, get_codebase_config, + load_config) +def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs): + return registry.module_dict[cls_name](*args, **kwargs) + + +__BACKEND_MODEL = mmcv.utils.Registry( + 'backend_pose_detectors', build_func=__build_backend_model) + + +@__BACKEND_MODEL.register_module('end2end') class End2EndModel(BaseBackendModel): """End to end model for inference of pose detection. @@ -31,15 +42,14 @@ def __init__(self, model_cfg: Union[str, mmcv.Config] = None, **kwargs): super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg) - from mmpose.models.heads.topdown_heatmap_base_head import \ - TopdownHeatmapBaseHead + from mmpose.models import builder self.deploy_cfg = deploy_cfg self.model_cfg = model_cfg self._init_wrapper( backend=backend, backend_files=backend_files, device=device) # create base_head for decoding heatmap - base_head = TopdownHeatmapBaseHead() + base_head = builder.build_head(model_cfg.model.keypoint_head) base_head.test_cfg = model_cfg.model.test_cfg self.base_head = base_head @@ -57,7 +67,9 @@ def _init_wrapper(self, backend, backend_files, device): backend=backend, backend_files=backend_files, device=device, - output_names=output_names) + input_names=[self.input_name], + output_names=output_names, + deploy_cfg=self.deploy_cfg) def forward(self, img: torch.Tensor, img_metas: Sequence[Sequence[dict]], *args, **kwargs): @@ -73,10 +85,12 @@ def forward(self, img: torch.Tensor, img_metas: Sequence[Sequence[dict]], Returns: list: A list contains predictions. """ + batch_size, _, img_height, img_width = img.shape input_img = img.contiguous() outputs = self.forward_test(input_img, img_metas, *args, **kwargs) heatmaps = outputs[0] - key_points = self.base_head.decode(img_metas, heatmaps) + key_points = self.base_head.decode( + img_metas, heatmaps, img_size=[img_width, img_height]) return key_points def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ @@ -136,6 +150,80 @@ def show_result(self, win_name=win_name) +@__BACKEND_MODEL.register_module('sdk') +class SDKEnd2EndModel(End2EndModel): + """SDK inference class, converts SDK output to mmcls format.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _cs2xyxy(self, + _center: np.ndarray, + _scale: np.ndarray, + padding: float = 1.25): + """This encodes (center, scale) to fake bbox(x,y,x,y) The dataloader in + mmpose convert the bbox of image to (center, scale) and use these + information in the pre/post process of model. Some setting of + dataloader will not collect bbox key. While in practice, we receive + image and bbox as input. Therefore this method try to convert the + (center, scale) back to bbox. It can not restore the real box with just + (center, scale) information, but sdk can handle the fake bbox normally. + + Args: + _center: (np.ndarray[float32](2,)) Center of the bbox (x, y) + _scale: (np.ndarray[float32](2,)) Scale of the bbox w & h + + Returns: + - np.ndarray[float32](4,): fake box if keypoint, the process in + topdown_affine will calculate original center, scale. + """ + scale = _scale.copy() + scale = scale / padding * 200 + center = _center.copy() + # fake box + box = np.array([center - 0.5 * scale, + center + 0.5 * scale - 1]).flatten() + return box + + def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list: + """Run forward inference. + + Args: + img (List[torch.Tensor]): A list contains input image(s) + in [N x C x H x W] format. + *args: Other arguments. + **kwargs: Other key-pair arguments. + + Returns: + list: A list contains predictions. + """ + image_paths = [] + boxes = np.zeros(shape=(img.shape[0], 6)) + bbox_ids = [] + sdk_boxes = [] + for i, img_meta in enumerate(kwargs['img_metas']): + center = img_meta['center'] + scale = img_meta['scale'] + boxes[i, :2] = center + boxes[i, 2:4] = scale + boxes[i, 4] = np.prod(scale * 200.0) + boxes[i, 5] = img_meta[ + 'bbox_score'] if 'bbox_score' in img_meta else 1.0 + sdk_boxes.append(self._cs2xyxy(center, scale)) + image_paths.append(img_meta['image_file']) + bbox_ids.append(img_meta['bbox_id']) + + pred = self.wrapper.handle( + [img[0].contiguous().detach().cpu().numpy()], sdk_boxes)[0] + + result = dict( + preds=pred, + boxes=boxes, + image_paths=image_paths, + bbox_ids=bbox_ids) + return result + + def build_pose_detection_model(model_files: Sequence[str], model_cfg: Union[str, mmcv.Config], deploy_cfg: Union[str, mmcv.Config], @@ -157,12 +245,14 @@ def build_pose_detection_model(model_files: Sequence[str], deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) backend = get_backend(deploy_cfg) - backend_pose_model = End2EndModel( - backend, - model_files, - device, - deploy_cfg=deploy_cfg, + model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end') + + backend_pose_model = __BACKEND_MODEL.build( + model_type, + backend=backend, + backend_files=model_files, + device=device, model_cfg=model_cfg, - **kwargs) + deploy_cfg=deploy_cfg) return backend_pose_model diff --git a/mmdeploy/codebase/mmpose/models/heads/__init__.py b/mmdeploy/codebase/mmpose/models/heads/__init__.py index f462d37c75..f48ba2e343 100644 --- a/mmdeploy/codebase/mmpose/models/heads/__init__.py +++ b/mmdeploy/codebase/mmpose/models/heads/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .deeppose_regression_head import deeppose_regression_head__inference_model from .topdown_heatmap_multi_stage_head import \ topdown_heatmap_msmu_head__inference_model from .topdown_heatmap_simple_head import \ @@ -6,5 +7,6 @@ __all__ = [ 'topdown_heatmap_simple_head__inference_model', - 'topdown_heatmap_msmu_head__inference_model' + 'topdown_heatmap_msmu_head__inference_model', + 'deeppose_regression_head__inference_model' ] diff --git a/mmdeploy/codebase/mmpose/models/heads/deeppose_regression_head.py b/mmdeploy/codebase/mmpose/models/heads/deeppose_regression_head.py new file mode 100644 index 0000000000..c484fa05da --- /dev/null +++ b/mmdeploy/codebase/mmpose/models/heads/deeppose_regression_head.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmpose.models.heads.DeepposeRegressionHead.inference_model') +def deeppose_regression_head__inference_model(ctx, self, x, flip_pairs=None): + """Rewrite `forward_test` of TopDown for default backend. + + Rewrite this function to run forward directly. And we don't need to + transform result to np.ndarray. + + Args: + x (torch.Tensor[N,K,H,W]): Input features. + flip_pairs (None | list[tuple]): + Pairs of keypoints which are mirrored. + + Returns: + output_heatmap (torch.Tensor): Output heatmaps. + """ + assert flip_pairs is None + output = self.forward(x) + return output diff --git a/mmdeploy/codebase/mmpose/models/heads/topdown_heatmap_multi_stage_head.py b/mmdeploy/codebase/mmpose/models/heads/topdown_heatmap_multi_stage_head.py index 5bb1014a43..b00d4af460 100644 --- a/mmdeploy/codebase/mmpose/models/heads/topdown_heatmap_multi_stage_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/topdown_heatmap_multi_stage_head.py @@ -24,3 +24,29 @@ def topdown_heatmap_msmu_head__inference_model(ctx, self, x, flip_pairs=None): assert isinstance(output, list) output = output[-1] return output + + +@FUNCTION_REWRITER.register_rewriter( + 'mmpose.models.heads.TopdownHeatmapMultiStageHead.inference_model') +def topdown_heatmap_multi_stage_head__inference_model(ctx, + self, + x, + flip_pairs=None): + """Rewrite ``inference_model`` for default backend. + + Rewrite this function to run forward directly. And we don't need to + transform result to np.ndarray. + + Args: + x (list[torch.Tensor[N,K,H,W]]): Input features. + flip_pairs (None | list[tuple]): + Pairs of keypoints which are mirrored. + + Returns: + output_heatmap (torch.Tensor): Output heatmaps. + """ + assert flip_pairs is None + output = self.forward(x) + assert isinstance(output, list) + output = output[-1] + return output diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index 094f7b7ab5..bddd09de85 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -73,5 +73,7 @@ class Backend(AdvancedEnum): Task.TEXT_DETECTION: dict(component='TextDetHead', cls_name='TextDetector'), Task.TEXT_RECOGNITION: - dict(component='CTCConvertor', cls_name='TextRecognizer') + dict(component='CTCConvertor', cls_name='TextRecognizer'), + Task.POSE_DETECTION: + dict(component='Detector', cls_name='PoseDetector') }