Skip to content

Commit

Permalink
feat: allow returning images in json in base64 format
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed Aug 4, 2023
1 parent 28060f9 commit 05096fd
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/dto/ddtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace dd

const oatpp::ClassId GpuIdsClass::CLASS_ID("GpuIds");

const oatpp::ClassId ImageClass::CLASS_ID("Image");

template <>
const oatpp::ClassId DTOVectorClass<double>::CLASS_ID("vector<double>");

Expand Down
85 changes: 85 additions & 0 deletions src/dto/ddtypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
#ifndef DD_DTO_TYPES_HPP
#define DD_DTO_TYPES_HPP

#include <opencv2/opencv.hpp>

#include "oatpp/core/Types.hpp"
#include "oatpp/parser/json/mapping/ObjectMapper.hpp"
#include "apidata.h"
#include "utils/cv_utils.hpp"

namespace dd
{
Expand All @@ -46,10 +49,51 @@ namespace dd
}
};

struct VImage
{
cv::Mat _img;
#ifdef USE_CUDA_CV
cv::cuda::GpuMat _cuda_img;
#endif
std::string _ext = ".png";

VImage(const cv::Mat &img, const std::string &ext = ".png")
: _img(img), _ext(ext)
{
}
#ifdef USE_CUDA_CV
VImage(const cv::cuda::GpuMat &cuda_img, const std::string &ext = ".png")
: _cuda_img(cuda_img), _ext(ext)
{
}
#endif
bool is_cuda() const
{
#ifdef USE_CUDA_CV
return !_cuda_img.empty();
#else
return false;
#endif
}

/** get image on CPU whether it's on GPU or not */
const cv::Mat &get_img()
{
#ifdef USE_CUDA_CV
if (is_cuda())
{
_cuda_img.download(_img);
}
#endif
return _img;
}
};

namespace __class
{
class APIDataClass;
class GpuIdsClass;
class ImageClass;
template <typename T> class DTOVectorClass;
}

Expand All @@ -59,6 +103,8 @@ namespace dd
typedef oatpp::data::mapping::type::Primitive<VGpuIds,
__class::GpuIdsClass>
GpuIds;
typedef oatpp::data::mapping::type::Primitive<VImage, __class::ImageClass>
DTOImage;
template <typename T>
using DTOVector
= oatpp::data::mapping::type::Primitive<std::vector<T>,
Expand Down Expand Up @@ -89,6 +135,18 @@ namespace dd
}
};

class ImageClass
{
public:
static const oatpp::ClassId CLASS_ID;

static oatpp::Type *getType()
{
static oatpp::Type type(CLASS_ID);
return &type;
}
};

template <typename T> class DTOVectorClass
{
public:
Expand All @@ -113,6 +171,9 @@ namespace dd
{
(void)type;
(void)deserializer;
// XXX: this has a failure case if the stream contains excaped "{" or "}"
// Since this is a temporary workaround until we use DTO everywhere, it
// might not be required to be fixed
if (caret.isAtChar('{'))
{
auto start = caret.getCurrData();
Expand Down Expand Up @@ -221,6 +282,30 @@ namespace dd
}
}

static inline oatpp::Void
imageDeserialize(oatpp::parser::json::mapping::Deserializer *deserializer,
oatpp::parser::Caret &caret,
const oatpp::Type *const type)
{
(void)type;
auto str_base64
= deserializer->deserialize(caret, oatpp::String::Class::getType())
.cast<oatpp::String>();
return DTOImage(VImage{ cv_utils::base64_to_image(*str_base64) });
}

static inline void
imageSerialize(oatpp::parser::json::mapping::Serializer *serializer,
oatpp::data::stream::ConsistentOutputStream *stream,
const oatpp::Void &obj)
{
(void)serializer;
auto img_dto = obj.cast<DTOImage>();
std::string encoded
= cv_utils::image_to_base64(img_dto->get_img(), img_dto->_ext);
stream->writeSimple(encoded);
}

// Inspired by oatpp json deserializer
template <typename T> inline T readVecElement(oatpp::parser::Caret &caret);

Expand Down
20 changes: 17 additions & 3 deletions src/dto/predict_out.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,24 @@ namespace dd

DTO_FIELD_INFO(vals)
{
info->description = "[Unsupervised] Array containing model output "
"values. Can be in different formats: double, "
"binarized double, booleans, binarized string";
info->description
= "[Unsupervised] Array containing model output "
"values. Can be in different formats: double, "
"binarized double, booleans, binarized string, base64 image";
}
DTO_FIELD(Any, vals);

DTO_FIELD_INFO(images)
{
info->description
= "[Unsupervised] Array of images returned by the model";
}
DTO_FIELD(Vector<DTOImage>, images);

DTO_FIELD_INFO(imgsize)
{
info->description = "[Unsupervised] Image size";
}
DTO_FIELD(Object<Dimensions>, imgsize);

DTO_FIELD_INFO(confidences)
Expand All @@ -140,6 +153,7 @@ namespace dd
DTO_FIELD(String, index_uri);

public:
// XXX: Legacy & deprecated
std::vector<cv::Mat> _images; /**<allow to pass images in the DTO */
};

Expand Down
20 changes: 18 additions & 2 deletions src/unsupervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#ifndef UNSUPERVISEDOUTPUTCONNECTOR_H
#define UNSUPERVISEDOUTPUTCONNECTOR_H

#include <vector>
#include <map>

#include "dto/predict_out.hpp"

namespace dd
Expand Down Expand Up @@ -119,6 +122,9 @@ namespace dd
_bool_binarized = ad_out.get("bool_binarized").get<bool>();
else if (ad_out.has("string_binarized"))
_string_binarized = ad_out.get("string_binarized").get<bool>();

if (ad_out.has("encoding"))
_image_encoding = ad_out.get("encoding").get<std::string>();
}

void set_results(std::vector<UnsupervisedResult> &&results)
Expand Down Expand Up @@ -321,8 +327,16 @@ namespace dd
auto pred_dto = DTO::Prediction::createShared();
pred_dto->uri = _vvres.at(i)._uri.c_str();
if (_vvres.at(i)._images.size() != 0)
pred_dto->_images = _vvres.at(i)._images;
if (_bool_binarized)
{
// XXX: legacy
pred_dto->_images = _vvres.at(i)._images;

pred_dto->images = oatpp::Vector<DTO::DTOImage>::createShared();
for (auto &image : _vvres.at(i)._images)
pred_dto->images->push_back(
DTO::VImage{ image, _image_encoding });
}
else if (_bool_binarized)
pred_dto->vals
= DTO::DTOVector<bool>(std::move(_vvres.at(i)._bvals));
else if (_string_binarized)
Expand Down Expand Up @@ -368,6 +382,8 @@ namespace dd
= false; /**< boolean binary representation of output values. */
bool _string_binarized = false; /**< boolean string as binary
representation of output values. */
std::string _image_encoding
= ".png"; /**< encoding used for output images */
#ifdef USE_SIMSEARCH
int _search_nn = 10; /**< default nearest neighbors per search. */
#endif
Expand Down
11 changes: 11 additions & 0 deletions src/utils/oatpp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ namespace dd
DTO::apiDataDeserialize);
deser->setDeserializerMethod(DTO::GpuIds::Class::CLASS_ID,
DTO::gpuIdsDeserialize);
deser->setDeserializerMethod(DTO::DTOImage::Class::CLASS_ID,
DTO::imageDeserialize);
deser->setDeserializerMethod(DTO::DTOVector<double>::Class::CLASS_ID,
DTO::vectorDeserialize<double>);
deser->setDeserializerMethod(DTO::DTOVector<uint8_t>::Class::CLASS_ID,
Expand All @@ -49,6 +51,8 @@ namespace dd
DTO::apiDataSerialize);
ser->setSerializerMethod(DTO::GpuIds::Class::CLASS_ID,
DTO::gpuIdsSerialize);
ser->setSerializerMethod(DTO::DTOImage::Class::CLASS_ID,
DTO::imageSerialize);
ser->setSerializerMethod(DTO::DTOVector<double>::Class::CLASS_ID,
DTO::vectorSerialize<double>);
ser->setSerializerMethod(DTO::DTOVector<uint8_t>::Class::CLASS_ID,
Expand Down Expand Up @@ -191,6 +195,13 @@ namespace dd
jval.PushBack(dto_gpuid->_ids[i], jdoc.GetAllocator());
}
}
else if (polymorph.getValueType() == DTO::DTOImage::Class::getType())
{
auto dto_img = polymorph.cast<DTO::DTOImage>();
std::string img_str
= cv_utils::image_to_base64(dto_img->get_img(), dto_img->_ext);
jval.SetString(img_str.c_str(), jdoc.GetAllocator());
}
else if (polymorph.getValueType()->classId.id
== oatpp::data::mapping::type::__class::AbstractVector::
CLASS_ID.id
Expand Down
25 changes: 25 additions & 0 deletions tests/ut-tensorrtapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,31 @@ TEST(tensorrtapi, service_predict_gan_onnx)
ASSERT_TRUE(jd["body"]["predictions"][0]["vals"].IsArray());
ASSERT_EQ(jd["body"]["predictions"][0]["vals"].Size(), 360 * 360 * 3);

// predict to image
jpredictstr
= "{\"service\":\"" + sname
+ "\",\"parameters\":{\"input\":{\"height\":360,"
"\"width\":360,\"rgb\":true,\"scale\":0.00392,\"mean\":[0.5,0.5,0.5]"
",\"std\":[0.5,0.5,0.5]},\"output\":{\"image\":true},\"mllib\":{"
"\"extract_layer\":\"last\"}},\"data\":[\""
+ cyclegan_onnx_repo + "horse.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
jd = JDoc();
// std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
ASSERT_TRUE(jd["body"]["predictions"][0]["images"].IsArray());
ASSERT_EQ(jd["body"]["predictions"][0]["images"].Size(), 1);
// png image
std::string base64_img
= jd["body"]["predictions"][0]["images"][0].GetString();
// may be small differences between machines, versions of libpng/jpeg?
ASSERT_NEAR(base64_img.size(), 388292, 100);
// cv::imwrite("onnx_gan_base64.jpg", cv_utils::base64_to_image(base64_img));

// delete
ASSERT_TRUE(fileops::file_exists(cyclegan_onnx_repo + "TRTengine_arch"
+ get_trt_archi() + "_fp16_bs1"));

Expand Down
11 changes: 7 additions & 4 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ static std::string iterations_ttransformer_cpu = "100";
static std::string iterations_ttransformer_gpu = "1000";

static std::string iterations_resnet50 = "200";
/// different values to mitigate failure due to randomness
static std::string iterations_resnet50_split = "300";
static std::string iterations_vit = "200";
static std::string iterations_detection = "200";
static std::string iterations_deeplabv3 = "200";
Expand Down Expand Up @@ -831,7 +833,7 @@ TEST(torchapi, service_train_images_split)
std::string jtrainstr
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":"
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
+ iterations_resnet50_split + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":4,\"solver_type\":\"ADAM\",\"test_"
"interval\":200},\"net\":{\"batch_size\":4},\"nclasses\":2,"
"\"resume\":false},"
Expand All @@ -846,7 +848,8 @@ TEST(torchapi, service_train_images_split)
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

ASSERT_TRUE(jd["body"]["measure"]["iteration"] == 200) << "iterations";
int it_count = std::stoi(iterations_resnet50_split);
ASSERT_TRUE(jd["body"]["measure"]["iteration"] == it_count) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["train_loss"].GetDouble() <= 3.0)
<< "loss";

Expand All @@ -862,9 +865,9 @@ TEST(torchapi, service_train_images_split)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
+ iterations_resnet50 + ".ptw"));
+ iterations_resnet50_split + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
+ iterations_resnet50 + ".pt"));
+ iterations_resnet50_split + ".pt"));
fileops::clear_directory(resnet50_train_repo + "train.lmdb");
fileops::clear_directory(resnet50_train_repo + "test_0.lmdb");
fileops::remove_dir(resnet50_train_repo + "train.lmdb");
Expand Down

0 comments on commit 05096fd

Please sign in to comment.