Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.3】为 Paddle 新增 masked_fill API RFC #616

Merged
merged 13 commits into from
Sep 18, 2023
263 changes: 263 additions & 0 deletions rfcs/APIs/20230913_api_design_for_masked_fill.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# paddle.masked_fill 设计文档

| API名称 | paddle.masked_fill |
| ------------ | -------------------------------------- |
| 提交作者 | AndSonder |
| 提交时间 | 2023-09-13 |
| 版本号 | V1.0 |
| 依赖飞桨版本 | develop |
| 文件名 | 20230913_api_design_for_masked_fill.md |

# 一、概述

## 1、相关背景

`masked_fill` 是一个常用的API,该 API 的作用是根据 `mask` 信息,将 `value` 中的值填充到 `Tensor` 中 `mask` 对应为 `True` 的位置。这个功能在语义分割、序列标注等任务中经常用到。因此,在Paddle中提供该API,方便用户使用。

## 2、功能目标

在 Paddle 框架中,新增 `paddle.masked_fill` 对于一个Tensor,根据mask信息,将 value 中的值填充到该Tensor中mask对应为True的位置。

## 3、意义

该API是一个常用的API,可以方便用户使用。让用户不用自己实现该功能,提高用户的使用效率。

# 二、飞桨现状

目前paddle缺少相关功能实现。只能通过 paddle 现有的 API 组合实现。

```python
# paddlepaddle >= 2.0
import paddle

paddle.seed(123)
x = paddle.ones([3, 3], dtype='float64')
# Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.00276479, 0.45899123, 0.96637046],
# [0.66818708, 0.05855134, 0.33184195],
# [0.34202638, 0.95503175, 0.33745834]])

mask = paddle.randint(0, 2, [3, 3]).astype('bool')
# Tensor(shape=[3, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,
# [[True , True , False],
# [True , True , True ],
# [True , True , True ]])

def masked_fill(x, mask, value):
return paddle.where(mask, value, x)

out = masked_fill(x, mask, 2.)
# Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2. , 2. , 0.96637046],
# [2. , 2. , 2. ],
# [2. , 2. , 2. ]])
```


where 支持在 CPU 和 GPU 上运行。

paddle.where 支持的 dtype:

```python
CPU Kernel
float, double, int, int64_t

GPU Kernel
float,double,int,int64_t,float16,bfloat16
```

使用 where 可以完成 masked_fill API,支持 broadcast 机制。

```python
x = paddle.ones([3, 3], dtype='float64')
mask = paddle.randint(0, 2, [1, 3]).astype('bool')

out = masked_fill(x, mask, 2.)
print(out)

# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[2., 1., 2.],
# [2., 1., 2.],
# [2., 1., 2.]])
```



# 三、业内方案调研

## Pytorch

Pytorch中 有 API `Tensor.masked_fill_(mask, value)`

在pytorch中,介绍为:

```
Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里辛苦再补充下pytorch这个API参数接收的类型、数据类型、shape要求等信息

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充~


其中输入参数的描述如下:

- mask (BoolTensor) – the boolean mask
- value (float) – the value to fill in with

### 实现方法


在实现方法上, Pytorch 设计了两种实现方式,一种是CPU实现,一种是GPU实现。



核心代码如下:

```cpp
// GPU 实现
void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kBool, kHalf, kBFloat16, kComplexHalf, iter.common_dtype(), "masked_fill_", [&]() {
const auto value_ = value.to<scalar_t>();
gpu_kernel(
iter, [value_] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t {
if (mask) {
return value_;
}
return self;
});
});
}

// CPU 实现
template <typename scalar_t>
void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) {
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
char* dst = data[0];
char* mask = data[1];
for (const auto i : c10::irange(n)) {
bool mask_value = *reinterpret_cast<bool*>(mask + strides[1] * i);

if (mask_value) {
*(scalar_t*)(dst + strides[0] * i) = value;
}
}
};
iter.for_each(loop);
}

void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf,
iter.dtype(), "masked_fill", [&] {
scalar_t scalar_val = value.to<scalar_t>();
auto mask_dtype = iter.input_dtype(0);
TORCH_CHECK(mask_dtype == ScalarType::Bool, "masked_fill only supports boolean masks, "
"but got mask with dtype ", mask_dtype);
cpu_masked_fill_kernel<scalar_t>(iter, scalar_val);
});
}

```

Pytorch 在 CPU 和 GPU 上对 masked_fill 的实现方式有些不同:

CPU 实现:

1. 使用模板函数 cpu_masked_fill_kernel 来实现标量值填充逻辑。

2. TensorIterator::for_each 启动循环,对每组数据调用 lambda 函数。

3. lambda 中直接访问指针进行填充判断和赋值。

4. 使用宏生成不同数据类型的特化模板。

5. 调用入口做参数校验。

GPU 实现:

1. 用 gpu_kernel 启动 CUDA kernel。

2. 用 GPU Lambda 编写 kernel 函数体。

3. kernel 函数签名为 (value, mask) -> output, 直接在 GPU 上判断和赋值。

4. 用宏生成不同数据类型的 kernel。

5. 调用入口转换 value 为模板类型。



## Tensorflow

Tensorflow 并没有直接提供 `masked_fill` 的API,但是可以通过 `tf.where` 来实现。相关讨论PR: https://github.com/tensorflow/tensorflow/pull/41975

讨论结果为使用 `tf.where` 实现 `masked_fill` 的功能更加高效,因此没有提供 `masked_fill` 的API。

# 四、对比分析

- Pytorch 自定义Kernel的方式更加高效
- Tensorflow 通过 `tf.where` 实现 `masked_fill`


# 五、方案设计

## 命名与参数设计

paddle.masked_fill(input, mask, value)

paddle.masked_fill_(input, mask, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉之前comment漏了,这个地方input命名和目前paddle的其他API命名习惯有差异,请参考下https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/api_contributing_guides/api_design_guidelines_standard_cn.html 这个文档,再提PR修改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这里修复了,麻烦您再review一下 #637 @zoooo0820


Tensor.masked_fill(mask, value)

Tensor.masked_fill_(mask, value)

masked_fill_支持inplace方式修改输入张量。

- `input (Tensor, float, double, int, int64_t, float16, bfloat16)`: 输入的张量,需要进行填充操作。
- `mask (Tensor, bool)`: 用于指定填充位置的布尔值掩码张量,与 input 张量形状相同。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask参数这里需要明确,是必须与input相同,还是只要满足可广播的条件就行。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- `value (Tensor, float, double, int, int64_t, float16, bfloat16)`: 待填充的数据,参数类型支持布尔值、整数、浮点数以及0维的张量。
- `name (str,可选)` - 具体用法请参见 [Name](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_guides/low_level/program.html#api-guide-name),一般无需设置,默认值为 None。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个建议用这样的格式描述:
input(Tensor) : 参数描述,支持的数据类型有xxx, 目前的写法容易有表示同时可接收Tensor和其他python类型的歧义。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



## 底层OP设计

依赖python实现,无需底层op支持。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里描述不太准确,辛苦再修改一下。本质上现在也是使用已有的OP(where / full等),只是不需要额外开发新的OP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


## API实现方案

在 python/paddle/tensor/manipulation.py 中增加 masked_fill 以及 masked_fill_ 函数。

通过 `paddle.where` 实现。

```python
out = paddle.where(mask, value, x)
```

## 代码实现文件路径

函数API实现路径: python/paddle/tensor/manipulation.py

单元测试路径:在 Paddle repo 的 test/ 目录, 同时在 paddle/test/legacy_test/test_inplace.py 中新增对应的inplace api 单测


# 六、测试和验收的考量

测试考虑的case如下:

- 输入的mask和input的形状不一致,但是可以broadcast
- 校验参数 value 的正确性, 是否是支持的数据类型,当 value 是0维 tensor 时梯度正确回传
- 测试在进行反向梯度计算时结果的正确性
- 错误检查:输入x不是Tensor时,能否正确抛出错误


# 七、可行性分析及规划排期

方案实施难度可控,工期上可以满足在当前版本周期内开发完成。

# 八、影响面

为独立新增API,对其他模块没有影响

# 名词解释


# 附件及参考资料