-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support sparse for custom python operators #8620
Conversation
python/mxnet/operator.py
Outdated
return in_stype, [in_stype[0]]*len(self.list_outputs()), \ | ||
[in_stype[0]]*len(self.list_auxiliary_states()) | ||
|
||
def infer_storage_type(self, in_stype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would it be more natural to put def infer_storage_type
before def infer_storage_type_backward
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this
python/mxnet/operator.py
Outdated
Returns | ||
------- | ||
in_stype : list | ||
list of argument stypes. Can be modified from in_stype. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does Can be modified from in_stype
mean? Can they be modified when it's already set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it cannot be modified if it is already set. Can be modified only if it is undefined storage type. I have removed the line to avoid confusion
src/operator/custom/custom.cc
Outdated
for (size_t i = 0; i < oattr->size(); i++) { | ||
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage); | ||
} | ||
dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use DISPATCH_MODE_ASSIGN_CHECK
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes Changed.
if in_data[0].stype == 'default': | ||
aux[0][:] = 1 | ||
self.assign(out_data[0], req[0], in_data[0]*in_data[0]) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add check here:
if data.stype == 'csr':
assert(isinstance(data, CSRNDArray))
check_numeric_gradient(op, [x], [aux]) | ||
|
||
x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this line duplicated on purpose?
@@ -3570,12 +3570,19 @@ def test_rcbrt_op(): | |||
def test_custom_op(): | |||
class Sqr(mx.operator.CustomOp): | |||
def forward(self, is_train, req, in_data, out_data, aux): | |||
self.assign(out_data[0], req[0], in_data[0]*in_data[0]) | |||
aux[0][:] = 1 | |||
#self.assign(out_data[0], req[0], in_data[0]*in_data[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused code?
This reverts commit 938eda9.
* Add asscipy support and coo format support * Comment misalignment change * Add documentation for Sparse NDarray * Change comment * Adding comments and support for dtype * Modifying tests * Add spsp None check * Fix lint * Custom operators for sparse * Use DISPATCH_MODE_ASSIGN_CHECK * Change NDArray to _ndarray_cls * Remove redundant code * Add a test to make sure the NDArray is an instance of CSRNDArray * Fix lint * Fix test * Trigger CI
…ache#8733) This reverts commit 938eda9.
…ache#8733) This reverts commit 938eda9.
* Add asscipy support and coo format support * Comment misalignment change * Add documentation for Sparse NDarray * Change comment * Adding comments and support for dtype * Modifying tests * Add spsp None check * Fix lint * Custom operators for sparse * Use DISPATCH_MODE_ASSIGN_CHECK * Change NDArray to _ndarray_cls * Remove redundant code * Add a test to make sure the NDArray is an instance of CSRNDArray * Fix lint * Fix test * Trigger CI
…ache#8733) This reverts commit 938eda9.
Description
This PR is to support sparse for custom python operators. Specifically, one should be able to use sparse storage types with custom python operators, right now only default storage types are supported.
Checklist
Essentials
make lint
)Changes
Comments
@piiswrong @eric-haibin-lin