Skip to content

Commit

Permalink
nullptr bugfix for XPU pg mode (#49043)
Browse files Browse the repository at this point in the history
* nullptr bugfix for XPU pg mode

Also a few kernels is added to xpu whitelist

* increase error msg length
  • Loading branch information
XiaociZhang authored Dec 14, 2022
1 parent f2a8dd5 commit f0dab19
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
25 changes: 24 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroupBKCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place);

if (!calc_event_) {
if (!calc_event_ ||
(place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end())) {
CreateBKCLEnvCache(place, key);
}

Expand All @@ -170,6 +171,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream);

if (!use_calc_stream) {
PADDLE_ENFORCE_NOT_NULL(
comm_ctx.get(), platform::errors::Fatal("comm context is nullptr."));
task->comm_event_->Record(*comm_ctx.get());
}

Expand Down Expand Up @@ -369,6 +372,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
Expand Down Expand Up @@ -406,6 +413,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
Expand Down Expand Up @@ -442,6 +453,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));

return Collective(
&out_tensors[0],
Expand Down Expand Up @@ -481,6 +496,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));

return Collective(
&out_tensors[0],
Expand Down Expand Up @@ -518,6 +537,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(out_tensors),
true,
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ XPUOpMap& get_kl2_ops() {
{"abs", XPUKernelSet({phi::DataType::FLOAT32})},
{"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
{"adadelta", XPUKernelSet({phi::DataType::FLOAT32})},
{"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -402,6 +403,13 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"reshape_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"resnet_unit",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"resnet_unit_grad",
Expand Down Expand Up @@ -485,6 +493,14 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ PD_REGISTER_KERNEL(multiply,
ALL_LAYOUT,
phi::MultiplyKernel,
phi::dtype::float16,
float) {}
float,
int,
int64_t) {}
PD_REGISTER_KERNEL(subtract,
XPU,
ALL_LAYOUT,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ PD_REGISTER_KERNEL(multiply_raw,
ALL_LAYOUT,
phi::MultiplyRawKernel,
phi::dtype::float16,
float) {}
float,
int,
int64_t) {}

0 comments on commit f0dab19

Please sign in to comment.