Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.6】implement nan_to_num #42469

Merged
merged 9 commits into from
Oct 25, 2022

Conversation

tiancaishaonvjituizi
Copy link
Contributor

@tiancaishaonvjituizi tiancaishaonvjituizi commented May 4, 2022

PR types

New features

PR changes

OPs

Describe

实现 nan_to_num

Hackathon issue:#40329
Hackathon RFC:PaddlePaddle/community#93

@paddle-bot-old
Copy link

paddle-bot-old bot commented May 4, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle.disable_static(place=self.place)

with paddle.fluid.dygraph.guard():
x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype(np.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么只测试了 np.float:

因为只要把 x_np 的 dtype 设置为 np.float64,在 paddle/fluid/imperative/prepared_operator.h:586 就会报错说 "nan" 这个属性的 dtype 不支持,设置 GLOG_v=10 看日志发现,x_np 为 np.float64 时,"nan" 的 dtype 也被认为是 double(这一点就很奇怪),我尝试在 paddle/fluid/imperative/prepared_operator.h:529 的 switch 里加入 FLOAT64 的支持,结果仍然会报错,错误信息是

ValueError: (InvalidArgument) boost::get failed, cannot get value (attr) by type double, its type is float.

所以这应该是 paddle 内部的一个 bug,超出了本 PR 的范围

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样会不支持float64的kernel了,float64目前有OP bug需要修复。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是说 paddle 自己有 bug 吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是paddle,应该是OP定义那里,定义那里如果还不支持double,就先用float吧

template <typename T, typename Context>
void NanToNumKernel(const Context& ctx,
const DenseTensor& x,
T nan,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的nan/posinf/neginf应该和nan_to_num_op.cc中的OP定义保持一致,是float类型,所以导致了下面的这个问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当时是因为 paddle 新版的 codegen 不支持 double 类型,所以 op 定义用了 float。我试试把 codegen 的 bug 修一下然后都设置成 double

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果AddAttr("nan", "(float) the value to replace NaNs with."),用了float,这里也需要用float。这些参数全部用float就可以吧

paddle.disable_static(place=self.place)

with paddle.fluid.dygraph.guard():
x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype(np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样会不支持float64的kernel了,float64目前有OP bug需要修复。

@zhwesky2010
Copy link
Contributor

GPU下现在编不过,需要修一下
infoflow 2022-05-09 19-19-00

@tiancaishaonvjituizi
Copy link
Contributor Author

GPU下现在编不过,需要修一下

这个错误我完全没有头绪,只在 windows 下会出现,涉及到的函数 ”_fdtest“ 并不是我写的,而且我的写法和 eye_kernel_impl.h 是一样的,不知道为什么会出错。paddle 官方开发者对这个错误有什么想法吗

@zhwesky2010
Copy link
Contributor

GPU下现在编不过,需要修一下

这个错误我完全没有头绪,只在 windows 下会出现,涉及到的函数 ”_fdtest“ 并不是我写的,而且我的写法和 eye_kernel_impl.h 是一样的,不知道为什么会出错。paddle 官方开发者对这个错误有什么想法吗

这个是说不能从__global__里调__host__主机函数,推测是std::isnanstd::numeric_limits<T>::infinity 这两个函数的问题

@tiancaishaonvjituizi
Copy link
Contributor Author

@zhouwei25 我已经参考 pytorch 修改了 isnan 相关的实现,现在应该没有问题了。但现在 ci 有很多奇怪的错误(网络问题、test_parallel_dygraph_unused_variables_gloo 超时等),该怎么操作呢

@tiancaishaonvjituizi
Copy link
Contributor Author

@zhouwei25 现在 CI 几乎都过了

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zhwesky2010
zhwesky2010 previously approved these changes May 16, 2022
@tiancaishaonvjituizi
Copy link
Contributor Author

@zhouwei25 所有 CI 已过,只差 api 改动所需的 approve 了

@Ligoml
Copy link
Contributor

Ligoml commented May 17, 2022

@tiancaishaonvjituizi 需要在paddle/docs下同步添加中文文档哈,这里有一些参考文档:API书写规范文档预览工具中文API文档复制英文API文档示例代码

python/paddle/tensor/math.py Outdated Show resolved Hide resolved
python/paddle/tensor/math.py Outdated Show resolved Hide resolved
python/paddle/tensor/math.py Outdated Show resolved Hide resolved
@tiancaishaonvjituizi
Copy link
Contributor Author

@Ligoml 本 PR 中的 review 意见已修复,中文文档稍后提交

@zhwesky2010
Copy link
Contributor

嗯如果不赞同C++的当前实现方案,那建议就还是由paddle.where组装吧,这个inf的API bug目前已经通知去修复了

嗯,那就先等待修复吧

equal判断 inf==inf 的API bug已修复,麻烦看看还有没有其他问题

Signed-off-by: tiancaishaonvjituizi <[email protected]>
@tiancaishaonvjituizi
Copy link
Contributor Author

tiancaishaonvjituizi commented Aug 13, 2022

嗯如果不赞同C++的当前实现方案,那建议就还是由paddle.where组装吧,这个inf的API bug目前已经通知去修复了

嗯,那就先等待修复吧

equal判断 inf==inf 的API bug已修复,麻烦看看还有没有其他问题

真的还有(哭笑不得),用 np.finfo(np.float64).max 给 paddle float64 tensor 赋值变成了 inf:

>>> x=torch.empty(2, dtype=torch.float64)
>>> torch.full_like(x, np.finfo(np.float64).max)
tensor([1.7977e+308, 1.7977e+308], dtype=torch.float64)
>>> x=paddle.to_tensor([1,1], dtype=paddle.float64)
>>> paddle.full_like(x, np.finfo(np.float64).max)
Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=True,
       [inf., inf.])

可能是哪里的中间变量数据类型被误用了 float32 而不是 float64,因为我观察到下面的现象:

>>> paddle.full_like(x, np.finfo(np.float32).max)
Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=True,
       [340282346638528859811704183484516925440.,
        340282346638528859811704183484516925440.])
>>> paddle.full_like(x, np.finfo(np.float32).max*2)
Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=True,
       [inf., inf.])

