Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Aug 13, 2024
1 parent d815238 commit 4e89a42
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
'c_reduce_sum',
'c_softmax_with_cross_entropy',
'c_split',
'comm_init_all',
'decayed_adagrad',
'distributed_fused_lamb',
'distributed_fused_lamb_',
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"full",
"partial_send",
"push_dense",
"comm_init_all",
]

# prim op with one input and one output, with no attribute
Expand Down
219 changes: 219 additions & 0 deletions paddle/phi/core/platform/device/xpu/bkcl_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef _WIN32
#if defined(PADDLE_WITH_XPU_BKCL)
#pragma once

#include <stdio.h>

#include <memory>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <unordered_map>
#include <vector>

#include "glog/logging.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/platform/collective_helper.h"
#include "paddle/phi/core/platform/device_context.h"
#include "xpu/bkcl.h"
#include "xpu/runtime.h"

#define BKCL_ID_VARNAME "BKCLID"

namespace paddle {
namespace platform {

inline int GetBKCLRankID(BKCLContext_t comm) {
return reinterpret_cast<int *>(comm)[0];
}

inline int GetBKCLDevID(BKCLContext_t comm) {
return reinterpret_cast<int *>(comm)[1];
}

inline int GetBKCLNRanks(BKCLContext_t comm) {
return reinterpret_cast<int *>(comm)[2];
}

class BKCLGroupGuard {
public:
static std::mutex &BKCLMutex() {
static std::mutex mtx;
return mtx;
}

inline BKCLGroupGuard() {
BKCLMutex().lock();
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
}

inline ~BKCLGroupGuard() PADDLE_MAY_THROW {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end());
BKCLMutex().unlock();
}
};

struct BKCLContext {
std::unique_ptr<phi::XPUContext> ctx_;
BKCLContext_t comm_;

explicit BKCLContext(int dev_id)
: ctx_(new phi::XPUContext(phi::XPUPlace(dev_id))), comm_{nullptr} {}

XPUStream stream() const { return ctx_->stream(); }
BKCLContext_t comm() const { return comm_; }

int device_id() const { return ctx_->GetPlace().device; }
};

struct InitBKCLPara {
BKCLUniqueId *bkcl_id;
int rank;
int nranks;
int dev_id;
BKCLContext_t *ctx;
};

static void *init_bkcl_context_func(void *args) {
struct InitBKCLPara *para = (struct InitBKCLPara *)args;
platform::SetXPUDeviceId(para->dev_id);
PADDLE_ENFORCE_XPU_SUCCESS(
bkcl_init_rank(para->ctx, para->rank, para->nranks, para->bkcl_id));
return nullptr;
}

struct BKCLContextMap {
std::unordered_map<int, BKCLContext> contexts_;
std::vector<int> order_;
std::vector<phi::Place> places_;
size_t num_trainers_;
size_t trainer_id_;
BKCLUniqueId *bkcl_id_;

explicit BKCLContextMap(const std::vector<phi::Place> &places,
BKCLUniqueId *bkcl_id = nullptr,
size_t num_trainers = 1,
size_t trainer_id = 0) {
places_ = places;
bkcl_id_ = bkcl_id;
num_trainers_ = num_trainers;
trainer_id_ = trainer_id;
}

// Synchronization is required and can only be initialized with
// multithreading.
int init() {
PADDLE_ENFORCE_EQ(
!places_.empty(),
true,
common::errors::InvalidArgument("The BKCL place should not be empty."));
order_.reserve(places_.size());
for (auto &p : places_) {
int dev_id = p.device;
order_.emplace_back(dev_id);
contexts_.emplace(dev_id, BKCLContext(dev_id));
}
PADDLE_ENFORCE_EQ(
order_.size(),
contexts_.size(),
common::errors::Unavailable("BKCL Context Map does not support "
"contain two or more same device"));

std::unique_ptr<BKCLContext_t[]> comms(new BKCLContext_t[order_.size()]);
std::unique_ptr<InitBKCLPara[]> paras(new InitBKCLPara[order_.size()]);
std::unique_ptr<pthread_t[]> pids(new pthread_t[order_.size()]);
BKCLResult_t ret;
BKCLUniqueId id;
// if num_trainers == 1, should create a new bkcl id for local comms.
if (num_trainers_ == 1 && bkcl_id_ == nullptr) {
ret = bkcl_get_unique_id(&id);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS,
ret,
common::errors::PreconditionNotMet(
"bkcl get unique id failed [%d]", ret));
bkcl_id_ = &id;
}
PADDLE_ENFORCE_NOT_NULL(
bkcl_id_,
common::errors::InvalidArgument("The BKCL id should not be null."));
{
int nranks = num_trainers_ * order_.size();
for (size_t i = 0; i < order_.size(); ++i) {
int rank;
if (order_.size() > 1) {
rank = trainer_id_ * order_.size() + i;
} else {
rank = trainer_id_;
}
VLOG(1) << "init bkcl rank:" << rank << ", nranks:" << nranks
<< ", xpu_id:" << order_[i];
paras[i].rank = rank;
paras[i].nranks = nranks;
paras[i].dev_id = order_[i];
paras[i].bkcl_id = bkcl_id_;
paras[i].ctx = &comms[i];
PADDLE_ENFORCE_EQ(pthread_create(&pids[i],
nullptr,
init_bkcl_context_func,
reinterpret_cast<void *>(&paras[i])),
0,
common::errors::External("pthread_create failed"));
}
for (size_t i = 0; i < order_.size(); i++) {
pthread_join(pids[i], nullptr);
}
}
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
return 0;
}

BKCLContextMap(const BKCLContextMap &other) = delete;
BKCLContextMap &operator=(const BKCLContextMap &other) = delete;

phi::XPUContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }

phi::XPUContext *DevCtx(phi::Place p) const { return DevCtx(p.device); }

const BKCLContext &at(phi::Place p) const { return this->at(p.device); }

const BKCLContext &at(int dev_id) const { return contexts_.at(dev_id); }

void WaitAll() {
for (auto &p : contexts_) {
p.second.ctx_->Wait();
}
}
};

inline std::string GetFlatBKCLVarName(size_t pos) {
if (pos == 0) {
return BKCL_ID_VARNAME;
}
return string::Sprintf("%s_%d", BKCL_ID_VARNAME, static_cast<int>(pos));
}

} // namespace platform
} // namespace paddle

#endif // PADDLE_WITH_XPU_BKCL
#endif
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ void AssignValueInferMeta(const std::vector<int>& shape,
out->set_dtype(dtype);
}

void CommInitAllInferMeta(const std::vector<int>& devices, int ring_id) {}

void CreateArrayInferMeta(DataType dtype, MetaTensor* out) {
out->set_dtype(dtype);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ void AssignValueInferMeta(const std::vector<int>& shape,
DataType dtype,
MetaTensor* out);

void CommInitAllInferMeta(const std::vector<int>& devices, int ring_id);

void CreateVecShapeInferMeta(const std::vector<int64_t>& shape,
DataType dtype,
MetaTensor* out);
Expand Down
42 changes: 42 additions & 0 deletions paddle/phi/kernels/gpu/comm_init_all_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/platform/collective_helper.h"
#endif

namespace phi {

template <typename T, typename Context>
void CommInitAllKernel(const Context& dev_ctx,
const std::vector<int>& devices_input,
int ring_id) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
std::vector<int> devices = devices_input;
if (devices.empty()) {
devices = phi::backends::gpu::GetSelectedDevices();
}

paddle::platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices,
ring_id);
#endif
}

} // namespace phi

PD_REGISTER_KERNEL(
comm_init_all, GPU, ALL_LAYOUT, phi::CommInitAllKernel, float) {}
76 changes: 76 additions & 0 deletions paddle/phi/kernels/xpu/comm_init_all_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/phi/core/platform/collective_helper.h"
#include "paddle/phi/core/platform/device/xpu/bkcl_helper.h"
#endif

namespace phi {

template <typename T, typename Context>
void CommInitAllKernel(const Context& dev_ctx,
const std::vector<int>& devices_input,
int ring_id) {
#if defined(PADDLE_WITH_XPU_BKCL)
std::vector<int> devices = devices_input;

if (devices.empty()) {
int count = phi::backends::xpu::GetXPUDeviceCount();
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
}

if (devices.size() > 1) {
std::vector<phi::Place> place_list_;
for (size_t i = 0; i < devices.size(); ++i) {
auto p = phi::XPUPlace(devices[i]);
place_list_.push_back(p);
}

// create pthread to bkcl_init_rank on all devices
auto ptr = new paddle::platform::BKCLContextMap(place_list_);
ptr->init();

for (size_t i = 0; i < devices.size(); ++i) {
paddle::platform::BKCLCommContext::Instance().AssignBKCLComm(
ptr->contexts_.at(devices[i]).comm_,
devices.size(),
devices[i],
devices[i],
ring_id);

VLOG(0) << "bkcl communicator of rank " << devices[i] << " in ring "
<< ring_id << " has been created on device " << devices[i];

// TODO(WorgenZhang): need release comm_map_ when quit
// std::call_once(once_flag_, []() {
// std::atexit([]() {
// platform::BKCLCommContext::Instance().ReleaseBKCLComms(); });
// });
}

VLOG(0) << "done bkcl_init_rank on all devices";
}
#endif
}

} // namespace phi

PD_REGISTER_KERNEL(
comm_init_all, XPU, ALL_LAYOUT, phi::CommInitAllKernel, float) {}
10 changes: 10 additions & 0 deletions paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@
data_type : dtype
inplace: (input -> output)

- op : comm_init_all
args : (int[] devices={}, int ring_id=0)
output :
infer_meta :
func : CommInitAllInferMeta
param : [devices, ring_id]
kernel :
func : comm_init_all
data_type : DataType::FLOAT32

- op : dequantize_linear
args : (Tensor x, Tensor scale, Tensor zero_point, Tensor in_accum, Tensor in_state, int quant_axis = 0, int bit_length = 8, int qmin = -128, int qmax = 127, int round_type = 0, bool is_test = true, bool only_observer = false)
output : Tensor(y), Tensor(out_state), Tensor(out_accum), Tensor(out_scale)
Expand Down
Loading

0 comments on commit 4e89a42

Please sign in to comment.