forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
401 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
|
||
CXX_STD = CXX11 | ||
PKG_CPPFLAGS = -I../inst/include | ||
PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file export.h | ||
* \brief Export module that takes charge of code generation and document | ||
* Generation for functions exported from R-side | ||
*/ | ||
|
||
#include <cctype> | ||
#include <cstring> | ||
#include <string> | ||
#include <vector> | ||
#include <iomanip> | ||
#include <sstream> | ||
#include <random> | ||
#include "dmlc/base.h" | ||
#include "dmlc/io.h" | ||
#include "dmlc/timer.h" | ||
#include "dmlc/logging.h" | ||
#include "dmlc/recordio.h" | ||
#include <opencv2/opencv.hpp> | ||
#include "image_recordio.h" | ||
#include "base.h" | ||
#include "im2rec.h" | ||
|
||
namespace mxnet { | ||
namespace R { | ||
|
||
int GetInterMethod(int inter_method, int old_width, int old_height, | ||
int new_width, int new_height, std::mt19937& prnd) { // NOLINT(*) | ||
if (inter_method == 9) { | ||
if (new_width > old_width && new_height > old_height) { | ||
return 2; // CV_INTER_CUBIC for enlarge | ||
} else if (new_width <old_width && new_height < old_height) { | ||
return 3; // CV_INTER_AREA for shrink | ||
} else { | ||
return 1; // CV_INTER_LINEAR for others | ||
} | ||
} else if (inter_method == 10) { | ||
std::uniform_int_distribution<size_t> rand_uniform_int(0, 4); | ||
return rand_uniform_int(prnd); | ||
} else { | ||
return inter_method; | ||
} | ||
} | ||
|
||
IM2REC* IM2REC::Get() { | ||
static IM2REC inst; | ||
return &inst; | ||
} | ||
|
||
void IM2REC::InitRcppModule() { | ||
using namespace Rcpp; // NOLINT(*) | ||
IM2REC::Get()->scope_ = ::getCurrentScope(); | ||
function("mx.internal.im2rec", &IM2REC::im2rec, | ||
Rcpp::List::create(_["image_lst"], | ||
_["root"], | ||
_["output_rec"], | ||
_["label_width"], | ||
_["pack_label"], | ||
_["new_size"], | ||
_["nsplit"], | ||
_["partid"], | ||
_["center_crop"], | ||
_["quality"], | ||
_["color_mode"], | ||
_["unchanged"], | ||
_["inter_method"], | ||
_["encoding"]), | ||
""); | ||
} | ||
|
||
void IM2REC::im2rec(const std::string & image_lst, const std::string & root, | ||
const std::string & output_rec, | ||
int label_width, int pack_label, int new_size, int nsplit, | ||
int partid, int center_crop, int quality, | ||
int color_mode, int unchanged, | ||
int inter_method, std::string encoding) { | ||
// Check parameters ranges | ||
if (color_mode != -1 && color_mode != 0 && color_mode != 1) { | ||
Rcpp::stop("Color mode must be -1, 0 or 1."); | ||
} | ||
if (encoding != std::string(".jpg") && encoding != std::string(".png")) { | ||
Rcpp::stop("Encoding mode must be .jpg or .png."); | ||
} | ||
if (label_width <= 1 && pack_label) { | ||
Rcpp::stop("pack_label can only be used when label_width > 1"); | ||
} | ||
if (new_size > 0) { | ||
LOG(INFO) << "New Image Size: Short Edge " << new_size; | ||
} else { | ||
LOG(INFO) << "Keep origin image size"; | ||
} | ||
if (center_crop) { | ||
LOG(INFO) << "Center cropping to square"; | ||
} | ||
if (color_mode == 0) { | ||
LOG(INFO) << "Use gray images"; | ||
} | ||
if (color_mode == -1) { | ||
LOG(INFO) << "Keep original color mode"; | ||
} | ||
LOG(INFO) << "Encoding is " << encoding; | ||
|
||
if (encoding == std::string(".png") && quality > 9) { | ||
quality = 3; | ||
} | ||
if (inter_method != 1) { | ||
switch (inter_method) { | ||
case 0: | ||
LOG(INFO) << "Use inter_method CV_INTER_NN"; | ||
break; | ||
case 2: | ||
LOG(INFO) << "Use inter_method CV_INTER_CUBIC"; | ||
break; | ||
case 3: | ||
LOG(INFO) << "Use inter_method CV_INTER_AREA"; | ||
break; | ||
case 4: | ||
LOG(INFO) << "Use inter_method CV_INTER_LANCZOS4"; | ||
break; | ||
case 9: | ||
LOG(INFO) << "Use inter_method mod auto(cubic for enlarge, area for shrink)"; | ||
break; | ||
case 10: | ||
LOG(INFO) << "Use inter_method mod rand(nn/bilinear/cubic/area/lanczos4)"; | ||
break; | ||
} | ||
} | ||
std::random_device rd; | ||
std::mt19937 prnd(rd()); | ||
using namespace dmlc; | ||
static const size_t kBufferSize = 1 << 20UL; | ||
mxnet::io::ImageRecordIO rec; | ||
size_t imcnt = 0; | ||
double tstart = dmlc::GetTime(); | ||
dmlc::InputSplit *flist = | ||
dmlc::InputSplit::Create(image_lst.c_str(), partid, nsplit, "text"); | ||
std::ostringstream os; | ||
if (nsplit == 1) { | ||
os << output_rec; | ||
} else { | ||
os << output_rec << ".part" << std::setw(3) << std::setfill('0') << partid; | ||
} | ||
LOG(INFO) << "Write to output: " << os.str(); | ||
dmlc::Stream *fo = dmlc::Stream::Create(os.str().c_str(), "w"); | ||
LOG(INFO) << "Output: " << os.str(); | ||
dmlc::RecordIOWriter writer(fo); | ||
std::string fname, path, blob; | ||
std::vector<unsigned char> decode_buf; | ||
std::vector<unsigned char> encode_buf; | ||
std::vector<int> encode_params; | ||
if (encoding == std::string(".png")) { | ||
encode_params.push_back(CV_IMWRITE_PNG_COMPRESSION); | ||
encode_params.push_back(quality); | ||
LOG(INFO) << "PNG encoding compression: " << quality; | ||
} else { | ||
encode_params.push_back(CV_IMWRITE_JPEG_QUALITY); | ||
encode_params.push_back(quality); | ||
LOG(INFO) << "JPEG encoding quality: " << quality; | ||
} | ||
dmlc::InputSplit::Blob line; | ||
std::vector<float> label_buf(label_width, 0.f); | ||
|
||
while (flist->NextRecord(&line)) { | ||
std::string sline(static_cast<char*>(line.dptr), line.size); | ||
std::istringstream is(sline); | ||
if (!(is >> rec.header.image_id[0] >> rec.header.label)) continue; | ||
label_buf[0] = rec.header.label; | ||
for (int k = 1; k < label_width; ++k) { | ||
RCHECK(is >> label_buf[k]) | ||
<< "Invalid ImageList, did you provide the correct label_width?"; | ||
} | ||
if (pack_label) rec.header.flag = label_width; | ||
rec.SaveHeader(&blob); | ||
if (pack_label) { | ||
size_t bsize = blob.size(); | ||
blob.resize(bsize + label_buf.size()*sizeof(float)); | ||
memcpy(BeginPtr(blob) + bsize, | ||
BeginPtr(label_buf), label_buf.size()*sizeof(float)); | ||
} | ||
RCHECK(std::getline(is, fname)); | ||
// eliminate invalid chars in the end | ||
while (fname.length() != 0 && | ||
(isspace(*fname.rbegin()) || !isprint(*fname.rbegin()))) { | ||
fname.resize(fname.length() - 1); | ||
} | ||
// eliminate invalid chars in beginning. | ||
const char *p = fname.c_str(); | ||
while (isspace(*p)) ++p; | ||
path = root + p; | ||
// use "r" is equal to rb in dmlc::Stream | ||
dmlc::Stream *fi = dmlc::Stream::Create(path.c_str(), "r"); | ||
decode_buf.clear(); | ||
size_t imsize = 0; | ||
while (true) { | ||
decode_buf.resize(imsize + kBufferSize); | ||
size_t nread = fi->Read(BeginPtr(decode_buf) + imsize, kBufferSize); | ||
imsize += nread; | ||
decode_buf.resize(imsize); | ||
if (nread != kBufferSize) break; | ||
} | ||
delete fi; | ||
|
||
|
||
if (unchanged != 1) { | ||
cv::Mat img = cv::imdecode(decode_buf, color_mode); | ||
RCHECK(img.data != NULL) << "OpenCV decode fail:" << path; | ||
cv::Mat res = img; | ||
if (new_size > 0) { | ||
if (center_crop) { | ||
if (img.rows > img.cols) { | ||
int margin = (img.rows - img.cols)/2; | ||
img = img(cv::Range(margin, margin+img.cols), cv::Range(0, img.cols)); | ||
} else { | ||
int margin = (img.cols - img.rows)/2; | ||
img = img(cv::Range(0, img.rows), cv::Range(margin, margin + img.rows)); | ||
} | ||
} | ||
int interpolation_method = 1; | ||
if (img.rows > img.cols) { | ||
if (img.cols != new_size) { | ||
interpolation_method = GetInterMethod(inter_method, img.cols, img.rows, | ||
new_size, | ||
img.rows * new_size / img.cols, prnd); | ||
cv::resize(img, res, cv::Size(new_size, | ||
img.rows * new_size / img.cols), | ||
0, 0, interpolation_method); | ||
} else { | ||
res = img.clone(); | ||
} | ||
} else { | ||
if (img.rows != new_size) { | ||
interpolation_method = GetInterMethod(inter_method, img.cols, | ||
img.rows, new_size * img.cols / img.rows, | ||
new_size, prnd); | ||
cv::resize(img, res, cv::Size(new_size * img.cols / img.rows, | ||
new_size), 0, 0, interpolation_method); | ||
} else { | ||
res = img.clone(); | ||
} | ||
} | ||
} | ||
encode_buf.clear(); | ||
RCHECK(cv::imencode(encoding, res, encode_buf, encode_params)); | ||
|
||
// write buffer | ||
size_t bsize = blob.size(); | ||
blob.resize(bsize + encode_buf.size()); | ||
memcpy(BeginPtr(blob) + bsize, | ||
BeginPtr(encode_buf), encode_buf.size()); | ||
} else { | ||
size_t bsize = blob.size(); | ||
blob.resize(bsize + decode_buf.size()); | ||
memcpy(BeginPtr(blob) + bsize, | ||
BeginPtr(decode_buf), decode_buf.size()); | ||
} | ||
writer.WriteRecord(BeginPtr(blob), blob.size()); | ||
// write header | ||
++imcnt; | ||
if (imcnt % 1000 == 0) { | ||
LOG(INFO) << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed"; | ||
} | ||
} | ||
LOG(INFO) << "Total: " << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed"; | ||
delete fo; | ||
delete flist; | ||
} | ||
} // namespace R | ||
} // namespace mxnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file export.h | ||
* \brief Export module that takes charge of code generation and document | ||
* Generation for functions exported from R-side | ||
*/ | ||
|
||
#ifndef MXNET_RCPP_IM2REC_H_ | ||
#define MXNET_RCPP_IM2REC_H_ | ||
|
||
#include <Rcpp.h> | ||
#include <string> | ||
|
||
namespace mxnet { | ||
namespace R { | ||
|
||
class IM2REC { | ||
public: | ||
/*! | ||
* \brief Export the generated file into path. | ||
* \param path The path to be exported. | ||
*/ | ||
static void im2rec(const std::string & image_lst, const std::string & root, | ||
const std::string & output_rec, | ||
int label_width = 1, int pack_label = 0, int new_size = -1, int nsplit = 1, | ||
int partid = 0, int center_crop = 0, int quality = 95, | ||
int color_mode = 1, int unchanged = 0, | ||
int inter_method = 1, std::string encoding = ".jpg"); | ||
// intialize the Rcpp module | ||
static void InitRcppModule(); | ||
|
||
public: | ||
// get the singleton of exporter | ||
static IM2REC* Get(); | ||
/*! \brief The scope of current module to export */ | ||
Rcpp::Module* scope_; | ||
}; | ||
|
||
} // namespace R | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_RCPP_IM2REC_H_ |
Oops, something went wrong.