-
Notifications
You must be signed in to change notification settings - Fork 271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.3】为 Paddle 新增 masked_fill API RFC #616
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你好,出于减轻硬件适配、组合算子、编译器适配难度的考虑,除组合难度大必须依赖底层实现,或预期有极大性能差距的情况,不建议新增OP。从这个RFC的介绍上看,也通过where等OP (或API)组合。因此辛苦补充一下这几点信息:
- 从Paddle的现状看,当前的where 、 full能否满足组合的要求 (包括broadcast / dtype / 反向支持 / 是否可以inplace 等几个方面)
- Pytorch API的细节情况,主要是是否支持broadcast 、 各个Tensor的dtype等,以及value的类型
|
||
paddle.masked_fill(input, mask, value, inplace=False) | ||
|
||
paddle.masked_fill_(input, mask, value, inplace=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否inplace改直接由API是否带下划线区分即可,不需要再加inplace参数
|
||
Tensor.masked_fill(input, mask, value) | ||
|
||
Tensor.masked_fill_(input, mask, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个情况下input
即前面的Tensor
,不应在输入参数里
|
||
``` | ||
Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里辛苦再补充下pytorch这个API参数接收的类型、数据类型、shape要求等信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充~
|
||
- `input (Tensor)`: 输入的张量,需要进行填充操作。 | ||
- `mask (Tensor, bool)`: 用于指定填充位置的布尔值掩码张量,与 input 张量形状相同。 | ||
- `value (Tensor, bool, int, float)`: 待填充的数据,参数类型支持布尔值、整数、浮点数以及0维的张量。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
理想情况下Value也需要适配complex类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充~
通过这样的组合方式实现不知道是否符合需求? def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x) |
- `input (Tensor)`: 输入的张量,需要进行填充操作。 | ||
- `mask (Tensor, bool)`: 用于指定填充位置的布尔值掩码张量,与 input 张量形状相同。 | ||
- `value (Tensor, bool, int, float, complex)`: 待填充的数据,参数类型支持布尔值、整数、浮点数以及0维的张量。 | ||
- `inplace (bool, optional)`: 是否进行 inplace 操作。如果设置为 True,则会直接修改输入张量,否则返回一个新的张量,默认为 False。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方的inplace参数还需要移除一下,此外辛苦参考其他API,补充下name参数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
```python | ||
out = paddle.full(x.shape, value, x.dtype) | ||
out = paddle.where(mask, y, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
粗略看这个方案的计算逻辑是正确的,但是需要确认下前面说的几点,依赖的这两个API的dtype支持情况、broadcast情况,反向支持情况等等,可以先调研下补充到现状信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
粗略看这个方案的计算逻辑是正确的,但是需要确认下前面说的几点,依赖的这两个API的dtype支持情况、broadcast情况,反向支持情况等等,可以先调研下补充到现状信息
除了反向的支持情况都补充了,反向支持情况应该怎么看呢,还麻烦帮忙解答一下~ @zoooo0820
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接使用full_like的话是创建一个新的Tensor,如果value本身是scalar的也不需要反向,这个没有问题。如果value是Tensor的话可以考虑跳过这个api,where应该本身是支持broadcast的,可以验证一下。
此外where中本身也较多针对值的full_like操作,可以也看下是否可以复用,避免额外操作
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接使用full_like的话是创建一个新的Tensor,如果value本身是scalar的也不需要反向,这个没有问题。如果value是Tensor的话可以考虑跳过这个api,where应该本身是支持broadcast的,可以验证一下。
此外where中本身也较多针对值的full_like操作,可以也看下是否可以复用,避免额外操作
value是Tensor的话可以考虑跳过这个api, 这个我有点太清楚具体要怎么跳过。 broadcast 的已经在文档里面写了一个例子
啦,辛苦老师看一下啦,辛苦老师看一下
x = paddle.ones([3, 3], dtype='float32')
mask = paddle.randint(0, 2, [1, 3]).astype('bool')
out = masked_fill(x, mask, 2)
print(out)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[2., 1., 2.],
# [2., 1., 2.],
# [2., 1., 2.]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接使用full_like的话是创建一个新的Tensor,如果value本身是scalar的也不需要反向,这个没有问题。如果value是Tensor的话可以考虑跳过这个api,where应该本身是支持broadcast的,可以验证一下。
此外where中本身也较多针对值的full_like操作,可以也看下是否可以复用,避免额外操作
where支持x或y是非tensor输入,可以看下是否可以不使用full_like呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
```python | ||
out = paddle.full(x.shape, value, x.dtype) | ||
out = paddle.where(mask, y, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接使用full_like的话是创建一个新的Tensor,如果value本身是scalar的也不需要反向,这个没有问题。如果value是Tensor的话可以考虑跳过这个api,where应该本身是支持broadcast的,可以验证一下。
此外where中本身也较多针对值的full_like操作,可以也看下是否可以复用,避免额外操作
|
||
masked_fill_支持inplace方式修改输入张量。 | ||
|
||
- `input (Tensor)`: 输入的张量,需要进行填充操作。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据上面整理的依赖api的dtype情况,辛苦在这里明确下这个API可支持的dtype吧,此外得看下tensor是否支持复数,否则value如果要支持complex,行为需要明确
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据上面整理的依赖api的dtype情况,辛苦在这里明确下这个API可支持的dtype吧,此外得看下tensor是否支持复数,否则value如果要支持complex,行为需要明确
已补充,where 算子不支持复数,所以这个api应该也是不能支持复数的
|
||
def masked_fill(x, mask, value): | ||
y = paddle.full_like(x, value, x.dtype) | ||
return paddle.where(mask, y, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认可以不使用full_like的话,这个地方移除一下吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认可以不使用full_like的话,这个地方移除一下吧
当 value 是 scalar 的时候需要将其转换为 paddle tensor, paddle.where 里面也有类似的逻辑
if np.isscalar(value):
value = paddle.full([1], value, x.dtype)
float,double,int8_t,uint8_t,int16_t,int,int64_t,bool,float16,bfloat16,complex32,complex64 | ||
|
||
GPU Kernel | ||
float,double,int8_t,uint8_t,int16_t,int,int64_t,bool,float16,bfloat16,complex32,complex64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
## 底层OP设计 | ||
|
||
依赖python实现,无需底层op支持。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里描述不太准确,辛苦再修改一下。本质上现在也是使用已有的OP(where / full等),只是不需要额外开发新的OP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
- `input (Tensor, float, double, int, int64_t, float16, bfloat16)`: 输入的张量,需要进行填充操作。 | ||
- `mask (Tensor, bool)`: 用于指定填充位置的布尔值掩码张量,与 input 张量形状相同。 | ||
- `value (Tensor, float, double, int, int64_t, float16, bfloat16)`: 待填充的数据,参数类型支持布尔值、整数、浮点数以及0维的张量。 | ||
- `name (str,可选)` - 具体用法请参见 [Name](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_guides/low_level/program.html#api-guide-name),一般无需设置,默认值为 None。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个建议用这样的格式描述:
input(Tensor) : 参数描述,支持的数据类型有xxx
, 目前的写法容易有表示同时可接收Tensor和其他python类型的歧义。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
masked_fill_支持inplace方式修改输入张量。 | ||
|
||
- `input (Tensor, float, double, int, int64_t, float16, bfloat16)`: 输入的张量,需要进行填充操作。 | ||
- `mask (Tensor, bool)`: 用于指定填充位置的布尔值掩码张量,与 input 张量形状相同。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask
参数这里需要明确,是必须与input
相同,还是只要满足可广播的条件就行。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
paddle.masked_fill(input, mask, value) | ||
|
||
paddle.masked_fill_(input, mask, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
抱歉之前comment漏了,这个地方input命名和目前paddle的其他API命名习惯有差异,请参考下https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/api_contributing_guides/api_design_guidelines_standard_cn.html 这个文档,再提PR修改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这里修复了,麻烦您再review一下 #637 @zoooo0820
为 Paddle 新增 masked_fill API RFC 文档