Skip to content

Commit

Permalink
feat(torch): add translation and bbox duplication to data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
royale authored and mergify[bot] committed Jan 3, 2023
1 parent e26a775 commit 8752e1f
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 5 deletions.
63 changes: 61 additions & 2 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ namespace dd
applyCrop(src, _crop_params, crop_x, crop_y, true, true);
}

void TorchImgRandAugCV::applyDuplicateBBox(
std::vector<std::vector<float>> &bboxes, std::vector<int> &classes,
const float &img_width, const float &img_height)
{
std::vector<std::vector<float>> nbboxes;
std::vector<int> nclasses;
for (size_t i = 0; i < bboxes.size(); ++i)
{
std::vector<float> bbox = bboxes.at(i);
for (int tx = -img_width; tx <= img_width; tx += img_width)
{
for (int ty = -img_height; ty <= img_height; ty += img_height)
{
std::vector<float> nbox = bbox;
if (tx != 0) // horizontal mirror and translation
{
nbox[0] = tx + img_width - bbox[2];
nbox[2] = tx + img_width - bbox[0];
}
if (ty != 0) // vertical mirror and translation
{
nbox[1] = ty + img_height - bbox[3];
nbox[3] = ty + img_height - bbox[1];
}
nbboxes.push_back(nbox);
nclasses.push_back(classes.at(i));
}
}
}
bboxes = nbboxes;
classes = nclasses;
}

void
TorchImgRandAugCV::augment_with_bbox(cv::Mat &src,
std::vector<torch::Tensor> &targets)
Expand All @@ -82,6 +115,15 @@ namespace dd
bboxes.push_back(bbox); // add (xmin, ymin, xmax, ymax)
classes.push_back(c[bb].item<int>());
}
GeometryParams geoparams = _geometry_params;
bool duplicated = false;
// in mirrored padding mode, duplicate bboxes in the enlarged image
if (geoparams._geometry_pad_mode == 2)
{
duplicated = true;
applyDuplicateBBox(bboxes, classes, static_cast<float>(src.cols),
static_cast<float>(src.rows));
}

bool mirror = applyMirror(src);
if (mirror)
Expand All @@ -104,7 +146,6 @@ namespace dd
static_cast<float>(src.rows), crop_x, crop_y);
}
applyCutout(src, _cutout_params);
GeometryParams geoparams = _geometry_params;
cv::Mat src_c = src.clone();
applyGeometry(src_c, geoparams, true);
if (!geoparams._lambda.empty())
Expand All @@ -122,7 +163,7 @@ namespace dd

// replacing the initial bboxes with the transformed ones.
nbbox = bboxes.size();
if (!cropped)
if (!cropped && !duplicated)
{
for (int bb = 0; bb < nbbox; ++bb)
{
Expand Down Expand Up @@ -583,6 +624,24 @@ namespace dd
outputQuad[1].x = cols - outputQuad[0].x;
}
}
if (cp._geometry_transl_horizontal)
{
float tx = cols * cp._geometry_transl_factor
* (2 * _uniform_real_1(_rnd_gen) - 1);
for (int i = 0; i < 4; ++i)
{
outputQuad[i].x += tx;
}
}
if (cp._geometry_transl_vertical)
{
float ty = rows * cp._geometry_transl_factor
* (2 * _uniform_real_1(_rnd_gen) - 1);
for (int i = 0; i < 4; ++i)
{
outputQuad[i].y += ty;
}
}
}

void TorchImgRandAugCV::applyGeometry(cv::Mat &src, GeometryParams &cp,
Expand Down
18 changes: 15 additions & 3 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,14 @@ namespace dd

GeometryParams(const float &prob, const bool &geometry_persp_horizontal,
const bool &geometry_persp_vertical,
const bool &geometry_transl_horizontal,
const bool &geometry_transl_vertical,
const bool &geometry_zoom_out, const bool &geometry_zoom_in,
const std::string &geometry_pad_mode_str)
: _prob(prob), _geometry_persp_horizontal(geometry_persp_horizontal),
_geometry_persp_vertical(geometry_persp_vertical),
_geometry_transl_horizontal(geometry_transl_horizontal),
_geometry_transl_vertical(geometry_transl_vertical),
_geometry_zoom_out(geometry_zoom_out),
_geometry_zoom_in(geometry_zoom_in)
{
Expand All @@ -161,14 +165,19 @@ namespace dd
bool _geometry_persp_horizontal
= true; /**< horizontal perspective change. */
bool _geometry_persp_vertical = true; /**< vertical perspective change. */
bool _geometry_transl_horizontal
= false; /**< horizontal translation change. */
bool _geometry_transl_vertical
= false; /**< vertical translation change. */
bool _geometry_zoom_out
= true; /**< distance change: look from further away. */
bool _geometry_zoom_in = true; /**< distance change: look closer. */
float _geometry_zoom_factor = 0.25; /**< zoom factor: 0.25 means that image
can be *1.25 or /1.25. */
float _geometry_persp_factor
= 0.25; /**< persp factor: 0.25 means that new
image corners be in 1.25 or 0.75. */
float _geometry_persp_factor = 0.25; /**< persp factor: 0.25 means that new
image corners be in 1.25 or 0.75. */
float _geometry_transl_factor = 0.5; /**< transl factor: 0.5 means that new
image corners be in 1.5 or 0.5. */
uint8_t _geometry_pad_mode = 1; /**< filling around images, 1: constant, 2:
repeat nearest (replicate). */
float _geometry_bbox_intersect
Expand Down Expand Up @@ -306,6 +315,9 @@ namespace dd

protected:
bool roll_weighted_dice(const float &prob);
void applyDuplicateBBox(std::vector<std::vector<float>> &bboxes,
std::vector<int> &classes, const float &img_width,
const float &img_height);
bool applyMirror(cv::Mat &src, const bool &sample = true);
void applyMirrorBBox(std::vector<std::vector<float>> &bboxes,
const float &img_width);
Expand Down
9 changes: 9 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,12 @@ namespace dd
if (ad_geometry.has("persp_horizontal"))
geometry_params._geometry_persp_horizontal
= ad_geometry.get("persp_horizontal").get<bool>();
if (ad_geometry.has("transl_vertical"))
geometry_params._geometry_transl_vertical
= ad_geometry.get("transl_vertical").get<bool>();
if (ad_geometry.has("transl_horizontal"))
geometry_params._geometry_transl_horizontal
= ad_geometry.get("transl_horizontal").get<bool>();
if (ad_geometry.has("zoom_out"))
geometry_params._geometry_zoom_out
= ad_geometry.get("zoom_out").get<bool>();
Expand All @@ -716,6 +722,9 @@ namespace dd
if (ad_geometry.has("persp_factor"))
geometry_params._geometry_persp_factor
= ad_geometry.get("persp_factor").get<double>();
if (ad_geometry.has("transl_factor"))
geometry_params._geometry_transl_factor
= ad_geometry.get("transl_factor").get<double>();
if (ad_geometry.has("zoom_factor"))
geometry_params._geometry_zoom_factor
= ad_geometry.get("zoom_factor").get<double>();
Expand Down

0 comments on commit 8752e1f

Please sign in to comment.