Skip to content

Commit

Permalink
[GPUPS]fix feasign = 0 in dy_mf_fill_dvals (PaddlePaddle#40)
Browse files Browse the repository at this point in the history
* fix feasign = 0 in fill_dvals

* Update hashtable_inl.h

* Update heter_comm_inl.h
  • Loading branch information
zmxdream authored Jul 9, 2022
1 parent 7d7fb8e commit 9f8e5a7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
11 changes: 8 additions & 3 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ __global__ void dy_mf_search_kernel(Table* table,
*(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
}
else {
int len_per_thread = (len - 9) / (blockDim.y - 9);
int remain = (len - 9) % (blockDim.y - 9);
int len_per_thread = (len - 9) / (blockDim.x - 9);
int remain = (len - 9) % (blockDim.x - 9);
int real_len = len_per_thread;
if ((k - 9) < remain) real_len++;
int left = -1, right = -1;
Expand All @@ -116,7 +116,12 @@ __global__ void dy_mf_search_kernel(Table* table,
for(int j = left; j < right; j++) *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4);
}
} else {
if (keys[i] != 0) printf("pull miss key: %llu",keys[i]);
if (keys[i] != 0 && k == 0) printf("pull miss key: %llu",keys[i]);
if (keys[i] == 0 && k == 0) {
uint64_t offset = i * pull_feature_value_size;
FeatureValue* cur = (FeatureValue*)(vals + offset);
cur->mf_dim = 0;
}
}
}
}
Expand Down
21 changes: 11 additions & 10 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ __global__ void dy_mf_fill_shard_grads<FeatureKey, FeaturePushValue, int>(Featur
if (k == 2 || k == 4) *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
else if (k < 5) *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
else {
int len_per_thread = (len - 5) / (blockDim.y - 5);
int remain = (len - 5) % (blockDim.y - 5);
int len_per_thread = (len - 5) / (blockDim.x - 5);
int remain = (len - 5) % (blockDim.x - 5);
int real_len = len_per_thread;
if ((k - 5) < remain) real_len++;
int left = -1, right = -1;
Expand Down Expand Up @@ -224,15 +224,15 @@ __global__ void dy_mf_fill_dvals<FeatureValue, int>(FeatureValue* d_shard_vals,
FeatureValue& input = *(FeatureValue*)((char*)d_shard_vals + i * val_size);
char* cur_p = (char*)cur;
char* input_p = (char*)(&input);
int len = 9 + input.mf_dim + 1;

if (k == 3 || k == 6 || k == 7) *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
else if (k < 8) *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
else if (k == 8) {
*(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
} else {
if (input.mf_dim != 0) { // for feasign 0, mf_dim = 0
int len = 9 + input.mf_dim + 1;
if (k == 3 || k == 6 || k == 7) *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
else if (k < 8) *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
else if (k == 8) {
*(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
} else {
int len_per_thread = (len - 9) / (blockDim.x - 9);
int remain = (len - 9) % (blockDim.y - 9);
int remain = (len - 9) % (blockDim.x - 9);
int real_len = len_per_thread;
if ((k - 9) < remain) real_len++;
int left = -1, right = -1;
Expand All @@ -244,6 +244,7 @@ __global__ void dy_mf_fill_dvals<FeatureValue, int>(FeatureValue* d_shard_vals,
right = left + real_len;
}
for(int j = left; j < right; j++) *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4);
}
}
}
}
Expand Down

0 comments on commit 9f8e5a7

Please sign in to comment.