From feca0b1664b4bfaf238ac8c525aa3a3ab80c5ea2 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 23 Sep 2019 18:20:59 +0800 Subject: [PATCH 1/3] common: improve server load balancing --- byteps/common/global.cc | 11 +++++++++-- byteps/common/global.h | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/byteps/common/global.cc b/byteps/common/global.cc index 2fba86fad..fd45e7bc1 100644 --- a/byteps/common/global.cc +++ b/byteps/common/global.cc @@ -41,6 +41,8 @@ uint32_t BytePSGlobal::_partition_bytes = 4096000; std::shared_ptr BytePSGlobal::_basic_comm; std::shared_ptr BytePSGlobal::_shm_obj; std::unordered_map BytePSGlobal::ps_kv_; +std::hash BytePSGlobal::key_hash_; +std::vector BytePSGlobal::server_accumulated_len_; volatile BytePSScheduledQueue* BytePSGlobal::_queues[QueueNum] = {NULL}; std::mutex BytePSGlobal::_queues_mutex[QueueNum]; @@ -111,6 +113,8 @@ void BytePSGlobal::Init() { if (_is_distributed_job) { BPS_CHECK(getenv("DMLC_NUM_SERVER")) << "error: launch distributed job, but env DMLC_NUM_SERVER not set"; + int num_server = atoi(getenv("DMLC_NUM_SERVER")); + for (int i = 0; i < num_server; ++i) server_accumulated_len_.push_back(0); } BPS_LOG(DEBUG) << "Number of worker=" << _num_worker << ", launching " @@ -307,13 +311,16 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) { const int num_servers = krs.size(); BPS_CHECK_GT(num_servers, 0); // send it to a single random picked server - int server = (((key >> 16) + key) * 9973) % num_servers; - BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server; + int server = key_hash_(key) % num_servers; + server_accumulated_len_[server] += len; + BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server + << ", current workload for this server is " << server_accumulated_len_[server]; ps::Key ps_key = krs[server].begin() + key; BPS_CHECK_LT(ps_key, krs[server].end()); pskv.keys.push_back(ps_key); pskv.lens.push_back(len); pskv.size = len; + } BPS_LOG(TRACE) << "key " << key << " is encoded to " << pskv.keys[0]; return pskv; diff --git a/byteps/common/global.h b/byteps/common/global.h index 2f55b325e..263b3f39a 100644 --- a/byteps/common/global.h +++ b/byteps/common/global.h @@ -84,6 +84,8 @@ class BytePSGlobal { static BPSContext& GetContextFromName(const std::string& name); static uint32_t GetTensorCount(); + static std::hash key_hash_; + static std::vector server_accumulated_len_; static std::unordered_map ps_kv_; static PSKV& EncodeDefaultKey(uint64_t key, size_t len); From c0eb146f9b1259449e037abf9c48c4ac6bfec4c1 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 23 Sep 2019 21:55:21 +0800 Subject: [PATCH 2/3] add hash as an option --- byteps/common/global.cc | 21 ++++++++++++++------- byteps/common/global.h | 5 +++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/byteps/common/global.cc b/byteps/common/global.cc index fd45e7bc1..3fa1267ae 100644 --- a/byteps/common/global.cc +++ b/byteps/common/global.cc @@ -41,8 +41,9 @@ uint32_t BytePSGlobal::_partition_bytes = 4096000; std::shared_ptr BytePSGlobal::_basic_comm; std::shared_ptr BytePSGlobal::_shm_obj; std::unordered_map BytePSGlobal::ps_kv_; -std::hash BytePSGlobal::key_hash_; -std::vector BytePSGlobal::server_accumulated_len_; +std::hash BytePSGlobal::_hash_fn; +std::vector BytePSGlobal::_server_accumulated_len; +bool BytePSGlobal::_use_hash; volatile BytePSScheduledQueue* BytePSGlobal::_queues[QueueNum] = {NULL}; std::mutex BytePSGlobal::_queues_mutex[QueueNum]; @@ -113,8 +114,9 @@ void BytePSGlobal::Init() { if (_is_distributed_job) { BPS_CHECK(getenv("DMLC_NUM_SERVER")) << "error: launch distributed job, but env DMLC_NUM_SERVER not set"; + _use_hash = getenv("BYTEPS_USE_HASH_KEY") ? atoi(getenv("BYTEPS_USE_HASH_KEY")) : false; int num_server = atoi(getenv("DMLC_NUM_SERVER")); - for (int i = 0; i < num_server; ++i) server_accumulated_len_.push_back(0); + for (int i = 0; i < num_server; ++i) _server_accumulated_len.push_back(0); } BPS_LOG(DEBUG) << "Number of worker=" << _num_worker << ", launching " @@ -311,16 +313,21 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) { const int num_servers = krs.size(); BPS_CHECK_GT(num_servers, 0); // send it to a single random picked server - int server = key_hash_(key) % num_servers; - server_accumulated_len_[server] += len; + int server; + if (_use_hash) { + server = _hash_fn(key) % num_servers; + } else { + server = ((key >> 16) + (key % 65536)) % num_servers; + } + _server_accumulated_len[server] += len; BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server - << ", current workload for this server is " << server_accumulated_len_[server]; + << ", accumulated workload for this server is " + << _server_accumulated_len[server]; ps::Key ps_key = krs[server].begin() + key; BPS_CHECK_LT(ps_key, krs[server].end()); pskv.keys.push_back(ps_key); pskv.lens.push_back(len); pskv.size = len; - } BPS_LOG(TRACE) << "key " << key << " is encoded to " << pskv.keys[0]; return pskv; diff --git a/byteps/common/global.h b/byteps/common/global.h index 263b3f39a..4d9a87b2a 100644 --- a/byteps/common/global.h +++ b/byteps/common/global.h @@ -84,8 +84,9 @@ class BytePSGlobal { static BPSContext& GetContextFromName(const std::string& name); static uint32_t GetTensorCount(); - static std::hash key_hash_; - static std::vector server_accumulated_len_; + static bool _use_hash; + static std::hash _hash_fn; + static std::vector _server_accumulated_len; static std::unordered_map ps_kv_; static PSKV& EncodeDefaultKey(uint64_t key, size_t len); From 9ef0313d8f8d12659bd2b14dfdad760dcb3307b1 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 23 Sep 2019 21:58:02 +0800 Subject: [PATCH 3/3] a bit improvement --- byteps/common/global.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/byteps/common/global.cc b/byteps/common/global.cc index 3fa1267ae..3e5fe15ff 100644 --- a/byteps/common/global.cc +++ b/byteps/common/global.cc @@ -317,7 +317,7 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) { if (_use_hash) { server = _hash_fn(key) % num_servers; } else { - server = ((key >> 16) + (key % 65536)) % num_servers; + server = (((key >> 16) + (key % 65536)) * 9973) % num_servers; } _server_accumulated_len[server] += len; BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server