Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.16】add RFC for take API #186

Merged
merged 13 commits into from
Jul 27, 2022

Conversation

S-HuaBomb
Copy link
Contributor

为 paddle 新增 take API

@paddle-bot
Copy link

paddle-bot bot commented Jul 14, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请检查PR提交格式和内容是否完备,具体请参考示例模版
Your PR has been submitted. Thanks for your contribution!
Please check its format and content. For this, you can refer to Template and Demo.


3. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。

# 三、业内方案调研
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请补充下tf的调研情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
```

在 [代码位置](https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4932) 中也定义了 `take` 方法:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和49行的定义方法有什么不同呢,前面是C++实现,后面是python实现?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,因为 torch/onnx/symbolic_opset*.py 中的方法是 Aten 运算符中已经存在的方法。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以将两者的区别补充到RFC中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


## API 实现方案

该 API 需要添加在 Paddle repo 的 `python/paddle/tensor/math.py` 文件中;并在 `python/paddle/tensor/init.py` 中添加 `take` API,以支持 Tensor.take 的调用方式。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python/paddle/tensor/__init__.py,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


- `index` 索引越界时直接报错。

- 在动态图、静态图下的都能得到正确的结果。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

覆盖CPU、GPU两种测试场景

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# 八、影响面

增加了一个 `paddle.take` API。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为独立新增API,对其他模块没有影响

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


- 在维度支持上,`Numpy.take` 支持指定轴,`torch.take` 不支持。

- `Numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeff41404 讨论:numpy的设计更好且能覆盖pytorch的功能。可以增加axismode两个参数,来支持指定轴和索引越界的3种方式么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于 numpy 在指定轴索引后得到的结果不能保证与 index 的 shape 一致,会破坏 take 方法的输出结果形状与 index 一致的特性。因此我们决定新增的 paddle.take 的功能与 torch.takenumpy.take 的默认形式保持一致,即,不增加 axis 参数指定索引轴;在 torch.take 的基础上增加 mode 参数提供三种 index 索引越界的处理方式。尽可能保持 take 索引方法简洁、易理解的特性。

见修改后的 RFC:#217

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

破坏take 方法的输出结果形状与 index 一致的特性

@jeff41404 讨论,只有在设置了axis情况下才有影响,默认不设置不会影响。但axis是兼容升级,可以后续有需求再增加。本次RFC和对应的PR可以只增加mode参数。


```python
paddle.take(
input: Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据Paddle API设计规范,函数操作只有一个待操作的张量参数时,用 x 命名。(输入input要改成x)https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/api_design_guidelines_standard_cn.html#leimingyufangfamingdeguifan
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx,Done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants