Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dist transpiler support prefetch #9714

Merged
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
edcfcad
init
jacquesqiao Apr 5, 2018
66ab88a
add some check
jacquesqiao Apr 5, 2018
29174df
add dist transpile logic
jacquesqiao Apr 7, 2018
54656a1
add insert op for block
jacquesqiao Apr 7, 2018
171560b
init change get_pserver_program
jacquesqiao Apr 7, 2018
3605922
Merge branch 'develop' into dist-transpiler-support-prefetch
jacquesqiao Apr 7, 2018
d3f2d4c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 8, 2018
6973bcb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 8, 2018
38ed3e8
optimize code
jacquesqiao Apr 8, 2018
eb31b66
fix a bug
jacquesqiao Apr 8, 2018
b4e974a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 8, 2018
d672592
can run now
jacquesqiao Apr 8, 2018
2e69b77
start to do table split
jacquesqiao Apr 9, 2018
3ad3eea
start to process table gradient
jacquesqiao Apr 9, 2018
a07a063
complete pserver part
jacquesqiao Apr 10, 2018
53d6459
can send_vars now
jacquesqiao Apr 10, 2018
e0fca82
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 10, 2018
b1e398d
revert cpplint
jacquesqiao Apr 10, 2018
cf9d25f
fix a bug
jacquesqiao Apr 10, 2018
064a913
optimize code
jacquesqiao Apr 10, 2018
f81d6b3
move dist test to models
jacquesqiao Apr 10, 2018
f467b18
revert the interface of distribute_transpiler.transpile
jacquesqiao Apr 11, 2018
4b8189f
fix prefetch_block
jacquesqiao Apr 11, 2018
d1c8f4b
optimize trainspiler code
jacquesqiao Apr 11, 2018
9d3ecca
Merge branch 'develop' into dist-transpiler-support-prefetch
jacquesqiao Apr 11, 2018
dff691c
add comment to sum_op
jacquesqiao Apr 11, 2018
bb27df1
add warning log
jacquesqiao Apr 11, 2018
356b9e6
fix comment
jacquesqiao Apr 11, 2018
2f4962d
fix test_send_recv
jacquesqiao Apr 11, 2018
e2674e8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 11, 2018
063a956
fix test_send_recv
jacquesqiao Apr 11, 2018
fde5445
Merge branch 'develop' into dist-transpiler-support-prefetch
jacquesqiao Apr 11, 2018
8eea574
fix train with no distributed table
jacquesqiao Apr 11, 2018
193be56
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Apr 12, 2018
4554e7b
optimize GetDims
jacquesqiao Apr 12, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class BlockDesc {

/*
* Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or
* Note that for either input or output variable, if it is also an input or
* output variable of other ops, we should remain it.
*/
void RemoveOp(size_t s, size_t e);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
return var->Get<SelectedRows>().value().dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this affect other places like optimization ops?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will optimize this code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

} else {
return DDim({-1});
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/concat_op.h"

#include <string>
#include <vector>

namespace paddle {
Expand All @@ -33,7 +35,7 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
const size_t n = ins.size();

PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
// PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May delete the comment.


auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply;

std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
Expand Down
45 changes: 29 additions & 16 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <ostream>
#include <thread>
#include <thread> // NOLINT
#include <vector>

#include "paddle/fluid/operators/listen_and_serv_op.h"

Expand Down Expand Up @@ -88,27 +89,35 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,

auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");

framework::Executor executor(dev_place);
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
if (blkid != prefetch_block->ID()) {
block_list.push_back(blkid);
}
}
auto prepared = executor.Prepare(*program, block_list);
auto optimize_prepared = executor.Prepare(*program, block_list);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to prepare all the blocks of the program, so maybe the name prepared is more suitable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimize_prepared is used to be different with prefetch_prepared

// Insert placeholder for block0 which holds current op itself.
prepared.insert(prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));

rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code L106 have already prepared all the blocks, so we don't need to prepare the prefetch_block again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
Expand Down Expand Up @@ -166,16 +175,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
parallel_blkids.push_back(1);
double ts = detail::GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
&recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
if (blkid != prefetch_block->ID()) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
parallel_blkids.push_back(blkid);
}
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
&recv_scope);
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";

// Reset the received sparse variables, the sum operator would not
Expand Down Expand Up @@ -211,6 +222,8 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/listen_and_serv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <stdint.h>
#include <ostream>
#include <string>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
Expand All @@ -27,6 +28,7 @@ namespace paddle {
namespace operators {

constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";

void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/prefetch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <future>
#include <future> // NOLINT
#include <ostream>

#include "paddle/fluid/framework/data_type.h"
Expand Down Expand Up @@ -50,8 +50,8 @@ class PrefetchOp : public framework::OperatorBase {

for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << "to get "
<< outs[i] << "back";
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back";
rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i],
outs[i]);
} else {
Expand All @@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out",
"(SelectedRows) result "
"(LoDTensor) result "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type of Output variable is SelectedRows, just because the shape was not a certain value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it should be LoDTensor, because the following op is not certain, most of them can only process LoDTensor, SelectedRows is constructed when backward.

"to be fetched from parameter server")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/send_vars_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <future>
#include <future> // NOLINT
#include <ostream>

#include "paddle/fluid/framework/data_type.h"
Expand All @@ -36,7 +36,7 @@ class SendVarsOp : public framework::OperatorBase {
auto ins = Inputs("X");

std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_sent");
int sync_send = Attr<int>("sync_send");

platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/sgd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
// TODO(qijun): check dimensions of Param and Grad at complie
// and run time.
// TODO(qijun): check dimensions of Param and Grad at compile
// and runtime.
ctx->SetOutputDim("ParamOut", param_dim);
}

Expand Down
14 changes: 8 additions & 6 deletions paddle/fluid/operators/split_ids_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,21 @@ class SplitIdsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");

auto ids_var_type = ctx->GetInputsVarType("Ids").front();
PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR);

auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
}
};

class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR);
block->Var(out_var)->SetType(input_var->GetType());
}
}
};
Expand All @@ -73,4 +74,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>);
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
70 changes: 49 additions & 21 deletions paddle/fluid/operators/split_ids_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,63 @@ namespace operators {
template <typename DeviceContext, typename T>
class SplitIdsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("SplitIds do not support GPU kernel");
}

auto& ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T* ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
const auto *ids_var = ctx.InputVar("Ids");
if (ids_var->IsType<framework::LoDTensor>()) {
const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();

std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());
std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());

// split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}
// split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}

// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto *shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto *shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
}
}
} else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0], ids_selected_rows->rows().size(), "");
const T *ids = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
// get rows for outputs
for (auto &id : ids_rows) {
size_t shard_id = static_cast<size_t>(id) % shard_num;
outs[shard_id]->mutable_rows()->push_back(id);
}

// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto* shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto* shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
int64_t row_width = ids_dims[1];
for (auto &out : outs) {
out->set_height(ids_selected_rows->height());
framework::DDim ddim = framework::make_ddim(
{static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (size_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
row_width * sizeof(T));
}
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/sum_op.h"

#include <algorithm>
#include <string>
#include <vector>

#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"

Expand All @@ -36,8 +38,8 @@ class SumOp : public framework::OperatorWithKernel {
}

auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
// size_t N = x_dims.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete these comments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add TODO here, maybe this check need to add back in the future.

// PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");

framework::DDim in_dim({0});
for (auto& x_dim : x_dims) {
Expand Down
Loading