Skip to content

Commit

Permalink
Revert "Support sparse for custom python operators (apache#8620)" (ap…
Browse files Browse the repository at this point in the history
…ache#8733)

This reverts commit 938eda9.
  • Loading branch information
piiswrong authored and zhreshold committed Dec 14, 2017
1 parent ae97b45 commit 5d59dab
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 266 deletions.
8 changes: 1 addition & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ enum CustomOpPropCallbacks {
kCustomOpPropInferShape,
kCustomOpPropDeclareBackwardDependency,
kCustomOpPropCreateOperator,
kCustomOpPropInferType,
kCustomOpPropInferStorageType,
kCustomOpPropBackwardInferStorageType
kCustomOpPropInferType
};


Expand All @@ -161,10 +159,6 @@ typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/);
typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/,
unsigned** /*shapes*/, void* /*state*/);
typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/);
typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/);
typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/,
int * /*stypes*/,
void * /*state*/);
typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/,
const int* /*out_data*/, int* /*num_deps*/,
int** /*rdeps*/, void* /*state*/);
Expand Down
161 changes: 17 additions & 144 deletions python/mxnet/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except, too-many-lines
# pylint: disable=invalid-name, protected-access, too-many-arguments, no-self-use, too-many-locals, broad-except
"""numpy interface for operators."""
from __future__ import absolute_import

Expand All @@ -31,9 +31,6 @@
from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str
from . import symbol, context
from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _STORAGE_TYPE_ID_TO_STR
from .ndarray import _ndarray_cls


c_int_p = POINTER(c_int)

Expand Down Expand Up @@ -521,51 +518,6 @@ def infer_type(self, in_type):
return in_type, [in_type[0]]*len(self.list_outputs()), \
[in_type[0]]*len(self.list_auxiliary_states())

def infer_storage_type(self, in_stype):
"""infer_storage_type interface. Used to infer storage type of
inputs and outputs in the forward pass.
Parameters
----------
in_stype : list of stypes, Valid stypes are default, row_sparse and
csr
Returns
-------
in_stype : list
list of argument stypes.
out_stype : list
list of output types calculated from in_stype,
in the same order as declared in list_outputs.
aux_type : Optional, list
list of aux types calculated from in_stype,
in the same order as declared in list_auxiliary_states.
"""
return in_stype, [in_stype[0]]*len(self.list_outputs()), \
[in_stype[0]]*len(self.list_auxiliary_states())

def infer_storage_type_backward(self, in_stype):
"""infer_storage_type_backward interface. Used to infer storage
type of inputs and outputs in the backward pass.
Parameters
----------
in_stype : list of stypes. Provide the in_stypes in the
following order: output_grads, in_data, out_data, aux_data(optional)
Returns
-------
in_stype : list
list of input stypes.
out_stype : list
list of output stypes calculated from in_stype.
aux_stype : list
list of aux stypes calculated from in_stype,
in the same order as declared in list_auxiliary_states.
"""
return in_stype, [in_stype[0]]*len(self.list_outputs()), \
[in_stype[0]]*len(self.list_auxiliary_states())

def list_outputs(self):
"""list_outputs interface. Can override when creating new operators.
Expand Down Expand Up @@ -654,8 +606,6 @@ def do_register(prop_cls):
infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int),
POINTER(POINTER(mx_uint)), c_void_p)
infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
list_functype = CFUNCTYPE(c_int, POINTER(POINTER(POINTER(c_char))), c_void_p)
deps_functype = CFUNCTYPE(c_int, c_int_p, c_int_p, c_int_p,
c_int_p, POINTER(c_int_p), c_void_p)
Expand Down Expand Up @@ -711,81 +661,6 @@ def infer_shape_entry(num_tensor, tensor_dims,
return False
return True

def infer_storage_type_backward_entry(num_tensor, tensor_stypes, _):
"""C Callback for CustomOpProp::InferStorageTypeBackward"""
try:
n_in = len(op_prop.list_arguments())
n_out = len(op_prop.list_outputs())
n_aux = len(op_prop.list_auxiliary_states())
total_inputs = n_in + 2 * n_out
total_aux = n_aux
total_outputs = n_in
assert num_tensor == (2 * n_in + 2 * n_out + n_aux)

stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] \
for i in range(total_inputs + total_aux)]
ret = op_prop.infer_storage_type_backward(stypes)
if len(ret) == 2:
istype, ostype = ret
astype = []
elif len(ret) == 3:
istype, ostype, astype = ret
else:
raise AssertionError("infer_storage_type backward must return 2 or 3 lists")
assert len(ostype) == total_outputs, \
"InferStorageTypeBackward Error: expecting %d entries in returned output " \
"stypes, got %d."%(total_outputs, len(ostype))
assert len(istype) == (total_inputs), \
"InferStorageTypeBackward Error: expecting %d entries in returned output " \
"stypes, got %d."%(total_inputs, len(istype))
rtype = list(istype) + list(ostype) + list(astype)
for i, dtype in enumerate(rtype):
tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype]
infer_storage_type_backward_entry._ref_holder = [tensor_stypes]
except Exception:
print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
return False
return True


def infer_storage_type_entry(num_tensor, tensor_stypes, _):
"""C Callback for CustomOpProp::InferStorageType"""
try:
n_in = len(op_prop.list_arguments())
n_out = len(op_prop.list_outputs())
n_aux = len(op_prop.list_auxiliary_states())
assert num_tensor == n_in + n_out + n_aux

stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] for i in range(n_in)]
ret = op_prop.infer_storage_type(stypes)
if len(ret) == 2:
istype, ostype = ret
astype = []
elif len(ret) == 3:
istype, ostype, astype = ret
else:
raise AssertionError("infer_storage_type must return 2 or 3 lists")

assert len(ostype) == n_out, \
"InferStorageType Error: expecting %d entries in returned output " \
"stypes, got %d."%(n_out, len(ostype))
assert len(istype) == n_in, \
"InferStorageType Error: expecting %d entries in returned input " \
"stypes, got %d."%(n_in, len(istype))
assert len(astype) == n_aux, \
"InferStorageType Error: expecting %d entries in returned aux state " \
"stypes, got %d."%(n_aux, len(astype))
rtype = list(istype) + list(ostype) + list(astype)
for i, dtype in enumerate(rtype):
tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype]

infer_storage_type_entry._ref_holder = [tensor_stypes]
except Exception:
print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
return False
return True


def infer_type_entry(num_tensor, tensor_types, _):
"""C Callback for CustomOpProp::InferType"""
try:
Expand All @@ -805,13 +680,13 @@ def infer_type_entry(num_tensor, tensor_types, _):
raise AssertionError("infer_type must return 2 or 3 lists")
assert len(otype) == n_out, \
"InferType Error: expecting %d entries in returned output " \
"types, got %d."%(n_out, len(otype))
"shapes, got %d."%(n_out, len(otype))
assert len(itype) == n_in, \
"InferType Error: expecting %d entries in returned input " \
"types, got %d."%(n_in, len(itype))
"shapes, got %d."%(n_in, len(itype))
assert len(atype) == n_aux, \
"InferType Error: expecting %d entries in returned aux state " \
"types, got %d."%(n_aux, len(atype))
"shapes, got %d."%(n_aux, len(atype))
rtype = list(itype) + list(otype) + list(atype)
for i, dtype in enumerate(rtype):
tensor_types[i] = _DTYPE_NP_TO_MX[dtype]
Expand Down Expand Up @@ -900,13 +775,13 @@ def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
tensors = [[] for i in range(5)]
for i in range(num_ndarray):
if tags[i] == 1 or tags[i] == 4:
tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
NDArrayHandle),
writable=True))
tensors[tags[i]].append(NDArray(cast(ndarraies[i],
NDArrayHandle),
writable=True))
else:
tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
NDArrayHandle),
writable=False))
tensors[tags[i]].append(NDArray(cast(ndarraies[i],
NDArrayHandle),
writable=False))
reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))]
with ctx:
op.forward(is_train=is_train, req=reqs,
Expand All @@ -924,13 +799,13 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
tensors = [[] for i in range(5)]
for i in range(num_ndarray):
if tags[i] == 2 or tags[i] == 4:
tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
NDArrayHandle),
writable=True))
tensors[tags[i]].append(NDArray(cast(ndarraies[i],
NDArrayHandle),
writable=True))
else:
tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
NDArrayHandle),
writable=False))
tensors[tags[i]].append(NDArray(cast(ndarraies[i],
NDArrayHandle),
writable=False))
reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))]
with ctx:
op.backward(req=reqs,
Expand Down Expand Up @@ -988,9 +863,7 @@ def delete_entry(_):
infershape_functype(infer_shape_entry),
deps_functype(declare_backward_dependency_entry),
createop_functype(create_operator_entry),
infertype_functype(infer_type_entry),
inferstorage_functype(infer_storage_type_entry),
inferstorage_backward_functype(infer_storage_type_backward_entry)]
infertype_functype(infer_type_entry)]
callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks]
contexts = [None]*len(callbacks)
ret[0] = MXCallbackList(c_int(len(callbacks)),
Expand Down
100 changes: 10 additions & 90 deletions src/operator/custom/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,100 +351,20 @@ void Backward(const OpStatePtr& state,
Imperative::Get()->set_is_recording(prev_recording);
}

inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* iattr,
std::vector<int>* oattr) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);

if (params.info->num_callbacks <= kCustomOpPropBackwardInferStorageType) {
for (size_t i = 0; i < iattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
}
for (size_t i = 0; i < oattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}

std::vector<int> stypes;
stypes.reserve(params.num_outs * 2 + params.num_args * 2 + params.num_auxs);
for (size_t i = 0; i < iattr->size(); ++i) {
stypes.push_back((*iattr)[i]);
}
for (size_t i = 0; i < oattr->size(); ++i) {
stypes.push_back((*oattr)[i]);
}

CHECK(reinterpret_cast<CustomOpBackwardInferStorageTypeFunc>(
params.info->callbacks[kCustomOpPropBackwardInferStorageType])(
stypes.size(), stypes.data(),
params.info->contexts[kCustomOpPropBackwardInferStorageType]));
for (size_t i = 0; i < 2 * params.num_outs + params.num_args; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
}
for (size_t i = 0; i < params.num_args; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(
*oattr, i, stypes[i + 2 * params.num_outs + params.num_args]);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(
*iattr, i + 2 * params.num_outs + params.num_args,
stypes[i + 2 * params.num_outs + 2 * params.num_args]);
}

DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}

// infer storage function for custom op, which assigns kDefaultStorage for
// all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* iattr, std::vector<int>* oattr) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);

if (params.info->num_callbacks <= kCustomOpPropInferStorageType) {
for (size_t i = 0; i < iattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
}
for (size_t i = 0; i < oattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
std::vector<int> *iattr,
std::vector<int> *oattr) {
for (int& v : *oattr) {
if (v == -1) v = kDefaultStorage;
}

std::vector<int> stypes;
stypes.reserve(params.num_args + params.num_outs + params.num_auxs);
for (size_t i = 0; i < params.num_args; ++i) {
stypes.push_back((*iattr)[i]);
}
for (const auto& i : *oattr) {
stypes.push_back(i);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
stypes.push_back((*iattr)[params.num_args + i]);
for (int& v : *iattr) {
if (v == -1) v = kDefaultStorage;
}

CHECK(reinterpret_cast<CustomOpInferStorageTypeFunc>(
params.info->callbacks[kCustomOpPropInferStorageType])(
stypes.size(), stypes.data(),
params.info->contexts[kCustomOpPropInferStorageType]));
for (size_t i = 0; i < params.num_args; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
}
for (size_t i = 0; i < params.num_outs; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, stypes[params.num_args + i]);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, params.num_args + i,
stypes[params.num_args + params.num_outs + i]);
}

DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
return true;
}

Expand Down Expand Up @@ -510,7 +430,7 @@ NNVM_REGISTER_OP(_backward_Custom)
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
.set_attr<FInferStorageType>("FInferStorageType", BackwardInferStorageType);
.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);

} // namespace custom
} // namespace op
Expand Down
Loading

0 comments on commit 5d59dab

Please sign in to comment.