Skip to content

Commit

Permalink
[R] im2rec in R. close apache#7273
Browse files Browse the repository at this point in the history
  • Loading branch information
thirdwing committed Aug 8, 2017
1 parent 1a617fa commit 63c53cb
Show file tree
Hide file tree
Showing 8 changed files with 401 additions and 22 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ rcpplint:
rpkg:
mkdir -p R-package/inst
mkdir -p R-package/inst/libs
cp src/io/image_recordio.h R-package/src
cp -rf lib/libmxnet.so R-package/inst/libs
mkdir -p R-package/inst/include
cp -rf include/* R-package/inst/include
Expand Down
49 changes: 48 additions & 1 deletion R-package/R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,52 @@ mx.util.filter.null <- function(lst) {
#'
#' @export
mxnet.export <- function(path) {
mxnet.internal.export(path.expand(path))
mx.internal.export(path.expand(path))
}

#' Convert images into image recordio format
#' @param image_lst
#' The image lst file
#' @param root
#' The root folder for image files
#' @param output_rec
#' The output rec file
#' @param label_width
#' The label width in the list file. Default is 1.
#' @param pack_label
#' Whether to also pack multi dimenional label in the record file. Default is 0.
#' @param new_size
#' The shorter edge of image will be resized to the newsize.
#' Original images will be packed by default.
#' @param nsplit
#' It is used for part generation, logically split the image.lst to NSPLIT parts by position.
#' Default is 1.
#' @param partid
#' It is used for part generation, pack the images from the specific part in image.lst.
#' Default is 0.
#' @param center_crop
#' Whether to crop the center image to make it square. Default is 0.
#' @param quality
#' JPEG quality for encoding (1-100, default: 95) or PNG compression for encoding (1-9, default: 3).
#' @param color_mode
#' Force color (1), gray image (0) or keep source unchanged (-1). Default is 1.
#' @param unchanged
#' Keep the original image encoding, size and color. If set to 1, it will ignore the others parameters.
#' @param inter_method
#' NN(0), BILINEAR(1), CUBIC(2), AREA(3), LANCZOS4(4), AUTO(9), RAND(10). Default is 1.
#' @param encoding
#' The encoding type for images. It can be '.jpg' or '.png'. Default is '.jpg'.
#' @export
im2rec <- function(image_lst, root, output_rec, label_width = 1L,
pack_label = 0L, new_size = -1L, nsplit = 1L,
partid = 0L, center_crop = 0L, quality = 95L,
color_mode = 1L, unchanged = 0L, inter_method = 1L,
encoding = ".jpg") {
image_lst <- path.expand(image_lst)
root <- path.expand(root)
output_rec <- path.expand(output_rec)
mx.internal.im2rec(image_lst, root, output_rec, label_width,
pack_label, new_size, nsplit, partid,
center_crop, quality, color_mode, unchanged,
inter_method, encoding)
}
2 changes: 1 addition & 1 deletion R-package/src/Makevars
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@

CXX_STD = CXX11
PKG_CPPFLAGS = -I../inst/include
PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS)
2 changes: 1 addition & 1 deletion R-package/src/export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Exporter* Exporter::Get() {
void Exporter::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
Exporter::Get()->scope_ = ::getCurrentScope();
function("mxnet.internal.export", &Exporter::Export,
function("mx.internal.export", &Exporter::Export,
Rcpp::List::create(_["path"]),
"Internal function of mxnet, used to export generated functions file.");
}
Expand Down
269 changes: 269 additions & 0 deletions R-package/src/im2rec.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
/*!
* Copyright (c) 2017 by Contributors
* \file export.h
* \brief Export module that takes charge of code generation and document
* Generation for functions exported from R-side
*/

#include <cctype>
#include <cstring>
#include <string>
#include <vector>
#include <iomanip>
#include <sstream>
#include <random>
#include "dmlc/base.h"
#include "dmlc/io.h"
#include "dmlc/timer.h"
#include "dmlc/logging.h"
#include "dmlc/recordio.h"
#include <opencv2/opencv.hpp>
#include "image_recordio.h"
#include "base.h"
#include "im2rec.h"

namespace mxnet {
namespace R {

int GetInterMethod(int inter_method, int old_width, int old_height,
int new_width, int new_height, std::mt19937& prnd) { // NOLINT(*)
if (inter_method == 9) {
if (new_width > old_width && new_height > old_height) {
return 2; // CV_INTER_CUBIC for enlarge
} else if (new_width <old_width && new_height < old_height) {
return 3; // CV_INTER_AREA for shrink
} else {
return 1; // CV_INTER_LINEAR for others
}
} else if (inter_method == 10) {
std::uniform_int_distribution<size_t> rand_uniform_int(0, 4);
return rand_uniform_int(prnd);
} else {
return inter_method;
}
}

IM2REC* IM2REC::Get() {
static IM2REC inst;
return &inst;
}

void IM2REC::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
IM2REC::Get()->scope_ = ::getCurrentScope();
function("mx.internal.im2rec", &IM2REC::im2rec,
Rcpp::List::create(_["image_lst"],
_["root"],
_["output_rec"],
_["label_width"],
_["pack_label"],
_["new_size"],
_["nsplit"],
_["partid"],
_["center_crop"],
_["quality"],
_["color_mode"],
_["unchanged"],
_["inter_method"],
_["encoding"]),
"");
}

void IM2REC::im2rec(const std::string & image_lst, const std::string & root,
const std::string & output_rec,
int label_width, int pack_label, int new_size, int nsplit,
int partid, int center_crop, int quality,
int color_mode, int unchanged,
int inter_method, std::string encoding) {
// Check parameters ranges
if (color_mode != -1 && color_mode != 0 && color_mode != 1) {
Rcpp::stop("Color mode must be -1, 0 or 1.");
}
if (encoding != std::string(".jpg") && encoding != std::string(".png")) {
Rcpp::stop("Encoding mode must be .jpg or .png.");
}
if (label_width <= 1 && pack_label) {
Rcpp::stop("pack_label can only be used when label_width > 1");
}
if (new_size > 0) {
LOG(INFO) << "New Image Size: Short Edge " << new_size;
} else {
LOG(INFO) << "Keep origin image size";
}
if (center_crop) {
LOG(INFO) << "Center cropping to square";
}
if (color_mode == 0) {
LOG(INFO) << "Use gray images";
}
if (color_mode == -1) {
LOG(INFO) << "Keep original color mode";
}
LOG(INFO) << "Encoding is " << encoding;

if (encoding == std::string(".png") && quality > 9) {
quality = 3;
}
if (inter_method != 1) {
switch (inter_method) {
case 0:
LOG(INFO) << "Use inter_method CV_INTER_NN";
break;
case 2:
LOG(INFO) << "Use inter_method CV_INTER_CUBIC";
break;
case 3:
LOG(INFO) << "Use inter_method CV_INTER_AREA";
break;
case 4:
LOG(INFO) << "Use inter_method CV_INTER_LANCZOS4";
break;
case 9:
LOG(INFO) << "Use inter_method mod auto(cubic for enlarge, area for shrink)";
break;
case 10:
LOG(INFO) << "Use inter_method mod rand(nn/bilinear/cubic/area/lanczos4)";
break;
}
}
std::random_device rd;
std::mt19937 prnd(rd());
using namespace dmlc;
static const size_t kBufferSize = 1 << 20UL;
mxnet::io::ImageRecordIO rec;
size_t imcnt = 0;
double tstart = dmlc::GetTime();
dmlc::InputSplit *flist =
dmlc::InputSplit::Create(image_lst.c_str(), partid, nsplit, "text");
std::ostringstream os;
if (nsplit == 1) {
os << output_rec;
} else {
os << output_rec << ".part" << std::setw(3) << std::setfill('0') << partid;
}
LOG(INFO) << "Write to output: " << os.str();
dmlc::Stream *fo = dmlc::Stream::Create(os.str().c_str(), "w");
LOG(INFO) << "Output: " << os.str();
dmlc::RecordIOWriter writer(fo);
std::string fname, path, blob;
std::vector<unsigned char> decode_buf;
std::vector<unsigned char> encode_buf;
std::vector<int> encode_params;
if (encoding == std::string(".png")) {
encode_params.push_back(CV_IMWRITE_PNG_COMPRESSION);
encode_params.push_back(quality);
LOG(INFO) << "PNG encoding compression: " << quality;
} else {
encode_params.push_back(CV_IMWRITE_JPEG_QUALITY);
encode_params.push_back(quality);
LOG(INFO) << "JPEG encoding quality: " << quality;
}
dmlc::InputSplit::Blob line;
std::vector<float> label_buf(label_width, 0.f);

while (flist->NextRecord(&line)) {
std::string sline(static_cast<char*>(line.dptr), line.size);
std::istringstream is(sline);
if (!(is >> rec.header.image_id[0] >> rec.header.label)) continue;
label_buf[0] = rec.header.label;
for (int k = 1; k < label_width; ++k) {
RCHECK(is >> label_buf[k])
<< "Invalid ImageList, did you provide the correct label_width?";
}
if (pack_label) rec.header.flag = label_width;
rec.SaveHeader(&blob);
if (pack_label) {
size_t bsize = blob.size();
blob.resize(bsize + label_buf.size()*sizeof(float));
memcpy(BeginPtr(blob) + bsize,
BeginPtr(label_buf), label_buf.size()*sizeof(float));
}
RCHECK(std::getline(is, fname));
// eliminate invalid chars in the end
while (fname.length() != 0 &&
(isspace(*fname.rbegin()) || !isprint(*fname.rbegin()))) {
fname.resize(fname.length() - 1);
}
// eliminate invalid chars in beginning.
const char *p = fname.c_str();
while (isspace(*p)) ++p;
path = root + p;
// use "r" is equal to rb in dmlc::Stream
dmlc::Stream *fi = dmlc::Stream::Create(path.c_str(), "r");
decode_buf.clear();
size_t imsize = 0;
while (true) {
decode_buf.resize(imsize + kBufferSize);
size_t nread = fi->Read(BeginPtr(decode_buf) + imsize, kBufferSize);
imsize += nread;
decode_buf.resize(imsize);
if (nread != kBufferSize) break;
}
delete fi;


if (unchanged != 1) {
cv::Mat img = cv::imdecode(decode_buf, color_mode);
RCHECK(img.data != NULL) << "OpenCV decode fail:" << path;
cv::Mat res = img;
if (new_size > 0) {
if (center_crop) {
if (img.rows > img.cols) {
int margin = (img.rows - img.cols)/2;
img = img(cv::Range(margin, margin+img.cols), cv::Range(0, img.cols));
} else {
int margin = (img.cols - img.rows)/2;
img = img(cv::Range(0, img.rows), cv::Range(margin, margin + img.rows));
}
}
int interpolation_method = 1;
if (img.rows > img.cols) {
if (img.cols != new_size) {
interpolation_method = GetInterMethod(inter_method, img.cols, img.rows,
new_size,
img.rows * new_size / img.cols, prnd);
cv::resize(img, res, cv::Size(new_size,
img.rows * new_size / img.cols),
0, 0, interpolation_method);
} else {
res = img.clone();
}
} else {
if (img.rows != new_size) {
interpolation_method = GetInterMethod(inter_method, img.cols,
img.rows, new_size * img.cols / img.rows,
new_size, prnd);
cv::resize(img, res, cv::Size(new_size * img.cols / img.rows,
new_size), 0, 0, interpolation_method);
} else {
res = img.clone();
}
}
}
encode_buf.clear();
RCHECK(cv::imencode(encoding, res, encode_buf, encode_params));

// write buffer
size_t bsize = blob.size();
blob.resize(bsize + encode_buf.size());
memcpy(BeginPtr(blob) + bsize,
BeginPtr(encode_buf), encode_buf.size());
} else {
size_t bsize = blob.size();
blob.resize(bsize + decode_buf.size());
memcpy(BeginPtr(blob) + bsize,
BeginPtr(decode_buf), decode_buf.size());
}
writer.WriteRecord(BeginPtr(blob), blob.size());
// write header
++imcnt;
if (imcnt % 1000 == 0) {
LOG(INFO) << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
}
}
LOG(INFO) << "Total: " << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
delete fo;
delete flist;
}
} // namespace R
} // namespace mxnet
42 changes: 42 additions & 0 deletions R-package/src/im2rec.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*!
* Copyright (c) 2017 by Contributors
* \file export.h
* \brief Export module that takes charge of code generation and document
* Generation for functions exported from R-side
*/

#ifndef MXNET_RCPP_IM2REC_H_
#define MXNET_RCPP_IM2REC_H_

#include <Rcpp.h>
#include <string>

namespace mxnet {
namespace R {

class IM2REC {
public:
/*!
* \brief Export the generated file into path.
* \param path The path to be exported.
*/
static void im2rec(const std::string & image_lst, const std::string & root,
const std::string & output_rec,
int label_width = 1, int pack_label = 0, int new_size = -1, int nsplit = 1,
int partid = 0, int center_crop = 0, int quality = 95,
int color_mode = 1, int unchanged = 0,
int inter_method = 1, std::string encoding = ".jpg");
// intialize the Rcpp module
static void InitRcppModule();

public:
// get the singleton of exporter
static IM2REC* Get();
/*! \brief The scope of current module to export */
Rcpp::Module* scope_;
};

} // namespace R
} // namespace mxnet

#endif // MXNET_RCPP_IM2REC_H_
Loading

0 comments on commit 63c53cb

Please sign in to comment.