From 9f8e5a74a7b784ca43eb62a67335fd856e68aa35 Mon Sep 17 00:00:00 2001 From: zmxdream Date: Sat, 9 Jul 2022 21:38:02 +0800 Subject: [PATCH] [GPUPS]fix feasign = 0 in dy_mf_fill_dvals (#40) * fix feasign = 0 in fill_dvals * Update hashtable_inl.h * Update heter_comm_inl.h --- .../framework/fleet/heter_ps/hashtable_inl.h | 11 +++++++--- .../framework/fleet/heter_ps/heter_comm_inl.h | 21 ++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h index a196cdd12375b..48a744a36c0b3 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h @@ -101,8 +101,8 @@ __global__ void dy_mf_search_kernel(Table* table, *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4); } else { - int len_per_thread = (len - 9) / (blockDim.y - 9); - int remain = (len - 9) % (blockDim.y - 9); + int len_per_thread = (len - 9) / (blockDim.x - 9); + int remain = (len - 9) % (blockDim.x - 9); int real_len = len_per_thread; if ((k - 9) < remain) real_len++; int left = -1, right = -1; @@ -116,7 +116,12 @@ __global__ void dy_mf_search_kernel(Table* table, for(int j = left; j < right; j++) *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4); } } else { - if (keys[i] != 0) printf("pull miss key: %llu",keys[i]); + if (keys[i] != 0 && k == 0) printf("pull miss key: %llu",keys[i]); + if (keys[i] == 0 && k == 0) { + uint64_t offset = i * pull_feature_value_size; + FeatureValue* cur = (FeatureValue*)(vals + offset); + cur->mf_dim = 0; + } } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index aa2b5adca2aca..212970fb08cf6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -119,8 +119,8 @@ __global__ void dy_mf_fill_shard_grads(Featur if (k == 2 || k == 4) *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4); else if (k < 5) *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4); else { - int len_per_thread = (len - 5) / (blockDim.y - 5); - int remain = (len - 5) % (blockDim.y - 5); + int len_per_thread = (len - 5) / (blockDim.x - 5); + int remain = (len - 5) % (blockDim.x - 5); int real_len = len_per_thread; if ((k - 5) < remain) real_len++; int left = -1, right = -1; @@ -224,15 +224,15 @@ __global__ void dy_mf_fill_dvals(FeatureValue* d_shard_vals, FeatureValue& input = *(FeatureValue*)((char*)d_shard_vals + i * val_size); char* cur_p = (char*)cur; char* input_p = (char*)(&input); - int len = 9 + input.mf_dim + 1; - - if (k == 3 || k == 6 || k == 7) *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4); - else if (k < 8) *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4); - else if (k == 8) { - *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4); - } else { + if (input.mf_dim != 0) { // for feasign 0, mf_dim = 0 + int len = 9 + input.mf_dim + 1; + if (k == 3 || k == 6 || k == 7) *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4); + else if (k < 8) *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4); + else if (k == 8) { + *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4); + } else { int len_per_thread = (len - 9) / (blockDim.x - 9); - int remain = (len - 9) % (blockDim.y - 9); + int remain = (len - 9) % (blockDim.x - 9); int real_len = len_per_thread; if ((k - 9) < remain) real_len++; int left = -1, right = -1; @@ -244,6 +244,7 @@ __global__ void dy_mf_fill_dvals(FeatureValue* d_shard_vals, right = left + real_len; } for(int j = left; j < right; j++) *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4); + } } } }