-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPU] Support npu op index_select #34611
Conversation
# case 1: | ||
with program_guard(Program(), Program()): | ||
x = fluid.layers.data(name='x', shape=[-1, 4], dtype='float32') | ||
index = fluid.layers.data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前不推荐使用fluid下的API了,例如,fluid.layers.data替换为paddle.static.data,下面其他的接口都类似,查阅下文档
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 所有fluid相关API已remove并替换。
index = fluid.layers.data( | ||
name='index', shape=[3], dtype='int32', append_batch_size=False) | ||
z = paddle.index_select(x, index, axis=1) | ||
#exe = fluid.Executor(fluid.CPUPlace()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无用代码记得删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
import paddle.fluid as fluid | ||
import paddle.fluid.core as core | ||
from paddle.fluid import Program, program_guard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用paddle.static.program_guard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self.input_data() | ||
|
||
# case 1: | ||
with fluid.dygraph.guard(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前框架默认使用动态图执行,可以通过enable_static或者disable_static接口来控制。别用fluid.dygraph.guard()了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# case 1: | ||
with fluid.dygraph.guard(): | ||
x = fluid.dygraph.to_variable(self.data_x) | ||
index = fluid.dygraph.to_variable(self.data_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
动态图组网参考下paddle.index_select,to_variable -> to_tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
#include "paddle/fluid/framework/tensor_util.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16-19行的4个头文件都不需要,可以删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
def config(self): | ||
self.x_shape = (100, 4, 5) | ||
self.x_type = np.float64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认下你的index_select算子没有注册double类型的数据,为什么这里设置的x_type是float64的情况下单测可以跑过呢?看看是不是其实fall back到CPU算子上了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
笔误,已改正,float64->float32。PR描述中的测试结果是对应float32的测试,没有问题。
d9f9a33
to
705514d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
[NPU] Support npu op index_select (forward op)
说明:
1.关于数据类型
cpu端的index_select算子支持的数据类型有float, double, int, int64;相比cpu端,npu端没有支持double,因为index_select前向是通过调用CANN GatherV2 API 实现的,该 API 不支持double类型。
2.该PR在npu端的单测覆盖了cpu端的所有测试情况。
3.本PR实现了index_select的前向计算,添加反向时,可参考代码中标注'todo'的注释。
运行结果
单测运行结果
调用 index_select npu kernel