Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPUPS]add support for kunlun2 #40985

Merged
merged 89 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
a9b2b04
refactor heter comm kernel
zmxdream Mar 26, 2022
862b618
update. test=develop
zmxdream Mar 27, 2022
dbf991f
update calc_shard_offset. test=develop
zmxdream Mar 27, 2022
d72995e
update xpu kernel. test=develop
zmxdream Mar 27, 2022
6e56704
update args of calc_shard_offset
zmxdream Mar 27, 2022
7bd6c38
update. test=develop
zmxdream Mar 27, 2022
f9b802f
remove customGradMerger
zmxdream Mar 27, 2022
eb3d0f7
update. test=develop
zmxdream Mar 28, 2022
c5b5f2f
update. test=develop
zmxdream Mar 28, 2022
dc2929e
fix. test=develop
zmxdream Mar 28, 2022
dc4bcb8
update. test=develop
zmxdream Mar 29, 2022
1bb6263
update. test=develop
zmxdream Mar 29, 2022
286e3a5
update optimizer kernel
zmxdream Mar 29, 2022
b4310c2
update. test=develop
zmxdream Mar 30, 2022
c197be5
update. test=develop
zmxdream Mar 30, 2022
4bb5774
update. test=develop
zmxdream Mar 30, 2022
5fd9fc5
update. test=develop
zmxdream Mar 30, 2022
e75a5da
update. test=develop
zmxdream Mar 30, 2022
fd32187
update. test=develop
zmxdream Mar 30, 2022
0948109
update. test=develop
zmxdream Mar 30, 2022
8030fc7
update. test=develop
zmxdream Mar 30, 2022
1241579
fix. test=develop
zmxdream Mar 30, 2022
7f15950
fix. test=develop
WorgenZhang Mar 30, 2022
3e77a11
add optimizer kernel. test=develop
zmxdream Mar 30, 2022
9a80ea5
fix. test=develop
zmxdream Mar 30, 2022
299dd85
Merge branch 'develop' into heter_comm_dev
zmxdream Mar 31, 2022
a5ef42e
fix. test=develop
zmxdream Mar 31, 2022
fe6a197
fix. test=develop
zmxdream Mar 31, 2022
cc35ae6
fix. test=develop
zmxdream Mar 31, 2022
f09061a
fix kunlun not support size_t. test=develop
zmxdream Mar 31, 2022
a3a6003
fix. test=develop
zmxdream Mar 31, 2022
cd736c6
fix. test=develop
zmxdream Mar 31, 2022
b95f50f
fix. test=develop
zmxdream Mar 31, 2022
f3a6abb
fix. test=develop
zmxdream Mar 31, 2022
8cb2de3
fix. test=develop
zmxdream Mar 31, 2022
45eb741
fix. test=develop
zmxdream Mar 31, 2022
b454886
fix. test=develop
zmxdream Mar 31, 2022
60ea5f4
fix. test=develop
zmxdream Mar 31, 2022
49342f6
fix. test=develop
zmxdream Mar 31, 2022
bde6f24
update hashtable. test=develop
zmxdream Mar 31, 2022
46eb54d
update. test=develop
zmxdream Mar 31, 2022
2c1dab9
fix. test=develop
zmxdream Apr 1, 2022
7412b87
fix. test=develop
zmxdream Apr 1, 2022
09b6a6b
fix. test=develop
zmxdream Apr 1, 2022
9151951
fix. test=develop
zmxdream Apr 1, 2022
accf945
fix. test=develop
zmxdream Apr 1, 2022
bc9a92f
fix. test=develop
zmxdream Apr 1, 2022
794d9dd
fix. test=develop
zmxdream Apr 1, 2022
c1c302c
update. test=develop
zmxdream Apr 1, 2022
19a715a
update. test=develop
zmxdream Apr 1, 2022
b16f3ed
fix. test=develop
zmxdream Apr 1, 2022
ae83b29
fix. test=develop
zmxdream Apr 1, 2022
21b5384
fix. test=develop
zmxdream Apr 1, 2022
0eaf06c
fix. test=develop
zmxdream Apr 1, 2022
38c4d43
fix. test=develop
zmxdream Apr 1, 2022
3edf3a8
fix. test=develop
zmxdream Apr 1, 2022
4a9f99c
fix. test=develop
zmxdream Apr 1, 2022
3316c1d
fix. test=develop
zmxdream Apr 1, 2022
2edf154
fix. test=develop
zmxdream Apr 1, 2022
467e8f6
fix. test=develop
zmxdream Apr 1, 2022
082f979
fix. test=develop
zmxdream Apr 1, 2022
5bcc90f
fix. test=develop
zmxdream Apr 1, 2022
d9410ab
fix. test=develop
zmxdream Apr 2, 2022
197c7ed
fix. test=develop
zmxdream Apr 2, 2022
598ae1b
fix. test=develop
zmxdream Apr 2, 2022
dfcc42f
fix. test=develop
zmxdream Apr 3, 2022
757a080
template init. test=develop
zmxdream Apr 6, 2022
503566b
hashtable template init. test=develop
zmxdream Apr 6, 2022
de5ccea
fix. test=develop
zmxdream Apr 6, 2022
301d529
fix. test=devlop
zmxdream Apr 6, 2022
dec4b01
fix. test=develop
zmxdream Apr 6, 2022
ba2f38d
fix. test=develop
zmxdream Apr 6, 2022
03d5d71
fix. test=develop
zmxdream Apr 6, 2022
43101d0
fix. test=develop
zmxdream Apr 6, 2022
a745570
fix. test=develop
zmxdream Apr 7, 2022
79a1ea6
merge develop. test=develop
zmxdream Apr 7, 2022
becfea4
fix. test=develop
zmxdream Apr 7, 2022
039f85a
fix. test=develop
zmxdream Apr 7, 2022
fecaad9
fix. test=develop
zmxdream Apr 7, 2022
fd7c98a
fix. test=develop
zmxdream Apr 10, 2022
102cdd4
fix. test=develop
zmxdream Apr 10, 2022
e99dee1
fix. test=develop
zmxdream Apr 10, 2022
757145d
fix. test=develop
zmxdream Apr 11, 2022
fbba72d
fix. test=develop
zmxdream Apr 11, 2022
d8366cb
fix. test=develop
zmxdream Apr 12, 2022
60336b8
fix. test=develop
zmxdream Apr 12, 2022
453282a
fix. test=develop
zmxdream Apr 12, 2022
8a28d11
fix. test=develop
zmxdream Apr 12, 2022
c16b575
fix. test=develop
zmxdream Apr 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
#include <vector>

#ifdef PADDLE_WITH_PSLIB
#include "common_value.h" // NOLINT
#include "common/common_value.h" // NOLINT
#endif

#ifdef PADDLE_WITH_PSCORE
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ IF(WITH_GPU)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
SET(HETERPS_DEPS ${HETERPS_DEPS} ${RPC_DEPS})
endif()
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_library(heter_comm_kernel SRCS heter_comm_kernel.cu feature_value.h DEPS ${HETERPS_DEPS})
nv_library(hashtable_kernel SRCS hashtable_kernel.cu feature_value.h DEPS ${HETERPS_DEPS})
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h mem_pool.h DEPS ${HETERPS_DEPS} heter_comm_kernel hashtable_kernel)
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
if(WITH_PSCORE)
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,18 @@ struct FeaturePushValue {
float lr_g;
float mf_g[MF_DIM];

__device__ __forceinline__ FeaturePushValue
operator+(const FeaturePushValue& a) const {
FeaturePushValue out;
out.slot = a.slot;
out.show = a.show + show;
out.clk = a.clk + clk;
out.lr_g = a.lr_g + lr_g;
for (int i = 0; i < MF_DIM; ++i) {
out.mf_g[i] = a.mf_g[i] + mf_g[i];
}
return out;
}
// __device__ __forceinline__ FeaturePushValue
// operator+(const FeaturePushValue& a) const {
// FeaturePushValue out;
// out.slot = a.slot;
// out.show = a.show + show;
// out.clk = a.clk + clk;
// out.lr_g = a.lr_g + lr_g;
// for (int i = 0; i < MF_DIM; ++i) {
// out.mf_g[i] = a.mf_g[i] + mf_g[i];
// }
// return out;
// }
};

} // end namespace framework
Expand Down
98 changes: 82 additions & 16 deletions paddle/fluid/framework/fleet/heter_ps/hashtable.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,38 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <glog/logging.h>
#include <limits>
#include <memory>
#include <vector>

#ifdef PADDLE_WITH_PSLIB
#include "common_value.h" // NOLINT
#endif
#ifdef PADDLE_WITH_PSCORE

#if defined(PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#endif
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/phi/core/utils/rw_lock.h"
#include "thrust/pair.h"
// #include "cudf/concurrent_unordered_map.cuh.h"

#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "thrust/pair.h"
#elif defined(__xpu__)
#include <xpu/runtime.h>
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd.h"
#endif

namespace paddle {
namespace framework {

#if defined(PADDLE_WITH_CUDA)
template <typename KeyType, typename ValType>
class TableContainer
: public concurrent_unordered_map<KeyType, ValType,
Expand All @@ -45,31 +55,84 @@ class TableContainer
std::numeric_limits<KeyType>::max()>(
capacity, ValType()) {}
};
#elif defined(PADDLE_WITH_XPU_KP)

template <typename KeyType, typename ValType>
class XPUCacheArray {
public:
explicit XPUCacheArray(size_t capacity) : capacity_(capacity), size_(0) {
xpu_malloc(reinterpret_cast<void**>(&keys), capacity_ * sizeof(KeyType));
xpu_malloc(reinterpret_cast<void**>(&vals), capacity_ * sizeof(ValType));
}

virtual ~XPUCacheArray() {
xpu_free(keys);
xpu_free(vals);
}

void print() {}
// ValType* find(const KeyType& key) { return NULL; }
// bool insert(const KeyType& key, const ValType& val) { return true; }

int prefetch(const int dev_id, XPUStream stream = NULL) {}
size_t size() { return size_; }

private:
long long capacity_;
long long size_;
KeyType* keys;
ValType* vals;
};
#endif

template <typename KeyType, typename ValType>
class HashTable {
public:
HashTable(size_t capacity);
explicit HashTable(size_t capacity);
virtual ~HashTable();
HashTable(const HashTable&) = delete;
HashTable& operator=(const HashTable&) = delete;

template <typename StreamType>
void insert(const KeyType* d_keys, const ValType* d_vals, size_t len,
gpuStream_t stream);
StreamType stream);

template <typename StreamType>
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index,
gpuStream_t stream);
StreamType stream);

