Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

common: improve server load balancing #116

Merged
merged 3 commits into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ uint32_t BytePSGlobal::_partition_bytes = 4096000;
std::shared_ptr<BytePSComm> BytePSGlobal::_basic_comm;
std::shared_ptr<BytePSSharedMemory> BytePSGlobal::_shm_obj;
std::unordered_map<uint64_t, PSKV> BytePSGlobal::ps_kv_;
std::hash<ps::Key> BytePSGlobal::_hash_fn;
std::vector<unsigned long> BytePSGlobal::_server_accumulated_len;
bool BytePSGlobal::_use_hash;

volatile BytePSScheduledQueue* BytePSGlobal::_queues[QueueNum] = {NULL};
std::mutex BytePSGlobal::_queues_mutex[QueueNum];
Expand Down Expand Up @@ -111,6 +114,9 @@ void BytePSGlobal::Init() {
if (_is_distributed_job) {
BPS_CHECK(getenv("DMLC_NUM_SERVER"))
<< "error: launch distributed job, but env DMLC_NUM_SERVER not set";
_use_hash = getenv("BYTEPS_USE_HASH_KEY") ? atoi(getenv("BYTEPS_USE_HASH_KEY")) : false;
int num_server = atoi(getenv("DMLC_NUM_SERVER"));
for (int i = 0; i < num_server; ++i) _server_accumulated_len.push_back(0);
}

BPS_LOG(DEBUG) << "Number of worker=" << _num_worker << ", launching "
Expand Down Expand Up @@ -307,8 +313,16 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) {
const int num_servers = krs.size();
BPS_CHECK_GT(num_servers, 0);
// send it to a single random picked server
int server = (((key >> 16) + key) * 9973) % num_servers;
BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server;
int server;
if (_use_hash) {
server = _hash_fn(key) % num_servers;
} else {
server = (((key >> 16) + (key % 65536)) * 9973) % num_servers;
}
_server_accumulated_len[server] += len;
BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server
<< ", accumulated workload for this server is "
<< _server_accumulated_len[server];
ps::Key ps_key = krs[server].begin() + key;
BPS_CHECK_LT(ps_key, krs[server].end());
pskv.keys.push_back(ps_key);
Expand Down
3 changes: 3 additions & 0 deletions byteps/common/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class BytePSGlobal {
static BPSContext& GetContextFromName(const std::string& name);
static uint32_t GetTensorCount();

static bool _use_hash;
static std::hash<ps::Key> _hash_fn;
static std::vector<unsigned long> _server_accumulated_len;
static std::unordered_map<uint64_t, PSKV> ps_kv_;
static PSKV& EncodeDefaultKey(uint64_t key, size_t len);

Expand Down