-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.16】add RFC for take API #186
Conversation
|
||
3. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。 | ||
|
||
# 三、业内方案调研 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请补充下tf的调研情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
``` | ||
|
||
在 [代码位置](https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4932) 中也定义了 `take` 方法: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和49行的定义方法有什么不同呢,前面是C++实现,后面是python实现?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,因为 torch/onnx/symbolic_opset*.py 中的方法是 Aten
运算符中已经存在的方法。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以将两者的区别补充到RFC中
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
## API 实现方案 | ||
|
||
该 API 需要添加在 Paddle repo 的 `python/paddle/tensor/math.py` 文件中;并在 `python/paddle/tensor/init.py` 中添加 `take` API,以支持 Tensor.take 的调用方式。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python/paddle/tensor/__init__.py
,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
- `index` 索引越界时直接报错。 | ||
|
||
- 在动态图、静态图下的都能得到正确的结果。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
覆盖CPU、GPU两种测试场景
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
# 八、影响面 | ||
|
||
增加了一个 `paddle.take` API。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为独立新增API,对其他模块没有影响
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
- 在维度支持上,`Numpy.take` 支持指定轴,`torch.take` 不支持。 | ||
|
||
- `Numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和 @jeff41404 讨论:numpy的设计更好且能覆盖pytorch的功能。可以增加axis
和mode
两个参数,来支持指定轴和索引越界的3种方式么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于 numpy 在指定轴索引后得到的结果不能保证与 index 的 shape 一致,会破坏 take 方法的输出结果形状与 index 一致的特性。因此我们决定新增的
paddle.take
的功能与torch.take
和numpy.take
的默认形式保持一致,即,不增加 axis 参数指定索引轴;在torch.take
的基础上增加 mode 参数提供三种 index 索引越界的处理方式。尽可能保持 take 索引方法简洁、易理解的特性。
见修改后的 RFC:#217
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
破坏take 方法的输出结果形状与 index 一致的特性
和 @jeff41404 讨论,只有在设置了axis情况下才有影响,默认不设置不会影响。但axis是兼容升级,可以后续有需求再增加。本次RFC和对应的PR可以只增加mode参数。
|
||
```python | ||
paddle.take( | ||
input: Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据Paddle API设计规范,函数操作只有一个待操作的张量参数时,用 x 命名。(输入input要改成x)https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/api_design_guidelines_standard_cn.html#leimingyufangfamingdeguifan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx,Done
为 paddle 新增 take API