-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[WIP] Fallback mechanism for mx.np operators #16923
[WIP] Fallback mechanism for mx.np operators #16923
Conversation
# try to fallback to official NumPy op | ||
onp_op = _get_np_op(name) | ||
new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs] | ||
out = onp_op(*new_inputs, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will break the computational graph, and could not compute the gradient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are aware of this. More sophisticated fallback mechanism is illustrated in #16698 by leveraging CustomOp
. To reach 100% NumPy op coverage within a month, this the simplest and fastest pathway though. In the future, we will gradually replace those fallback ops with native implementation in backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be better to use mx.autograd.Function
to wrap these numpy operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wkcn
mx.autograd.Function
currently does not support DeepNumpy, some extra infrastructure is required.
Also, I am not sure if mx.autograd.Function
could be integrated into HybridBlock, I cannot find corresponding cases covered in the unit tests.
Fix lint Fix
fe362cc
to
2cd2094
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Add fallback mechanism Fix lint Fix * Add unit tests for linalg.cond and heaviside * Add spacing * Fix lint * Skip python2 for dispatching array function Co-authored-by: Hao Jin <[email protected]>
* Add fallback mechanism Fix lint Fix * Add unit tests for linalg.cond and heaviside * Add spacing * Fix lint * Skip python2 for dispatching array function Co-authored-by: Hao Jin <[email protected]>
* Add fallback mechanism Fix lint Fix * Add unit tests for linalg.cond and heaviside * Add spacing * Fix lint * Skip python2 for dispatching array function Co-authored-by: Hao Jin <[email protected]>
Description
Fallback mechanism for
mx.np
operators.@haojin2