现在 nan_to_num float32 的测试已经过了,代码已上传。有一说一,我觉得我真的仁至义尽了。一个简简单单的 api,approve 后又一刀切要求做巧妇难为无米之炊的重写,接着经过反复拉扯终于确定了可行的重写方案,又接二连三遇到 paddle 的各种 bug。要不您跟您的领导反映一下,我已经把这个 api 基本实现了,剩下的问题只有 paddle 自己的 bug,您或者您的领导 @jeff41404 @zhouwei25 自己继续推动这个 api 吧,我仁至义尽了,一开始就表达过,我现在在做的事情是情分不是本分

Signed-off-by: tiancaishaonvjituizi <[email protected]>
@zhwesky2010
Copy link
Contributor

zhwesky2010 commented Aug 16, 2022

OK,我们来处理这些问题

@zhwesky2010
Copy link
Contributor

zhwesky2010 commented Sep 9, 2022

@tiancaishaonvjituizi 你好,我们正在开展OP attribute不支持double类型的工作,但由于涉及到底层OP架构修改和API/OP的不兼容升级改动,时间跨度可能较长。经过内部讨论,当前nan_to_num API开发先暂时绕过,支持float32类型的单测即可,float64单测case先写上但是先注释掉就可以。

@tiancaishaonvjituizi
Copy link
Contributor Author

@zhouwei25 ok 了,留了一个 float32 的测试

self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

# def test_static(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

静态图float32可以通过单测吗?float64暂不用管,可以屏蔽单测


paddle.enable_static()

# def test_check_grad(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

反向float32可以通过单测吗?float64暂不用管,可以屏蔽单测

Signed-off-by: tiancaishaonvjituizi <[email protected]>
@Ligoml
Copy link
Contributor

Ligoml commented Oct 24, 2022

image

重新跑一下pre-commit吧

@tiancaishaonvjituizi
Copy link
Contributor Author

重新跑一下pre-commit吧

好(

Signed-off-by: tiancaishaonvjituizi <[email protected]>
@tiancaishaonvjituizi
Copy link
Contributor Author

tiancaishaonvjituizi commented Oct 24, 2022

@Ligoml 该怎么让 pre-commit 重新安装格式化工具呢?我这里的 yapf 版本似乎和 CI 里用的 yapf 版本不同,得不到和 CI 一样的格式化结果,我暂时按照 CI 输出的 diff 手动格式化了

Signed-off-by: tiancaishaonvjituizi <[email protected]>
@luotao1
Copy link
Contributor

luotao1 commented Oct 24, 2022

2022-10-24 15:51:31 Your PR code style check failed.
2022-10-24 15:51:31 Please install pre-commit locally and set up git hook scripts:
2022-10-24 15:51:31 
2022-10-24 15:51:31     pip install pre-commit==2.17.0
2022-10-24 15:51:31     pre-commit install
2022-10-24 15:51:31 
2022-10-24 15:51:31 Then, run pre-commit to check codestyle issues in your PR:
2022-10-24 15:51:31 
2022-10-24 15:51:31     pre-commit run --files python/paddle/__init__.py python/paddle/fluid/tests/unittests/test_nan_to_num_op.py python/paddle/tensor/__init__.py python/paddle/tensor/math.py
2022-10-24 15:51:31 

我这里的 yapf 版本似乎和 CI 里用的 yapf 版本

develop分支于昨天刚刚从yapf版本升级到black版本。需要先merge下develop分支,然后按上面的步骤操作下,会自动格式化的。

@tiancaishaonvjituizi
Copy link
Contributor Author

develop分支于昨天刚刚从yapf版本升级到black版本。需要先merge下develop分支,然后按上面的步骤操作下,会自动格式化的。

好的,给力

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

@tiancaishaonvjituizi
Copy link
Contributor Author

image

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有两个小问题可以在下一个PR里补充

paddle.enable_static()


# class BaseTestCases:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以在下一个PR里加一个注释说明下这些单测注释的原因

dx = paddle.grad(y, x_tensor)[0].numpy()

np_grad = np_nan_to_num_grad(x_np, np.ones_like(x_np))
self.assertTrue(np.allclose(np_grad, dx))
Copy link
Contributor

@luotao1 luotao1 Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022-10-25 11:07:01 0. It is recommended to use 'np.testing.assert_allclose' and 'np.testing.array_equal' instead of 'self.assertTrue(np.allclose(...))' and 'self.assertTrue(np.array_equal(...))'.
2022-10-25 11:07:01 Please modify the code below. If anything is unclear, please read the specification [ https://github.com/PaddlePaddle/community/blob/master/rfcs/CodeStyle/20220805_code_style_improvement_for_unittest.md#background ]. If it is a mismatch, please request qili93 (Recommend) or luotao1 review and approve.

image

单测报错信息进行了升级,可以在下一个PR里根据提示使用np.testing.assert_allclose代替单测里的self.assertTrue(np.allclose

@luotao1 luotao1 merged commit 9507969 into PaddlePaddle:develop Oct 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants