diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 90270f776456..a0e842c21765 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -585,24 +585,20 @@ MXNET_DLL int MXAutogradBackward(mx_uint num_output, /*! * \brief create cached operator */ -MXNET_DLL int MXCachedCreateOp(AtomicSymbolCreator creator, - int num_inputs, - int num_params, - const char **param_keys, - const char **param_vals, +MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out); /*! * \brief free cached operator */ -MXNET_DLL int MXCachedFree(CachedOpHandle handle); +MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle); /*! * \brief invoke cached operator */ -MXNET_DLL int MXCachedInvoke(CachedOpHandle handle, - int num_inputs, - NDArrayHandle *inputs, - int *num_outputs, - NDArrayHandle **outputs); +MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs); //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- @@ -670,19 +666,6 @@ MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, const char **keys, const char **vals, SymbolHandle *out); -/*! - * \brief Create an AtomicSymbol from cached op. - * \param handle cached node attribute. - * \param name name of new symbol. - * \param num_args the number of symbol arguments - * \param args symbol arguments - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXCachedCreateSymbol(CachedOpHandle handle, - const char* name, - mx_uint num_args, - SymbolHandle* args, - SymbolHandle* out); /*! * \brief Create a Variable Symbol. * \param name name of the variable diff --git a/perl-package/AI-MXNet/lib/AI/MXNet.pm b/perl-package/AI-MXNet/lib/AI/MXNet.pm index 530b6eca23a4..41bb1a18b493 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet.pm @@ -3,7 +3,6 @@ use v5.14.0; use strict; use warnings; use AI::MXNet::Base; -use AI::MXNet::CachedOp; use AI::MXNet::Callback; use AI::MXNet::NDArray; use AI::MXNet::Symbol; diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm b/perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm deleted file mode 100644 index bec3f5029c33..000000000000 --- a/perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm +++ /dev/null @@ -1,41 +0,0 @@ -package AI::MXNet::CachedOp; - -=head1 NAME - - AI::MXNet::CachedOp - A wrapper around CachedOpHandle -=cut - -use strict; -use warnings; -use AI::MXNet::Base; -use Mouse; - -has 'op' => (is => 'ro', isa => 'Str', required => 1); -has 'handle' => (is => 'ro', isa => 'CachedOpHandle', required => 1); -around BUILDARGS => sub { - my $orig = shift; - my $class = shift; - my ($op, $num_input, %kwargs) = @_; - for my $key (keys %kwargs) - { - $kwargs{ $key } = "(" .join(", ", @{ $kwargs{ $key } }) .")" - if ref $kwargs{ $key } eq 'ARRAY'; - } - my $AtomicSymbolCreator = check_call(AI::NNVMCAPI::GetOpHandle($op)); - my $handle = check_call( - AI::MXNetCAPI::CachedCreateOp( - $AtomicSymbolCreator, - $num_input, - scalar(keys %kwargs), - \%kwargs - ) - ); - return $class->$orig(op => $op, handle => $handle); -}; - -sub DEMOLISH -{ - check_call(AI::MXNetCAPI::CachedFree(shift->handle)); -} - -1; \ No newline at end of file diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm index 68c4e7061ec3..53579b2f1caf 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm @@ -1372,7 +1372,6 @@ method backward(Maybe[AI::MXNet::NDArray] $out_grad=, Bool $retain_graph=0) ) } -method CachedOp(@args) { AI::MXNet::CachedOp->new(@args) } my $lvalue_methods = join "\n", map {"use attributes 'AI::MXNet::NDArray', \\&AI::MXNet::NDArray::$_, 'lvalue';"} qw/at slice aspdl asmpdl reshape copy sever T astype as_in_context copyto empty zero ones full diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Base.pm b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Base.pm index 0c48336c2aae..7fb6d0e61110 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Base.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Base.pm @@ -140,44 +140,6 @@ method _init_ndarray_module() } } -method invoke( - AI::MXNet::CachedOp $cached_op, - ArrayRef[AI::MXNet::NDArray] $args, - Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out=, - Maybe[Str] $name= -) -{ - my $original_output; - if(defined $out) - { - $original_output = $out; - if(not ref($out) eq 'ARRAY') - { - $out = [$out]; - } - } - else - { - $out = []; - } - my $output = check_call( - AI::MXNetCAPI::CachedInvoke( - $cached_op->handle, - scalar(@$args), - [map { $_->handle } @$args], - [map { $_->handle } @$out] - ) - ); - return $original_output if defined $original_output; - if(@$output == 1) - { - return $self->new(handle => $output->[0]); - } - else - { - return [map { $self->new(handle => $_) } @$output]; - } -} __PACKAGE__->_init_ndarray_module; diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm index eec32640953c..e22e4189721a 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm @@ -1347,7 +1347,6 @@ method arange(Index :$start=0, Index :$stop=, Num :$step=1.0, Index :$repeat=1, }); } -method CachedOp(@args) { AI::MXNet::CachedOp->new(@args) } sub _parse_arguments { diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol/Base.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol/Base.pm index 3eaee237bed0..69ff952eca1a 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Symbol/Base.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Symbol/Base.pm @@ -167,20 +167,6 @@ method _init_symbol_module() } } -method invoke(AI::MXNet::CachedOp $cached_op, ArrayRef[AI::MXNet::Symbol] $args, Maybe[Str] $name=) -{ - my $hint = lc($cached_op->op); - $name = AI::MXNet::Symbol::NameManager->current->get($name, $hint); - my $handle = check_call( - AI::MXNetCAPI::CachedCreateSymbol( - $cached_op->handle, - $name, - scalar(@$args), - [map { $_->handle } @$args] - ) - ); - return $self->new(handle => $handle); -} __PACKAGE__->_init_symbol_module; diff --git a/perl-package/AI-MXNet/t/test_ndarray.t b/perl-package/AI-MXNet/t/test_ndarray.t index 53e5749d00f6..f4ecebcc8d18 100644 --- a/perl-package/AI-MXNet/t/test_ndarray.t +++ b/perl-package/AI-MXNet/t/test_ndarray.t @@ -36,17 +36,6 @@ sub test_moveaxis is_deeply($X->moveaxis(2, 0)->shape, [3, 2, 2]); } -sub test_cached -{ - my $op = mx->nd->CachedOp('Convolution', 3, kernel=>[3, 3], num_filter=>10); - my $data = mx->nd->ones([3, 4, 10, 10]); - my $weight = mx->nd->ones([10, 4, 3, 3]); - my $bias = mx->nd->ones([10]); - my $o1 = mx->nd->invoke($op, [$data, $weight, $bias]); - $bias .= 2; - my $o2 = mx->nd->invoke($op, [$data, $weight, $bias]); - ok(almost_equal($o2->aspdl, $o1->aspdl + 1)); -} sub test_output { @@ -64,5 +53,4 @@ sub test_output test_ndarray_reshape(); test_moveaxis(); -test_cached(); -test_output(); \ No newline at end of file +test_output(); diff --git a/perl-package/AI-MXNet/t/test_symbol.t b/perl-package/AI-MXNet/t/test_symbol.t index bf9e90598929..7b42e1b6cd5e 100644 --- a/perl-package/AI-MXNet/t/test_symbol.t +++ b/perl-package/AI-MXNet/t/test_symbol.t @@ -221,24 +221,6 @@ sub test_load_000800 test_load_000800(); -sub test_cached -{ - my $op = mx->sym->CachedOp('Convolution', 3, kernel=>[3, 3], num_filter=>10); - my $data = mx->sym->var('data'); - my $weight = mx->sym->var('weight'); - my $bias = mx->sym->var('bias'); - my $out = mx->sym->invoke($op, [$data, $weight, $bias], 'conv'); - is_deeply($out->list_arguments, ['data', 'weight', 'bias']); - is_deeply($out->list_outputs, ['conv_output']); - { - local($mx::NameManager) = mx->name->Prefix('test_'); - is(mx->sym->invoke($op, [$data, $weight, $bias])->name,'test_convolution0'); - is(mx->sym->invoke($op, [$data, $weight, $bias])->name, 'test_convolution1'); - } -} - -test_cached(); - __DATA__ { "nodes": [ @@ -427,4 +409,4 @@ __DATA__ ], "arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12, 13, 15], "heads": [[16, 0]] -} \ No newline at end of file +} diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index 295832eb24dc..d0705d5acc72 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -119,7 +119,6 @@ static void ExecutorMonitor_callback(const char* name, NDArrayHandle handle, voi SWIG_TypeClientData(SWIGTYPE_p_MXKVStore, (void *)"KVStoreHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXRecordIO, (void *)"RecordIOHandle"); SWIG_TypeClientData(SWIGTYPE_p_MXRtc, (void *)"RtcHandle"); - SWIG_TypeClientData(SWIGTYPE_p_MXCachedOp, (void *)"CachedOpHandle"); %} /*! \brief manually define unsigned int */ @@ -151,8 +150,6 @@ typedef MXKVStore *KVStoreHandle; typedef MXRecordIO *RecordIOHandle; /*! \brief handle to MXRtc*/ typedef MXRtc *RtcHandle; -/*! \brief handle to cached operator */ -typedef MXCachedOp *CachedOpHandle; typedef void (*ExecutorMonitorCallback)(const char*, NDArrayHandle, @@ -628,30 +625,6 @@ int MXAutogradBackward(mx_uint num_output, NDArrayHandle* in, int retain_graph); -/*! - * \brief create cached operator - */ -int MXCachedCreateOp(AtomicSymbolCreator in, - int num_inputs, - int num_params, - const char **keys, - const char **vals, - CachedOpHandle *out); - -/*! - * \brief free cached operator - */ -int MXCachedFree(CachedOpHandle handle); - -/*! - * \brief invoke cached operator - */ -int MXCachedInvoke(CachedOpHandle handle, - int num_inputs, - NDArrayHandle *in, - int *out_size, - NDArrayHandle** out_array); - //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- @@ -719,20 +692,6 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator in, const char **keys, const char **vals, SymbolHandle *out); -/*! - * \brief Create an AtomicSymbol from cached op. - * \param handle cached node attribute. - * \param name name of new symbol. - * \param num_args the number of symbol arguments - * \param args symbol arguments - * \return 0 when success, -1 when failure happens - */ -int MXCachedCreateSymbol(CachedOpHandle handle, - const char* name, - mx_uint num_args, - SymbolHandle* in, - SymbolHandle* out); - /*! * \brief Create a Variable Symbol. * \param name name of the variable diff --git a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i index 8574647512f5..792f8472d05a 100644 --- a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i +++ b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i @@ -311,13 +311,12 @@ (DataIterHandle *out) (ExecutorHandle temp), (KVStoreHandle *out) (KVStoreHandle temp), (RecordIOHandle *out) (RecordIOHandle temp), - (RtcHandle *out) (RtcHandle temp), - (CachedOpHandle *out) (CachedOpHandle temp) + (RtcHandle *out) (RtcHandle temp) { $1 = &temp; } %typemap(argout) (NDArrayHandle *out), (FunctionHandle* out), (SymbolHandle *out), (ExecutorHandle *out), (DataIterHandle *out), - (KVStoreHandle *out), (RecordIOHandle *out), (RtcHandle *out) (RtcHandle temp), (CachedOpHandle *out) (CachedOpHandle temp) + (KVStoreHandle *out), (RecordIOHandle *out), (RtcHandle *out) (RtcHandle temp) { if(!result) { diff --git a/python/mxnet/_ctypes/common.py b/python/mxnet/_ctypes/common.py deleted file mode 100644 index 24e2048eee4c..000000000000 --- a/python/mxnet/_ctypes/common.py +++ /dev/null @@ -1,30 +0,0 @@ -# coding: utf-8 -"""Common code between symbolic and ndarray.""" -from __future__ import absolute_import as _abs - -import ctypes - -from ..base import _LIB -from ..base import c_array, c_str -from ..base import OpHandle, CachedOpHandle -from ..base import check_call - - -class CachedOp(object): - """Cached operator handle.""" - __slots__ = ["handle", "op"] - def __init__(self, op, num_input, **kwargs): - self.op = op - op_handle = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(op), ctypes.byref(op_handle))) - self.handle = CachedOpHandle() - check_call(_LIB.MXCachedCreateOp( - op_handle, - ctypes.c_int(num_input), - ctypes.c_int(len(kwargs)), - c_array(ctypes.c_char_p, [c_str(key) for key in kwargs]), - c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]), - ctypes.byref(self.handle))) - - def __del__(self): - check_call(_LIB.MXCachedFree(self.handle)) diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index a678e1726f02..396c57a41dfb 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -10,10 +10,9 @@ from ..base import _LIB from ..base import c_array, py_str, c_str, mx_uint, _Null -from ..base import NDArrayHandle, OpHandle +from ..base import NDArrayHandle, OpHandle, CachedOpHandle from ..base import check_call from ..ndarray_doc import _build_doc -from .common import CachedOp class NDArrayBase(object): @@ -81,31 +80,48 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): for i in range(num_output.value)] -def invoke(cached_op, args, out=None, name=None): # pylint: disable=unused-argument - """ctypes implementation of imperative invoke wrapper""" - if out is not None: - original_output = out - if isinstance(out, NDArrayBase): - out = (out,) - num_output = ctypes.c_int(len(out)) - output_vars = c_array(NDArrayHandle, [i.handle for i in out]) - output_vars = ctypes.cast(output_vars, ctypes.POINTER(NDArrayHandle)) - else: - original_output = None - output_vars = ctypes.POINTER(NDArrayHandle)() - num_output = ctypes.c_int(0) - - check_call(_LIB.MXCachedInvoke( - cached_op.handle, - ctypes.c_int(len(args)), - c_array(NDArrayHandle, [arr.handle for arr in args]), - ctypes.byref(num_output), - ctypes.byref(output_vars))) +class CachedOp(object): + """Cached operator handle.""" + __slots__ = ["handle"] + def __init__(self, sym): + self.handle = CachedOpHandle() + check_call(_LIB.MXCreateCachedOp( + sym.handle, + ctypes.byref(self.handle))) - if original_output is not None: - return original_output - if num_output.value == 1: - return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) - else: - return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) - for i in range(num_output.value)] + def __del__(self): + check_call(_LIB.MXFreeCachedOp(self.handle)) + + def __call__(self, *args, **kwargs): + """ctypes implementation of imperative invoke wrapper""" + out = kwargs.pop('out', None) + if out is not None: + original_output = out + if isinstance(out, NDArrayBase): + out = (out,) + num_output = ctypes.c_int(len(out)) + output_vars = c_array(NDArrayHandle, [i.handle for i in out]) + output_vars = ctypes.cast(output_vars, ctypes.POINTER(NDArrayHandle)) + else: + original_output = None + output_vars = ctypes.POINTER(NDArrayHandle)() + num_output = ctypes.c_int(0) + if kwargs: + raise TypeError( + "CachedOp.__call__ got unexpected keyword argument(s): " + \ + ', '.join(kwargs.keys())) + + check_call(_LIB.MXInvokeCachedOp( + self.handle, + ctypes.c_int(len(args)), + c_array(NDArrayHandle, [arr.handle for arr in args]), + ctypes.byref(num_output), + ctypes.byref(output_vars))) + + if original_output is not None: + return original_output + if num_output.value == 1: + return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) + else: + return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) + for i in range(num_output.value)] diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py index 9026b20cd7db..5cbff551cf55 100644 --- a/python/mxnet/_ctypes/symbol.py +++ b/python/mxnet/_ctypes/symbol.py @@ -8,8 +8,6 @@ from ..base import c_array, c_str, mx_uint from ..base import SymbolHandle from ..base import check_call -from ..name import NameManager -from .common import CachedOp # pylint: disable=unused-import _symbol_cls = None @@ -102,20 +100,6 @@ def _set_symbol_class(cls): _symbol_cls = cls -def invoke(cached_op, args, name=None): - """Call cached symbolic operator""" - ret = SymbolHandle() - hint = cached_op.op.lower() - name = c_str(NameManager.current.get(name, hint)) - check_call(_LIB.MXCachedCreateSymbol( - cached_op.handle, - name, - mx_uint(len(args)), - c_array(SymbolHandle, [s.handle for s in args]), - ctypes.byref(ret))) - return _symbol_cls(ret) - - def _symbol_creator(handle, args, kwargs, keys, vals, name): sym_handle = SymbolHandle() check_call(_LIB.MXSymbolCreateAtomicSymbol( diff --git a/python/mxnet/cython/base.pyi b/python/mxnet/cython/base.pyi index 651258135ef3..d73e1a7d0194 100644 --- a/python/mxnet/cython/base.pyi +++ b/python/mxnet/cython/base.pyi @@ -99,75 +99,11 @@ cdef extern from "mxnet/c_api.h": const char **param_keys, const char **param_vals); int MXNDArrayFree(NDArrayHandle handle); - int MXCachedCreateOp(OpHandle creator, - int num_inputs, - int num_params, - const char **param_keys, - const char **param_vals, + int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out); - int MXCachedFree(CachedOpHandle handle); - int MXCachedInvoke(CachedOpHandle handle, + int MXFreeCachedOp(CachedOpHandle handle); + int MXInvokeCachedOp(CachedOpHandle handle, int num_inputs, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs); - int MXCachedCreateSymbol(CachedOpHandle handle, - const char* name, - unsigned num_args, - SymbolHandle* args, - SymbolHandle* out); - - -cdef class CachedOp: - """Cached operator handle.""" - cdef CachedOpHandle chandle - cdef string cop - - cdef _set_handle(self, handle): - cdef unsigned long long ptr - if handle is None: - self.chandle = NULL - else: - ptr = handle.value - self.chandle = (ptr) - - property handle: - def __get__(self): - if self.chandle == NULL: - return None - else: - return _ctypes.cast(self.chandle, _ctypes.c_void_p) - def __set__(self, value): - self._set_handle(value) - - property op: - def __get__(self): - return py_str(self.cop.c_str()) - def __set__(self, value): - self.cop = c_str(value) - - def __init__(self, op, num_input, **kwargs): - cdef OpHandle op_handle - cdef vector[string] ckeys - cdef vector[string] cvals - - self.op = op - CALL(NNGetOpHandle(self.cop.c_str(), &op_handle)) - - for k, v in kwargs.items(): - ckeys.push_back(c_str(k)) - cvals.push_back(c_str(str(v))) - - cdef vector[const char*] param_keys = SVec2Ptr(ckeys) - cdef vector[const char*] param_vals = SVec2Ptr(cvals) - - CALL(MXCachedCreateOp( - op_handle, - num_input, - len(kwargs), - CBeginPtr(param_keys), - CBeginPtr(param_vals), - &self.chandle)) - - def __del__(self): - CALL(MXCachedFree(self.chandle)) diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx index 24e37b54c7be..a861ae661b45 100644 --- a/python/mxnet/cython/ndarray.pyx +++ b/python/mxnet/cython/ndarray.pyx @@ -61,46 +61,76 @@ cdef NewArray(NDArrayHandle handle): return nd -def invoke(cached_op, args, out=None, name=None): - """ctypes implementation of imperative invoke wrapper""" - cdef vector[NDArrayHandle] ndvars - cdef vector[NDArrayHandle] output_vars - cdef NDArrayHandle* p_output_vars - cdef NDArrayHandle ret_handle - cdef int num_output - - for i in args: - ndvars.push_back((i).chandle) +cdef class CachedOp: + """Cached operator handle.""" + cdef CachedOpHandle chandle - original_output = None - if out is not None: - original_output = out - if isinstance(out, NDArrayBase): - output_vars.push_back((out).chandle) + cdef _set_handle(self, handle): + cdef unsigned long long ptr + if handle is None: + self.chandle = NULL else: - for i in out: - output_vars.push_back((i).chandle) + ptr = handle.value + self.chandle = (ptr) - num_output = output_vars.size() - if output_vars.size() == 0: - output_vars.resize(1) - p_output_vars = NULL - else: - p_output_vars = &output_vars[0] + property handle: + def __get__(self): + if self.chandle == NULL: + return None + else: + return _ctypes.cast(self.chandle, _ctypes.c_void_p) + def __set__(self, value): + self._set_handle(value) - CALL(MXCachedInvoke( - (cached_op).chandle, - len(args), - &ndvars[0] if ndvars.size() != 0 else NULL, - &num_output, - &p_output_vars)) + def __init__(self, sym): + cdef unsigned long long ptr = sym.handle.value + CALL(MXCreateCachedOp( + (ptr), + &self.chandle)) + + def __del__(self): + CALL(MXFreeCachedOp(self.chandle)) + + def __call__(self, *args, out=None): + """ctypes implementation of imperative invoke wrapper""" + cdef vector[NDArrayHandle] ndvars + cdef vector[NDArrayHandle] output_vars + cdef NDArrayHandle* p_output_vars + cdef NDArrayHandle ret_handle + cdef int num_output + + for i in args: + ndvars.push_back((i).chandle) + + original_output = None + if out is not None: + original_output = out + if isinstance(out, NDArrayBase): + output_vars.push_back((out).chandle) + else: + for i in out: + output_vars.push_back((i).chandle) - if original_output is not None: - return original_output - if num_output == 1: - return NewArray(p_output_vars[0]) - else: - return tuple(NewArray(p_output_vars[i]) for i in range(num_output)) + num_output = output_vars.size() + if output_vars.size() == 0: + output_vars.resize(1) + p_output_vars = NULL + else: + p_output_vars = &output_vars[0] + + CALL(MXInvokeCachedOp( + (self).chandle, + len(args), + &ndvars[0] if ndvars.size() != 0 else NULL, + &num_output, + &p_output_vars)) + + if original_output is not None: + return original_output + if num_output == 1: + return NewArray(p_output_vars[0]) + else: + return tuple(NewArray(p_output_vars[i]) for i in range(num_output)) def _imperative_invoke(handle, ndargs, keys, vals, out): diff --git a/python/mxnet/cython/symbol.pyx b/python/mxnet/cython/symbol.pyx index e8787fba77a3..aea0aa9f4809 100644 --- a/python/mxnet/cython/symbol.pyx +++ b/python/mxnet/cython/symbol.pyx @@ -79,22 +79,6 @@ cdef NewSymbol(SymbolHandle handle): return sym -def invoke(cached_op, args, name=None): - cdef SymbolHandle ret - cdef vector[SymbolHandle] sym_args - hint = cached_op.op.lower() - cdef string cname = c_str(NameManager.current.get(name, hint)) - for i in args: - sym_args.push_back((i).chandle) - CALL(MXCachedCreateSymbol( - (cached_op).chandle, - cname.c_str(), - len(args), - &sym_args[0] if sym_args.size() != 0 else NULL, - &ret)) - return NewSymbol(ret) - - def _symbol_creator(handle, args, kwargs, keys, vals, name): cdef unsigned long long ihandle = handle cdef OpHandle chandle = ihandle diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 8900843f5937..55f4b17b86d3 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -32,18 +32,18 @@ try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class - from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke + from ._ctypes.ndarray import CachedOp, _imperative_invoke elif _sys.version_info >= (3, 0): from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke - from ._cy3.ndarray import invoke, CachedOp, _imperative_invoke + from ._cy3.ndarray import CachedOp, _imperative_invoke else: from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke - from ._cy2.ndarray import invoke, CachedOp, _imperative_invoke + from ._cy2.ndarray import CachedOp, _imperative_invoke except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke - from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke + from ._ctypes.ndarray import CachedOp, _imperative_invoke # pylint: enable=unused-import # pylint: disable= no-member diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 14203e59862d..fa43c34afe19 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -29,18 +29,18 @@ try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: from ._ctypes.symbol import SymbolBase, _set_symbol_class - from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import + from ._ctypes.symbol import _symbol_creator # pylint: disable=unused-import elif _sys.version_info >= (3, 0): from ._cy3.symbol import SymbolBase, _set_symbol_class - from ._cy3.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import + from ._cy3.symbol import _symbol_creator # pylint: disable=unused-import else: from ._cy2.symbol import SymbolBase, _set_symbol_class - from ._cy2.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import + from ._cy2.symbol import _symbol_creator # pylint: disable=unused-import except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") from ._ctypes.symbol import SymbolBase, _set_symbol_class - from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import + from ._ctypes.symbol import _symbol_creator # pylint: disable=unused-import _GRAD_REQ_MAP = {'null': 0, 'write': 1, 'add': 3} @@ -705,7 +705,7 @@ def list_auxiliary_states(self): Returns ------- - aux_states : list of string + aux_states : list of str List of the auxiliary states in input symbol. Notes @@ -721,6 +721,30 @@ def list_auxiliary_states(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] + def list_inputs(self): + """Lists all arguments and auxiliary states of this Symbol. + + Returns + ------- + inputs : list of str + List of all inputs. + + Examples + -------- + >>> bn = mx.sym.BatchNorm(name='bn') + >>> bn.list_arguments() + ['bn_data', 'bn_gamma', 'bn_beta'] + >>> bn.list_auxiliary_states() + ['bn_moving_mean', 'bn_moving_var'] + >>> bn.list_inputs() + ['bn_data', 'bn_gamma', 'bn_beta', 'bn_moving_mean', 'bn_moving_var'] + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.NNSymbolListInputNames( + self.handle, 0, ctypes.byref(size), ctypes.byref(sarr))) + return [py_str(sarr[i]) for i in range(size.value)] + def infer_type(self, *args, **kwargs): """Infers the type of all arguments and all outputs, given the known types for some arguments. diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 0be1d3574dd9..dfdd46b6aa90 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -103,31 +103,43 @@ void SetNDInputsOutputs(const nnvm::Op* op, void SetContext(Context* p_ctx, const nnvm::NodeAttrs& attrs, - const int& num_inputs, const std::vector& ndinputs, - const int& infered_num_outputs, - const std::vector& ndoutputs) { + const std::vector& ndoutputs, + const Context& default_ctx) { Context& ctx = *p_ctx; - if (num_inputs) { + if (ndinputs.size()) { ctx = ndinputs[0].ctx(); - } else if (infered_num_outputs && !ndoutputs[0].is_none()) { + for (size_t i = 1; i < ndinputs.size(); ++i) { + CHECK_EQ(ndinputs[i].ctx().dev_mask(), ctx.dev_mask()) + << "All inputs must live on the same context. " + << "But the first argument is on " + << (ctx.dev_mask() == gpu::kDevMask ? "GPU" : "CPU") + << " while the " << i+1 << "-th argument is on " + << (ndinputs[i].ctx().dev_mask() == gpu::kDevMask ? "GPU" : "CPU"); + } + } else if (ndoutputs.size() && !ndoutputs[0].is_none()) { ctx = ndoutputs[0].ctx(); } else if (attrs.dict.find("ctx") != attrs.dict.end()) { ctx = Context::FromString(attrs.dict.at("ctx")); } else { - ctx = Context::CPU(); + ctx = default_ctx; } // Pinned context doesn't propagate if (ctx.dev_type == Context::kCPUPinned) { ctx = Context::CPU(); } +#if !MXNET_USE_CUDA + if (ctx.dev_mask() == gpu::kDevMask) { + LOG(INFO) << "GPU support is disabled. Compile MXNet with " + << "USE_CUDA=1 to enable GPU support."; + } +#endif // MXNET_USE_CUDA } void SetShapeType(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const Context& ctx, const std::vector& ndinputs, - const int& infered_num_outputs, std::vector* p_ndoutputs) { std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); @@ -148,7 +160,7 @@ void SetShapeType(const nnvm::Op* op, CHECK(infershape.count(op)) << "Operator " << op->name << " is missing FInferShape attribute"; CHECK(infershape[op](attrs, &in_shapes, &out_shapes)); - CHECK_EQ(out_shapes.size(), static_cast(infered_num_outputs)); + CHECK_EQ(out_shapes.size(), ndoutputs.size()); // infer type std::vector& in_types = ret->arg_types; @@ -165,9 +177,9 @@ void SetShapeType(const nnvm::Op* op, CHECK(infertype.count(op)) << "Operator " << op->name << " is missing FInferType attribute"; CHECK(infertype[op](attrs, &in_types, &out_types)); - CHECK_EQ(out_types.size(), static_cast(infered_num_outputs)); + CHECK_EQ(out_types.size(), ndoutputs.size()); - for (int i = 0; i < infered_num_outputs; ++i) { + for (size_t i = 0; i < ndoutputs.size(); ++i) { if (ndoutputs[i].is_none()) { ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); } else { @@ -322,35 +334,28 @@ void PushOperator(std::shared_ptr opr, 0, PROFILER_MESSAGE(op->name.c_str())); } -void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, - int num_inputs, - NDArrayHandle *inputs, - int *num_outputs, - NDArrayHandle **outputs) { +void ImperativeInvokeImpl(const Context& default_ctx, + const nnvm::NodeAttrs& attrs, + std::vector* p_ndinputs, + std::vector* p_ndoutputs) { static auto& fcpu = nnvm::Op::GetAttr("FCompute"); static auto& fgpu = nnvm::Op::GetAttr("FCompute"); static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); static auto& createop = nnvm::Op::GetAttr("FCreateLayerOp"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - NDArray** outarray = *reinterpret_cast(outputs); - const nnvm::Op *op = attrs.op; - int infered_num_outputs; - int num_visible_outputs; - SetNumOutputs(op, attrs, num_inputs, - &infered_num_outputs, &num_visible_outputs); + const nnvm::Op *op = attrs.op; + std::vector& ndinputs = *p_ndinputs; + std::vector& ndoutputs = *p_ndoutputs; - std::vector ndinputs, ndoutputs; - SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs, - num_outputs, infered_num_outputs, num_visible_outputs, outarray); if (ndfunc.count(op)) { ndfunc[op](attrs, ndinputs, &ndoutputs); } else { // TODO(piiswrong): infer ctx Context ctx; - SetContext(&ctx, attrs, num_inputs, ndinputs, infered_num_outputs, ndoutputs); - SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs); + SetContext(&ctx, attrs, ndinputs, ndoutputs, default_ctx); + SetShapeType(op, attrs, ctx, ndinputs, &ndoutputs); std::vector read_vars, write_vars; std::vector requested; @@ -383,22 +388,8 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, requested, auxidx, ndinputs, ndoutputs); } else { LOG(FATAL) - << "Operator " << op->name - << " cannot be run; requires at least one of" - << " FCompute, NDArrayFunction, FCreateOperator be registered"; - } - } - - if (outarray == nullptr) { - ret->ret_handles.clear(); - for (int i = 0; i < num_visible_outputs; ++i) { - ret->ret_handles.push_back( - reinterpret_cast(new NDArray(std::move(ndoutputs[i])))); - } - *outputs = dmlc::BeginPtr(ret->ret_handles); - } else { - for (int i = 0; i < *num_outputs; ++i) { - *outarray[i] = std::move(ndoutputs[i]); + << "Operator " << op->name << " is not implemented for " + << (ctx.dev_mask() == gpu::kDevMask ? "GPU." : "CPU."); } } } @@ -412,46 +403,114 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, const char **param_keys, const char **param_vals) { const nnvm::Op* op = static_cast(creator); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + NDArray** outarray = *reinterpret_cast(outputs); API_BEGIN(); nnvm::NodeAttrs attrs; SetOpAttrs(op, &attrs, num_inputs, num_params, param_keys, param_vals); - ImperativeInvokeImpl(attrs, num_inputs, inputs, num_outputs, outputs); + + int infered_num_outputs; + int num_visible_outputs; + SetNumOutputs(op, attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); + + std::vector ndinputs, ndoutputs; + SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs, + num_outputs, infered_num_outputs, num_visible_outputs, outarray); + + ImperativeInvokeImpl(Context::CPU(), attrs, &ndinputs, &ndoutputs); + + if (outarray == nullptr) { + ret->ret_handles.clear(); + for (int i = 0; i < num_visible_outputs; ++i) { + ret->ret_handles.push_back( + reinterpret_cast(new NDArray(std::move(ndoutputs[i])))); + } + *outputs = dmlc::BeginPtr(ret->ret_handles); + } else { + for (int i = 0; i < *num_outputs; ++i) { + *outarray[i] = std::move(ndoutputs[i]); + } + } API_END(); } -int MXCachedCreateOp(AtomicSymbolCreator creator, - int num_inputs, - int num_params, - const char **param_keys, - const char **param_vals, +int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out) { - const nnvm::Op* op = static_cast(creator); + nnvm::Symbol* sym = static_cast(handle); API_BEGIN(); - nnvm::NodeAttrs *attrs = new nnvm::NodeAttrs; - SetOpAttrs(op, attrs, num_inputs, num_params, param_keys, param_vals); - *out = attrs; + nnvm::Graph *g = new nnvm::Graph; + g->outputs = sym->outputs; + auto vars = sym->ListInputs(nnvm::Symbol::kAll); + CHECK_GE(vars.size(), 1) << "CachedOp must have at least 1 input."; + g->attrs["vars"] = std::make_shared(std::move(vars)); + *out = g; API_END(); } -int MXCachedFree(CachedOpHandle handle) { - nnvm::NodeAttrs *attrs = static_cast(handle); - +int MXFreeCachedOp(CachedOpHandle handle) { + nnvm::Graph *g = static_cast(handle); API_BEGIN(); - delete attrs; + delete g; API_END(); } -int MXCachedInvoke(CachedOpHandle handle, - int num_inputs, - NDArrayHandle *inputs, - int *num_outputs, - NDArrayHandle **outputs) { - nnvm::NodeAttrs *attrs = static_cast(handle); +int MXInvokeCachedOp(CachedOpHandle handle, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs) { + nnvm::Graph *g = static_cast(handle); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + NDArray** outarray = *reinterpret_cast(outputs); API_BEGIN(); - ImperativeInvokeImpl(*attrs, num_inputs, inputs, num_outputs, outputs); + const std::vector& vars = + g->GetAttr >("vars"); + const nnvm::IndexedGraph& idx = g->indexed_graph(); + CHECK_EQ(static_cast(num_inputs), vars.size()) + << "Actually number of inputs differs from expected number of inputs"; + Context default_ctx = static_cast(inputs[0])->ctx(); + + std::vector buff(idx.num_node_entries()); + for (size_t i = 0; i < vars.size(); ++i) { + buff[idx.entry_id(idx.node_id(vars[i].get()), 0)] = + *static_cast(inputs[i]); + } + + for (size_t i = 0; i < idx.num_nodes(); ++i) { + const nnvm::IndexedGraph::Node& node = idx[i]; + if (node.source->attrs.op == nullptr) continue; + std::vector in; + in.reserve(node.inputs.size()); + for (const auto& j : node.inputs) { + in.emplace_back(buff[idx.entry_id(j)]); + } + std::vector out(node.source->num_outputs()); + ImperativeInvokeImpl(default_ctx, node.source->attrs, &in, &out); + + for (size_t j = 0; j < node.source->num_outputs(); ++j) { + buff[idx.entry_id(i, j)] = std::move(out[j]); + } + } + + if (outarray == nullptr) { + ret->ret_handles.clear(); + for (const auto& i : idx.outputs()) { + ret->ret_handles.push_back( + reinterpret_cast( + new NDArray(std::move(buff[idx.entry_id(i)])))); + } + *num_outputs = idx.outputs().size(); + *outputs = dmlc::BeginPtr(ret->ret_handles); + } else { + CHECK_EQ(static_cast(*num_outputs), idx.outputs().size()) + << "Specifed number of output differs from expected number of outputs"; + for (size_t i = 0; i < idx.outputs().size(); ++i) { + *outarray[i] = std::move(buff[idx.entry_id(idx.outputs()[i])]); + } + } API_END(); } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index cad9e604df60..d3603e94b2a1 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -124,22 +124,6 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, API_END_HANDLE_ERROR(delete s;); } -int MXCachedCreateSymbol(CachedOpHandle handle, - const char* name, - mx_uint num_args, - SymbolHandle* args, - SymbolHandle* out) { - nnvm::Symbol *s = new nnvm::Symbol(); - const nnvm::NodeAttrs *attrs = static_cast(handle); - API_BEGIN(); - *s = nnvm::Symbol::CreateFunctor(*attrs); - nnvm::array_view parg( - (nnvm::Symbol**)args, (nnvm::Symbol**)args + num_args); // NOLINT(*) - s->Compose(parg, std::unordered_map(), name); - *out = s; - API_END_HANDLE_ERROR(delete s;) -} - int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { return NNSymbolCreateVariable(name, out); } diff --git a/src/initialize.cc b/src/initialize.cc index d57fec84f72b..29a687671cee 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -11,6 +11,8 @@ namespace mxnet { +static void (*prev_handler)(int) = nullptr; + void segfault_logger(int sig) { const int MAX_STACK_SIZE = 10; void *stack[MAX_STACK_SIZE]; @@ -28,14 +30,20 @@ void segfault_logger(int sig) { } #endif // DMLC_LOG_STACK_TRACE - exit(1); + if (prev_handler == nullptr || + prev_handler == SIG_DFL) { + exit(1); + } else if (prev_handler == SIG_IGN) { + } else { + prev_handler(sig); + } } class LibraryInitializer { public: LibraryInitializer() { dmlc::InitLogging("mxnet"); - // signal(SIGSEGV, segfault_logger); + prev_handler = signal(SIGSEGV, segfault_logger); #if MXNET_USE_PROFILER // ensure profiler's constructor are called before atexit. engine::Profiler::Get(); diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index dd38bdf98606..8c58d3b47a69 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -628,13 +628,17 @@ def test_iter(): def test_cached(): - op = mx.nd.CachedOp('Convolution', 3, kernel=(3, 3), num_filter=10) + sym = mx.sym.Convolution(kernel=(3, 3), num_filter=10) + 2 + op = mx.nd.CachedOp(sym) data = mx.nd.ones((3, 4, 10, 10)) weight = mx.nd.ones((10, 4, 3, 3)) bias = mx.nd.ones((10,)) - o1 = mx.nd.invoke(op, [data, weight, bias]) + o1 = op(data, weight, bias) bias[:] = 2 - o2 = mx.nd.invoke(op, [data, weight, bias]) + o2 = op(data, weight, bias) + assert_almost_equal(o2.asnumpy(), o1.asnumpy()+1) + o2[:] = 0 + op(data, weight, bias, out=o2) assert_almost_equal(o2.asnumpy(), o1.asnumpy()+1) def test_output(): diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 28fc8a4fc77b..093a8f3a40e0 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -225,19 +225,6 @@ def test_zero_prop2(): assert False -def test_cached(): - op = mx.sym.CachedOp('Convolution', 3, kernel=(3, 3), num_filter=10) - data = mx.sym.var('data') - weight = mx.sym.var('weight') - bias = mx.sym.var('bias') - out = mx.sym.invoke(op, [data, weight, bias], 'conv') - assert out.list_arguments() == ['data', 'weight', 'bias'] - assert out.list_outputs() == ['conv_output'] - with mx.name.Prefix('test_'): - assert mx.sym.invoke(op, [data, weight, bias]).name == 'test_convolution0' - assert mx.sym.invoke(op, [data, weight, bias]).name == 'test_convolution1' - - if __name__ == '__main__': import nose nose.runmodule()