-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API improvement for paddle.linalg.svd_lowrank #62876
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
C = self.random_matrix(rank, columns, *batch_dims, **kwargs) | ||
return B.matmul(C) | ||
|
||
def run_subtest( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy有计算svd_lowrank的API吗,如果有的话,单测对比结果时 可以简洁一些
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy没有直接计算svd_lowrank的API,这个是复用了test_pca_lowrank的代码,是不是直接在test_pca_lowrank增加svd_lowrank测试案例更好一点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样写也可以
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sorry to inform you that ccdb50c's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@NKNaN 看一下CI未通过的原因 |
@zhwesky2010 是超过7天,Paddle-bot自动标记fail了,重新rerun或者merge下develop就好了 |
python/paddle/tensor/linalg.py
Outdated
zero or more batch dimensions. N and M can be arbitrary positive number. | ||
The data type of ``x`` should be float32 or float64. | ||
q (int, optional): A slightly overestimated rank of :math:`X`. | ||
Default value is :math:`q=min(6,N,M)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的默认值,和在函数定义 def svd_lowrank(x, q=None, niter=2, M=None, name=None):
里 q=None
是不是冲突了?辛苦看一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* add svd lowrank api * add test * fix param M * fix test timeout * update docs
PR Category
User Experience
PR Types
Improvements
Description
将 paddle 的 linalg.py 中已有的 svd_lowrank 函数计算逻辑公开出来,对照 pytorch: https://pytorch.org/docs/stable/generated/torch.svd_lowrank.html