template <typename StreamType>
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
gpuStream_t stream);
void get(const KeyType* d_keys, char* d_vals, size_t len, gpuStream_t stream);
StreamType stream);

template <typename StreamType>
void get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream);

void show();
void dump_to_cpu(int devid, cudaStream_t stream);

template <typename GradType, typename Sgd>
template <typename StreamType>
void dump_to_cpu(int devid, StreamType stream);

#if defined(PADDLE_WITH_CUDA)

template <typename GradType, typename Sgd, typename StreamType>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
Sgd sgd, gpuStream_t stream);
Sgd sgd, StreamType stream);

template <typename Sgd>
template <typename Sgd, typename StreamType>
void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd,
gpuStream_t stream);
StreamType stream);

#elif defined(PADDLE_WITH_XPU_KP)
template <typename GradType, typename StreamType>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
StreamType stream);

template <typename StreamType>
void update(const KeyType* d_keys, const char* d_grads, size_t len,
StreamType stream);

#endif

int size() { return container_->size(); }

Expand All @@ -84,7 +147,11 @@ class HashTable {
std::unique_ptr<phi::RWLock> rwlock_{nullptr};

private:
#if defined(PADDLE_WITH_CUDA)
TableContainer<KeyType, ValType>* container_;
#elif defined(PADDLE_WITH_XPU_KP)
XPUCacheArray<KeyType, ValType>* container_;
#endif
int BLOCK_SIZE_{256};
float LOAD_FACTOR{0.75f};
size_t capacity_;
Expand All @@ -94,5 +161,4 @@ class HashTable {
};
} // end namespace framework
} // end namespace paddle
#include "hashtable_inl.h"
#endif
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_HETERPS
#include <thread>
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"

namespace paddle {
namespace framework {

#if defined(PADDLE_WITH_CUDA)

template <typename value_type>
struct ReplaceOp {
__host__ __device__ value_type operator()(value_type new_value,
Expand Down Expand Up @@ -87,6 +92,7 @@ __global__ void dy_mf_search_kernel(Table* table,
}
}
}

template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
const typename Table::key_type* const keys,
Expand Down Expand Up @@ -135,8 +141,9 @@ void HashTable<KeyType, ValType>::show() {
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
size_t len, gpuStream_t stream) {
size_t len, StreamType stream) {
if (len == 0) {
return;
}
Expand All @@ -146,8 +153,9 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
size_t len, gpuStream_t stream) {
size_t len, StreamType stream) {
if (len == 0) {
return;
}
Expand All @@ -157,9 +165,10 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
const ValType* d_vals, size_t len,
gpuStream_t stream) {
StreamType stream) {
if (len == 0) {
return;
}
Expand All @@ -169,22 +178,24 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index,
gpuStream_t stream) {
StreamType stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
if (pool == NULL) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
pool, start_index);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
container_->prefetch(cudaCpuDeviceId, stream);
std::vector<std::thread> threads;
size_t num = container_->size();
Expand Down Expand Up @@ -260,10 +271,10 @@ void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
}

template <typename KeyType, typename ValType>
template <typename GradType, typename Sgd>
template <typename GradType, typename Sgd, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const GradType* d_grads, size_t len,
Sgd sgd, gpuStream_t stream) {
Sgd sgd, StreamType stream) {
if (len == 0) {
return;
}
Expand All @@ -273,19 +284,66 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
}

template <typename KeyType, typename ValType>
template <typename Sgd>
template <typename Sgd, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const char* d_grads, size_t len,
Sgd sgd, gpuStream_t stream) {
Sgd sgd, StreamType stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;

dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
}

template class HashTable<unsigned long, paddle::framework::FeatureValue>;

template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
cudaStream_t>(const unsigned long* d_keys,
paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);

// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
// const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
// stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
cudaStream_t>(const unsigned long* d_keys,
const paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
// cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
// size_t start_index, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
paddle::framework::FeaturePushValue,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>,
cudaStream_t>(const unsigned long* d_keys,
const paddle::framework::FeaturePushValue* d_grads,
size_t len, Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>
sgd,
cudaStream_t stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
// Optimizer<paddle::framework::FeatureValue,
// paddle::framework::FeaturePushValue>,
// cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
// len,
// Optimizer<paddle::framework::FeatureValue,
// paddle::framework::FeaturePushValue>
// sgd,
// cudaStream_t stream);

#endif
} // end namespace framework
} // end namespace paddle
#endif
Loading