Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4】:为maxout算子支持 float16 数据类型 #50976

Merged
merged 12 commits into from
Apr 27, 2023

Conversation

Patrick-Star125
Copy link
Contributor

@Patrick-Star125 Patrick-Star125 commented Feb 27, 2023

PR types

Others

PR changes

Others

Description

性能数据(op benchmark)

input_shape groups axis fp32 fp16 diff
32, 12, 128, 128 2 -1 0.12460938 0.11993335 0.0047

@paddle-bot
Copy link

paddle-bot bot commented Feb 27, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

template class MaxOutGradFunctor<phi::CPUContext, double>;
template class MaxOutFunctor<phi::CPUContext, float>;
template class MaxOutFunctor<phi::CPUContext, phi::dtype::float16>;
template class MaxOutFunctor<phi::CPUContext, double>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们目前仅需要为GPU支持fp16。CPU的实现不需要修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_check_output关于place及fp16支持情况的判断和TestMaxOutOpFP16的装饰器的使用保留一处应该就可以,上面的装饰器会在非GPU的测试环境自动跳过单测,所以下面的内容应该是执行不到的。

前向应该没有涉及到计算,只是数据的搬运?这里单测的阈值不设置能否通过?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以通过

if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.5
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果使用装饰器的话,这里的place的判断应该不需要了。

这个max_relative_error需要设置这么大吗?需要结合反向kernel实现分析下是否有降低误差的可能

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -784,7 +784,7 @@ def maxout(x, groups, axis=1, name=None):

Parameters:
x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type
of input is float32 or float64.
of input is float16, float32 or float64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API实现中,有动静态图2个分支。静态图分支能否正常运行?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加静态图分支的测试

@luotao1
Copy link
Contributor

luotao1 commented Mar 10, 2023

请更新下代码格式来通过 PR-CI-Codestyle-Check 流水线


def set_attrs(self):
pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的FP16单测可以继承TestMaxOutOp,对TestMaxOutOp做一些小的改动,比如支持设置dtype,shape,attrs,这样可以简化代码。

可以参考低精度单测规范中的介绍。https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/amp_precision/amp_test_dev_guide_cn.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不推荐使用fluid的api。可以参考#50832中的PR静态图的单测写法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,原本的测试也用了fluid.data,需要一并修改吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以不修改

phi::MaxOutGradKernel,
float,
phi::dtype::float16,
double) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

反向kernel可能也需要调整为FP32计算精度,已降低精度的损失。

Copy link
Contributor Author

@Patrick-Star125 Patrick-Star125 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.意思是直接去掉phi::dtype::float16吗?这样做测试反向算子似乎会出错
2.请问如何判断是否会导致精度损失过大,能否改进计算逻辑减少损失

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前的修改只是给算子注册了fp16类型,但是看你并没有对kernel的实现做修改。
需要分析下前、反向的计算,里面的一些计算过程在fp16下是否会损失精度。单测因为运行时间的限制设置的shape都比较小,在自己开发环境上可以尝试把shape调大到比如1000+以上的数据规模,再看看单测里这几个fp16的case精度检查是否能达标呢?

关于问题2,在官网文档中都有详细介绍。https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/amp_precision/amp_op_dev_guide_cn.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理解了,已经将fp16单测与fp32单测对齐,测试方式和误差要求一致
1.maxout函数的逻辑为对tensor按指定组大小遍历取最大值,只有比较操作,不涉及计算,对于MaxOutFunctor和MaxOutGradFunctor的参数input_tensor的处理和output_tensor的计算都不含有规约计算,无溢出风险。
2.在线下的测试中我尝试了[32, 12, 128, 128]、[320, 12, 128, 128]、[320, 120, 128, 128]形式均可以通过,更大的tensor因为设备显存不足暂时无法测试,,但应该精度可以达标。

place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.001
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的装饰器不需要使用,OpTest会自动为FP16的单测跳过不支持的设备。test_check_output可以使用父类的方法。test_check_grad使用check_grad接口,指定max_relative_error即可。(可以参考数据类型扩展任务中已经合入的单测写法)

Copy link
Contributor Author

@Patrick-Star125 Patrick-Star125 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在线下测试中,装饰器可以去除,但是不指定place似乎无法找到kernel
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是可以不用装饰器的,你可以参考下任务列表里面已经合入的一些PR的单测写法

exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = maxout_forward_naive(self.x_np, self.groups, self.axis)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以不修改

@luotao1
Copy link
Contributor

luotao1 commented Apr 26, 2023

需要解决下 CodeStyle 流水线的问题

@Patrick-Star125
Copy link
Contributor Author

Done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants