diff --git a/byteps/common/global.cc b/byteps/common/global.cc index 2fba86fad..3e5fe15ff 100644 --- a/byteps/common/global.cc +++ b/byteps/common/global.cc @@ -41,6 +41,9 @@ uint32_t BytePSGlobal::_partition_bytes = 4096000; std::shared_ptr BytePSGlobal::_basic_comm; std::shared_ptr BytePSGlobal::_shm_obj; std::unordered_map BytePSGlobal::ps_kv_; +std::hash BytePSGlobal::_hash_fn; +std::vector BytePSGlobal::_server_accumulated_len; +bool BytePSGlobal::_use_hash; volatile BytePSScheduledQueue* BytePSGlobal::_queues[QueueNum] = {NULL}; std::mutex BytePSGlobal::_queues_mutex[QueueNum]; @@ -111,6 +114,9 @@ void BytePSGlobal::Init() { if (_is_distributed_job) { BPS_CHECK(getenv("DMLC_NUM_SERVER")) << "error: launch distributed job, but env DMLC_NUM_SERVER not set"; + _use_hash = getenv("BYTEPS_USE_HASH_KEY") ? atoi(getenv("BYTEPS_USE_HASH_KEY")) : false; + int num_server = atoi(getenv("DMLC_NUM_SERVER")); + for (int i = 0; i < num_server; ++i) _server_accumulated_len.push_back(0); } BPS_LOG(DEBUG) << "Number of worker=" << _num_worker << ", launching " @@ -307,8 +313,16 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) { const int num_servers = krs.size(); BPS_CHECK_GT(num_servers, 0); // send it to a single random picked server - int server = (((key >> 16) + key) * 9973) % num_servers; - BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server; + int server; + if (_use_hash) { + server = _hash_fn(key) % num_servers; + } else { + server = (((key >> 16) + (key % 65536)) * 9973) % num_servers; + } + _server_accumulated_len[server] += len; + BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server + << ", accumulated workload for this server is " + << _server_accumulated_len[server]; ps::Key ps_key = krs[server].begin() + key; BPS_CHECK_LT(ps_key, krs[server].end()); pskv.keys.push_back(ps_key); diff --git a/byteps/common/global.h b/byteps/common/global.h index 2f55b325e..4d9a87b2a 100644 --- a/byteps/common/global.h +++ b/byteps/common/global.h @@ -84,6 +84,9 @@ class BytePSGlobal { static BPSContext& GetContextFromName(const std::string& name); static uint32_t GetTensorCount(); + static bool _use_hash; + static std::hash _hash_fn; + static std::vector _server_accumulated_len; static std::unordered_map ps_kv_; static PSKV& EncodeDefaultKey(uint64_t key, size_t len);