This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Numpy] Backward error in mixed int64 + float32 #18084
Comments
Another failure case: import mxnet as mx
from mxnet.gluon import HybridBlock
mx.npx.set_np()
class Foo(HybridBlock):
def hybrid_forward(self, F, query):
query_shape = F.npx.shape_array(query)
return query / F.np.sqrt(query_shape[-1].astype(mx.np.float32))
foo = Foo()
foo.hybridize()
a = mx.np.ones((5, 5, 5), dtype=mx.np.float16)
out = foo(a)
print(out)
a.attach_grad()
with mx.autograd.record():
out = foo(a)
out.backward()
print(a.grad) Error:
|
Assignee: @BenjaminCHEN2016 |
fixed by #18250 |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
This is related to #18022.
Reproducible example:
Error message:
Currently, I have to use
query / F.np.sqrt(query_shape[-1].astype(np.float32))
.The text was updated successfully, but these errors were encountered: