-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature/Add Broadcast and Gather op handle #9825
feature/Add Broadcast and Gather op handle #9825
Conversation
4c91482
to
d2ca99a
Compare
d2ca99a
to
8eaec5d
Compare
} | ||
return nullptr; | ||
} | ||
BCastOpHandle::BCastOpHandle(const std::vector<Scope *> &local_scopes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BCastOpHandle => BroadcastOpHandle
不要用缩写,除非是人尽皆知的 —— 读者的知识背景可能和作者很不一样,作者觉得望文生义的缩写,对于读者来说可能就是天书了。
而且 boradcast 是一个英语单词,如果要缩写,也是缩写成 B,而不是 BCast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
} | ||
|
||
BroadCastDestroy(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BoradCast => Broadcast
Broadcast 是一个英语单词,是广播的意思,不是两个单词的组合
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/details/broad_cast_op_handle.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broad_cast => broadcast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much @wangkuiyi
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/details/broad_cast_op_handle.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
return nullptr; | ||
} | ||
BCastOpHandle::BCastOpHandle(const std::vector<Scope *> &local_scopes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
} | ||
|
||
BroadCastDestroy(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
883e246
to
6db96ec
Compare
8459fdd
to
042821f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments. LG overall.
|
||
#include "paddle/fluid/platform/device_context.h" | ||
|
||
namespace f = paddle::framework; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the style in operator_test.cc:
namespace paddle {
namespace framework {
TEST...
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
namespace framework { | ||
namespace details { | ||
|
||
Tensor *GetTensorFromVar(Variable *in_var) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method seems duplicated. Can we put it in a common place?
s = e; | ||
} | ||
} else if (pre_in_var->IsType<framework::LoDTensor>()) { | ||
// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is here empty? If it is undefined, it shouldn't be allowed? maybe the op should be called SelectedRowsGather?
} | ||
} | ||
|
||
std::string GatherOpHandle::Name() const { return "broadcast"; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gather?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
} | ||
} | ||
|
||
void TestGatherLodTensor() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up this test?
f::details::GatherOpHandle* gather_op_handle_; | ||
}; | ||
|
||
// TEST_F(GatherTester, TestCPUGatherTestLodTensor) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GatherLodTensor
is also necessary when using multi-threads runs the parallel_exe
on CPU but not GPU.
I am developing it.
042821f
to
e768491
Compare
… feature/add_gather_and_BCast_op_handle
11f4ff6
to
f9a983b
Compare
f9a983b
to
e26c6d7
Compare
… feature/add_gather_and_BCast_op_handle
1a20c34
to
8b597d9
Compare
} | ||
|
||
std::vector<int64_t> out_rows; | ||
std::vector<Tensor *> in_tensors; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use std::vector here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
void BroadcastInitOp(int input_scope_idx) { | ||
for (size_t j = 0; j < gpu_list_.size(); ++j) { | ||
local_scope_.push_back(&g_scope_.NewScope()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that local_scope_
is not cleaned for each unit test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct TestFixture {
std::vector<std::unique_ptr<DeviceContext>> contexts_;
std::vector<Scope* > local_scopes_;
Scope global_scope_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandle>> vars_;
};
9058084
to
6964720
Compare
6964720
to
02842cf
Compare
// Wait input done, this Wait is asynchronous operation | ||
auto &in_place = in_var_handle[0]->place_; | ||
if (in_var_handle[0]->generated_op_) { | ||
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[in_place]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[in_place]); | ||
for (auto *out : out_var_handles) { | ||
auto &out_p = out->place_; | ||
if (platform::is_same_place(in_place, out_p)) continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"%s is not in the local_scopes ", out->name_); | ||
|
||
auto *s = local_scopes_[out_scope_idx]; | ||
auto out_var = s->FindVar(out->name_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L83 - L88 can be rewritten by
auto* out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->Name);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
auto *s = local_scopes_[out_scope_idx]; | ||
auto out_var = s->FindVar(out->name_); | ||
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), | ||
"The place of input and output should be the same."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The place of input and output should be the same."
--> "Places must be all on CPU or all on CUDA."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent, except some nitpicking comments.
40baee8
to
384d6ee
Compare
No description provided.