-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize where_index_op(prefix sum) #30601
Optimize where_index_op(prefix sum) #30601
Conversation
update Paddle to newest version
Merge newest Paddle code
merge newest Paddle code
Thanks for your contribution! |
Sorry to inform you that cde9c19's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
…f imperfect self-kernel, test=develop
… optimize-whereindexop-prefix
Sorry to inform you that 86f20e9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
… optimize-whereindexop-prefix
__host__ __device__ bool operator()(const T &val) { | ||
return static_cast<bool>(val); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数之间加个空行。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!所有函数间均已加空行
for (int64_t i = 0; i < numel; i++) { | ||
if (static_cast<bool>(cond_data[i])) { | ||
h_true_index.push_back(i); | ||
struct CheckTrue { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这个功能封装成一个struct的意义不大,一般将一些操作定义成Fuctor,是希望将Functor作为函数的参数,从而支持设置多种同类型不同的操作。
另外,结构名不建议使用动宾结构来命名。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
memory::Alloc(platform::CPUPlace(), (rank + 1) * sizeof(int64_t)); | ||
int64_t *ptr_stride = reinterpret_cast<int64_t *>(d_tmp_mem->ptr()); | ||
int64_t *ptr_true_num = ptr_stride + rank; | ||
int64_t *h_stride = reinterpret_cast<int64_t *>(h_tmp_mem->ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stride是什么含义?感觉变量的命名表义不够清晰。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已重命名。ptr_stride是个数组,其中每个元素存的是每个维度的步长值(也就是stride),现在已重命名为stride_array。
auto d_tmp_mem = memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t)); | ||
auto h_tmp_mem = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尽量少用tmp命名,变量名应该是尽可能的望文知意。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改所有变量名,并添加若干注释
cub::DeviceScan::InclusiveSum(nullptr, cub_tmp_size, d_true_num, d_true_num, | ||
numel, dev_ctx.stream()); | ||
auto cub_tmp = memory::Alloc(dev_ctx, cub_tmp_size * sizeof(int64_t)); | ||
void *ptr_mem = cub_tmp->ptr(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptr_mem
不能做到望文知意。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为cub_data标识这是cub库函数用到的数据
…Paddle into optimize-whereindexop-prefix
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { | ||
true_num_array[idx] = static_cast<bool>(cond_data[idx]) ? 1 : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve performance: true_num_array[idx] = static_cast<bool>(cond_data[idx]) ? 1 : 0;
--> true_num_array[idx] = static_cast<int64_t>(static_cast<bool>(cond_data[idx]));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (int64_t i = 0; i < numel; i++) { | ||
if (static_cast<bool>(cond_data[i])) { | ||
h_true_index.push_back(i); | ||
__global__ void KeGetTrueNum(const T *cond_data, const int64_t numel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KeGetTrueNum
: What is Ke
? If it means kernel
, just use GetTrueNum
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
|
||
template <typename T> | ||
__global__ void KeSetTrueIndex(int64_t *out_ptr, const T *cond_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
… optimize-whereindexop-prefix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for op benchmark ci
PR types
Performance optimization
PR changes
OPs
Describe
优化
优化效果:
竞品对比(测试数据大小
[100, 100, 100]
):V100-SXM2-32GB
机器循环100次tf.where
torch.nonzero
可读性性优化d7fd2da:
优化变量名并增加若干注释
优化345442eb:
优化方法:
cub::DeviceScan::InclusiveSum
方法替代自写的不完美的prefix sum kernel优化点:
alloc
操作移到kernel前KeScanPrefixSum
kernel替换为cub库带的cub::DeviceScan::InclusiveSum
,提高可靠性out->Resize
一次,用InclusiveSum
而非ExclusiveSum
的一个原因也是为了可以直接得到true_num
KeGetTrueNum
和KeSetTrueIndex
此时都做到了访存合并优化2cde9c19:
优化方法:
thrust::inclusive_scan
优化点:
thrust::inclusive_scan
的defalut stream
导致的不稳定性尽管这里起了4个kernel,但由于每个kernel都起来多个block,因此计算速度比单一kernel但仅1个block要快得多。
优化1
优化方法:
优化点:
true_index
计算过程分为三步,每步起一个kernelKeGetTrueNum
统计cond_data
是否为true,是则将true_num_array
设为1,否则设为0thrust::inclusive_scan
计算true_num_array
的前缀和(高优化点),由此得到true_index
对应的索引值KeSetTrueIndex
将对应索引值写入out_ptr
中待优化点:
thrust::inclusive_scan
在defalut stream
上的表现,自写prefix sum
kernel代替defalut stream
上的thrust::inclusive_scan
dev_ctx.Wait()
步骤先前优化版本PR30556
优化方法:
融合为一个kernel