diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index b42fefd065e8..d474337772a0 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -93,6 +93,53 @@ def __exit__(self, ptype, value, trace): _BlockScope._current.value = self._old_scope +def _gather_type_ctx_info(args): + """Analyze the elements inside the nested args object and find: + - If there exists ndarray + - If there exists symbol + - All contexts appearing in args + + Parameters + ---------- + args : list or NDArray or Symbol + Could be a nested architecture. + + Returns + ------- + has_symbol : bool + Whether the elements in args contains symbols + has_ndarray : bool + Whether the elements in args contains ndarrays + ctx_set : set of mxnet.context.Context + Contains all possible contexts of the inner ndarrays in args. Can be empty if there is no + ndarray inside args. + first_ctx : mxnet.context.Context or None + Context of the first appeared NDArray (for backward-compatibility) + """ + if isinstance(args, NDArray): + return False, True, {args.context}, args.context + elif isinstance(args, Symbol): + return True, False, set(), None + elif isinstance(args, (list, tuple)): + has_symbol = False + has_ndarray = False + ctx_set = set() + first_ctx = None + for ele in args: + ele_has_sym, ele_has_nd, ele_ctx_set, ele_first_ctx =\ + _gather_type_ctx_info(ele) + has_symbol = has_symbol or ele_has_sym + has_ndarray = has_ndarray or ele_has_nd + if first_ctx is None and ele_first_ctx is not None: + first_ctx = ele_first_ctx + ctx_set = ctx_set | ele_ctx_set + if has_symbol and has_ndarray: + break + return has_symbol, has_ndarray, ctx_set, first_ctx + else: + return False, False, set(), None + + def _flatten(args, inout_str): """Parse the arguments into a flattened list + an additional format array. The format array stores the structure of the original arguments to help reconstruct the inputs. @@ -120,9 +167,11 @@ def _flatten(args, inout_str): if args is None: return [None], int(-1) - assert isinstance(args, (list, tuple)), \ - "HybridBlock {} must be (nested) list of Symbol or NDArray, " \ - "but got {} of type {}".format(inout_str, str(args), str(type(args))) + if not isinstance(args, (list, tuple)): + raise ValueError("When hybridized, the input of HybridBlock {}" + " must be (nested) list of Symbol" + " or NDArray, " + "but got {} of type {}".format(inout_str, str(args), str(type(args)))) flat = [] fmts = [] for i in args: @@ -164,9 +213,10 @@ def _merger(args, fmt): else: return args[:fmt], args[fmt:] - assert isinstance(args, (list, tuple)), \ - "HybridBlock output must be (nested) list of Symbol or NDArray, " \ - "but got {} of type {}".format(args, type(args)) + if not isinstance(args, (list, tuple)): + raise ValueError("When hybridized, the output of HybridBlock must be (nested)" + " list of Symbol or NDArray, " + "but got {} of type {}".format(args, type(args))) ret = [] for i in fmt: res, args = _merger(args, i) @@ -1054,38 +1104,26 @@ def register_op_hook(self, callback, monitor_all=False): def forward(self, x, *args): """Defines the forward computation. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`.""" - flatten_args = _flatten([x] + list(args), 'inputs')[0] - is_ndarray = None - ctx = None - exist_sym_nd = False - for ele in flatten_args: - if isinstance(ele, NDArray): - if is_ndarray is False: - raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' - ' types for the input.\n' - 'Received types are: {}.' - .format([type(ele) for ele in flatten_args])) - is_ndarray = True - exist_sym_nd = True - ctx = ele.context - elif isinstance(ele, Symbol): - if is_ndarray: - raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' - ' types for the input.\n' - 'Received types are: {}.' - .format([type(ele) for ele in flatten_args])) - is_ndarray = False - exist_sym_nd = True - else: - assert ele is None, 'Only support None, NDArray and Symbol as the input' - if not exist_sym_nd: - raise ValueError('There must at least one NDArray or Symbol in the input, received') - if is_ndarray: - with ctx: - if self._active: + has_symbol, has_ndarray, ctx_set, first_ctx = _gather_type_ctx_info([x] + list(args)) + if has_symbol and has_ndarray: + raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' + ' types for the input. Please check the type of the args.\n') + if not has_symbol and not has_ndarray: + raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.' + ' Please check the type of the args.\n') + if has_ndarray: + ctx = first_ctx + if self._active: + if len(ctx_set) > 1: + raise ValueError('Find multiple contexts in the input, ' + 'After hybridized, the HybridBlock only supports one input ' + 'context. You can print the ele.context in the ' + 'input arguments to inspect their contexts. ' + 'Find all contexts = {}'.format(ctx_set)) + with ctx: return self._call_cached_op(x, *args) - + with ctx: try: params = {k: v.data(ctx) for k, v in self._reg_params.items()} except DeferredInitializationError: diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index fc650294a538..3454e0b55816 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -510,6 +510,21 @@ def test_bulking(): .format(fully_bulked_time - fastest_half_bulked_time, times_str) +@with_seed() +def test_hybridblock_mix_ctx_raise(): + class FooHybrid(gluon.HybridBlock): + def hybrid_forward(self, F, a, b): + if isinstance(a, (list, tuple)): + a = sum(a) + if isinstance(b, (list, tuple)): + b = sum(b) + return a + b + foo_hybrid = FooHybrid() + foo_hybrid.hybridize() + assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.gpu()), + mx.nd.ones((10,), ctx=mx.cpu()))) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 380ce762a9f7..ef9e8990259f 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -528,6 +528,50 @@ def hybrid_forward(self, F, a, b): assert_raises(ValueError, lambda: foo1(mx.nd.ones((10,)), mx.nd.ones((10,)))) +@with_seed() +def test_hybrid_block_hybrid_no_hybrid(): + class FooHybrid(gluon.HybridBlock): + def hybrid_forward(self, F, a, b): + if isinstance(a, (list, tuple)): + a = sum(a) + if isinstance(b, (list, tuple)): + b = sum(b) + return a + b + + class Foo(gluon.Block): + def forward(self, a, b): + if isinstance(a, (list, tuple)): + a = sum(a) + if isinstance(b, (list, tuple)): + b = sum(b) + return a + b + # When hybridize is not called, HybridBlock acts the same as Block + foo_hybrid = FooHybrid() + foo = Foo() + for a, b in [(mx.nd.ones((10,)), 1), + (mx.nd.ones((20,)), 2), + ([mx.nd.ones((10,)), mx.nd.ones((10,))], + [mx.nd.ones((10)), mx.nd.ones((10,)), mx.nd.ones((10,))]), + ([mx.nd.ones((10,)), mx.nd.ones((10,))], 3)]: + hybrid_block_out = foo_hybrid(a, b) + block_out = foo(a, b) + assert_almost_equal(hybrid_block_out.asnumpy(), block_out.asnumpy()) + # When hybridize is called, we need to make sure that the model raises for the unsupported cases + # 1. Scalar values in the input + # 2. No mixing of sym/ndarray + # 3. No mixing of cpu ndarray and gpu ndarray (Tested in gpu/test_gluon_gpu.py) + # 4. Allow mixing of cpu_pinned and cpu + foo_hybrid = FooHybrid() + foo_hybrid.hybridize() + assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,)), 1)) + foo_hybrid = FooHybrid() + foo_hybrid.hybridize() + assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,)), mx.sym.var('a'))) + foo_hybrid = FooHybrid() + foo_hybrid.hybridize() + assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.cpu(1)), + mx.nd.ones((10,), ctx=mx.cpu(2)))) + @with_seed() def check_layer_forward(layer, dshape):