Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
try to fix block
Browse files Browse the repository at this point in the history
fix bug

fix

fix

fix lint

fix

backward-compatible for inferencing the ctx

fix lint

try to improve

try to fix

Update block.py

Revert "Update block.py"

This reverts commit bbcf41f.

Revert "try to fix"

This reverts commit 7d7f35c.

Revert "try to improve"

This reverts commit f510132.

Revert "Revert "try to improve""

This reverts commit 872a7abe34c9afa97eb631fab8c3ce8558a22af8.

Revert "Revert "try to fix""

This reverts commit 48e235dd4c9ee6a88d1a2515a8a1f3e57319c217.

Revert "Revert "Update block.py""

This reverts commit e0d3949245050a4c60c4834db73c764d87fde6f7.

fix

fix lint
  • Loading branch information
sxjscience committed Oct 18, 2019
1 parent 1e8cc90 commit feb7d55
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 36 deletions.
110 changes: 74 additions & 36 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,53 @@ def __exit__(self, ptype, value, trace):
_BlockScope._current.value = self._old_scope


def _gather_type_ctx_info(args):
"""Analyze the elements inside the nested args object and find:
- If there exists ndarray
- If there exists symbol
- All contexts appearing in args
Parameters
----------
args : list or NDArray or Symbol
Could be a nested architecture.
Returns
-------
has_symbol : bool
Whether the elements in args contains symbols
has_ndarray : bool
Whether the elements in args contains ndarrays
ctx_set : set of mxnet.context.Context
Contains all possible contexts of the inner ndarrays in args. Can be empty if there is no
ndarray inside args.
first_ctx : mxnet.context.Context or None
Context of the first appeared NDArray (for backward-compatibility)
"""
if isinstance(args, NDArray):
return False, True, {args.context}, args.context
elif isinstance(args, Symbol):
return True, False, set(), None
elif isinstance(args, (list, tuple)):
has_symbol = False
has_ndarray = False
ctx_set = set()
first_ctx = None
for ele in args:
ele_has_sym, ele_has_nd, ele_ctx_set, ele_first_ctx =\
_gather_type_ctx_info(ele)
has_symbol = has_symbol or ele_has_sym
has_ndarray = has_ndarray or ele_has_nd
if first_ctx is None and ele_first_ctx is not None:
first_ctx = ele_first_ctx
ctx_set = ctx_set | ele_ctx_set
if has_symbol and has_ndarray:
break
return has_symbol, has_ndarray, ctx_set, first_ctx
else:
return False, False, set(), None


