-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dist transpiler support prefetch #9714
Dist transpiler support prefetch #9714
Conversation
… dist-transpiler-support-prefetch
… dist-transpiler-support-prefetch
… dist-transpiler-support-prefetch
… dist-transpiler-support-prefetch
paddle/fluid/operators/concat_op.cc
Outdated
@@ -33,7 +35,7 @@ class ConcatOp : public framework::OperatorWithKernel { | |||
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); | |||
const size_t n = ins.size(); | |||
|
|||
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); | |||
// PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May delete the comment.
} | ||
auto prepared = executor.Prepare(*program, block_list); | ||
auto optimize_prepared = executor.Prepare(*program, block_list); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to prepare all the blocks of the program, so maybe the name prepared
is more suitable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optimize_prepared is used to be different with prefetch_prepared
|
||
rpc_service_->SetScope(&recv_scope); | ||
rpc_service_->SetDevCtx(&dev_ctx); | ||
// TODO(qiao) set proper fields for table lookup and update | ||
rpc_service_->SetExecutor(&executor); | ||
rpc_service_->SetPrefetchBlkdId(0); | ||
VLOG(3) << "prefetch block id is " << prefetch_block->ID(); | ||
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code L106 have already prepared all the blocks, so we don't need to prepare the prefetch_block again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I skipped the prefetch_block here https://github.com/PaddlePaddle/Paddle/pull/9714/files#diff-64ee97d744659db61dc8ae72bfc103b5R102
@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { | |||
"(RPCClient) The RPC client object which will be" | |||
"initialized at most once."); | |||
AddOutput("Out", | |||
"(SelectedRows) result " | |||
"(LoDTensor) result " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the type of Output variable is SelectedRows
, just because the shape was not a certain value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it should be LoDTensor, because the following op is not certain, most of them can only process LoDTensor, SelectedRows is constructed when backward.
paddle/fluid/operators/sum_op.cc
Outdated
@@ -36,8 +38,8 @@ class SumOp : public framework::OperatorWithKernel { | |||
} | |||
|
|||
auto x_dims = ctx->GetInputsDim("X"); | |||
size_t N = x_dims.size(); | |||
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); | |||
// size_t N = x_dims.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please delete these comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add TODO here, maybe this check need to add back in the future.
@@ -252,12 +315,114 @@ def transpile(self, | |||
outputs={"Out": [orig_param]}, | |||
attrs={"axis": 0}) | |||
|
|||
if self.has_distributed_lookup_table: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move these following code into an independent function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Awesome! Thanks for PR and make it work! |
… dist-transpiler-support-prefetch
… dist-transpiler-support-prefetch
paddle/fluid/framework/operator.cc
Outdated
@@ -55,7 +55,7 @@ static DDim GetDims(const Scope& scope, const std::string& name) { | |||
if (var->IsType<LoDTensor>()) { | |||
return var->Get<LoDTensor>().dims(); | |||
} else if (var->IsType<SelectedRows>()) { | |||
return var->Get<SelectedRows>().GetCompleteDims(); | |||
return var->Get<SelectedRows>().value().dims(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this affect other places like optimization ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will optimize this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# 2. add split_ids_op and send_vars_op to send gradient to pservers | ||
# there should only be one table_name | ||
all_ops = program.global_block().ops | ||
table_grad_name = framework.grad_var_name(self.table_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad_var_name
sometimes may not get the "real" grad var name, for backward may create a different name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, here the name of the table parameter's gradient will always be table_name@GRAD, the table_name@GRAD@RENAME name will be merged into table_name@GRAD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
project: #9597
task list: #9211
test code: https://github.com/jacquesqiao/models/tree/dist-lookup-table/dist_lookup_table
remain problem:
prefetch block has to be at that last, or RunPreparedContext will fail.