-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【complex op No.4】support complex for bmm #64603
base: develop
Are you sure you want to change the base?
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that 84df26b's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
test/legacy_test/test_bmm_op.py
Outdated
class TestBmmOpCaseComplex64(TestBmmOp): | ||
def setUp(self): | ||
self.op_type = "bmm" | ||
self.python_api = paddle.tensor.bmm | ||
self.public_python_api = paddle.tensor.bmm | ||
X = ( | ||
np.random.uniform(1, 5, (10, 3, 4)) | ||
+ 1j * np.random.uniform(1, 5, (10, 3, 4)) | ||
).astype("complex64") | ||
Y = ( | ||
np.random.uniform(1, 5, (10, 4, 2)) | ||
+ 1j * np.random.uniform(1, 5, (10, 4, 2)) | ||
).astype("complex64") | ||
self.inputs = {'X': X, 'Y': Y} | ||
Out = np.matmul(X, Y) | ||
self.outputs = {'Out': Out} | ||
|
||
def test_check_output(self): | ||
self.check_output(check_pir=True, check_prim=False) | ||
# pass | ||
|
||
def test_checkout_grad(self): | ||
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim=False) | ||
|
||
|
||
class TestBmmOpCaseComplex128(TestBmmOp): | ||
def setUp(self): | ||
self.op_type = "bmm" | ||
self.python_api = paddle.tensor.bmm | ||
self.public_python_api = paddle.tensor.bmm | ||
X = ( | ||
np.random.uniform(1, 5, (10, 3, 4)) | ||
+ 1j * np.random.uniform(1, 5, (10, 3, 4)) | ||
).astype("complex128") | ||
Y = ( | ||
np.random.uniform(1, 5, (10, 4, 2)) | ||
+ 1j * np.random.uniform(1, 5, (10, 4, 2)) | ||
).astype("complex128") | ||
self.inputs = {'X': X, 'Y': Y} | ||
Out = np.matmul(X, Y) | ||
self.outputs = {'Out': Out} | ||
|
||
def test_check_output(self): | ||
self.check_output(check_pir=True, check_prim=False) | ||
# pass | ||
|
||
def test_checkout_grad(self): | ||
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考其他单测,补充一下self.dtype=np.complex64/np.complex218
这个设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,但是本地仍然有较大的误差
Sorry to inform you that 6c1582f's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
|
PR Category
Others
PR Types
New features
Description
支持bmm的复数类型,bmm前向是直接调用的blas的gemm,因此前向可以直接增加类型即可。反向参考
Paddle/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h
中的matmulgradkernel新增了复数类型。