def _flatten(args, inout_str):
"""Parse the arguments into a flattened list + an additional format array.
The format array stores the structure of the original arguments to help reconstruct the inputs.
Expand Down Expand Up @@ -120,9 +167,11 @@ def _flatten(args, inout_str):
if args is None:
return [None], int(-1)

assert isinstance(args, (list, tuple)), \
"HybridBlock {} must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(inout_str, str(args), str(type(args)))
if not isinstance(args, (list, tuple)):
raise ValueError("When hybridized, the input of HybridBlock {}"
" must be (nested) list of Symbol"
" or NDArray, "
"but got {} of type {}".format(inout_str, str(args), str(type(args))))
flat = []
fmts = []
for i in args:
Expand Down Expand Up @@ -164,9 +213,10 @@ def _merger(args, fmt):
else:
return args[:fmt], args[fmt:]

assert isinstance(args, (list, tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got {} of type {}".format(args, type(args))
if not isinstance(args, (list, tuple)):
raise ValueError("When hybridized, the output of HybridBlock must be (nested)"
" list of Symbol or NDArray, "
"but got {} of type {}".format(args, type(args)))
ret = []
for i in fmt:
res, args = _merger(args, i)
Expand Down Expand Up @@ -1054,38 +1104,26 @@ def register_op_hook(self, callback, monitor_all=False):
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
flatten_args = _flatten([x] + list(args), 'inputs')[0]
is_ndarray = None
ctx = None
exist_sym_nd = False
for ele in flatten_args:
if isinstance(ele, NDArray):
if is_ndarray is False:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = True
exist_sym_nd = True
ctx = ele.context
elif isinstance(ele, Symbol):
if is_ndarray:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input.\n'
'Received types are: {}.'
.format([type(ele) for ele in flatten_args]))
is_ndarray = False
exist_sym_nd = True
else:
assert ele is None, 'Only support None, NDArray and Symbol as the input'
if not exist_sym_nd:
raise ValueError('There must at least one NDArray or Symbol in the input, received')

if is_ndarray:
with ctx:
if self._active:
has_symbol, has_ndarray, ctx_set, first_ctx = _gather_type_ctx_info([x] + list(args))
if has_symbol and has_ndarray:
raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
' types for the input. Please check the type of the args.\n')
if not has_symbol and not has_ndarray:
raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
' Please check the type of the args.\n')
if has_ndarray:
ctx = first_ctx
if self._active:
if len(ctx_set) > 1:
raise ValueError('Find multiple contexts in the input, '
'After hybridized, the HybridBlock only supports one input '
'context. You can print the ele.context in the '
'input arguments to inspect their contexts. '
'Find all contexts = {}'.format(ctx_set))
with ctx:
return self._call_cached_op(x, *args)

with ctx:
try:
params = {k: v.data(ctx) for k, v in self._reg_params.items()}
except DeferredInitializationError:
Expand Down
15 changes: 15 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,21 @@ def test_bulking():
.format(fully_bulked_time - fastest_half_bulked_time, times_str)


@with_seed()
def test_hybridblock_mix_ctx_raise():
class FooHybrid(gluon.HybridBlock):
def hybrid_forward(self, F, a, b):
if isinstance(a, (list, tuple)):
a = sum(a)
if isinstance(b, (list, tuple)):
b = sum(b)
return a + b
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.gpu()),
mx.nd.ones((10,), ctx=mx.cpu())))


if __name__ == '__main__':
import nose
nose.runmodule()
44 changes: 44 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,50 @@ def hybrid_forward(self, F, a, b):
assert_raises(ValueError, lambda: foo1(mx.nd.ones((10,)), mx.nd.ones((10,))))


@with_seed()
def test_hybrid_block_hybrid_no_hybrid():
class FooHybrid(gluon.HybridBlock):
def hybrid_forward(self, F, a, b):
if isinstance(a, (list, tuple)):
a = sum(a)
if isinstance(b, (list, tuple)):
b = sum(b)
return a + b

class Foo(gluon.Block):
def forward(self, a, b):
if isinstance(a, (list, tuple)):
a = sum(a)
if isinstance(b, (list, tuple)):
b = sum(b)
return a + b
# When hybridize is not called, HybridBlock acts the same as Block
foo_hybrid = FooHybrid()
foo = Foo()
for a, b in [(mx.nd.ones((10,)), 1),
(mx.nd.ones((20,)), 2),
([mx.nd.ones((10,)), mx.nd.ones((10,))],
[mx.nd.ones((10)), mx.nd.ones((10,)), mx.nd.ones((10,))]),
([mx.nd.ones((10,)), mx.nd.ones((10,))], 3)]:
hybrid_block_out = foo_hybrid(a, b)
block_out = foo(a, b)
assert_almost_equal(hybrid_block_out.asnumpy(), block_out.asnumpy())
# When hybridize is called, we need to make sure that the model raises for the unsupported cases
# 1. Scalar values in the input
# 2. No mixing of sym/ndarray
# 3. No mixing of cpu ndarray and gpu ndarray (Tested in gpu/test_gluon_gpu.py)
# 4. Allow mixing of cpu_pinned and cpu
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,)), 1))
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,)), mx.sym.var('a')))
foo_hybrid = FooHybrid()
foo_hybrid.hybridize()
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.cpu(1)),
mx.nd.ones((10,), ctx=mx.cpu(2))))


@with_seed()
def check_layer_forward(layer, dshape):
Expand Down

0 comments on commit feb7d55

Please sign in to comment.