Skip to content

Commit

Permalink
Merge pull request #8991 from reyoung/feature/shuffle_reader
Browse files Browse the repository at this point in the history
Feature/shuffle reader
  • Loading branch information
reyoung authored Mar 14, 2018
2 parents 881c522 + 127b371 commit 48f213e
Show file tree
Hide file tree
Showing 12 changed files with 270 additions and 127 deletions.
20 changes: 2 additions & 18 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}

std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}

void SetDim(const std::string& name, const DDim& dim) override {
Expand All @@ -470,15 +462,7 @@ class RuntimeInferShapeContext : public InferShapeContext {

void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}

proto::VarType::Type GetVarType(const std::string& name) const override {
Expand Down
22 changes: 15 additions & 7 deletions paddle/fluid/framework/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@

namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}

DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
shapes_.size());
return shapes_[idx];
}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}

void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &expect = dims_[i];

PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
}
}
}
} // namespace framework
} // namespace paddle
51 changes: 20 additions & 31 deletions paddle/fluid/framework/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,29 @@

#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"

#include <memory>
#include <thread>
#include <vector>

namespace paddle {
namespace framework {

class ReaderBase {
public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;

virtual void ReInit() = 0;

DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }

virtual bool HasNext() const = 0;

virtual ~ReaderBase() {}

protected:
std::vector<DDim> shapes_;
};

class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
virtual ~ReaderBase();
};

class DecoratedReader : public ReaderBase {
public:
explicit DecoratedReader(ReaderBase* reader)
: ReaderBase(reader->shapes()), reader_(reader) {
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}

Expand All @@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_;
};

class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);

void ReadNext(std::vector<LoDTensor>* out) override;

protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;

private:
std::vector<DDim> dims_;
};

// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class ReaderHolder {
Expand All @@ -78,19 +80,6 @@ class ReaderHolder {
reader_->ReInit();
}

DDim shape(size_t idx) const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shape(idx);
}
std::vector<DDim> shapes() const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shapes();
}
void set_shapes(const std::vector<DDim>& shapes) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->set_shapes(shapes);
}

bool HasNext() const { return reader_->HasNext(); }

private:
Expand Down
111 changes: 88 additions & 23 deletions paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;

class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader),
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize)) {
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
struct Item {
Item() : ctx_(nullptr) {}

std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
};

explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#endif
}
}

start_thread();
}

void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach();
}

Expand All @@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
private:
void PrefetchThreadFunc();

framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
framework::Channel<Item>* buffer_;
platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
};

class CreateDoubleBufferReaderOp : public framework::OperatorBase {
Expand All @@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new DoubleBufferReader(underlying_reader.Get()));

auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
sin.seekg(std::string("CUDA:").size(), std::ios::beg);
size_t num;
sin >> num;
place = platform::CUDAPlace(static_cast<int>(num));
}

out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
}
};

Expand All @@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training.
)DOC");
std::unordered_set<std::string> enum_range;
constexpr size_t kMaxCUDADevs = 128;
for (size_t i = 0; i < kMaxCUDADevs; ++i) {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
.InEnum({enum_range});
}
};

void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
buffer_->Receive(out);
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}

*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
}
}

void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
// The existing prefetch thread will terminate for the buffer_ is closed.
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
start_thread();
}

void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
while (true) {
std::vector<framework::LoDTensor> batch;
reader_->ReadNext(&batch);
if (batch.empty()) {
// EOF
buffer_->Close();
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
break;
size_t gpu_ctx_offset = 0;
while (reader_->HasNext()) {
Item batch;
reader_->ReadNext(&batch.payloads_);
if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch;
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
}
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
}

if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates.";
break;
}
}
buffer_->Close();
}

bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}

} // namespace reader
} // namespace operators
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ namespace operators {
namespace reader {

template <typename T>
class RandomDataGenerator : public framework::FileReader {
class RandomDataGenerator : public framework::ReaderBase {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max)
: FileReader(shapes), min_(min), max_(max) {
: framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()();
Expand Down Expand Up @@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float max_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
std::vector<framework::DDim> shapes_;
};

template <typename T>
Expand Down
21 changes: 12 additions & 9 deletions paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ namespace operators {
namespace reader {
class RecordIOFileReader : public framework::FileReader {
public:
RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& shapes)
: FileReader(shapes),
explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims)
: FileReader(dims),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}

void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}

bool HasNext() const override { return scanner_.HasNext(); }

void ReInit() override { scanner_.Reset(); }

protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}

private:
recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_;
Expand All @@ -54,12 +55,12 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename");

auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes));
out->Reset(
new RecordIOFileReader(filename, RestoreShapes(shape_concat, ranks)));
}
};

Expand All @@ -85,3 +86,5 @@ namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
reader::CreateRecordIOReaderOp,
reader::CreateRecordIOReaderOpMaker);

REGISTER_FILE_READER(recordio, reader::RecordIOFileReader);
Loading

0 comments on commit 48f213e

Please sign in to comment.