diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 8ea2b0e0e5dc..0726566da55d 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -146,7 +146,9 @@ enum CustomOpPropCallbacks { kCustomOpPropInferShape, kCustomOpPropDeclareBackwardDependency, kCustomOpPropCreateOperator, - kCustomOpPropInferType + kCustomOpPropInferType, + kCustomOpPropInferStorageType, + kCustomOpPropBackwardInferStorageType }; @@ -158,6 +160,10 @@ typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/); typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/, unsigned** /*shapes*/, void* /*state*/); typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/); +typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/); +typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/, + int * /*stypes*/, + void * /*state*/); typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/, const int* /*out_data*/, int* /*num_deps*/, int** /*rdeps*/, void* /*state*/); diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index 1337bbccc3c8..8fcf127d6259 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except +# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except, too-many-lines """numpy interface for operators.""" from __future__ import absolute_import @@ -30,6 +30,9 @@ from .base import c_array, c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str from . import symbol, context from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _STORAGE_TYPE_ID_TO_STR +from .ndarray import _ndarray_cls + c_int_p = POINTER(c_int) @@ -513,6 +516,51 @@ def infer_type(self, in_type): return in_type, [in_type[0]]*len(self.list_outputs()), \ [in_type[0]]*len(self.list_auxiliary_states()) + def infer_storage_type(self, in_stype): + """infer_storage_type interface. Used to infer storage type of + inputs and outputs in the forward pass. + + Parameters + ---------- + in_stype : list of stypes, Valid stypes are default, row_sparse and + csr + + Returns + ------- + in_stype : list + list of argument stypes. + out_stype : list + list of output types calculated from in_stype, + in the same order as declared in list_outputs. + aux_type : Optional, list + list of aux types calculated from in_stype, + in the same order as declared in list_auxiliary_states. + """ + return in_stype, [in_stype[0]]*len(self.list_outputs()), \ + [in_stype[0]]*len(self.list_auxiliary_states()) + + def infer_storage_type_backward(self, in_stype): + """infer_storage_type_backward interface. Used to infer storage + type of inputs and outputs in the backward pass. + + Parameters + ---------- + in_stype : list of stypes. Provide the in_stypes in the + following order: output_grads, in_data, out_data, aux_data(optional) + + Returns + ------- + in_stype : list + list of input stypes. + out_stype : list + list of output stypes calculated from in_stype. + aux_stype : list + list of aux stypes calculated from in_stype, + in the same order as declared in list_auxiliary_states. + """ + return in_stype, [in_stype[0]]*len(self.list_outputs()), \ + [in_stype[0]]*len(self.list_auxiliary_states()) + def list_outputs(self): """list_outputs interface. Can override when creating new operators. @@ -601,6 +649,8 @@ def do_register(prop_cls): infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), POINTER(POINTER(mx_uint)), c_void_p) infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p) + inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p) + inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p) list_functype = CFUNCTYPE(c_int, POINTER(POINTER(POINTER(c_char))), c_void_p) deps_functype = CFUNCTYPE(c_int, c_int_p, c_int_p, c_int_p, c_int_p, POINTER(c_int_p), c_void_p) @@ -654,6 +704,81 @@ def infer_shape_entry(num_tensor, tensor_dims, return False return True + def infer_storage_type_backward_entry(num_tensor, tensor_stypes, _): + """C Callback for CustomOpProp::InferStorageTypeBackward""" + try: + n_in = len(op_prop.list_arguments()) + n_out = len(op_prop.list_outputs()) + n_aux = len(op_prop.list_auxiliary_states()) + total_inputs = n_in + 2 * n_out + total_aux = n_aux + total_outputs = n_in + assert num_tensor == (2 * n_in + 2 * n_out + n_aux) + + stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] \ + for i in range(total_inputs + total_aux)] + ret = op_prop.infer_storage_type_backward(stypes) + if len(ret) == 2: + istype, ostype = ret + astype = [] + elif len(ret) == 3: + istype, ostype, astype = ret + else: + raise AssertionError("infer_storage_type backward must return 2 or 3 lists") + assert len(ostype) == total_outputs, \ + "InferStorageTypeBackward Error: expecting %d entries in returned output " \ + "stypes, got %d."%(total_outputs, len(ostype)) + assert len(istype) == (total_inputs), \ + "InferStorageTypeBackward Error: expecting %d entries in returned output " \ + "stypes, got %d."%(total_inputs, len(istype)) + rtype = list(istype) + list(ostype) + list(astype) + for i, dtype in enumerate(rtype): + tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype] + infer_storage_type_backward_entry._ref_holder = [tensor_stypes] + except Exception: + print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc())) + return False + return True + + + def infer_storage_type_entry(num_tensor, tensor_stypes, _): + """C Callback for CustomOpProp::InferStorageType""" + try: + n_in = len(op_prop.list_arguments()) + n_out = len(op_prop.list_outputs()) + n_aux = len(op_prop.list_auxiliary_states()) + assert num_tensor == n_in + n_out + n_aux + + stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] for i in range(n_in)] + ret = op_prop.infer_storage_type(stypes) + if len(ret) == 2: + istype, ostype = ret + astype = [] + elif len(ret) == 3: + istype, ostype, astype = ret + else: + raise AssertionError("infer_storage_type must return 2 or 3 lists") + + assert len(ostype) == n_out, \ + "InferStorageType Error: expecting %d entries in returned output " \ + "stypes, got %d."%(n_out, len(ostype)) + assert len(istype) == n_in, \ + "InferStorageType Error: expecting %d entries in returned input " \ + "stypes, got %d."%(n_in, len(istype)) + assert len(astype) == n_aux, \ + "InferStorageType Error: expecting %d entries in returned aux state " \ + "stypes, got %d."%(n_aux, len(astype)) + rtype = list(istype) + list(ostype) + list(astype) + for i, dtype in enumerate(rtype): + tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype] + + infer_storage_type_entry._ref_holder = [tensor_stypes] + except Exception: + print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc())) + return False + return True + + def infer_type_entry(num_tensor, tensor_types, _): """C Callback for CustomOpProp::InferType""" try: @@ -673,13 +798,13 @@ def infer_type_entry(num_tensor, tensor_types, _): raise AssertionError("infer_type must return 2 or 3 lists") assert len(otype) == n_out, \ "InferType Error: expecting %d entries in returned output " \ - "shapes, got %d."%(n_out, len(otype)) + "types, got %d."%(n_out, len(otype)) assert len(itype) == n_in, \ "InferType Error: expecting %d entries in returned input " \ - "shapes, got %d."%(n_in, len(itype)) + "types, got %d."%(n_in, len(itype)) assert len(atype) == n_aux, \ "InferType Error: expecting %d entries in returned aux state " \ - "shapes, got %d."%(n_aux, len(atype)) + "types, got %d."%(n_aux, len(atype)) rtype = list(itype) + list(otype) + list(atype) for i, dtype in enumerate(rtype): tensor_types[i] = _DTYPE_NP_TO_MX[dtype] @@ -768,13 +893,13 @@ def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _): tensors = [[] for i in range(5)] for i in range(num_ndarray): if tags[i] == 1 or tags[i] == 4: - tensors[tags[i]].append(NDArray(cast(ndarraies[i], - NDArrayHandle), - writable=True)) + tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i], + NDArrayHandle), + writable=True)) else: - tensors[tags[i]].append(NDArray(cast(ndarraies[i], - NDArrayHandle), - writable=False)) + tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i], + NDArrayHandle), + writable=False)) reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))] with ctx: op.forward(is_train=is_train, req=reqs, @@ -792,13 +917,13 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _): tensors = [[] for i in range(5)] for i in range(num_ndarray): if tags[i] == 2 or tags[i] == 4: - tensors[tags[i]].append(NDArray(cast(ndarraies[i], - NDArrayHandle), - writable=True)) + tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i], + NDArrayHandle), + writable=True)) else: - tensors[tags[i]].append(NDArray(cast(ndarraies[i], - NDArrayHandle), - writable=False)) + tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i], + NDArrayHandle), + writable=False)) reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))] with ctx: op.backward(req=reqs, @@ -856,7 +981,9 @@ def delete_entry(_): infershape_functype(infer_shape_entry), deps_functype(declare_backward_dependency_entry), createop_functype(create_operator_entry), - infertype_functype(infer_type_entry)] + infertype_functype(infer_type_entry), + inferstorage_functype(infer_storage_type_entry), + inferstorage_backward_functype(infer_storage_type_backward_entry)] callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks] contexts = [None]*len(callbacks) ret[0] = MXCallbackList(c_int(len(callbacks)), diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 683423f96920..5e35e908447a 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -350,20 +350,100 @@ void Backward(const OpStatePtr& state, Imperative::Get()->set_is_recording(prev_recording); } +inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* iattr, + std::vector* oattr) { + const CustomParam& params = nnvm::get(attrs.parsed); + + if (params.info->num_callbacks <= kCustomOpPropBackwardInferStorageType) { + for (size_t i = 0; i < iattr->size(); i++) { + STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage); + } + for (size_t i = 0; i < oattr->size(); i++) { + STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage); + } + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + return true; + } + + std::vector stypes; + stypes.reserve(params.num_outs * 2 + params.num_args * 2 + params.num_auxs); + for (size_t i = 0; i < iattr->size(); ++i) { + stypes.push_back((*iattr)[i]); + } + for (size_t i = 0; i < oattr->size(); ++i) { + stypes.push_back((*oattr)[i]); + } + + CHECK(reinterpret_cast( + params.info->callbacks[kCustomOpPropBackwardInferStorageType])( + stypes.size(), stypes.data(), + params.info->contexts[kCustomOpPropBackwardInferStorageType])); + for (size_t i = 0; i < 2 * params.num_outs + params.num_args; ++i) { + STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]); + } + for (size_t i = 0; i < params.num_args; ++i) { + STORAGE_TYPE_ASSIGN_CHECK( + *oattr, i, stypes[i + 2 * params.num_outs + params.num_args]); + } + for (size_t i = 0; i < params.num_auxs; ++i) { + STORAGE_TYPE_ASSIGN_CHECK( + *iattr, i + 2 * params.num_outs + params.num_args, + stypes[i + 2 * params.num_outs + 2 * params.num_args]); + } + + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + return true; +} + // infer storage function for custom op, which assigns kDefaultStorage for // all undefined stypes, and dispatch on DispatchMode::kFComputeEx. -inline bool InferStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, +inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *iattr, - std::vector *oattr) { - for (int& v : *oattr) { - if (v == -1) v = kDefaultStorage; + std::vector* iattr, std::vector* oattr) { + const CustomParam& params = nnvm::get(attrs.parsed); + + if (params.info->num_callbacks <= kCustomOpPropInferStorageType) { + for (size_t i = 0; i < iattr->size(); i++) { + STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage); + } + for (size_t i = 0; i < oattr->size(); i++) { + STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage); + } + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + return true; } - for (int& v : *iattr) { - if (v == -1) v = kDefaultStorage; + + std::vector stypes; + stypes.reserve(params.num_args + params.num_outs + params.num_auxs); + for (size_t i = 0; i < params.num_args; ++i) { + stypes.push_back((*iattr)[i]); + } + for (const auto& i : *oattr) { + stypes.push_back(i); + } + for (size_t i = 0; i < params.num_auxs; ++i) { + stypes.push_back((*iattr)[params.num_args + i]); } - dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx); + + CHECK(reinterpret_cast( + params.info->callbacks[kCustomOpPropInferStorageType])( + stypes.size(), stypes.data(), + params.info->contexts[kCustomOpPropInferStorageType])); + for (size_t i = 0; i < params.num_args; ++i) { + STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]); + } + for (size_t i = 0; i < params.num_outs; ++i) { + STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, stypes[params.num_args + i]); + } + for (size_t i = 0; i < params.num_auxs; ++i) { + STORAGE_TYPE_ASSIGN_CHECK(*iattr, params.num_args + i, + stypes[params.num_args + params.num_outs + i]); + } + + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); return true; } @@ -429,7 +509,7 @@ NNVM_REGISTER_OP(_backward_Custom) }) .set_attr("FStatefulComputeEx", Backward) .set_attr("FStatefulComputeEx", Backward) -.set_attr("FInferStorageType", InferStorageType); +.set_attr("FInferStorageType", BackwardInferStorageType); } // namespace custom } // namespace op diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3484b18d0276..d322fa4c2af0 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3570,12 +3570,18 @@ def test_rcbrt_op(): def test_custom_op(): class Sqr(mx.operator.CustomOp): def forward(self, is_train, req, in_data, out_data, aux): - self.assign(out_data[0], req[0], in_data[0]*in_data[0]) - aux[0][:] = 1 + if in_data[0].stype == 'default': + aux[0][:] = 1 + self.assign(out_data[0], req[0], in_data[0]*in_data[0]) + else: + self.assign(out_data[0], req[0], mx.nd.sparse.square(in_data[0])) + if in_data[0].stype == 'csr': + assert(isinstance(in_data[0], mx.nd.sparse.CSRNDArray)) def backward(self, req, out_grad, in_data, out_data, in_grad, aux): self.assign(in_grad[0], req[0], 2*in_data[0]*out_grad[0]) - assert (aux[0].asnumpy() == 1).all() + if in_data[0].stype == 'default': + assert (aux[0].asnumpy() == 1).all() @mx.operator.register("sqr") class SqrProp(mx.operator.CustomOpProp): @@ -3597,6 +3603,16 @@ def infer_shape(self, in_shape): def infer_type(self, in_type): return in_type, [in_type[0]], [in_type[0]] + def infer_storage_type(self, in_stype): + if in_stype[0] == 'default': + return ['default'], ['default'], ['default'] + return ['csr'], ['csr'], ['csr'] + + def infer_storage_type_backward(self, in_stype): + if in_stype[1] == 'default': + return ['default', 'default', 'default'], ['default'], ['default'] + return ['default', 'csr', 'csr'], ['csr'], ['csr'] + def create_operator(self, ctx, shapes, dtypes): return Sqr() @@ -3609,15 +3625,18 @@ def create_operator(self, ctx, shapes, dtypes): data = mx.symbol.cast(data, dtype='float64') op = mx.symbol.cast(op, dtype='float32') - x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10))) - aux = mx.nd.zeros_like(x) check_numeric_gradient(op, [x], [aux]) + x = x.tostype('csr') + aux = mx.nd.zeros_like(x) x.attach_grad() with mx.contrib.autograd.train_section(): y = mx.nd.Custom(x, aux, op_type='sqr') y.backward() - + mx.nd.waitall() + assert (x.grad.stype == 'csr') + assert (y.stype == 'csr') + assert (aux.stype == 'csr') def test_psroipooling(): for num_rois in [1, 2]: