diff --git a/example/gluon/style_transfer/main.py b/example/gluon/style_transfer/main.py index dde992ae7005..816487ae9fd5 100644 --- a/example/gluon/style_transfer/main.py +++ b/example/gluon/style_transfer/main.py @@ -24,7 +24,7 @@ from PIL import Image from mxnet import autograd, gluon -from mxnet.gluon import nn, Block, HybridBlock, Parameter, ParameterDict +from mxnet.gluon import nn, Block, HybridBlock, Parameter import mxnet.ndarray as F import net diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py index ac70cfa08850..5b45be63de6f 100644 --- a/python/mxnet/contrib/amp/amp.py +++ b/python/mxnet/contrib/amp/amp.py @@ -686,7 +686,8 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None, # If dtype for the param was set in the json, cast the # param to this dtype attr_dict = converted_sym.attr_dict() - for name, param in block.collect_params().items(): + for param in block.collect_params().values(): + name = param.name if name in arg_names: arg_dict['arg:%s'%name] = param._reduce() if name in attr_dict and "__dtype__" in attr_dict[name]: @@ -719,7 +720,7 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None, if aux_param_name in arg_dict and param.dtype != arg_dict[aux_param_name].dtype: param.cast(arg_dict[aux_param_name].dtype) - ret.collect_params().load_dict(arg_dict, ctx=ctx) + ret.load_dict(arg_dict, ctx=ctx) return ret def list_lp16_ops(target_dtype): diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index e3f30f8fe9ec..2fda08067e0f 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -36,7 +36,7 @@ from ..ndarray import NDArray from .. import name as _name from .. import profiler as _profiler -from .parameter import Parameter, ParameterDict, DeferredInitializationError +from .parameter import Parameter, DeferredInitializationError from .utils import _indent, _brief_print_list, HookHandle, shape_is_known from .utils import _check_same_symbol_type, _check_all_np_ndarrays from .. import numpy_extension as _mx_npx @@ -57,57 +57,39 @@ def __init__(self, block): self._local._name_scope = None @staticmethod - def create(prefix, params, hint): + def count(hint): """ - Creates prefix, params, and profiler scope name for new `Block`. + Creates unique name for new `Block`. The profiler scope is to support the GPU memory profiler. """ current = getattr(_BlockScope._current, "value", None) - block = current._block() if current is not None else None - if current is None or block is None: - if prefix is None: - if not hasattr(_name.NameManager._current, "value"): - _name.NameManager._current.value = _name.NameManager() - prefix = _name.NameManager._current.value.get(None, hint) + '_' - # replace the trailing underscore with colon - profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \ - else prefix) + ":" - if params is None: - params = ParameterDict(prefix) - else: - params = ParameterDict(params.prefix, params) - return prefix, params, profiler_scope_name + if current is None: + if not hasattr(_name.NameManager._current, "value"): + _name.NameManager._current.value = _name.NameManager() + block_name = _name.NameManager._current.value.get(None, hint) + return block_name - if prefix is None: - count = current._counter.get(hint, 0) - prefix = '%s%d_'%(hint, count) - current._counter[hint] = count + 1 - if params is None: - parent = block.params - params = ParameterDict(parent.prefix+prefix, parent._shared) - else: - params = ParameterDict(params.prefix, params) - # replace the trailing underscore with colon - profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \ - else prefix) + ":" - return block.prefix + prefix, params, \ - block._profiler_scope_name + profiler_scope_name + count = current._counter.get(hint, 0) + block_name = '%s%d'%(hint, count) + current._counter[hint] = count + 1 + return block_name def __enter__(self): block = self._block() - if block is None or block._empty_prefix: + if block is None or block.name == '': return self self._local._old_scope = getattr(_BlockScope._current, "value", None) _BlockScope._current.value = self - self._local._name_scope = _name.Prefix(block.prefix) + self._local._name_scope = _name.Prefix(block.name + '_') self._local._name_scope.__enter__() - self._local._profiler_scope = _profiler.Scope(block._profiler_scope_name) + _profiler_scope_name = block.name + ":" + self._local._profiler_scope = _profiler.Scope(_profiler_scope_name) self._local._profiler_scope.__enter__() return self def __exit__(self, ptype, value, trace): block = self._block() - if block is None or block._empty_prefix: + if block is None or block.name == '': return self._local._name_scope.__exit__(ptype, value, trace) self._local._name_scope = None @@ -261,10 +243,8 @@ class Block(object): class Model(Block): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) - # use name_scope to give child Blocks appropriate names. - with self.name_scope(): - self.dense0 = nn.Dense(20) - self.dense1 = nn.Dense(20) + self.dense0 = nn.Dense(20) + self.dense1 = nn.Dense(20) def forward(self, x): x = mx.nd.relu(self.dense0(x)) @@ -279,31 +259,14 @@ def forward(self, x): will collect their Parameters recursively. You can also manually register child blocks with :py:meth:`register_child`. - Parameters - ---------- - prefix : str - Prefix acts like a name space. All children blocks created in parent block's - :py:meth:`name_scope` will have parent block's prefix in their name. - Please refer to - `naming tutorial `_ - for more info on prefix and naming. - params : ParameterDict or None - :py:class:`ParameterDict` for sharing weights with the new :py:class:`Block`. For example, - if you want ``dense1`` to share ``dense0``'s weights, you can do:: - - dense0 = nn.Dense(20) - dense1 = nn.Dense(20, params=dense0.collect_params()) """ - def __init__(self, prefix=None, params=None): - self._empty_prefix = prefix == '' - self._prefix, self._params, self._profiler_scope_name = \ - _BlockScope.create(prefix, params, self._alias()) - self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix - self._scope = _BlockScope(self) + def __init__(self): self._children = OrderedDict() self._reg_params = {} self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() + self._name = _BlockScope.count(self._alias()) + self._scope = _BlockScope(self) def __repr__(self): s = '{name}(\n{modstr}\n)' @@ -325,10 +288,6 @@ def __setattr__(self, name, value): if isinstance(value, Block): self.register_child(value, name) elif isinstance(value, Parameter): - assert name not in self._reg_params, \ - "Overriding Parameter attribute %s is not allowed. " \ - "If you want to share parameters between blocks, please set " \ - "'params' at Block construction instead." self._reg_params[name] = value super(Block, self).__setattr__(name, value) @@ -364,44 +323,26 @@ def _find_unregistered_block_in_container(data): def _alias(self): return self.__class__.__name__.lower() - @property - def prefix(self): - """Prefix of this :py:class:`Block`.""" - return self._prefix - @property def name(self): - """Name of this :py:class:`Block`, without '_' in the end.""" + """Name of this :py:class:`Block`, class name + counter """ return self._name - def name_scope(self): - """Returns a name space object managing a child :py:class:`Block` and parameter - names. Should be used within a ``with`` statement:: - - with self.name_scope(): - self.dense = nn.Dense(20) - - Please refer to - `the naming tutorial `_ - for more info on prefix and naming. - """ - return self._scope - @property def params(self): """Returns this :py:class:`Block`'s parameter dictionary (does not include its children's parameters).""" - return self._params + return self._reg_params def collect_params(self, select=None): - """Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its - children's Parameters(default), also can returns the select :py:class:`ParameterDict` + """Returns a :py:class:`Dict` containing this :py:class:`Block` and all of its + children's Parameters(default), also can returns the select :py:class:`Dict` which match some given regular expressions. - For example, collect the specified parameters in ['conv1_weight', 'conv1_bias', 'fc_weight', - 'fc_bias']:: + For example, collect the specified parameters in ['conv1.weight', 'conv1.bias', 'fc.weight', + 'fc.bias']:: - model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias') + model.collect_params('conv1.weight|conv1.bias|fc.weight|fc.bias') or collect all parameters whose names end with 'weight' or 'bias', this can be done using regular expressions:: @@ -415,26 +356,23 @@ def collect_params(self, select=None): Returns ------- - The selected :py:class:`ParameterDict` + The selected :py:class:`Dict` """ # We need to check here because blocks inside containers are not supported. self._check_container_with_block() - ret = ParameterDict(self._params.prefix) - if not select: - ret.update(self.params) - else: - pattern = re.compile(select) - ret.update({name:value for name, value in self.params.items() if pattern.match(name)}) - for cld in self._children.values(): - ret.update(cld().collect_params(select=select)) - return ret + return self._collect_params_with_prefix(select=select) - def _collect_params_with_prefix(self, prefix=''): + def _collect_params_with_prefix(self, prefix='', select=None): if prefix: prefix += '.' - ret = {prefix + key : val for key, val in self._reg_params.items()} + if select is None: + ret = {prefix + key : val for key, val in self._reg_params.items()} + else: + pattern = re.compile(select) + ret = {prefix + key : val for key, val in self._reg_params.items() if pattern.match(prefix + key)} + for name, child in self._children.items(): - ret.update(child()._collect_params_with_prefix(prefix + name)) + ret.update(child()._collect_params_with_prefix(prefix + name, select)) return ret def save_parameters(self, filename, deduplicate=False): @@ -473,26 +411,6 @@ def save_parameters(self, filename, deduplicate=False): save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn(filename, arg_dict) - def save_params(self, filename): - """[Deprecated] Please use save_parameters. Note that if you want load - from SymbolBlock later, please use export instead. - - Save parameters to file. - - filename : str - Path to file. - """ - warnings.warn("save_params is deprecated. Please use save_parameters. " - "Note that if you want load from SymbolBlock later, please " - "use export instead. For details, see " - "https://mxnet.apache.org/tutorials/gluon/save_lo" - "ad_params.html") - try: - self.collect_params().save(filename, strip_prefix=self.prefix) - except ValueError as e: - raise ValueError('%s\nsave_params is deprecated. Using ' \ - 'save_parameters may resolve this error.'%e.message) - def load_parameters(self, filename, ctx=None, allow_missing=False, ignore_extra=False, cast_dtype=False, dtype_source='current'): """Load parameters from file previously saved by `save_parameters`. @@ -542,61 +460,66 @@ def load_parameters(self, filename, ctx=None, allow_missing=False, raise ValueError(err_msg) else: loaded = ndarray.load(filename) - params = self._collect_params_with_prefix() - if not loaded and not params: - return - if not any('.' in i for i in loaded.keys()): - # legacy loading - loaded = None # This should be changed to `del loaded` when dropping Python 2 - self.collect_params().load( - filename, ctx, allow_missing, ignore_extra, self.prefix, - cast_dtype=cast_dtype, dtype_source=dtype_source) + if not loaded: return + full_dict = {'params': loaded, 'filename': filename} + self.load_dict(full_dict, ctx, allow_missing, ignore_extra, cast_dtype, dtype_source) + + def load_dict(self, param_dict, ctx=None, allow_missing=False, + ignore_extra=False, cast_dtype=False, dtype_source="current"): + """Load parameters from dict + + Parameters + ---------- + param_dict : dict + Dictionary containing model parameters + ctx : Context or list of Context + Context(s) initialize loaded parameters on. + allow_missing : bool, default False + Whether to silently skip loading parameters not represented in the file. + ignore_extra : bool, default False + Whether to silently ignore parameters from the file that are not + present in this dict. + cast_dtype : bool, default False + Cast the data type of the NDArray loaded from the checkpoint to the dtype + provided by the Parameter if any + dtype_source : str, default 'current' + must be in {'current', 'saved'} + Only valid if cast_dtype=True, specify the source of the dtype for casting + the parameters + """ + if isinstance(param_dict.get('filename'), str): + # pass from load_parameters + filename = param_dict['filename'] + param_dict = param_dict['params'] + else: + filename = None + params = self.collect_params() + error_str = "file: %s" % (filename) if filename else "param_dict" + loaded = {k[4:] if k.startswith('arg:') or k.startswith('aux:') else k: v \ + for k, v in param_dict.items()} if not allow_missing: - # Shared parameters are stored only a single time as of MXNet 1.6. - # We thus retrieve all prefixes (through _collect_params_with_prefix) - # that a shared parameter is used with. Check that there are no - # missing parameters that were not yet already loaded from the - # shared version. params_inv = defaultdict(list) for k, v in params.items(): params_inv[v].append(k) for name, param in params.items(): assert any(p in loaded for p in params_inv[param]), \ - "Parameter '%s' is missing in file '%s', which contains parameters: %s. " \ + "Parameter '%s' is missing in '%s', which contains parameters: %s. " \ "Set allow_missing=True to ignore missing parameters."%( - name, filename, _brief_print_list(loaded.keys())) + name, error_str, _brief_print_list(loaded.keys())) + for name in loaded: if not ignore_extra and name not in params: raise ValueError( - "Parameter '%s' loaded from file '%s' is not present in ParameterDict, " \ + "Parameter '%s' loaded from '%s' is not present in Dict, " \ "which contains parameters %s. Set ignore_extra=True to ignore. "%( - name, filename, _brief_print_list(self._params.keys()))) + name, error_str, _brief_print_list(params.keys()))) if name in params: params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source) - def load_params(self, filename, ctx=None, allow_missing=False, - ignore_extra=False): - """[Deprecated] Please use load_parameters. - - Load parameters from file. - - filename : str - Path to parameter file. - ctx : Context or list of Context, default cpu() - Context(s) to initialize loaded parameters on. - allow_missing : bool, default False - Whether to silently skip loading parameters not represents in the file. - ignore_extra : bool, default False - Whether to silently ignore parameters from the file that are not - present in this Block. - """ - warnings.warn("load_params is deprecated. Please use load_parameters.") - self.load_parameters(filename, ctx, allow_missing, ignore_extra) - def register_child(self, block, name=None): """Registers block as a child of self. :py:class:`Block` s assigned to self as attributes will be registered automatically.""" @@ -662,7 +585,6 @@ def apply(self, fn): def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_reinit=False): """Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children. - Equivalent to ``block.collect_params().initialize(...)`` Parameters ---------- @@ -676,7 +598,11 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_reinit : bool, default False Whether to force re-initialization if parameter is already initialized. """ - self.collect_params().initialize(init, ctx, verbose, force_reinit) + params = self.collect_params() + if verbose: + init.set_verbosity(verbose=verbose) + for k, v in params.items(): + v.initialize(None, ctx, init, force_reinit=force_reinit, structural_name=k) def hybridize(self, active=True, **kwargs): """ Please refer description of HybridBlock hybridize(). @@ -697,6 +623,111 @@ def cast(self, dtype): for _, param in self.params.items(): param.cast(dtype) + def zero_grad(self): + """Sets all Parameters' gradient buffer to 0.""" + # collect gradient arrays for each ctx + arrays = defaultdict(list) + params = self.collect_params() + for p in params.values(): + if p.grad_req == 'null' or p._grad is None: + continue + for g in p.list_grad(): + if g.stype == 'row_sparse': + ndarray.zeros_like(g, out=g) + else: + arrays[g.ctx].append(g) + + if len(arrays) == 0: + return + + if is_np_array(): + for arr in arrays.values(): + for ele in arr: + ele[()] = 0 + else: + for arr in arrays.values(): + ndarray.reset_arrays(*arr, num_arrays=len(arr)) + + def reset_ctx(self, ctx): + """Re-assign all Parameters to other contexts. + + Parameters + ---------- + ctx : Context or list of Context, default :py:meth:`context.current_context()`. + Assign Parameter to given context. If ctx is a list of Context, a + copy will be made for each context. + """ + params = self.collect_params() + for i in params.values(): + i.reset_ctx(ctx) + + def setattr(self, name, value): + """Set an attribute to a new value for all Parameters. + + For example, set grad_req to null if you don't need gradient w.r.t a + model's Parameters:: + + model.setattr('grad_req', 'null') + + or change the learning rate multiplier:: + + model.setattr('lr_mult', 0.5) + + Parameters + ---------- + name : str + Name of the attribute. + value : valid type for attribute name + The new value for the attribute. + """ + params = self.collect_params() + for i in params.values(): + setattr(i, name, value) + + def share_parameters(self, shared): + """Share parameters recursively inside the model. + + For example, if you want ``dense1`` to share ``dense0``'s weights, you can do:: + + dense0 = nn.Dense(20) + dense1 = nn.Dense(20) + dense1.share_parameters(dense0.collect_params()) + + which equals to + dense1.weight = dense0.weight + dense1.bias = dense0.bias + + Parameters + ---------- + shared : Dict + Dict of the shared parameters. + + Returns + ------- + this block + """ + if shared is None: + return self + if not isinstance(shared, (dict, OrderedDict)): + raise ValueError("'shared' should be in type of Dict. Get type {}!".format(type(shared))) + shared_set = set(shared.keys()) + self._shared_parameters(shared, shared_set) + if len(shared_set) > 0: + for name in shared_set: + warnings.warn("Parameter name {} is not in the current model!".format(name)) + return self + + def _shared_parameters(self, shared, shared_set, prefix=""): + if prefix: + prefix += '.' + for name in self._reg_params: + key = prefix + name + if shared.get(key) is not None: + setattr(self, name, shared[key]) + shared_set.remove(key) + for name, child in self._children.items(): + child()._shared_parameters(shared, shared_set, prefix + name) + def __call__(self, *args): """Calls forward. Only accepts positional arguments.""" for hook in self._forward_pre_hooks.values(): @@ -862,10 +893,8 @@ class HybridBlock(Block): class Model(HybridBlock): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) - # use name_scope to give child Blocks appropriate names. - with self.name_scope(): - self.dense0 = nn.Dense(20) - self.dense1 = nn.Dense(20) + self.dense0 = nn.Dense(20) + self.dense1 = nn.Dense(20) def forward(self, x): x = nd.relu(self.dense0(x)) @@ -895,8 +924,8 @@ def forward(self, x): `Hybrid - Faster training and easy deployment `_ """ - def __init__(self, prefix=None, params=None): - super(HybridBlock, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(HybridBlock, self).__init__() self._cached_graph = () self._cached_op = None self._out_format = None @@ -913,6 +942,10 @@ def __setattr__(self, name, value): """Registers parameters.""" super(HybridBlock, self).__setattr__(name, value) if isinstance(value, HybridBlock): + if self._active: + warnings.warn("Currently the model has been hybridized. Automatically deactivate the hybridization \ + when changing the children blocks.") + self._active = False self._clear_cached_op() def _get_graph_v1(self, *args): @@ -940,7 +973,7 @@ def _get_graph_v1(self, *args): flatten_inputs.append(None) grouped_inputs = _regroup(flatten_inputs, self._in_format) params = {i: j.var() for i, j in self._reg_params.items()} - with self.name_scope(): + with self._scope: out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter out, self._out_format = _flatten(out, "output") @@ -986,6 +1019,7 @@ def _build_cache(self, *args): data, out = self._get_graph(*args) data_names = {data.name: i for i, data in enumerate(data)} params = self.collect_params() + params = {p.name: p for p in params.values()} input_names = out.list_inputs() param_names = set(params.keys()) expected_names = set(input_names) @@ -1167,6 +1201,10 @@ def register_child(self, block, name=None): "please try HybridSequential instead."%( str(block), str(type(block)))) super(HybridBlock, self).register_child(block, name) + if self._active: + warnings.warn("Currently the model has been hybridized. Automatically deactivate the hybridization \ + when adding new children block.") + self._active = False self._clear_cached_op() def hybridize(self, active=True, backend=None, backend_opts=None, **kwargs): @@ -1205,6 +1243,10 @@ def hybridize(self, active=True, backend=None, backend_opts=None, **kwargs): super(HybridBlock, self).hybridize(active, **kwargs) def cast(self, dtype): + if self._active: + warnings.warn("Currently the model has been hybridized. Automatically deactivate the hybridization \ + when cast the block to use another data type.") + self._active = False self._clear_cached_op() super(HybridBlock, self).cast(dtype) @@ -1259,7 +1301,8 @@ def export(self, path, epoch=0, remove_amp_cast=True): will be created, where xxxx is the 4 digits epoch number. epoch : int Epoch number of saved model. - + remove_amp_cast : bool, optional + Whether to remove the amp_cast and amp_multicast operators, before saving the model. Returns ------- symbol_filename : str @@ -1278,12 +1321,12 @@ def export(self, path, epoch=0, remove_amp_cast=True): arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} - for name, param in self.collect_params().items(): - if name in arg_names: - arg_dict['arg:%s'%name] = param._reduce() + for param in self.collect_params().values(): + if param.name in arg_names: + arg_dict['arg:%s'%param.name] = param._reduce() else: - assert name in aux_names - arg_dict['aux:%s'%name] = param._reduce() + assert param.name in aux_names + arg_dict['aux:%s'%param.name] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save params_filename = '%s-%04d.params'%(path, epoch) save_fn(params_filename, arg_dict) @@ -1378,7 +1421,7 @@ def forward(self, x, *args): return self.hybrid_forward(ndarray, x, *args, **params) params = {i: j.var() for i, j in self._reg_params.items()} - with self.name_scope(): + with self._scope: return self.hybrid_forward(symbol, x, *args, **params) def hybrid_forward(self, F, x, *args, **kwargs): @@ -1394,18 +1437,6 @@ def hybrid_forward(self, F, x, *args, **kwargs): # pylint: disable= invalid-name raise NotImplementedError -def _common_prefix(names): - """Get the common prefix for all names""" - if not names: - return '' - prefix = names[0] - for name in names: - i = 0 - while i < len(prefix) and i < len(name) and prefix[i] == name[i]: - i += 1 - prefix = prefix[:i] - return prefix - class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models @@ -1418,22 +1449,21 @@ class SymbolBlock(HybridBlock): The desired output for SymbolBlock. inputs : Symbol or list of Symbol The Variables in output's argument that should be used as inputs. - params : ParameterDict + params : dict Parameter dictionary for arguments and auxililary states of outputs that are not inputs. Examples -------- >>> # To extract the feature from fc1 and fc2 layers of AlexNet: - >>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(), - prefix='model_') + >>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu()) >>> inputs = mx.sym.var('data') >>> out = alexnet(inputs) >>> internals = out.get_internals() >>> print(internals.list_outputs()) - ['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...] - >>> outputs = [internals['model_dense0_relu_fwd_output'], - internals['model_dense1_relu_fwd_output']] + ['data', ..., 'features_9_act_fwd_output', ..., 'features_11_act_fwd_output', ...] + >>> outputs = [internals['features_9_act_fwd_output'], + internals['features_11_act_fwd_output']] >>> # Create SymbolBlock that shares parameters with alexnet >>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params()) >>> x = mx.nd.random.normal(shape=(16, 3, 224, 224)) @@ -1462,8 +1492,7 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): Examples -------- - >>> net1 = gluon.model_zoo.vision.resnet18_v1( - ... prefix='resnet', pretrained=True) + >>> net1 = gluon.model_zoo.vision.resnet18_v1(pretrained=True) >>> net1.hybridize() >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32)) >>> out1 = net1(x) @@ -1487,7 +1516,7 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names] ret = SymbolBlock(sym, inputs) if param_file is not None: - ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved') + ret.load_parameters(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved') return ret def __repr__(self): @@ -1500,9 +1529,17 @@ def __repr__(self): modstr=modstr) def __init__(self, outputs, inputs, params=None): - super(SymbolBlock, self).__init__(prefix=None, params=None) - self._prefix = '' - self._params = ParameterDict('', params) + super(SymbolBlock, self).__init__() + structure = defaultdict(list) + if params is None: + params = {} + self._structured_named = False + elif any(k.find('.') != -1 for k in params): + self._structured_named = True + for k, v in params.items(): + structure[v.name].append(k) + params = {p.name : p for p in params.values()} + if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1: inputs = [inputs] if isinstance(outputs, (list, tuple)) and len(outputs) == 1: @@ -1536,17 +1573,31 @@ def __init__(self, outputs, inputs, params=None): arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) + def _set_params_attr(name, **kwargs): + if params.get(name) is None: + param = Parameter(**kwargs) + param._name = name + else: + param = params[name] + param._check_and_setattr(**kwargs) + if self._structured_named: + lis = structure[name] + assert len(lis) > 0, "Can not find structured name for Parameter %s in 'params'. " \ + "Please check 'params' is complete!" % name + for structured_name in lis: + self._reg_params[structured_name] = param + else: + self._reg_params[name] = param + for i, arg in enumerate(arg_params): if arg not in input_names: - self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i]) + _set_params_attr(name=arg, allow_deferred_init=True, dtype=arg_types[i]) for i, aux in enumerate(aux_params): if aux not in input_names: - self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) + _set_params_attr(name=aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) self._cached_graph = syms, out - len_prefix = len(_common_prefix(list(self._params.keys()))) - self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()} def forward(self, x, *args): if dc.is_deferred_compute(): diff --git a/python/mxnet/gluon/contrib/cnn/conv_layers.py b/python/mxnet/gluon/contrib/cnn/conv_layers.py index c4924c130a28..ef74c23d52e3 100644 --- a/python/mxnet/gluon/contrib/cnn/conv_layers.py +++ b/python/mxnet/gluon/contrib/cnn/conv_layers.py @@ -23,6 +23,7 @@ from .... import symbol from ...block import HybridBlock +from ...parameter import Parameter from ....base import numeric_types from ...nn import Activation @@ -103,80 +104,79 @@ def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0), num_deformable_group=1, layout='NCHW', use_bias=True, in_channels=0, activation=None, weight_initializer=None, bias_initializer='zeros', offset_weight_initializer='zeros', offset_bias_initializer='zeros', offset_use_bias=True, - op_name='DeformableConvolution', adj=None, prefix=None, params=None): - super(DeformableConvolution, self).__init__(prefix=prefix, params=params) - with self.name_scope(): - self._channels = channels - self._in_channels = in_channels - - assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now" - if isinstance(kernel_size, numeric_types): - kernel_size = (kernel_size,) * 2 - if isinstance(strides, numeric_types): - strides = (strides,) * len(kernel_size) - if isinstance(padding, numeric_types): - padding = (padding,) * len(kernel_size) - if isinstance(dilation, numeric_types): - dilation = (dilation,) * len(kernel_size) - self._op_name = op_name - - offset_channels = 2 * kernel_size[0] * kernel_size[1] * num_deformable_group - self._kwargs_offset = { - 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, - 'pad': padding, 'num_filter': offset_channels, 'num_group': groups, - 'no_bias': not offset_use_bias, 'layout': layout} - - self._kwargs_deformable_conv = { - 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, - 'pad': padding, 'num_filter': channels, 'num_group': groups, - 'num_deformable_group': num_deformable_group, - 'no_bias': not use_bias, 'layout': layout} - - if adj: - self._kwargs_offset['adj'] = adj - self._kwargs_deformable_conv['adj'] = adj - - dshape = [0] * (len(kernel_size) + 2) - dshape[layout.find('N')] = 1 - dshape[layout.find('C')] = in_channels - - op = getattr(symbol, 'Convolution') - offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset) - - offsetshapes = offset.infer_shape_partial()[0] - - self.offset_weight = self.params.get('offset_weight', shape=offsetshapes[1], - init=offset_weight_initializer, - allow_deferred_init=True) - - if offset_use_bias: - self.offset_bias = self.params.get('offset_bias', shape=offsetshapes[2], - init=offset_bias_initializer, - allow_deferred_init=True) - else: - self.offset_bias = None - - deformable_conv_weight_shape = [0] * (len(kernel_size) + 2) - deformable_conv_weight_shape[0] = channels - deformable_conv_weight_shape[2] = kernel_size[0] - deformable_conv_weight_shape[3] = kernel_size[1] - - self.deformable_conv_weight = self.params.get('deformable_conv_weight', - shape=deformable_conv_weight_shape, - init=weight_initializer, - allow_deferred_init=True) - - if use_bias: - self.deformable_conv_bias = self.params.get('deformable_conv_bias', shape=(channels,), - init=bias_initializer, - allow_deferred_init=True) - else: - self.deformable_conv_bias = None - - if activation: - self.act = Activation(activation, prefix=activation + '_') - else: - self.act = None + op_name='DeformableConvolution', adj=None): + super(DeformableConvolution, self).__init__() + self._channels = channels + self._in_channels = in_channels + + assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now" + if isinstance(kernel_size, numeric_types): + kernel_size = (kernel_size,) * 2 + if isinstance(strides, numeric_types): + strides = (strides,) * len(kernel_size) + if isinstance(padding, numeric_types): + padding = (padding,) * len(kernel_size) + if isinstance(dilation, numeric_types): + dilation = (dilation,) * len(kernel_size) + self._op_name = op_name + + offset_channels = 2 * kernel_size[0] * kernel_size[1] * num_deformable_group + self._kwargs_offset = { + 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, + 'pad': padding, 'num_filter': offset_channels, 'num_group': groups, + 'no_bias': not offset_use_bias, 'layout': layout} + + self._kwargs_deformable_conv = { + 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, + 'pad': padding, 'num_filter': channels, 'num_group': groups, + 'num_deformable_group': num_deformable_group, + 'no_bias': not use_bias, 'layout': layout} + + if adj: + self._kwargs_offset['adj'] = adj + self._kwargs_deformable_conv['adj'] = adj + + dshape = [0] * (len(kernel_size) + 2) + dshape[layout.find('N')] = 1 + dshape[layout.find('C')] = in_channels + + op = getattr(symbol, 'Convolution') + offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset) + + offsetshapes = offset.infer_shape_partial()[0] + + self.offset_weight = Parameter('offset_weight', shape=offsetshapes[1], + init=offset_weight_initializer, + allow_deferred_init=True) + + if offset_use_bias: + self.offset_bias = Parameter('offset_bias', shape=offsetshapes[2], + init=offset_bias_initializer, + allow_deferred_init=True) + else: + self.offset_bias = None + + deformable_conv_weight_shape = [0] * (len(kernel_size) + 2) + deformable_conv_weight_shape[0] = channels + deformable_conv_weight_shape[2] = kernel_size[0] + deformable_conv_weight_shape[3] = kernel_size[1] + + self.deformable_conv_weight = Parameter('deformable_conv_weight', + shape=deformable_conv_weight_shape, + init=weight_initializer, + allow_deferred_init=True) + + if use_bias: + self.deformable_conv_bias = Parameter('deformable_conv_bias', shape=(channels,), + init=bias_initializer, + allow_deferred_init=True) + else: + self.deformable_conv_bias = None + + if activation: + self.act = Activation(activation) + else: + self.act = None def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bias=None, deformable_conv_bias=None): if offset_bias is None: @@ -296,81 +296,80 @@ def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0), num_deformable_group=1, layout='NCHW', use_bias=True, in_channels=0, activation=None, weight_initializer=None, bias_initializer='zeros', offset_weight_initializer='zeros', offset_bias_initializer='zeros', offset_use_bias=True, - op_name='ModulatedDeformableConvolution', adj=None, prefix=None, params=None): - super(ModulatedDeformableConvolution, self).__init__(prefix=prefix, params=params) - with self.name_scope(): - self._channels = channels - self._in_channels = in_channels - - assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now" - if isinstance(kernel_size, numeric_types): - kernel_size = (kernel_size,) * 2 - if isinstance(strides, numeric_types): - strides = (strides,) * len(kernel_size) - if isinstance(padding, numeric_types): - padding = (padding,) * len(kernel_size) - if isinstance(dilation, numeric_types): - dilation = (dilation,) * len(kernel_size) - self._op_name = op_name - - offset_channels = num_deformable_group * 3 * kernel_size[0] * kernel_size[1] - self.offset_split_index = num_deformable_group * 2 * kernel_size[0] * kernel_size[1] - self._kwargs_offset = { - 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, - 'pad': padding, 'num_filter': offset_channels, 'num_group': groups, - 'no_bias': not offset_use_bias, 'layout': layout} - - self._kwargs_deformable_conv = { - 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, - 'pad': padding, 'num_filter': channels, 'num_group': groups, - 'num_deformable_group': num_deformable_group, - 'no_bias': not use_bias, 'layout': layout} - - if adj: - self._kwargs_offset['adj'] = adj - self._kwargs_deformable_conv['adj'] = adj - - deformable_conv_weight_shape = [0] * (len(kernel_size) + 2) - deformable_conv_weight_shape[0] = channels - deformable_conv_weight_shape[2] = kernel_size[0] - deformable_conv_weight_shape[3] = kernel_size[1] - - self.deformable_conv_weight = self.params.get('deformable_conv_weight', - shape=deformable_conv_weight_shape, - init=weight_initializer, - allow_deferred_init=True) - - if use_bias: - self.deformable_conv_bias = self.params.get('deformable_conv_bias', shape=(channels,), - init=bias_initializer, - allow_deferred_init=True) - else: - self.deformable_conv_bias = None - - dshape = [0] * (len(kernel_size) + 2) - dshape[layout.find('N')] = 1 - dshape[layout.find('C')] = in_channels - - op = getattr(symbol, 'Convolution') - offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset) - - offsetshapes = offset.infer_shape_partial()[0] - - self.offset_weight = self.params.get('offset_weight', shape=offsetshapes[1], - init=offset_weight_initializer, - allow_deferred_init=True) - - if offset_use_bias: - self.offset_bias = self.params.get('offset_bias', shape=offsetshapes[2], - init=offset_bias_initializer, - allow_deferred_init=True) - else: - self.offset_bias = None - - if activation: - self.act = Activation(activation, prefix=activation + '_') - else: - self.act = None + op_name='ModulatedDeformableConvolution', adj=None): + super(ModulatedDeformableConvolution, self).__init__() + self._channels = channels + self._in_channels = in_channels + + assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now" + if isinstance(kernel_size, numeric_types): + kernel_size = (kernel_size,) * 2 + if isinstance(strides, numeric_types): + strides = (strides,) * len(kernel_size) + if isinstance(padding, numeric_types): + padding = (padding,) * len(kernel_size) + if isinstance(dilation, numeric_types): + dilation = (dilation,) * len(kernel_size) + self._op_name = op_name + + offset_channels = num_deformable_group * 3 * kernel_size[0] * kernel_size[1] + self.offset_split_index = num_deformable_group * 2 * kernel_size[0] * kernel_size[1] + self._kwargs_offset = { + 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, + 'pad': padding, 'num_filter': offset_channels, 'num_group': groups, + 'no_bias': not offset_use_bias, 'layout': layout} + + self._kwargs_deformable_conv = { + 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, + 'pad': padding, 'num_filter': channels, 'num_group': groups, + 'num_deformable_group': num_deformable_group, + 'no_bias': not use_bias, 'layout': layout} + + if adj: + self._kwargs_offset['adj'] = adj + self._kwargs_deformable_conv['adj'] = adj + + deformable_conv_weight_shape = [0] * (len(kernel_size) + 2) + deformable_conv_weight_shape[0] = channels + deformable_conv_weight_shape[2] = kernel_size[0] + deformable_conv_weight_shape[3] = kernel_size[1] + + self.deformable_conv_weight = Parameter('deformable_conv_weight', + shape=deformable_conv_weight_shape, + init=weight_initializer, + allow_deferred_init=True) + + if use_bias: + self.deformable_conv_bias = Parameter('deformable_conv_bias', shape=(channels,), + init=bias_initializer, + allow_deferred_init=True) + else: + self.deformable_conv_bias = None + + dshape = [0] * (len(kernel_size) + 2) + dshape[layout.find('N')] = 1 + dshape[layout.find('C')] = in_channels + + op = getattr(symbol, 'Convolution') + offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset) + + offsetshapes = offset.infer_shape_partial()[0] + + self.offset_weight = Parameter('offset_weight', shape=offsetshapes[1], + init=offset_weight_initializer, + allow_deferred_init=True) + + if offset_use_bias: + self.offset_bias = Parameter('offset_bias', shape=offsetshapes[2], + init=offset_bias_initializer, + allow_deferred_init=True) + else: + self.offset_bias = None + + if activation: + self.act = Activation(activation) + else: + self.act = None def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bias=None, deformable_conv_bias=None): if offset_bias is None: diff --git a/python/mxnet/gluon/contrib/data/vision/dataloader.py b/python/mxnet/gluon/contrib/data/vision/dataloader.py index 3213398b2214..f5cbf6197790 100644 --- a/python/mxnet/gluon/contrib/data/vision/dataloader.py +++ b/python/mxnet/gluon/contrib/data/vision/dataloader.py @@ -92,7 +92,7 @@ def create_image_augment(data_shape, resize=0, rand_crop=False, rand_resize=Fals """ if inter_method == 10: inter_method = np.random.randint(0, 5) - augmenter = HybridSequential('default_img_augment_') + augmenter = HybridSequential() if resize > 0: augmenter.add(transforms.image.Resize(resize, interpolation=inter_method)) crop_size = (data_shape[2], data_shape[1]) @@ -220,9 +220,9 @@ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None, augmenter = create_image_augment(data_shape, **kwargs) elif isinstance(aug_list, list): if all([isinstance(a, HybridBlock) for a in aug_list]): - augmenter = HybridSequential('user_img_augment_') + augmenter = HybridSequential() else: - augmenter = Sequential('user_img_augment_') + augmenter = Sequential() for aug in aug_list: augmenter.add(aug) elif isinstance(aug_list, Block): @@ -316,7 +316,7 @@ def create_bbox_augment(data_shape, rand_crop=0, rand_pad=0, rand_gray=0, """ if inter_method == 10: inter_method = np.random.randint(0, 5) - augmenter = Sequential('default_bbox_aug_') + augmenter = Sequential() if rand_crop > 0: augmenter.add(bbox.ImageBboxRandomCropWithConstraints( p=rand_crop, min_scale=area_range[0], max_scale=1.0, @@ -439,9 +439,9 @@ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None, augmenter = create_bbox_augment(data_shape, **kwargs) elif isinstance(aug_list, list): if all([isinstance(a, HybridBlock) for a in aug_list]): - augmenter = HybridSequential('user_bbox_augment_') + augmenter = HybridSequential() else: - augmenter = Sequential('user_bbox_augment_') + augmenter = Sequential() for aug in aug_list: augmenter.add(aug) elif isinstance(aug_list, Block): @@ -449,7 +449,7 @@ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None, else: raise ValueError('aug_list must be a list of Blocks') augmenter.hybridize() - wrapper_aug = Sequential('wrapper_bbox_aug_') + wrapper_aug = Sequential() wrapper_aug.add(BboxLabelTransform(coord_normalized)) wrapper_aug.add(augmenter) diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index c47e02b7213f..8bdecccc844c 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -71,7 +71,8 @@ class Estimator(object): the training net is given below: >>> net = _get_train_network() - >>> val_net = _get_test_network(params=net.collect_params()) + >>> val_net = _get_test_network() + >>> val_net.share_parameters(net.collect_params()) >>> net.initialize(ctx=ctx) >>> est = Estimator(net, loss, val_net=val_net) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index a0d3c38a9e04..6bb2147f0d7c 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -26,6 +26,7 @@ import warnings from .... import ndarray as nd, context from ...block import HybridBlock, Block +from ...parameter import Parameter from ...nn import Sequential, HybridSequential, BatchNorm class Concurrent(Sequential): @@ -38,19 +39,17 @@ class Concurrent(Sequential): Example:: net = Concurrent() - # use net's name_scope to give children blocks appropriate names. - with net.name_scope(): - net.add(nn.Dense(10, activation='relu')) - net.add(nn.Dense(20)) - net.add(Identity()) + net.add(nn.Dense(10, activation='relu')) + net.add(nn.Dense(20)) + net.add(Identity()) Parameters ---------- axis : int, default -1 The axis on which to concatenate the outputs. """ - def __init__(self, axis=-1, prefix=None, params=None): - super(Concurrent, self).__init__(prefix=prefix, params=params) + def __init__(self, axis=-1): + super(Concurrent, self).__init__() self.axis = axis def forward(self, x): @@ -71,19 +70,17 @@ class HybridConcurrent(HybridSequential): Example:: net = HybridConcurrent() - # use net's name_scope to give children blocks appropriate names. - with net.name_scope(): - net.add(nn.Dense(10, activation='relu')) - net.add(nn.Dense(20)) - net.add(Identity()) + net.add(nn.Dense(10, activation='relu')) + net.add(nn.Dense(20)) + net.add(Identity()) Parameters ---------- axis : int, default -1 The axis on which to concatenate the outputs. """ - def __init__(self, axis=-1, prefix=None, params=None): - super(HybridConcurrent, self).__init__(prefix=prefix, params=params) + def __init__(self, axis=-1): + super(HybridConcurrent, self).__init__() self.axis = axis def hybrid_forward(self, F, x): @@ -103,14 +100,12 @@ class Identity(HybridBlock): Example:: net = HybridConcurrent() - # use net's name_scope to give child Blocks appropriate names. - with net.name_scope(): - net.add(nn.Dense(10, activation='relu')) - net.add(nn.Dense(20)) - net.add(Identity()) + net.add(nn.Dense(10, activation='relu')) + net.add(nn.Dense(20)) + net.add(Identity()) """ - def __init__(self, prefix=None, params=None): - super(Identity, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(Identity, self).__init__() def hybrid_forward(self, F, x): return x @@ -149,9 +144,9 @@ def __init__(self, input_dim, output_dim, dtype='float32', super(SparseEmbedding, self).__init__(**kwargs) self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim, 'dtype': dtype, 'sparse_grad': True} - self.weight = self.params.get('weight', shape=(input_dim, output_dim), - init=weight_initializer, dtype=dtype, - grad_stype='row_sparse', stype='row_sparse') + self.weight = Parameter('weight', shape=(input_dim, output_dim), + init=weight_initializer, dtype=dtype, + grad_stype='row_sparse', stype='row_sparse') def forward(self, x): weight = self.weight.row_sparse_data(x) @@ -232,7 +227,7 @@ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, num_devices = self._get_num_devices() if num_devices is None else num_devices self._kwargs = {'eps': epsilon, 'momentum': momentum, 'fix_gamma': not scale, 'use_global_stats': use_global_stats, - 'ndev': num_devices, 'key': self.prefix} + 'ndev': num_devices, 'key': self.name} def _get_num_devices(self): warnings.warn("Caution using SyncBatchNorm: " diff --git a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py index ab40c3fc95d7..745a0c518fcd 100644 --- a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py @@ -27,6 +27,7 @@ from ....base import numeric_types from ...rnn import HybridRecurrentCell +from ...parameter import Parameter def _get_conv_out_size(dimensions, kernels, paddings, dilations): @@ -42,9 +43,8 @@ def __init__(self, input_shape, hidden_channels, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, dims, - conv_layout, activation, - prefix=None, params=None): - super(_BaseConvRNNCell, self).__init__(prefix=prefix, params=params) + conv_layout, activation): + super(_BaseConvRNNCell, self).__init__() self._hidden_channels = hidden_channels self._input_shape = input_shape @@ -79,18 +79,18 @@ def __init__(self, input_shape, hidden_channels, self._h2h_pad, \ self._state_shape = self._decide_shapes() - self.i2h_weight = self.params.get('i2h_weight', shape=i2h_param_shape, - init=i2h_weight_initializer, - allow_deferred_init=True) - self.h2h_weight = self.params.get('h2h_weight', shape=h2h_param_shape, - init=h2h_weight_initializer, - allow_deferred_init=True) - self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_channels*self._num_gates,), - init=i2h_bias_initializer, - allow_deferred_init=True) - self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_channels*self._num_gates,), - init=h2h_bias_initializer, - allow_deferred_init=True) + self.i2h_weight = Parameter('i2h_weight', shape=i2h_param_shape, + init=i2h_weight_initializer, + allow_deferred_init=True) + self.h2h_weight = Parameter('h2h_weight', shape=h2h_param_shape, + init=h2h_weight_initializer, + allow_deferred_init=True) + self.i2h_bias = Parameter('i2h_bias', shape=(hidden_channels*self._num_gates,), + init=i2h_bias_initializer, + allow_deferred_init=True) + self.h2h_bias = Parameter('h2h_bias', shape=(hidden_channels*self._num_gates,), + init=h2h_bias_initializer, + allow_deferred_init=True) def _decide_shapes(self): channel_axis = self._conv_layout.find('C') @@ -179,7 +179,7 @@ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, i2h_pad, i2h_dilate, h2h_dilate, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, - dims, conv_layout, activation, prefix, params): + dims, conv_layout, activation): super(_ConvRNNCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, activation=activation, @@ -191,8 +191,7 @@ def __init__(self, input_shape, hidden_channels, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, dims=dims, - conv_layout=conv_layout, - prefix=prefix, params=params) + conv_layout=conv_layout) def state_info(self, batch_size=0): return [{'shape': (batch_size,)+self._state_shape, '__layout__': self._conv_layout}] @@ -255,18 +254,13 @@ class Conv1DRNNCell(_ConvRNNCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_rnn_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, i2h_pad=(0,), i2h_dilate=(1,), h2h_dilate=(1,), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCW', activation='tanh', - prefix=None, params=None): + conv_layout='NCW', activation='tanh'): super(Conv1DRNNCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -278,8 +272,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=1, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class Conv2DRNNCell(_ConvRNNCell): @@ -322,18 +315,13 @@ class Conv2DRNNCell(_ConvRNNCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_rnn_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, i2h_pad=(0, 0), i2h_dilate=(1, 1), h2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCHW', activation='tanh', - prefix=None, params=None): + conv_layout='NCHW', activation='tanh'): super(Conv2DRNNCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -345,8 +333,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=2, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class Conv3DRNNCell(_ConvRNNCell): @@ -389,10 +376,6 @@ class Conv3DRNNCell(_ConvRNNCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_rnn_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, @@ -400,8 +383,7 @@ def __init__(self, input_shape, hidden_channels, i2h_dilate=(1, 1, 1), h2h_dilate=(1, 1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCDHW', activation='tanh', - prefix=None, params=None): + conv_layout='NCDHW', activation='tanh'): super(Conv3DRNNCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -413,8 +395,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=3, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class _ConvLSTMCell(_BaseConvRNNCell): @@ -423,7 +404,7 @@ def __init__(self, input_shape, hidden_channels, i2h_pad, i2h_dilate, h2h_dilate, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, - dims, conv_layout, activation, prefix, params): + dims, conv_layout, activation): super(_ConvLSTMCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -435,8 +416,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=dims, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) def state_info(self, batch_size=0): return [{'shape': (batch_size,)+self._state_shape, '__layout__': self._conv_layout}, @@ -519,10 +499,6 @@ class Conv1DLSTMCell(_ConvLSTMCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_lstm_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, @@ -530,8 +506,7 @@ def __init__(self, input_shape, hidden_channels, i2h_dilate=(1,), h2h_dilate=(1,), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCW', activation='tanh', - prefix=None, params=None): + conv_layout='NCW', activation='tanh'): super(Conv1DLSTMCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -543,8 +518,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=1, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class Conv2DLSTMCell(_ConvLSTMCell): @@ -596,10 +570,6 @@ class Conv2DLSTMCell(_ConvLSTMCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_lstm_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, @@ -607,8 +577,7 @@ def __init__(self, input_shape, hidden_channels, i2h_dilate=(1, 1), h2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCHW', activation='tanh', - prefix=None, params=None): + conv_layout='NCHW', activation='tanh'): super(Conv2DLSTMCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -620,8 +589,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=2, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class Conv3DLSTMCell(_ConvLSTMCell): @@ -673,10 +641,6 @@ class Conv3DLSTMCell(_ConvLSTMCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_lstm_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, @@ -684,8 +648,7 @@ def __init__(self, input_shape, hidden_channels, i2h_dilate=(1, 1, 1), h2h_dilate=(1, 1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCDHW', activation='tanh', - prefix=None, params=None): + conv_layout='NCDHW', activation='tanh'): super(Conv3DLSTMCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -697,8 +660,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=3, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class _ConvGRUCell(_BaseConvRNNCell): @@ -706,7 +668,7 @@ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, i2h_pad, i2h_dilate, h2h_dilate, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, - dims, conv_layout, activation, prefix, params): + dims, conv_layout, activation): super(_ConvGRUCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -718,8 +680,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=dims, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) def state_info(self, batch_size=0): return [{'shape': (batch_size,)+self._state_shape, '__layout__': self._conv_layout}] @@ -803,10 +764,6 @@ class Conv1DGRUCell(_ConvGRUCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_gru_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, @@ -814,8 +771,7 @@ def __init__(self, input_shape, hidden_channels, i2h_dilate=(1,), h2h_dilate=(1,), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCW', activation='tanh', - prefix=None, params=None): + conv_layout='NCW', activation='tanh'): super(Conv1DGRUCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -827,8 +783,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=1, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class Conv2DGRUCell(_ConvGRUCell): @@ -875,10 +830,6 @@ class Conv2DGRUCell(_ConvGRUCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_gru_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, @@ -886,8 +837,7 @@ def __init__(self, input_shape, hidden_channels, i2h_dilate=(1, 1), h2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCHW', activation='tanh', - prefix=None, params=None): + conv_layout='NCHW', activation='tanh'): super(Conv2DGRUCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -899,8 +849,7 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=2, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) class Conv3DGRUCell(_ConvGRUCell): @@ -947,10 +896,6 @@ class Conv3DGRUCell(_ConvGRUCell): If argument type is string, it's equivalent to nn.Activation(act_type=str). See :func:`~mxnet.ndarray.Activation` for available choices. Alternatively, other activation blocks such as nn.LeakyReLU can be used. - prefix : str, default ``'conv_gru_``' - Prefix for name of layers (and name of weight if params is None). - params : RNNParams, default None - Container for weight sharing between cells. Created if None. """ def __init__(self, input_shape, hidden_channels, i2h_kernel, h2h_kernel, @@ -958,8 +903,7 @@ def __init__(self, input_shape, hidden_channels, i2h_dilate=(1, 1, 1), h2h_dilate=(1, 1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - conv_layout='NCDHW', activation='tanh', - prefix=None, params=None): + conv_layout='NCDHW', activation='tanh'): super(Conv3DGRUCell, self).__init__(input_shape=input_shape, hidden_channels=hidden_channels, i2h_kernel=i2h_kernel, h2h_kernel=h2h_kernel, @@ -971,5 +915,4 @@ def __init__(self, input_shape, hidden_channels, h2h_bias_initializer=h2h_bias_initializer, dims=3, conv_layout=conv_layout, - activation=activation, - prefix=prefix, params=params) + activation=activation) diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py index a161ec75b76e..155b6bd54207 100644 --- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py @@ -22,6 +22,7 @@ from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length from ... import tensor_types +from ...parameter import Parameter from ....base import _as_list class VariationalDropoutCell(ModifierCell): @@ -239,12 +240,6 @@ class LSTMPCell(HybridRecurrentCell): to zero. h2h_bias_initializer : str or Initializer Initializer for the bias vector. - prefix : str, default ``'lstmp_``' - Prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None - Container for weight sharing between cells. - Created if `None`. Inputs: - **data**: input tensor with shape `(batch_size, input_size)`. - **states**: a list of two initial recurrent state tensors, with shape @@ -258,27 +253,27 @@ def __init__(self, hidden_size, projection_size, i2h_weight_initializer=None, h2h_weight_initializer=None, h2r_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - input_size=0, prefix=None, params=None): - super(LSTMPCell, self).__init__(prefix=prefix, params=params) + input_size=0): + super(LSTMPCell, self).__init__() self._hidden_size = hidden_size self._input_size = input_size self._projection_size = projection_size - self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size), - init=i2h_weight_initializer, - allow_deferred_init=True) - self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, projection_size), - init=h2h_weight_initializer, - allow_deferred_init=True) - self.h2r_weight = self.params.get('h2r_weight', shape=(projection_size, hidden_size), - init=h2r_weight_initializer, - allow_deferred_init=True) - self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,), - init=i2h_bias_initializer, - allow_deferred_init=True) - self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,), - init=h2h_bias_initializer, - allow_deferred_init=True) + self.i2h_weight = Parameter('i2h_weight', shape=(4*hidden_size, input_size), + init=i2h_weight_initializer, + allow_deferred_init=True) + self.h2h_weight = Parameter('h2h_weight', shape=(4*hidden_size, projection_size), + init=h2h_weight_initializer, + allow_deferred_init=True) + self.h2r_weight = Parameter('h2r_weight', shape=(projection_size, hidden_size), + init=h2r_weight_initializer, + allow_deferred_init=True) + self.i2h_bias = Parameter('i2h_bias', shape=(4*hidden_size,), + init=i2h_bias_initializer, + allow_deferred_init=True) + self.h2h_bias = Parameter('h2h_bias', shape=(4*hidden_size,), + init=h2h_bias_initializer, + allow_deferred_init=True) def state_info(self, batch_size=0): return [{'shape': (batch_size, self._projection_size), '__layout__': 'NC'}, @@ -369,7 +364,7 @@ def dynamic_unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, >>> seq_len = 3 >>> batch_size = 2 >>> input_size = 5 - >>> cell = mx.gluon.rnn.LSTMCell(input_size, prefix='rnn_') + >>> cell = mx.gluon.rnn.LSTMCell(input_size) >>> cell.initialize(ctx=mx.cpu()) >>> rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size)) >>> state_shape = (batch_size, input_size) diff --git a/python/mxnet/gluon/model_zoo/vision/alexnet.py b/python/mxnet/gluon/model_zoo/vision/alexnet.py index daf4617cd12e..7bdacc915fb3 100644 --- a/python/mxnet/gluon/model_zoo/vision/alexnet.py +++ b/python/mxnet/gluon/model_zoo/vision/alexnet.py @@ -38,29 +38,27 @@ class AlexNet(HybridBlock): """ def __init__(self, classes=1000, **kwargs): super(AlexNet, self).__init__(**kwargs) - with self.name_scope(): - self.features = nn.HybridSequential(prefix='') - with self.features.name_scope(): - self.features.add(nn.Conv2D(64, kernel_size=11, strides=4, - padding=2, activation='relu')) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) - self.features.add(nn.Conv2D(192, kernel_size=5, padding=2, - activation='relu')) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) - self.features.add(nn.Conv2D(384, kernel_size=3, padding=1, - activation='relu')) - self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, - activation='relu')) - self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, - activation='relu')) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) - self.features.add(nn.Flatten()) - self.features.add(nn.Dense(4096, activation='relu')) - self.features.add(nn.Dropout(0.5)) - self.features.add(nn.Dense(4096, activation='relu')) - self.features.add(nn.Dropout(0.5)) + self.features = nn.HybridSequential() + self.features.add(nn.Conv2D(64, kernel_size=11, strides=4, + padding=2, activation='relu')) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) + self.features.add(nn.Conv2D(192, kernel_size=5, padding=2, + activation='relu')) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) + self.features.add(nn.Conv2D(384, kernel_size=3, padding=1, + activation='relu')) + self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, + activation='relu')) + self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, + activation='relu')) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) + self.features.add(nn.Flatten()) + self.features.add(nn.Dense(4096, activation='relu')) + self.features.add(nn.Dropout(0.5)) + self.features.add(nn.Dense(4096, activation='relu')) + self.features.add(nn.Dropout(0.5)) - self.output = nn.Dense(classes) + self.output = nn.Dense(classes) def hybrid_forward(self, F, x): x = self.features(x) diff --git a/python/mxnet/gluon/model_zoo/vision/densenet.py b/python/mxnet/gluon/model_zoo/vision/densenet.py index 83febd3658c4..51779b282353 100644 --- a/python/mxnet/gluon/model_zoo/vision/densenet.py +++ b/python/mxnet/gluon/model_zoo/vision/densenet.py @@ -29,15 +29,14 @@ from .... import base # Helpers -def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index): - out = nn.HybridSequential(prefix='stage%d_'%stage_index) - with out.name_scope(): - for _ in range(num_layers): - out.add(_make_dense_layer(growth_rate, bn_size, dropout)) +def _make_dense_block(num_layers, bn_size, growth_rate, dropout): + out = nn.HybridSequential() + for _ in range(num_layers): + out.add(_make_dense_layer(growth_rate, bn_size, dropout)) return out def _make_dense_layer(growth_rate, bn_size, dropout): - new_features = nn.HybridSequential(prefix='') + new_features = nn.HybridSequential() new_features.add(nn.BatchNorm()) new_features.add(nn.Activation('relu')) new_features.add(nn.Conv2D(bn_size * growth_rate, kernel_size=1, use_bias=False)) @@ -47,14 +46,14 @@ def _make_dense_layer(growth_rate, bn_size, dropout): if dropout: new_features.add(nn.Dropout(dropout)) - out = HybridConcurrent(axis=1, prefix='') + out = HybridConcurrent(axis=1) out.add(Identity()) out.add(new_features) return out def _make_transition(num_output_features): - out = nn.HybridSequential(prefix='') + out = nn.HybridSequential() out.add(nn.BatchNorm()) out.add(nn.Activation('relu')) out.add(nn.Conv2D(num_output_features, kernel_size=1, use_bias=False)) @@ -86,27 +85,26 @@ def __init__(self, num_init_features, growth_rate, block_config, bn_size=4, dropout=0, classes=1000, **kwargs): super(DenseNet, self).__init__(**kwargs) - with self.name_scope(): - self.features = nn.HybridSequential(prefix='') - self.features.add(nn.Conv2D(num_init_features, kernel_size=7, - strides=2, padding=3, use_bias=False)) - self.features.add(nn.BatchNorm()) - self.features.add(nn.Activation('relu')) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) - # Add dense blocks - num_features = num_init_features - for i, num_layers in enumerate(block_config): - self.features.add(_make_dense_block(num_layers, bn_size, growth_rate, dropout, i+1)) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - self.features.add(_make_transition(num_features // 2)) - num_features = num_features // 2 - self.features.add(nn.BatchNorm()) - self.features.add(nn.Activation('relu')) - self.features.add(nn.AvgPool2D(pool_size=7)) - self.features.add(nn.Flatten()) - - self.output = nn.Dense(classes) + self.features = nn.HybridSequential() + self.features.add(nn.Conv2D(num_init_features, kernel_size=7, + strides=2, padding=3, use_bias=False)) + self.features.add(nn.BatchNorm()) + self.features.add(nn.Activation('relu')) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) + # Add dense blocks + num_features = num_init_features + for i, num_layers in enumerate(block_config): + self.features.add(_make_dense_block(num_layers, bn_size, growth_rate, dropout)) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + self.features.add(_make_transition(num_features // 2)) + num_features = num_features // 2 + self.features.add(nn.BatchNorm()) + self.features.add(nn.Activation('relu')) + self.features.add(nn.AvgPool2D(pool_size=7)) + self.features.add(nn.Flatten()) + + self.output = nn.Dense(classes) def hybrid_forward(self, F, x): x = self.features(x) diff --git a/python/mxnet/gluon/model_zoo/vision/inception.py b/python/mxnet/gluon/model_zoo/vision/inception.py index 6bdc526a6a13..b13b0acf5949 100644 --- a/python/mxnet/gluon/model_zoo/vision/inception.py +++ b/python/mxnet/gluon/model_zoo/vision/inception.py @@ -30,14 +30,14 @@ # Helpers def _make_basic_conv(**kwargs): - out = nn.HybridSequential(prefix='') + out = nn.HybridSequential() out.add(nn.Conv2D(use_bias=False, **kwargs)) out.add(nn.BatchNorm(epsilon=0.001)) out.add(nn.Activation('relu')) return out def _make_branch(use_pool, *conv_settings): - out = nn.HybridSequential(prefix='') + out = nn.HybridSequential() if use_pool == 'avg': out.add(nn.AvgPool2D(pool_size=3, strides=1, padding=1)) elif use_pool == 'max': @@ -51,102 +51,97 @@ def _make_branch(use_pool, *conv_settings): out.add(_make_basic_conv(**kwargs)) return out -def _make_A(pool_features, prefix): - out = HybridConcurrent(axis=1, prefix=prefix) - with out.name_scope(): - out.add(_make_branch(None, - (64, 1, None, None))) - out.add(_make_branch(None, - (48, 1, None, None), - (64, 5, None, 2))) - out.add(_make_branch(None, - (64, 1, None, None), - (96, 3, None, 1), - (96, 3, None, 1))) - out.add(_make_branch('avg', - (pool_features, 1, None, None))) +def _make_A(pool_features): + out = HybridConcurrent(axis=1) + out.add(_make_branch(None, + (64, 1, None, None))) + out.add(_make_branch(None, + (48, 1, None, None), + (64, 5, None, 2))) + out.add(_make_branch(None, + (64, 1, None, None), + (96, 3, None, 1), + (96, 3, None, 1))) + out.add(_make_branch('avg', + (pool_features, 1, None, None))) return out -def _make_B(prefix): - out = HybridConcurrent(axis=1, prefix=prefix) - with out.name_scope(): - out.add(_make_branch(None, - (384, 3, 2, None))) - out.add(_make_branch(None, - (64, 1, None, None), - (96, 3, None, 1), - (96, 3, 2, None))) - out.add(_make_branch('max')) +def _make_B(): + out = HybridConcurrent(axis=1) + out.add(_make_branch(None, + (384, 3, 2, None))) + out.add(_make_branch(None, + (64, 1, None, None), + (96, 3, None, 1), + (96, 3, 2, None))) + out.add(_make_branch('max')) return out -def _make_C(channels_7x7, prefix): - out = HybridConcurrent(axis=1, prefix=prefix) - with out.name_scope(): - out.add(_make_branch(None, - (192, 1, None, None))) - out.add(_make_branch(None, - (channels_7x7, 1, None, None), - (channels_7x7, (1, 7), None, (0, 3)), - (192, (7, 1), None, (3, 0)))) - out.add(_make_branch(None, - (channels_7x7, 1, None, None), - (channels_7x7, (7, 1), None, (3, 0)), - (channels_7x7, (1, 7), None, (0, 3)), - (channels_7x7, (7, 1), None, (3, 0)), - (192, (1, 7), None, (0, 3)))) - out.add(_make_branch('avg', - (192, 1, None, None))) +def _make_C(channels_7x7): + out = HybridConcurrent(axis=1) + out.add(_make_branch(None, + (192, 1, None, None))) + out.add(_make_branch(None, + (channels_7x7, 1, None, None), + (channels_7x7, (1, 7), None, (0, 3)), + (192, (7, 1), None, (3, 0)))) + out.add(_make_branch(None, + (channels_7x7, 1, None, None), + (channels_7x7, (7, 1), None, (3, 0)), + (channels_7x7, (1, 7), None, (0, 3)), + (channels_7x7, (7, 1), None, (3, 0)), + (192, (1, 7), None, (0, 3)))) + out.add(_make_branch('avg', + (192, 1, None, None))) return out -def _make_D(prefix): - out = HybridConcurrent(axis=1, prefix=prefix) - with out.name_scope(): - out.add(_make_branch(None, - (192, 1, None, None), - (320, 3, 2, None))) - out.add(_make_branch(None, - (192, 1, None, None), - (192, (1, 7), None, (0, 3)), - (192, (7, 1), None, (3, 0)), - (192, 3, 2, None))) - out.add(_make_branch('max')) +def _make_D(): + out = HybridConcurrent(axis=1) + out.add(_make_branch(None, + (192, 1, None, None), + (320, 3, 2, None))) + out.add(_make_branch(None, + (192, 1, None, None), + (192, (1, 7), None, (0, 3)), + (192, (7, 1), None, (3, 0)), + (192, 3, 2, None))) + out.add(_make_branch('max')) return out -def _make_E(prefix): - out = HybridConcurrent(axis=1, prefix=prefix) - with out.name_scope(): - out.add(_make_branch(None, - (320, 1, None, None))) - - branch_3x3 = nn.HybridSequential(prefix='') - out.add(branch_3x3) - branch_3x3.add(_make_branch(None, - (384, 1, None, None))) - branch_3x3_split = HybridConcurrent(axis=1, prefix='') - branch_3x3_split.add(_make_branch(None, - (384, (1, 3), None, (0, 1)))) - branch_3x3_split.add(_make_branch(None, - (384, (3, 1), None, (1, 0)))) - branch_3x3.add(branch_3x3_split) - - branch_3x3dbl = nn.HybridSequential(prefix='') - out.add(branch_3x3dbl) - branch_3x3dbl.add(_make_branch(None, - (448, 1, None, None), - (384, 3, None, 1))) - branch_3x3dbl_split = HybridConcurrent(axis=1, prefix='') - branch_3x3dbl.add(branch_3x3dbl_split) - branch_3x3dbl_split.add(_make_branch(None, - (384, (1, 3), None, (0, 1)))) - branch_3x3dbl_split.add(_make_branch(None, - (384, (3, 1), None, (1, 0)))) - - out.add(_make_branch('avg', - (192, 1, None, None))) +def _make_E(): + out = HybridConcurrent(axis=1) + out.add(_make_branch(None, + (320, 1, None, None))) + + branch_3x3 = nn.HybridSequential() + out.add(branch_3x3) + branch_3x3.add(_make_branch(None, + (384, 1, None, None))) + branch_3x3_split = HybridConcurrent(axis=1) + branch_3x3_split.add(_make_branch(None, + (384, (1, 3), None, (0, 1)))) + branch_3x3_split.add(_make_branch(None, + (384, (3, 1), None, (1, 0)))) + branch_3x3.add(branch_3x3_split) + + branch_3x3dbl = nn.HybridSequential() + out.add(branch_3x3dbl) + branch_3x3dbl.add(_make_branch(None, + (448, 1, None, None), + (384, 3, None, 1))) + branch_3x3dbl_split = HybridConcurrent(axis=1) + branch_3x3dbl.add(branch_3x3dbl_split) + branch_3x3dbl_split.add(_make_branch(None, + (384, (1, 3), None, (0, 1)))) + branch_3x3dbl_split.add(_make_branch(None, + (384, (3, 1), None, (1, 0)))) + + out.add(_make_branch('avg', + (192, 1, None, None))) return out def make_aux(classes): - out = nn.HybridSequential(prefix='') + out = nn.HybridSequential() out.add(nn.AvgPool2D(pool_size=5, strides=3)) out.add(_make_basic_conv(channels=128, kernel_size=1)) out.add(_make_basic_conv(channels=768, kernel_size=5)) @@ -168,30 +163,29 @@ class Inception3(HybridBlock): def __init__(self, classes=1000, **kwargs): super(Inception3, self).__init__(**kwargs) # self.use_aux_logits = use_aux_logits - with self.name_scope(): - self.features = nn.HybridSequential(prefix='') - self.features.add(_make_basic_conv(channels=32, kernel_size=3, strides=2)) - self.features.add(_make_basic_conv(channels=32, kernel_size=3)) - self.features.add(_make_basic_conv(channels=64, kernel_size=3, padding=1)) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) - self.features.add(_make_basic_conv(channels=80, kernel_size=1)) - self.features.add(_make_basic_conv(channels=192, kernel_size=3)) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) - self.features.add(_make_A(32, 'A1_')) - self.features.add(_make_A(64, 'A2_')) - self.features.add(_make_A(64, 'A3_')) - self.features.add(_make_B('B_')) - self.features.add(_make_C(128, 'C1_')) - self.features.add(_make_C(160, 'C2_')) - self.features.add(_make_C(160, 'C3_')) - self.features.add(_make_C(192, 'C4_')) - self.features.add(_make_D('D_')) - self.features.add(_make_E('E1_')) - self.features.add(_make_E('E2_')) - self.features.add(nn.AvgPool2D(pool_size=8)) - self.features.add(nn.Dropout(0.5)) - - self.output = nn.Dense(classes) + self.features = nn.HybridSequential() + self.features.add(_make_basic_conv(channels=32, kernel_size=3, strides=2)) + self.features.add(_make_basic_conv(channels=32, kernel_size=3)) + self.features.add(_make_basic_conv(channels=64, kernel_size=3, padding=1)) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) + self.features.add(_make_basic_conv(channels=80, kernel_size=1)) + self.features.add(_make_basic_conv(channels=192, kernel_size=3)) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) + self.features.add(_make_A(32)) + self.features.add(_make_A(64)) + self.features.add(_make_A(64)) + self.features.add(_make_B()) + self.features.add(_make_C(128)) + self.features.add(_make_C(160)) + self.features.add(_make_C(160)) + self.features.add(_make_C(192)) + self.features.add(_make_D()) + self.features.add(_make_E()) + self.features.add(_make_E()) + self.features.add(nn.AvgPool2D(pool_size=8)) + self.features.add(nn.Dropout(0.5)) + + self.output = nn.Dense(classes) def hybrid_forward(self, F, x): x = self.features(x) diff --git a/python/mxnet/gluon/model_zoo/vision/mobilenet.py b/python/mxnet/gluon/model_zoo/vision/mobilenet.py index 88610571252e..69cd1c03ba10 100644 --- a/python/mxnet/gluon/model_zoo/vision/mobilenet.py +++ b/python/mxnet/gluon/model_zoo/vision/mobilenet.py @@ -80,13 +80,12 @@ class LinearBottleneck(nn.HybridBlock): def __init__(self, in_channels, channels, t, stride, **kwargs): super(LinearBottleneck, self).__init__(**kwargs) self.use_shortcut = stride == 1 and in_channels == channels - with self.name_scope(): - self.out = nn.HybridSequential() + self.out = nn.HybridSequential() - _add_conv(self.out, in_channels * t, relu6=True) - _add_conv(self.out, in_channels * t, kernel=3, stride=stride, - pad=1, num_group=in_channels * t, relu6=True) - _add_conv(self.out, channels, active=False, relu6=True) + _add_conv(self.out, in_channels * t, relu6=True) + _add_conv(self.out, in_channels * t, kernel=3, stride=stride, + pad=1, num_group=in_channels * t, relu6=True) + _add_conv(self.out, channels, active=False, relu6=True) def hybrid_forward(self, F, x): out = self.out(x) @@ -113,21 +112,19 @@ class MobileNet(HybridBlock): def __init__(self, multiplier=1.0, classes=1000, **kwargs): super(MobileNet, self).__init__(**kwargs) - with self.name_scope(): - self.features = nn.HybridSequential(prefix='') - with self.features.name_scope(): - _add_conv(self.features, channels=int(32 * multiplier), kernel=3, pad=1, stride=2) - dw_channels = [int(x * multiplier) for x in [32, 64] + [128] * 2 - + [256] * 2 + [512] * 6 + [1024]] - channels = [int(x * multiplier) for x in [64] + [128] * 2 + [256] * 2 - + [512] * 6 + [1024] * 2] - strides = [1, 2] * 3 + [1] * 5 + [2, 1] - for dwc, c, s in zip(dw_channels, channels, strides): - _add_conv_dw(self.features, dw_channels=dwc, channels=c, stride=s) - self.features.add(nn.GlobalAvgPool2D()) - self.features.add(nn.Flatten()) - - self.output = nn.Dense(classes) + self.features = nn.HybridSequential() + _add_conv(self.features, channels=int(32 * multiplier), kernel=3, pad=1, stride=2) + dw_channels = [int(x * multiplier) for x in [32, 64] + [128] * 2 + + [256] * 2 + [512] * 6 + [1024]] + channels = [int(x * multiplier) for x in [64] + [128] * 2 + [256] * 2 + + [512] * 6 + [1024] * 2] + strides = [1, 2] * 3 + [1] * 5 + [2, 1] + for dwc, c, s in zip(dw_channels, channels, strides): + _add_conv_dw(self.features, dw_channels=dwc, channels=c, stride=s) + self.features.add(nn.GlobalAvgPool2D()) + self.features.add(nn.Flatten()) + + self.output = nn.Dense(classes) def hybrid_forward(self, F, x): x = self.features(x) @@ -152,34 +149,31 @@ class MobileNetV2(nn.HybridBlock): def __init__(self, multiplier=1.0, classes=1000, **kwargs): super(MobileNetV2, self).__init__(**kwargs) - with self.name_scope(): - self.features = nn.HybridSequential(prefix='features_') - with self.features.name_scope(): - _add_conv(self.features, int(32 * multiplier), kernel=3, - stride=2, pad=1, relu6=True) - - in_channels_group = [int(x * multiplier) for x in [32] + [16] + [24] * 2 - + [32] * 3 + [64] * 4 + [96] * 3 + [160] * 3] - channels_group = [int(x * multiplier) for x in [16] + [24] * 2 + [32] * 3 - + [64] * 4 + [96] * 3 + [160] * 3 + [320]] - ts = [1] + [6] * 16 - strides = [1, 2] * 2 + [1, 1, 2] + [1] * 6 + [2] + [1] * 3 - - for in_c, c, t, s in zip(in_channels_group, channels_group, ts, strides): - self.features.add(LinearBottleneck(in_channels=in_c, channels=c, - t=t, stride=s)) - - last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 - _add_conv(self.features, last_channels, relu6=True) - - self.features.add(nn.GlobalAvgPool2D()) - - self.output = nn.HybridSequential(prefix='output_') - with self.output.name_scope(): - self.output.add( - nn.Conv2D(classes, 1, use_bias=False, prefix='pred_'), - nn.Flatten() - ) + self.features = nn.HybridSequential() + _add_conv(self.features, int(32 * multiplier), kernel=3, + stride=2, pad=1, relu6=True) + + in_channels_group = [int(x * multiplier) for x in [32] + [16] + [24] * 2 + + [32] * 3 + [64] * 4 + [96] * 3 + [160] * 3] + channels_group = [int(x * multiplier) for x in [16] + [24] * 2 + [32] * 3 + + [64] * 4 + [96] * 3 + [160] * 3 + [320]] + ts = [1] + [6] * 16 + strides = [1, 2] * 2 + [1, 1, 2] + [1] * 6 + [2] + [1] * 3 + + for in_c, c, t, s in zip(in_channels_group, channels_group, ts, strides): + self.features.add(LinearBottleneck(in_channels=in_c, channels=c, + t=t, stride=s)) + + last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 + _add_conv(self.features, last_channels, relu6=True) + + self.features.add(nn.GlobalAvgPool2D()) + + self.output = nn.HybridSequential() + self.output.add( + nn.Conv2D(classes, 1, use_bias=False), + nn.Flatten() + ) def hybrid_forward(self, F, x): x = self.features(x) diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py index 493bb17a969f..147ad6af2d9e 100644 --- a/python/mxnet/gluon/model_zoo/vision/resnet.py +++ b/python/mxnet/gluon/model_zoo/vision/resnet.py @@ -59,14 +59,14 @@ class BasicBlockV1(HybridBlock): """ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): super(BasicBlockV1, self).__init__(**kwargs) - self.body = nn.HybridSequential(prefix='') + self.body = nn.HybridSequential() self.body.add(_conv3x3(channels, stride, in_channels)) self.body.add(nn.BatchNorm()) self.body.add(nn.Activation('relu')) self.body.add(_conv3x3(channels, 1, channels)) self.body.add(nn.BatchNorm()) if downsample: - self.downsample = nn.HybridSequential(prefix='') + self.downsample = nn.HybridSequential() self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride, use_bias=False, in_channels=in_channels)) self.downsample.add(nn.BatchNorm()) @@ -105,7 +105,7 @@ class BottleneckV1(HybridBlock): """ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): super(BottleneckV1, self).__init__(**kwargs) - self.body = nn.HybridSequential(prefix='') + self.body = nn.HybridSequential() self.body.add(nn.Conv2D(channels//4, kernel_size=1, strides=stride)) self.body.add(nn.BatchNorm()) self.body.add(nn.Activation('relu')) @@ -115,7 +115,7 @@ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): self.body.add(nn.Conv2D(channels, kernel_size=1, strides=1)) self.body.add(nn.BatchNorm()) if downsample: - self.downsample = nn.HybridSequential(prefix='') + self.downsample = nn.HybridSequential() self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride, use_bias=False, in_channels=in_channels)) self.downsample.add(nn.BatchNorm()) @@ -253,31 +253,28 @@ class ResNetV1(HybridBlock): def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwargs): super(ResNetV1, self).__init__(**kwargs) assert len(layers) == len(channels) - 1 - with self.name_scope(): - self.features = nn.HybridSequential(prefix='') - if thumbnail: - self.features.add(_conv3x3(channels[0], 1, 0)) - else: - self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False)) - self.features.add(nn.BatchNorm()) - self.features.add(nn.Activation('relu')) - self.features.add(nn.MaxPool2D(3, 2, 1)) - - for i, num_layer in enumerate(layers): - stride = 1 if i == 0 else 2 - self.features.add(self._make_layer(block, num_layer, channels[i+1], - stride, i+1, in_channels=channels[i])) - self.features.add(nn.GlobalAvgPool2D()) - - self.output = nn.Dense(classes, in_units=channels[-1]) - - def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0): - layer = nn.HybridSequential(prefix='stage%d_'%stage_index) - with layer.name_scope(): - layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels, - prefix='')) - for _ in range(layers-1): - layer.add(block(channels, 1, False, in_channels=channels, prefix='')) + self.features = nn.HybridSequential() + if thumbnail: + self.features.add(_conv3x3(channels[0], 1, 0)) + else: + self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False)) + self.features.add(nn.BatchNorm()) + self.features.add(nn.Activation('relu')) + self.features.add(nn.MaxPool2D(3, 2, 1)) + + for i, num_layer in enumerate(layers): + stride = 1 if i == 0 else 2 + self.features.add(self._make_layer(block, num_layer, channels[i+1], + stride, in_channels=channels[i])) + self.features.add(nn.GlobalAvgPool2D()) + + self.output = nn.Dense(classes, in_units=channels[-1]) + + def _make_layer(self, block, layers, channels, stride, in_channels=0): + layer = nn.HybridSequential() + layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels)) + for _ in range(layers-1): + layer.add(block(channels, 1, False, in_channels=channels)) return layer def hybrid_forward(self, F, x): @@ -308,37 +305,34 @@ class ResNetV2(HybridBlock): def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwargs): super(ResNetV2, self).__init__(**kwargs) assert len(layers) == len(channels) - 1 - with self.name_scope(): - self.features = nn.HybridSequential(prefix='') - self.features.add(nn.BatchNorm(scale=False, center=False)) - if thumbnail: - self.features.add(_conv3x3(channels[0], 1, 0)) - else: - self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False)) - self.features.add(nn.BatchNorm()) - self.features.add(nn.Activation('relu')) - self.features.add(nn.MaxPool2D(3, 2, 1)) - - in_channels = channels[0] - for i, num_layer in enumerate(layers): - stride = 1 if i == 0 else 2 - self.features.add(self._make_layer(block, num_layer, channels[i+1], - stride, i+1, in_channels=in_channels)) - in_channels = channels[i+1] + self.features = nn.HybridSequential() + self.features.add(nn.BatchNorm(scale=False, center=False)) + if thumbnail: + self.features.add(_conv3x3(channels[0], 1, 0)) + else: + self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False)) self.features.add(nn.BatchNorm()) self.features.add(nn.Activation('relu')) - self.features.add(nn.GlobalAvgPool2D()) - self.features.add(nn.Flatten()) - - self.output = nn.Dense(classes, in_units=in_channels) - - def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0): - layer = nn.HybridSequential(prefix='stage%d_'%stage_index) - with layer.name_scope(): - layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels, - prefix='')) - for _ in range(layers-1): - layer.add(block(channels, 1, False, in_channels=channels, prefix='')) + self.features.add(nn.MaxPool2D(3, 2, 1)) + + in_channels = channels[0] + for i, num_layer in enumerate(layers): + stride = 1 if i == 0 else 2 + self.features.add(self._make_layer(block, num_layer, channels[i+1], + stride, in_channels=in_channels)) + in_channels = channels[i+1] + self.features.add(nn.BatchNorm()) + self.features.add(nn.Activation('relu')) + self.features.add(nn.GlobalAvgPool2D()) + self.features.add(nn.Flatten()) + + self.output = nn.Dense(classes, in_units=in_channels) + + def _make_layer(self, block, layers, channels, stride, in_channels=0): + layer = nn.HybridSequential() + layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels)) + for _ in range(layers-1): + layer.add(block(channels, 1, False, in_channels=channels)) return layer def hybrid_forward(self, F, x): diff --git a/python/mxnet/gluon/model_zoo/vision/squeezenet.py b/python/mxnet/gluon/model_zoo/vision/squeezenet.py index b97d1274a6f0..81d1e9a43871 100644 --- a/python/mxnet/gluon/model_zoo/vision/squeezenet.py +++ b/python/mxnet/gluon/model_zoo/vision/squeezenet.py @@ -30,10 +30,10 @@ # Helpers def _make_fire(squeeze_channels, expand1x1_channels, expand3x3_channels): - out = nn.HybridSequential(prefix='') + out = nn.HybridSequential() out.add(_make_fire_conv(squeeze_channels, 1)) - paths = HybridConcurrent(axis=1, prefix='') + paths = HybridConcurrent(axis=1) paths.add(_make_fire_conv(expand1x1_channels, 1)) paths.add(_make_fire_conv(expand3x3_channels, 3, 1)) out.add(paths) @@ -41,7 +41,7 @@ def _make_fire(squeeze_channels, expand1x1_channels, expand3x3_channels): return out def _make_fire_conv(channels, kernel_size, padding=0): - out = nn.HybridSequential(prefix='') + out = nn.HybridSequential() out.add(nn.Conv2D(channels, kernel_size, padding=padding)) out.add(nn.Activation('relu')) return out @@ -66,43 +66,42 @@ def __init__(self, version, classes=1000, **kwargs): super(SqueezeNet, self).__init__(**kwargs) assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected".format(version=version)) - with self.name_scope(): - self.features = nn.HybridSequential(prefix='') - if version == '1.0': - self.features.add(nn.Conv2D(96, kernel_size=7, strides=2)) - self.features.add(nn.Activation('relu')) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) - self.features.add(_make_fire(16, 64, 64)) - self.features.add(_make_fire(16, 64, 64)) - self.features.add(_make_fire(32, 128, 128)) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) - self.features.add(_make_fire(32, 128, 128)) - self.features.add(_make_fire(48, 192, 192)) - self.features.add(_make_fire(48, 192, 192)) - self.features.add(_make_fire(64, 256, 256)) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) - self.features.add(_make_fire(64, 256, 256)) - else: - self.features.add(nn.Conv2D(64, kernel_size=3, strides=2)) - self.features.add(nn.Activation('relu')) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) - self.features.add(_make_fire(16, 64, 64)) - self.features.add(_make_fire(16, 64, 64)) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) - self.features.add(_make_fire(32, 128, 128)) - self.features.add(_make_fire(32, 128, 128)) - self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) - self.features.add(_make_fire(48, 192, 192)) - self.features.add(_make_fire(48, 192, 192)) - self.features.add(_make_fire(64, 256, 256)) - self.features.add(_make_fire(64, 256, 256)) - self.features.add(nn.Dropout(0.5)) - - self.output = nn.HybridSequential(prefix='') - self.output.add(nn.Conv2D(classes, kernel_size=1)) - self.output.add(nn.Activation('relu')) - self.output.add(nn.AvgPool2D(13)) - self.output.add(nn.Flatten()) + self.features = nn.HybridSequential() + if version == '1.0': + self.features.add(nn.Conv2D(96, kernel_size=7, strides=2)) + self.features.add(nn.Activation('relu')) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) + self.features.add(_make_fire(16, 64, 64)) + self.features.add(_make_fire(16, 64, 64)) + self.features.add(_make_fire(32, 128, 128)) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) + self.features.add(_make_fire(32, 128, 128)) + self.features.add(_make_fire(48, 192, 192)) + self.features.add(_make_fire(48, 192, 192)) + self.features.add(_make_fire(64, 256, 256)) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) + self.features.add(_make_fire(64, 256, 256)) + else: + self.features.add(nn.Conv2D(64, kernel_size=3, strides=2)) + self.features.add(nn.Activation('relu')) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) + self.features.add(_make_fire(16, 64, 64)) + self.features.add(_make_fire(16, 64, 64)) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) + self.features.add(_make_fire(32, 128, 128)) + self.features.add(_make_fire(32, 128, 128)) + self.features.add(nn.MaxPool2D(pool_size=3, strides=2, ceil_mode=True)) + self.features.add(_make_fire(48, 192, 192)) + self.features.add(_make_fire(48, 192, 192)) + self.features.add(_make_fire(64, 256, 256)) + self.features.add(_make_fire(64, 256, 256)) + self.features.add(nn.Dropout(0.5)) + + self.output = nn.HybridSequential() + self.output.add(nn.Conv2D(classes, kernel_size=1)) + self.output.add(nn.Activation('relu')) + self.output.add(nn.AvgPool2D(13)) + self.output.add(nn.Flatten()) def hybrid_forward(self, F, x): x = self.features(x) diff --git a/python/mxnet/gluon/model_zoo/vision/vgg.py b/python/mxnet/gluon/model_zoo/vision/vgg.py index 8934c16da8e2..4dd04f8b03b2 100644 --- a/python/mxnet/gluon/model_zoo/vision/vgg.py +++ b/python/mxnet/gluon/model_zoo/vision/vgg.py @@ -50,22 +50,21 @@ class VGG(HybridBlock): def __init__(self, layers, filters, classes=1000, batch_norm=False, **kwargs): super(VGG, self).__init__(**kwargs) assert len(layers) == len(filters) - with self.name_scope(): - self.features = self._make_features(layers, filters, batch_norm) - self.features.add(nn.Dense(4096, activation='relu', - weight_initializer='normal', - bias_initializer='zeros')) - self.features.add(nn.Dropout(rate=0.5)) - self.features.add(nn.Dense(4096, activation='relu', - weight_initializer='normal', - bias_initializer='zeros')) - self.features.add(nn.Dropout(rate=0.5)) - self.output = nn.Dense(classes, + self.features = self._make_features(layers, filters, batch_norm) + self.features.add(nn.Dense(4096, activation='relu', weight_initializer='normal', - bias_initializer='zeros') + bias_initializer='zeros')) + self.features.add(nn.Dropout(rate=0.5)) + self.features.add(nn.Dense(4096, activation='relu', + weight_initializer='normal', + bias_initializer='zeros')) + self.features.add(nn.Dropout(rate=0.5)) + self.output = nn.Dense(classes, + weight_initializer='normal', + bias_initializer='zeros') def _make_features(self, layers, filters, batch_norm): - featurizer = nn.HybridSequential(prefix='') + featurizer = nn.HybridSequential() for i, num in enumerate(layers): for _ in range(num): featurizer.add(nn.Conv2D(filters[i], kernel_size=3, padding=1, diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index 3cccc851e39b..991a60b00fc6 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -22,6 +22,7 @@ from ... import initializer from ..block import HybridBlock +from ..parameter import Parameter from ...util import is_np_array @@ -134,9 +135,7 @@ class PReLU(HybridBlock): def __init__(self, alpha_initializer=initializer.Constant(0.25), in_channels=1, **kwargs): super(PReLU, self).__init__(**kwargs) - with self.name_scope(): - self.alpha = self.params.get('alpha', shape=(in_channels,), - init=alpha_initializer) + self.alpha = Parameter('alpha', shape=(in_channels,), init=alpha_initializer) def hybrid_forward(self, F, x, alpha): leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 8e364532a2f7..0f035eec31d3 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -29,6 +29,7 @@ from ..utils import _indent from ... import ndarray as nd, symbol as sym from ...util import is_np_array +from ..parameter import Parameter class Sequential(Block): @@ -37,13 +38,11 @@ class Sequential(Block): Example:: net = nn.Sequential() - # use net's name_scope to give child Blocks appropriate names. - with net.name_scope(): - net.add(nn.Dense(10, activation='relu')) - net.add(nn.Dense(20)) + net.add(nn.Dense(10, activation='relu')) + net.add(nn.Dense(20)) """ - def __init__(self, prefix=None, params=None): - super(Sequential, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(Sequential, self).__init__() self._layers = [] def add(self, *blocks): @@ -73,9 +72,8 @@ def __repr__(self): def __getitem__(self, key): layers = list(self._children.values())[key] if isinstance(layers, list): - net = type(self)(prefix=self._prefix) - with net.name_scope(): - net.add(*(l() for l in layers)) + net = type(self)() + net.add(*(l() for l in layers)) return net else: return layers() @@ -96,8 +94,8 @@ def hybridize(self, active=True, **kwargs): """ if self._children and all(isinstance(c(), HybridBlock) for c in self._children.values()): warnings.warn( - "All children of this Sequential layer '%s' are HybridBlocks. Consider " - "using HybridSequential for the best performance."%self.prefix, stacklevel=2) + "All children of this Sequential layer '%s'\n are HybridBlocks. Consider " + "using HybridSequential for the best performance."%repr(self), stacklevel=2) super(Sequential, self).hybridize(active, **kwargs) @@ -107,14 +105,12 @@ class HybridSequential(HybridBlock): Example:: net = nn.HybridSequential() - # use net's name_scope to give child Blocks appropriate names. - with net.name_scope(): - net.add(nn.Dense(10, activation='relu')) - net.add(nn.Dense(20)) + net.add(nn.Dense(10, activation='relu')) + net.add(nn.Dense(20)) net.hybridize() """ - def __init__(self, prefix=None, params=None): - super(HybridSequential, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(HybridSequential, self).__init__() self._layers = [] def add(self, *blocks): @@ -144,9 +140,8 @@ def __repr__(self): def __getitem__(self, key): layers = list(self._children.values())[key] if isinstance(layers, list): - net = type(self)(prefix=self._prefix) - with net.name_scope(): - net.add(*(l() for l in layers)) + net = type(self)() + net.add(*(l() for l in layers)) return net else: return layers() @@ -194,10 +189,6 @@ class Dense(HybridBlock): Size of the input data. If not specified, initialization will be deferred to the first time `forward` is called and `in_units` will be inferred from the shape of input data. - prefix : str or None - See document of `Block`. - params : ParameterDict or None - See document of `Block`. Inputs: @@ -216,22 +207,21 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True, in_units=0, **kwargs): super(Dense, self).__init__(**kwargs) self._flatten = flatten - with self.name_scope(): - self._units = units - self._in_units = in_units - self.weight = self.params.get('weight', shape=(units, in_units), - init=weight_initializer, dtype=dtype, - allow_deferred_init=True) - if use_bias: - self.bias = self.params.get('bias', shape=(units,), - init=bias_initializer, dtype=dtype, - allow_deferred_init=True) - else: - self.bias = None - if activation is not None: - self.act = Activation(activation, prefix=activation+'_') - else: - self.act = None + self._units = units + self._in_units = in_units + self.weight = Parameter('weight', shape=(units, in_units), + init=weight_initializer, dtype=dtype, + allow_deferred_init=True) + if use_bias: + self.bias = Parameter('bias', shape=(units,), + init=bias_initializer, dtype=dtype, + allow_deferred_init=True) + else: + self.bias = None + if activation is not None: + self.act = Activation(activation) + else: + self.act = None def hybrid_forward(self, F, x, weight, bias=None): fc = F.npx.fully_connected if is_np_array() else F.FullyConnected @@ -356,24 +346,24 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True, if in_channels != 0: self.in_channels = in_channels - self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(in_channels,), init=gamma_initializer, + self.gamma = Parameter('gamma', grad_req='write' if scale else 'null', + shape=(in_channels,), init=gamma_initializer, + allow_deferred_init=True, + differentiable=scale) + self.beta = Parameter('beta', grad_req='write' if center else 'null', + shape=(in_channels,), init=beta_initializer, + allow_deferred_init=True, + differentiable=center) + self.running_mean = Parameter('running_mean', grad_req='null', + shape=(in_channels,), + init=running_mean_initializer, + allow_deferred_init=True, + differentiable=False) + self.running_var = Parameter('running_var', grad_req='null', + shape=(in_channels,), + init=running_variance_initializer, allow_deferred_init=True, - differentiable=scale) - self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(in_channels,), init=beta_initializer, - allow_deferred_init=True, - differentiable=center) - self.running_mean = self.params.get('running_mean', grad_req='null', - shape=(in_channels,), - init=running_mean_initializer, - allow_deferred_init=True, - differentiable=False) - self.running_var = self.params.get('running_var', grad_req='null', - shape=(in_channels,), - init=running_variance_initializer, - allow_deferred_init=True, - differentiable=False) + differentiable=False) def cast(self, dtype): if np.dtype(dtype).name == 'float16': @@ -561,9 +551,9 @@ def __init__(self, input_dim, output_dim, dtype='float32', grad_stype = 'row_sparse' if sparse_grad else 'default' self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim, 'dtype': dtype, 'sparse_grad': sparse_grad} - self.weight = self.params.get('weight', shape=(input_dim, output_dim), - init=weight_initializer, dtype=dtype, - allow_deferred_init=True, grad_stype=grad_stype) + self.weight = Parameter('weight', shape=(input_dim, output_dim), + init=weight_initializer, dtype=dtype, + allow_deferred_init=True, grad_stype=grad_stype) def hybrid_forward(self, F, x, weight): embedding = F.npx.embedding if is_np_array() else F.Embedding @@ -666,12 +656,12 @@ def __init__(self, axis=1, epsilon=1e-5, center=True, scale=False, self._kwargs = {'eps': epsilon, 'axis': axis, 'center': center, 'scale': scale} self._axis = axis self._epsilon = epsilon - self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(in_channels,), init=gamma_initializer, - allow_deferred_init=True) - self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(in_channels,), init=beta_initializer, - allow_deferred_init=True) + self.gamma = Parameter('gamma', grad_req='write' if scale else 'null', + shape=(in_channels,), init=gamma_initializer, + allow_deferred_init=True) + self.beta = Parameter('beta', grad_req='write' if center else 'null', + shape=(in_channels,), init=beta_initializer, + allow_deferred_init=True) def hybrid_forward(self, F, x, gamma, beta): if self._axis == 1: @@ -747,19 +737,19 @@ class LayerNorm(HybridBlock): """ def __init__(self, axis=-1, epsilon=1e-5, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', - in_channels=0, prefix=None, params=None): - super(LayerNorm, self).__init__(prefix=prefix, params=params) + in_channels=0): + super(LayerNorm, self).__init__() self._kwargs = {'eps': epsilon, 'axis': axis, 'center': center, 'scale': scale} self._axis = axis self._epsilon = epsilon self._center = center self._scale = scale - self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(in_channels,), init=gamma_initializer, - allow_deferred_init=True) - self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(in_channels,), init=beta_initializer, - allow_deferred_init=True) + self.gamma = Parameter('gamma', grad_req='write' if scale else 'null', + shape=(in_channels,), init=gamma_initializer, + allow_deferred_init=True) + self.beta = Parameter('beta', grad_req='write' if center else 'null', + shape=(in_channels,), init=beta_initializer, + allow_deferred_init=True) def hybrid_forward(self, F, data, gamma, beta): layer_norm = F.npx.layer_norm if is_np_array() else F.LayerNorm @@ -838,19 +828,19 @@ class GroupNorm(HybridBlock): """ def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', - in_channels=0, prefix=None, params=None): - super(GroupNorm, self).__init__(prefix=prefix, params=params) + in_channels=0): + super(GroupNorm, self).__init__() self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale} self._num_groups = num_groups self._epsilon = epsilon self._center = center self._scale = scale - self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(in_channels,), init=gamma_initializer, - allow_deferred_init=True) - self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(in_channels,), init=beta_initializer, - allow_deferred_init=True) + self.gamma = Parameter('gamma', grad_req='write' if scale else 'null', + shape=(in_channels,), init=gamma_initializer, + allow_deferred_init=True) + self.beta = Parameter('beta', grad_req='write' if center else 'null', + shape=(in_channels,), init=beta_initializer, + allow_deferred_init=True) def hybrid_forward(self, F, data, gamma, beta): norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon) @@ -888,8 +878,8 @@ class Lambda(Block): Output: - ** *outputs **: one or more output data. Their shapes depend on the function. """ - def __init__(self, function, prefix=None): - super(Lambda, self).__init__(prefix=prefix) + def __init__(self, function): + super(Lambda, self).__init__() if isinstance(function, str): assert hasattr(nd, function), \ "Function name %s is not found in ndarray." % function @@ -932,8 +922,8 @@ class HybridLambda(HybridBlock): - ** *outputs **: one or more output data. Their shapes depend on the function. """ - def __init__(self, function, prefix=None): - super(HybridLambda, self).__init__(prefix=prefix) + def __init__(self, function): + super(HybridLambda, self).__init__() if isinstance(function, str): assert hasattr(nd, function) and hasattr(sym, function), \ "Function name %s is not found in symbol/ndarray." % function diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 4682684662cd..1ed0ae4f1914 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -27,6 +27,7 @@ 'ReflectionPad2D'] from ..block import HybridBlock +from ..parameter import Parameter from ... import symbol from ...base import numeric_types from .activations import Activation @@ -96,47 +97,46 @@ class _Conv(HybridBlock): def __init__(self, channels, kernel_size, strides, padding, dilation, groups, layout, in_channels=0, activation=None, use_bias=True, weight_initializer=None, bias_initializer='zeros', - op_name='Convolution', adj=None, prefix=None, params=None): - super(_Conv, self).__init__(prefix=prefix, params=params) - with self.name_scope(): - self._channels = channels - self._in_channels = in_channels - if isinstance(strides, numeric_types): - strides = (strides,)*len(kernel_size) - if isinstance(padding, numeric_types): - padding = (padding,)*len(kernel_size) - if isinstance(dilation, numeric_types): - dilation = (dilation,)*len(kernel_size) - self._op_name = op_name - self._kwargs = { - 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, - 'pad': padding, 'num_filter': channels, 'num_group': groups, - 'no_bias': not use_bias, 'layout': layout} - if adj is not None: - self._kwargs['adj'] = adj - - if is_np_array(): - dshape = [-1]*(len(kernel_size) + 2) - else: - dshape = [0]*(len(kernel_size) + 2) - - dshape[layout.find('N')] = 1 - dshape[layout.find('C')] = in_channels - wshapes = _infer_weight_shape(op_name, dshape, self._kwargs) - self.weight = self.params.get('weight', shape=wshapes[1], - init=weight_initializer, - allow_deferred_init=True) - if use_bias: - self.bias = self.params.get('bias', shape=wshapes[2], - init=bias_initializer, - allow_deferred_init=True) - else: - self.bias = None - - if activation is not None: - self.act = Activation(activation, prefix=activation+'_') - else: - self.act = None + op_name='Convolution', adj=None): + super(_Conv, self).__init__() + self._channels = channels + self._in_channels = in_channels + if isinstance(strides, numeric_types): + strides = (strides,)*len(kernel_size) + if isinstance(padding, numeric_types): + padding = (padding,)*len(kernel_size) + if isinstance(dilation, numeric_types): + dilation = (dilation,)*len(kernel_size) + self._op_name = op_name + self._kwargs = { + 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, + 'pad': padding, 'num_filter': channels, 'num_group': groups, + 'no_bias': not use_bias, 'layout': layout} + if adj is not None: + self._kwargs['adj'] = adj + + if is_np_array(): + dshape = [-1]*(len(kernel_size) + 2) + else: + dshape = [0]*(len(kernel_size) + 2) + + dshape[layout.find('N')] = 1 + dshape[layout.find('C')] = in_channels + wshapes = _infer_weight_shape(op_name, dshape, self._kwargs) + self.weight = Parameter('weight', shape=wshapes[1], + init=weight_initializer, + allow_deferred_init=True) + if use_bias: + self.bias = Parameter('bias', shape=wshapes[2], + init=bias_initializer, + allow_deferred_init=True) + else: + self.bias = None + + if activation is not None: + self.act = Activation(activation) + else: + self.act = None def hybrid_forward(self, F, x, weight, bias=None): if is_np_array(): diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 9f6f540f3a7f..37d5140e7939 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -20,10 +20,10 @@ """Neural network parameter.""" __all__ = ['DeferredInitializationError', 'Parameter', 'Constant', - 'ParameterDict', 'tensor_types'] + 'tensor_types'] -from collections import OrderedDict, defaultdict +import uuid import warnings import weakref import numpy as np @@ -32,7 +32,7 @@ from .. import symbol, ndarray, initializer, context, _deferred_compute as dc from ..context import Context, cpu from .. import autograd -from .utils import _indent, _brief_print_list, shape_is_known +from .utils import shape_is_known from ..util import is_np_shape, is_np_array from .. import numpy as _mx_np # pylint: disable=reimported @@ -61,8 +61,8 @@ class Parameter(object): Parameters ---------- - name : str - Name of this parameter. + name : str, default 'weight' + Name of this parameter. It decides the corresponding default initializer. grad_req : {'write', 'add', 'null'}, default 'write' Specifies how to update gradient to grad arrays. @@ -103,7 +103,7 @@ class Parameter(object): wd_mult : float Local weight decay multiplier for this Parameter. """ - def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, + def __init__(self, name='weight', grad_req='write', shape=None, dtype=mx_real_t, lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False, differentiable=True, stype='default', grad_stype='default'): self._var = None @@ -119,7 +119,8 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, if isinstance(shape, int): shape = (shape,) self._shape = shape - self.name = name + self._name = 'param_{}_{}'.format(str(uuid.uuid4()).replace('-', '_'), name) + self._structured_name = '' self._dtype = dtype self.lr_mult = lr_mult self.wd_mult = wd_mult @@ -127,21 +128,25 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, self.init = init # sparse related storage type information valid_stypes = ['default', 'row_sparse', 'csr'] - assert grad_stype in valid_stypes, "grad_stype for Parameter '%s' must be " \ - "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, grad_stype) - assert stype in valid_stypes, "stype for Parameter '%s' must be " \ - "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, stype) + assert grad_stype in valid_stypes, "grad_stype for Parameter must be " \ + "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (grad_stype) + assert stype in valid_stypes, "stype for Parameter must be " \ + "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (stype) self._grad_stype = grad_stype self._stype = stype def __repr__(self): - s = 'Parameter {name} (shape={shape}, dtype={dtype})' - return s.format(name=self.name, shape=self.shape, dtype=self.dtype) + s = 'Parameter (shape={shape}, dtype={dtype})' + return s.format(shape=self.shape, dtype=self.dtype) @property def grad_req(self): return self._grad_req + @property + def name(self): + return self._name + @grad_req.setter def grad_req(self, req): assert req in ['write', 'add', 'null'], \ @@ -353,7 +358,7 @@ def _finish_deferred_init(self): zeros_fn = ndarray.zeros data = zeros_fn(**kwargs) initializer.create(default_init)( - initializer.InitDesc(self.name, {'__init__': init}), data) + initializer.InitDesc(self.name, {'__init__': init, 'structure': self._structural_name}), data) self._init_impl(data, ctx) @@ -410,7 +415,7 @@ def _reduce(self): return data def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), - force_reinit=False): + force_reinit=False, structural_name=''): """Initializes parameter and gradient arrays. Only used for :py:class:`NDArray` API. Parameters @@ -431,7 +436,10 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), and :py:meth:`Parameter.init` are ``None``. force_reinit : bool, default False Whether to force re-initialization if parameter is already initialized. - + structural_name : str, default "" + The structural name for the parameter in the block. + The value would be accessed in InitDesc.attrs['structure'] by self-defined initializers. + Users may want to initialize parameters based on the block's structure Examples -------- >>> weight = mx.gluon.Parameter('weight', shape=(2, 2)) @@ -460,7 +468,7 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), stacklevel=2) return self._data = self._grad = None - + self._structural_name = structural_name if ctx is None: ctx = [context.current_context()] if isinstance(ctx, Context): @@ -664,6 +672,38 @@ def cast(self, dtype): self._grad = [i.astype(dtype) for i in self._grad] autograd.mark_variables(self._data, self._grad, self.grad_req) + def _check_and_setattr(self, **kwargs): + """check and set attributes for parameter""" + for k, v in kwargs.items(): + if hasattr(self, k) and getattr(self, k) is not None: + existing = getattr(self, k) + if k == 'shape' and len(v) == len(existing): + inferred_shape = [] + matched = True + for dim1, dim2 in zip(v, existing): + if dim1 != dim2 and dim1 > 0 and dim2 > 0: + matched = False + break + elif dim1 == dim2: + inferred_shape.append(dim1) + elif dim1 in (0, -1): # -1 means unknown dim size in np_shape mode + inferred_shape.append(dim2) + else: + inferred_shape.append(dim1) + + if matched: + self._shape = tuple(inferred_shape) + continue + elif k == 'dtype' and np.dtype(v) == np.dtype(existing): + continue + + assert v is None or v == existing, \ + "Cannot retrieve Parameter '%s' because desired attribute " \ + "does not match with stored for attribute '%s': " \ + "desired '%s' vs stored '%s'."%( + self.name, k, str(v), str(getattr(self, k))) + else: + setattr(self, k, v) class Constant(Parameter): """A constant parameter for holding immutable tensors. @@ -673,23 +713,21 @@ class Constant(Parameter): `Constant` s can be created with either:: - const = mx.gluon.Constant('const', [[1,2],[3,4]]) + const = mx.gluon.Constant([[1,2],[3,4]]) or:: class Block(gluon.Block): def __init__(self, **kwargs): super(Block, self).__init__(**kwargs) - self.const = self.params.get_constant('const', [[1,2],[3,4]]) + self.const = mx.gluon.Constant([[1,2],[3,4]]) Parameters ---------- - name : str - Name of the parameter. value : array-like Initial value for the constant. """ - def __init__(self, name, value): + def __init__(self, value): if not isinstance(value, ndarray.NDArray): array_fn = _mx_np.array if is_np_array() else ndarray.array value = array_fn(value) @@ -698,16 +736,16 @@ def __init__(self, name, value): class Init(initializer.Initializer): def _init_weight(self, _, arr): value.copyto(arr) - init_name = 'Constant_{}_{}'.format(name, id(self)) + init_name = 'Constant_{}'.format(id(self)) initializer.alias(init_name)(Init) super(Constant, self).__init__( - name, grad_req='null', shape=value.shape, dtype=value.dtype, + name='const', grad_req='null', shape=value.shape, dtype=value.dtype, init=init_name) def __repr__(self): - s = 'Constant {name} (shape={shape}, dtype={dtype})' - return s.format(name=self.name, shape=self.shape, dtype=self.dtype) + s = 'Constant (shape={shape}, dtype={dtype})' + return s.format(shape=self.shape, dtype=self.dtype) @property def grad_req(self): @@ -719,363 +757,3 @@ def grad_req(self, req): warnings.warn('Constant parameter "{}" does not support ' 'grad_req other than "null", and new value "{}" ' 'is ignored.'.format(self.name, req)) - - -class ParameterDict(object): - """A dictionary managing a set of parameters. - - Parameters - ---------- - prefix : str, default ``''`` - The prefix to be prepended to all Parameters' names created by this dict. - shared : ParameterDict or None - If not ``None``, when this dict's :py:meth:`get` method creates a new parameter, will - first try to retrieve it from "shared" dict. Usually used for sharing - parameters with another Block. - """ - def __init__(self, prefix='', shared=None): - self._prefix = prefix - self._params = OrderedDict() - self._shared = shared - - def __repr__(self): - s = '{name}(\n{content}\n)' - name = self._prefix+' ' if self._prefix else '' - return s.format(name=name, - content='\n'.join([_indent(' {0}'.format(v), 2) - for v in self.values()])) - - def __getitem__(self, key): - return self._params[key] - - def __iter__(self): - return iter(self._params) - - def items(self): - return self._params.items() - - def keys(self): - return self._params.keys() - - def values(self): - return self._params.values() - - @property - def prefix(self): - """Prefix of this dict. It will be prepended to :py:class:`Parameter`s' name created - with :py:func:`get`.""" - return self._prefix - - def _get_impl(self, name): - if name in self._params: - return self._params[name] - if self._shared is not None and name in self._shared._params: - self._params[name] = self._shared._params[name] - return self._shared._params[name] - return None - - def get(self, name, **kwargs): - """Retrieves a :py:class:`Parameter` with name ``self.prefix+name``. If not found, - :py:func:`get` will first try to retrieve it from "shared" dict. If still not - found, :py:func:`get` will create a new :py:class:`Parameter` with key-word arguments and - insert it to self. - - Parameters - ---------- - name : str - Name of the desired Parameter. It will be prepended with this dictionary's - prefix. - **kwargs : dict - The rest of key-word arguments for the created :py:class:`Parameter`. - - Returns - ------- - Parameter - The created or retrieved :py:class:`Parameter`. - """ - name = self.prefix + name - param = self._get_impl(name) - if param is None: # pylint: disable=too-many-nested-blocks - param = Parameter(name, **kwargs) - self._params[name] = param - else: - for k, v in kwargs.items(): - if hasattr(param, k) and getattr(param, k) is not None: - existing = getattr(param, k) - if k == 'shape' and len(v) == len(existing): - inferred_shape = [] - matched = True - for dim1, dim2 in zip(v, existing): - if dim1 != dim2 and dim1 > 0 and dim2 > 0: - matched = False - break - elif dim1 == dim2: - inferred_shape.append(dim1) - elif dim1 in (0, -1): # -1 means unknown dim size in np_shape mode - inferred_shape.append(dim2) - else: - inferred_shape.append(dim1) - - if matched: - param._shape = tuple(inferred_shape) - continue - elif k == 'dtype' and np.dtype(v) == np.dtype(existing): - continue - - assert v is None or v == existing, \ - "Cannot retrieve Parameter '%s' because desired attribute " \ - "does not match with stored for attribute '%s': " \ - "desired '%s' vs stored '%s'."%( - name, k, str(v), str(getattr(param, k))) - else: - setattr(param, k, v) - return param - - def get_constant(self, name, value=None): - """Retrieves a :py:class:`.Constant` with name ``self.prefix+name``. If not found, - :py:func:`get` will first try to retrieve it from "shared" dict. If still not - found, :py:func:`get` will create a new :py:class:`.Constant` with key-word - arguments and insert it to self. - - Parameters - ---------- - name : str - Name of the desired Constant. It will be prepended with this dictionary's - prefix. - value : array-like - Initial value of constant. - - Returns - ------- - :py:class:`.Constant` - The created or retrieved :py:class:`.Constant`. - """ - name = self.prefix + name - param = self._get_impl(name) - if param is None: - if value is None: - raise KeyError("No constant named '{}'. Please specify value " \ - "if you want to create a new constant.".format( - name)) - param = Constant(name, value) - self._params[name] = param - elif value is not None: - assert isinstance(param, Constant), \ - "Parameter '{}' already exists but it is not a constant.".format( - name) - if isinstance(value, ndarray.NDArray): - value = value.asnumpy() - assert param.shape == value.shape and \ - (param.value.asnumpy() == value).all(), \ - "Constant '{}' already exists but it's value doesn't match new " \ - "value".format(name) - return param - - def update(self, other): - """Copies all Parameters in ``other`` to self.""" - for k, v in other.items(): - if k in self._params: - assert self._params[k] is v, \ - "Cannot update self with other because they have different " \ - "Parameters with the same name '%s'"%k - - for k, v in other.items(): - self._params[k] = v - - def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, - force_reinit=False): - """Initializes all Parameters managed by this dictionary to be used for :py:class:`NDArray` - API. It has no effect when using :py:class:`Symbol` API. - - Parameters - ---------- - init : Initializer - Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``. - Otherwise, :py:meth:`Parameter.init` takes precedence. - ctx : Context or list of Context - Keeps a copy of Parameters on one or many context(s). - verbose : bool, default False - Whether to verbosely print out details on initialization. - force_reinit : bool, default False - Whether to force re-initialization if parameter is already initialized. - """ - if verbose: - init.set_verbosity(verbose=verbose) - for _, v in self.items(): - v.initialize(None, ctx, init, force_reinit=force_reinit) - - def zero_grad(self): - """Sets all Parameters' gradient buffer to 0.""" - # collect gradient arrays for each ctx - arrays = defaultdict(list) - for p in self.values(): - if p.grad_req == 'null' or p._grad is None: - continue - for g in p.list_grad(): - if g.stype == 'row_sparse': - ndarray.zeros_like(g, out=g) - else: - arrays[g.ctx].append(g) - - if len(arrays) == 0: - return - - if is_np_array(): - for arr in arrays.values(): - for ele in arr: - ele[()] = 0 - else: - for arr in arrays.values(): - ndarray.reset_arrays(*arr, num_arrays=len(arr)) - - def reset_ctx(self, ctx): - """Re-assign all Parameters to other contexts. - - Parameters - ---------- - ctx : Context or list of Context, default :py:meth:`context.current_context()`. - Assign Parameter to given context. If ctx is a list of Context, a - copy will be made for each context. - """ - for i in self.values(): - i.reset_ctx(ctx) - - def list_ctx(self): - """Returns a list of all the contexts on which the underlying Parameters - are initialized.""" - s = set() - for i in self.values(): - s.update(i.list_ctx()) - return list(s) - - def setattr(self, name, value): - """Set an attribute to a new value for all Parameters. - - For example, set grad_req to null if you don't need gradient w.r.t a - model's Parameters:: - - model.collect_params().setattr('grad_req', 'null') - - or change the learning rate multiplier:: - - model.collect_params().setattr('lr_mult', 0.5) - - Parameters - ---------- - name : str - Name of the attribute. - value : valid type for attribute name - The new value for the attribute. - """ - for i in self.values(): - setattr(i, name, value) - - def save(self, filename, strip_prefix=''): - """Save parameters to file. - - Parameters - ---------- - filename : str - Path to parameter file. - strip_prefix : str, default '' - Strip prefix from parameter names before saving. - """ - arg_dict = {} - for param in self.values(): - weight = param._reduce() - if not param.name.startswith(strip_prefix): - raise ValueError( - "Prefix '%s' is to be striped before saving, but Parameter's " - "name '%s' does not start with '%s'. " - "this may be due to your Block shares parameters from other " - "Blocks or you forgot to use 'with name_scope()' when creating " - "child blocks. For more info on naming, please see " - "https://mxnet.io/api/python/docs/tutorials/packages/gluon/blocks/naming.html"%( - strip_prefix, param.name, strip_prefix)) - arg_dict[param.name[len(strip_prefix):]] = weight - ndarray.save(filename, arg_dict) - - def load(self, filename, ctx=None, allow_missing=False, - ignore_extra=False, restore_prefix='', cast_dtype=False, - dtype_source="current"): - """Load parameters from file. - - Parameters - ---------- - filename : str - Path to parameter file. - ctx : Context or list of Context - Context(s) initialize loaded parameters on. - allow_missing : bool, default False - Whether to silently skip loading parameters not represents in the file. - ignore_extra : bool, default False - Whether to silently ignore parameters from the file that are not - present in this ParameterDict. - restore_prefix : str, default '' - prepend prefix to names of stored parameters before loading. - cast_dtype : bool, default False - Cast the data type of the parameter - dtype_source : str, default 'current' - must be in {'current', 'saved'} - Only valid if cast_dtype=True, specify the source of the dtype for casting - the parameters - """ - if restore_prefix: - for name in self.keys(): - assert name.startswith(restore_prefix), \ - "restore_prefix is '%s' but Parameters name '%s' does not start " \ - "with '%s'. For more info on naming, please see " \ - "https://mxnet.io/api/python/docs/tutorials/packages/gluon/blocks/naming.html"%( - restore_prefix, name, restore_prefix) - ndarray_load = ndarray.load(filename) - self.load_dict(ndarray_load, ctx, allow_missing, - ignore_extra, restore_prefix, filename, cast_dtype, dtype_source) - - def load_dict(self, param_dict, ctx=None, allow_missing=False, - ignore_extra=False, restore_prefix='', filename=None, cast_dtype=False, - dtype_source="current"): - """Load parameters from dict - - Parameters - ---------- - param_dict : dict - Dictionary containing model parameters, preprended with arg: and aux: names - ctx : Context or list of Context - Context(s) initialize loaded parameters on. - allow_missing : bool, default False - Whether to silently skip loading parameters not represented in the file. - ignore_extra : bool, default False - Whether to silently ignore parameters from the file that are not - present in this ParameterDict. - restore_prefix : str, default '' - prepend prefix to names of stored parameters before loading - filename : str, default None - cast_dtype : bool, default False - Cast the data type of the NDArray loaded from the checkpoint to the dtype - provided by the Parameter if any - """ - lprefix = len(restore_prefix) - loaded = [(k[4:] if k.startswith('arg:') or k.startswith('aux:') else k, v) \ - for k, v in param_dict.items()] if isinstance(param_dict, dict) else param_dict - arg_dict = {restore_prefix+k: v for k, v in loaded} - error_str = "file: %s" % (filename) if filename else "param_dict" - if not allow_missing: - for name in self.keys(): - assert name in arg_dict, \ - "Parameter '%s' is missing in %s, which contains parameters: %s. " \ - "Please make sure source and target networks have the same prefix." \ - "For more info on naming, please see " \ - "https://mxnet.io/api/python/docs/tutorials/packages/gluon/blocks/naming.html"%( - name[lprefix:], error_str, _brief_print_list(arg_dict.keys())) - for name in arg_dict: - if name not in self._params: - assert ignore_extra, \ - "Parameter '%s' loaded from %s is not present in ParameterDict, " \ - "choices are: %s. Set ignore_extra to True to ignore. " \ - "Please make sure source and target networks have the same prefix." \ - "For more info on naming, please see " \ - "https://mxnet.io/api/python/docs/tutorials/packages/gluon/blocks/naming.html"%( - name[lprefix:], error_str, _brief_print_list(self._params.keys())) - continue - self[name]._load_init(arg_dict[name], ctx, cast_dtype=cast_dtype, - dtype_source=dtype_source) diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index fe72c6af8d3b..380fb200d83c 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -29,6 +29,7 @@ from ... import symbol, ndarray from ...base import string_types, numeric_types, _as_list from ..block import Block, HybridBlock +from ..parameter import Parameter from ..utils import _indent from .. import tensor_types from ..nn import LeakyReLU @@ -125,18 +126,9 @@ def _reverse_sequences(sequences, unroll_step, valid_length=None): class RecurrentCell(Block): """Abstract base class for RNN cells - Parameters - ---------- - prefix : str, optional - Prefix for names of `Block`s - (this prefix is also used for names of weights if `params` is `None` - i.e. if `params` are being created and not reused) - params : Parameter or None, default None - Container for weight sharing between cells. - A new Parameter container is created if `params` is `None`. """ - def __init__(self, prefix=None, params=None): - super(RecurrentCell, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(RecurrentCell, self).__init__() self._modified = False self.reset() @@ -187,7 +179,7 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): info.update(kwargs) else: info = kwargs - state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter), + state = func(name='begin_state_%d'%(self._init_counter), **info) states.append(state) return states @@ -317,8 +309,8 @@ def forward(self, inputs, states): class HybridRecurrentCell(RecurrentCell, HybridBlock): """HybridRecurrentCell supports hybridize.""" - def __init__(self, prefix=None, params=None): - super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(HybridRecurrentCell, self).__init__() def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError @@ -356,12 +348,6 @@ class RNNCell(HybridRecurrentCell): input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. - prefix : str, default ``'rnn_'`` - Prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None - Container for weight sharing between cells. - Created if `None`. Inputs: @@ -377,23 +363,23 @@ class RNNCell(HybridRecurrentCell): def __init__(self, hidden_size, activation='tanh', i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - input_size=0, prefix=None, params=None): - super(RNNCell, self).__init__(prefix=prefix, params=params) + input_size=0): + super(RNNCell, self).__init__() self._hidden_size = hidden_size self._activation = activation self._input_size = input_size - self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size), - init=i2h_weight_initializer, - allow_deferred_init=True) - self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size), - init=h2h_weight_initializer, - allow_deferred_init=True) - self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,), - init=i2h_bias_initializer, - allow_deferred_init=True) - self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,), - init=h2h_bias_initializer, - allow_deferred_init=True) + self.i2h_weight = Parameter('i2h_weight', shape=(hidden_size, input_size), + init=i2h_weight_initializer, + allow_deferred_init=True) + self.h2h_weight = Parameter('h2h_weight', shape=(hidden_size, hidden_size), + init=h2h_weight_initializer, + allow_deferred_init=True) + self.i2h_bias = Parameter('i2h_bias', shape=(hidden_size,), + init=i2h_bias_initializer, + allow_deferred_init=True) + self.h2h_bias = Parameter('h2h_bias', shape=(hidden_size,), + init=h2h_bias_initializer, + allow_deferred_init=True) def state_info(self, batch_size=0): return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}] @@ -466,12 +452,6 @@ class LSTMCell(HybridRecurrentCell): input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. - prefix : str, default ``'lstm_'`` - Prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None, default None - Container for weight sharing between cells. - Created if `None`. activation : str, default 'tanh' Activation type to use. See nd/symbol Activation for supported types. @@ -493,24 +473,23 @@ class LSTMCell(HybridRecurrentCell): def __init__(self, hidden_size, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - input_size=0, prefix=None, params=None, activation='tanh', - recurrent_activation='sigmoid'): - super(LSTMCell, self).__init__(prefix=prefix, params=params) + input_size=0, activation='tanh', recurrent_activation='sigmoid'): + super(LSTMCell, self).__init__() self._hidden_size = hidden_size self._input_size = input_size - self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size), - init=i2h_weight_initializer, - allow_deferred_init=True) - self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size), - init=h2h_weight_initializer, - allow_deferred_init=True) - self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,), - init=i2h_bias_initializer, - allow_deferred_init=True) - self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,), - init=h2h_bias_initializer, - allow_deferred_init=True) + self.i2h_weight = Parameter('i2h_weight', shape=(4*hidden_size, input_size), + init=i2h_weight_initializer, + allow_deferred_init=True) + self.h2h_weight = Parameter('h2h_weight', shape=(4*hidden_size, hidden_size), + init=h2h_weight_initializer, + allow_deferred_init=True) + self.i2h_bias = Parameter('i2h_bias', shape=(4*hidden_size,), + init=i2h_bias_initializer, + allow_deferred_init=True) + self.h2h_bias = Parameter('h2h_bias', shape=(4*hidden_size,), + init=h2h_bias_initializer, + allow_deferred_init=True) self._activation = activation self._recurrent_activation = recurrent_activation @@ -594,12 +573,6 @@ class GRUCell(HybridRecurrentCell): input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. - prefix : str, default ``'gru_'`` - prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None, default None - Container for weight sharing between cells. - Created if `None`. activation : str, default 'tanh' Activation type to use. See nd/symbol Activation for supported types. @@ -621,23 +594,22 @@ class GRUCell(HybridRecurrentCell): def __init__(self, hidden_size, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - input_size=0, prefix=None, params=None, activation='tanh', - recurrent_activation='sigmoid'): - super(GRUCell, self).__init__(prefix=prefix, params=params) + input_size=0, activation='tanh', recurrent_activation='sigmoid'): + super(GRUCell, self).__init__() self._hidden_size = hidden_size self._input_size = input_size - self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size), - init=i2h_weight_initializer, - allow_deferred_init=True) - self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size), - init=h2h_weight_initializer, - allow_deferred_init=True) - self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,), - init=i2h_bias_initializer, - allow_deferred_init=True) - self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,), - init=h2h_bias_initializer, - allow_deferred_init=True) + self.i2h_weight = Parameter('i2h_weight', shape=(3*hidden_size, input_size), + init=i2h_weight_initializer, + allow_deferred_init=True) + self.h2h_weight = Parameter('h2h_weight', shape=(3*hidden_size, hidden_size), + init=h2h_weight_initializer, + allow_deferred_init=True) + self.i2h_bias = Parameter('i2h_bias', shape=(3*hidden_size,), + init=i2h_bias_initializer, + allow_deferred_init=True) + self.h2h_bias = Parameter('h2h_bias', shape=(3*hidden_size,), + init=h2h_bias_initializer, + allow_deferred_init=True) self._activation = activation self._recurrent_activation = recurrent_activation @@ -702,8 +674,8 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, class SequentialRNNCell(RecurrentCell): """Sequentially stacking multiple RNN cells.""" - def __init__(self, prefix=None, params=None): - super(SequentialRNNCell, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(SequentialRNNCell, self).__init__() self._layers = [] def __repr__(self): @@ -782,8 +754,8 @@ def hybrid_forward(self, *args, **kwargs): class HybridSequentialRNNCell(HybridRecurrentCell): """Sequentially stacking multiple HybridRNN cells.""" - def __init__(self, prefix=None, params=None): - super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(HybridSequentialRNNCell, self).__init__() self._layers = [] def __repr__(self): @@ -877,8 +849,8 @@ class DropoutCell(HybridRecurrentCell): - **out**: output tensor with shape `(batch_size, size)`. - **next_states**: returns input `states` directly. """ - def __init__(self, rate, axes=(), prefix=None, params=None): - super(DropoutCell, self).__init__(prefix, params) + def __init__(self, rate, axes=()): + super(DropoutCell, self).__init__() assert isinstance(rate, numeric_types), "rate must be a number" self._rate = rate self._axes = axes @@ -925,8 +897,7 @@ def __init__(self, base_cell): assert not base_cell._modified, \ "Cell %s is already modified. One cell cannot be modified twice"%base_cell.name base_cell._modified = True - super(ModifierCell, self).__init__(prefix=base_cell.prefix+self._alias(), - params=None) + super(ModifierCell, self).__init__() self.base_cell = base_cell @property @@ -1050,11 +1021,10 @@ class BidirectionalCell(HybridRecurrentCell): r_cell : RecurrentCell Cell for backward unrolling """ - def __init__(self, l_cell, r_cell, output_prefix='bi_'): - super(BidirectionalCell, self).__init__(prefix='', params=None) + def __init__(self, l_cell, r_cell): + super(BidirectionalCell, self).__init__() self.l_cell = l_cell self.r_cell = r_cell - self._output_prefix = output_prefix def __call__(self, inputs, states): raise NotImplementedError("Bidirectional cannot be stepped. Please use unroll") @@ -1105,10 +1075,10 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N if merge_outputs: reversed_r_outputs = F.stack(*reversed_r_outputs, axis=axis) outputs = F.concat(l_outputs, reversed_r_outputs, dim=2, - name='%sout'%self._output_prefix) + name='out') else: - outputs = [F.concat(l_o, r_o, dim=1, name='%st%d'%(self._output_prefix, i)) + outputs = [F.concat(l_o, r_o, dim=1, name='t%d'%(i)) for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed_r_outputs))] if valid_length is not None: outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis, diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index fa5d0831561a..c6d23e5edabf 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -20,13 +20,12 @@ # pylint: disable=too-many-branches, too-many-arguments, no-self-use # pylint: disable=too-many-lines, arguments-differ """Definition of various recurrent neural network layers.""" -import re __all__ = ['RNN', 'LSTM', 'GRU'] from ... import ndarray, symbol from .. import HybridBlock, tensor_types -from . import rnn_cell +from ..parameter import Parameter from ...util import is_np_array @@ -103,8 +102,7 @@ def __init__(self, hidden_size, num_layers, layout, ni = np * self._dir def _register_param(self, name, shape, init, dtype): - p = self.params.get(name, shape=shape, init=init, - allow_deferred_init=True, dtype=dtype) + p = Parameter(name, shape=shape, init=init, allow_deferred_init=True, dtype=dtype) setattr(self, name, p) return p @@ -123,66 +121,9 @@ def __repr__(self): mapping=mapping, **self.__dict__) - def _collect_params_with_prefix(self, prefix=''): - if prefix: - prefix += '.' - pattern = re.compile(r'(l|r)(\d)_(i2h|h2h|h2r)_(weight|bias)\Z') - def convert_key(m, bidirectional): # for compatibility with old parameter format - d, l, g, t = [m.group(i) for i in range(1, 5)] - if bidirectional: - return '_unfused.{}.{}_cell.{}_{}'.format(l, d, g, t) - else: - return '_unfused.{}.{}_{}'.format(l, g, t) - bidirectional = any(pattern.match(k).group(1) == 'r' for k in self._reg_params) - - ret = {prefix + convert_key(pattern.match(key), bidirectional) : val - for key, val in self._reg_params.items()} - for name, child in self._children.items(): - ret.update(child._collect_params_with_prefix(prefix + name)) - return ret - def state_info(self, batch_size=0): raise NotImplementedError - def _unfuse(self): - """Unfuses the fused RNN in to a stack of rnn cells.""" - assert not self._projection_size, "_unfuse does not support projection layer yet!" - assert not self._lstm_state_clip_min and not self._lstm_state_clip_max, \ - "_unfuse does not support state clipping yet!" - get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size, - activation='relu', - **kwargs), - 'rnn_tanh': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size, - activation='tanh', - **kwargs), - 'lstm': lambda **kwargs: rnn_cell.LSTMCell(self._hidden_size, - **kwargs), - 'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size, - **kwargs)}[self._mode] - - stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, params=self.params) - with stack.name_scope(): - ni = self._input_size - for i in range(self._num_layers): - kwargs = {'input_size': ni, - 'i2h_weight_initializer': self._i2h_weight_initializer, - 'h2h_weight_initializer': self._h2h_weight_initializer, - 'i2h_bias_initializer': self._i2h_bias_initializer, - 'h2h_bias_initializer': self._h2h_bias_initializer} - if self._dir == 2: - stack.add(rnn_cell.BidirectionalCell( - get_cell(prefix='l%d_'%i, **kwargs), - get_cell(prefix='r%d_'%i, **kwargs))) - else: - stack.add(get_cell(prefix='l%d_'%i, **kwargs)) - - if self._dropout > 0 and i != self._num_layers - 1: - stack.add(rnn_cell.DropoutCell(self._dropout)) - - ni = self._hidden_size * self._dir - - return stack - def cast(self, dtype): super(_RNNLayer, self).cast(dtype) self._dtype = dtype @@ -219,7 +160,7 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): info.update(kwargs) else: info = kwargs - state = func(name='%sh0_%d' % (self.prefix, i), **info) + state = func(name='h0_%d' % (i), **info) if is_np_array(): state = state.as_np_ndarray() states.append(state) @@ -348,10 +289,6 @@ class RNN(_RNNLayer): If not specified, it will be inferred from input. dtype : str, default 'float32' Type to initialize the parameters and default states to - prefix : str or None - Prefix of this `Block`. - params : ParameterDict or None - Shared Parameters for this `Block`. Inputs: @@ -468,10 +405,6 @@ class LSTM(_RNNLayer): input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. - prefix : str or None - Prefix of this `Block`. - params : `ParameterDict` or `None` - Shared Parameters for this `Block`. Inputs: @@ -582,10 +515,6 @@ class GRU(_RNNLayer): input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. - prefix : str or None - Prefix of this `Block`. - params : ParameterDict or None - Shared Parameters for this `Block`. Inputs: diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index fd03393b6374..66db9235528a 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -20,9 +20,11 @@ """Parameter optimizer.""" __all__ = ['Trainer'] +from collections import OrderedDict + from .. import optimizer as opt from ..model import _create_kvstore, _create_sparse_kvstore -from .parameter import ParameterDict, Parameter +from .parameter import Parameter from ..kvstore import KVStore @@ -41,7 +43,7 @@ class Trainer(object): Parameters ---------- - params : ParameterDict + params : Dict The set of parameters to optimize. optimizer : str or Optimizer The optimizer to use. See @@ -76,7 +78,7 @@ class Trainer(object): def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', compression_params=None, update_on_kvstore=None): param_list = [] - if isinstance(params, (dict, ParameterDict)): + if isinstance(params, (dict, OrderedDict)): for key in sorted(list(params.keys())): param_list.append(params[key]) params = param_list diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 86f36877dbf8..a8a797df7de7 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -2432,14 +2432,14 @@ def __init__(self, np_params): self._np_params = np_params def _init_weight(self, name, arr): - arr[()] = self._np_params[name] + arr[()] = self._np_params[name.attrs['structure']] saved_out_np = None saved_grad_np_l = None params_init = None use_autograd_flags = [False, True] if test_grad else [False] for hybridize in [False, True]: for use_autograd in use_autograd_flags: - net = net_builder(prefix='net_') + net = net_builder() if params_init is None: net.initialize() else: diff --git a/tests/nightly/dist_async_kvstore.py b/tests/nightly/dist_async_kvstore.py index b990b6b3f13e..f1bf13d93d37 100644 --- a/tests/nightly/dist_async_kvstore.py +++ b/tests/nightly/dist_async_kvstore.py @@ -28,11 +28,10 @@ def test_gluon_trainer_type(): def check_trainer_kv_update(weight_stype, update_on_kv): - params = mx.gluon.ParameterDict() - x = params.get('x', shape=(10,1), lr_mult=1.0, stype=weight_stype) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + x = mx.gluon.Parameter('x', shape=(10,1), lr_mult=1.0, stype=weight_stype) + x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') try: - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + trainer = mx.gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) trainer._init_kvstore() assert trainer._kv_initialized diff --git a/tests/nightly/dist_device_sync_kvstore.py b/tests/nightly/dist_device_sync_kvstore.py index a69687605798..b7b4e4c71f0e 100644 --- a/tests/nightly/dist_device_sync_kvstore.py +++ b/tests/nightly/dist_device_sync_kvstore.py @@ -106,11 +106,10 @@ def check_init(kv, cur_keys, cur_shape, device=False): def test_gluon_trainer_type(): def check_trainer_kv_update(update_on_kv): - params = mx.gluon.ParameterDict() - x = params.get('x', shape=(10,1), lr_mult=1.0) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + x = mx.gluon.Parameter('x', shape=(10,1), lr_mult=1.0) + x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') try: - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + trainer = mx.gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) trainer._init_kvstore() assert trainer._kv_initialized diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 4523a361cf88..3f5137ba09b9 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -353,17 +353,19 @@ def check_init(kv, cur_keys, cur_shape, device=False): def test_invalid_operations(): def check_invalid_gluon_trainer_reset(): - params = mx.gluon.ParameterDict() - x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse') - params.initialize(ctx=mx.cpu(0), init='zeros') + x = mx.gluon.Parameter('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse') + params = {'x': x} + x.initialize(ctx=mx.cpu(0), init='zeros') trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) - params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params') + params = {'x': x._reduce()} + mx.nd.save('test_gluon_trainer_reset_' + str(my_rank) + '.params', params) row_id = mx.nd.arange(0, 4) w = x.row_sparse_data(row_id) assert trainer._kv_initialized and trainer._update_on_kvstore mx.nd.waitall() # load would fail to reset kvstore since update_on_kvstore is True - assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params') + params = mx.nd.load('test_gluon_trainer_reset_' + str(my_rank) + '.params') + assert_exception(x._load_init, RuntimeError, params['x'], None) print('worker ' + str(my_rank) + ' passed check_invalid_gluon_trainer_reset') def check_invalid_pull(): @@ -377,10 +379,9 @@ def check_invalid_pull(): def test_gluon_trainer_type(): def check_trainer_kv_type(stype, grad_stype, update_on_kv, expected): - params = mx.gluon.ParameterDict() - x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + x = mx.gluon.Parameter('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) + x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = mx.gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) try: trainer._init_kvstore() diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py index 69380389d48e..30f5114b2c10 100644 --- a/tests/nightly/estimator/test_sentiment_rnn.py +++ b/tests/nightly/estimator/test_sentiment_rnn.py @@ -274,7 +274,7 @@ def test_estimator_gpu(): 'glove', pretrained_file_name='glove.6B.100d.txt', vocabulary=vocab) net.embedding.weight.set_data(glove_embedding.idx_to_vec) - net.embedding.collect_params().setattr('grad_req', 'null') + net.embedding.setattr('grad_req', 'null') acc = run(net, train_dataloader, test_dataloader, num_epochs=num_epochs, ctx=ctx, lr=lr) diff --git a/tests/nightly/model_backwards_compatibility_check/common.py b/tests/nightly/model_backwards_compatibility_check/common.py index 1791b732df22..69f9ffd468ca 100644 --- a/tests/nightly/model_backwards_compatibility_check/common.py +++ b/tests/nightly/model_backwards_compatibility_check/common.py @@ -138,15 +138,12 @@ def create_model_folder(model_name): class Net(gluon.Block): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - # layers created in name_scope will inherit name space - # from parent layer. - self.conv1 = nn.Conv2D(20, kernel_size=(5, 5)) - self.pool1 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) - self.conv2 = nn.Conv2D(50, kernel_size=(5, 5)) - self.pool2 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) - self.fc1 = nn.Dense(500) - self.fc2 = nn.Dense(2) + self.conv1 = nn.Conv2D(20, kernel_size=(5, 5)) + self.pool1 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) + self.conv2 = nn.Conv2D(50, kernel_size=(5, 5)) + self.pool2 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) + self.fc1 = nn.Dense(500) + self.fc2 = nn.Dense(2) def forward(self, x): x = self.pool1(F.tanh(self.conv1(x))) @@ -162,15 +159,12 @@ def forward(self, x): class HybridNet(gluon.HybridBlock): def __init__(self, **kwargs): super(HybridNet, self).__init__(**kwargs) - with self.name_scope(): - # layers created in name_scope will inherit name space - # from parent layer. - self.conv1 = nn.Conv2D(20, kernel_size=(5, 5)) - self.pool1 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) - self.conv2 = nn.Conv2D(50, kernel_size=(5, 5)) - self.pool2 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) - self.fc1 = nn.Dense(500) - self.fc2 = nn.Dense(2) + self.conv1 = nn.Conv2D(20, kernel_size=(5, 5)) + self.pool1 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) + self.conv2 = nn.Conv2D(50, kernel_size=(5, 5)) + self.pool2 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) + self.fc1 = nn.Dense(500) + self.fc2 = nn.Dense(2) def hybrid_forward(self, F, x): x = self.pool1(F.tanh(self.conv1(x))) @@ -186,14 +180,12 @@ def hybrid_forward(self, F, x): class SimpleLSTMModel(gluon.Block): def __init__(self, **kwargs): super(SimpleLSTMModel, self).__init__(**kwargs) - with self.name_scope(): - self.model = mx.gluon.nn.Sequential(prefix='') - with self.model.name_scope(): - self.model.add(mx.gluon.nn.Embedding(30, 10)) - self.model.add(mx.gluon.rnn.LSTM(20)) - self.model.add(mx.gluon.nn.Dense(100)) - self.model.add(mx.gluon.nn.Dropout(0.5)) - self.model.add(mx.gluon.nn.Dense(2, flatten=True, activation='tanh')) + self.model = mx.gluon.nn.Sequential() + self.model.add(mx.gluon.nn.Embedding(30, 10)) + self.model.add(mx.gluon.rnn.LSTM(20)) + self.model.add(mx.gluon.nn.Dense(100)) + self.model.add(mx.gluon.nn.Dropout(0.5)) + self.model.add(mx.gluon.nn.Dense(2, flatten=True, activation='tanh')) def forward(self, x): return self.model(x) diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index eb39fb1777ea..57337dc9c5cf 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -274,8 +274,8 @@ def test_fusion_boolean_inputs(): from mxnet.gluon import HybridBlock class Foo(HybridBlock): - def __init__(self, prefix=None, params=None): - super(Foo, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(Foo, self).__init__() def hybrid_forward(self, F, valid_length): mask = valid_length.astype(np.float32) @@ -293,8 +293,8 @@ def test_fusion_different_dimensions(): from mxnet.gluon import HybridBlock class Foo(HybridBlock): - def __init__(self, prefix=None, params=None): - super(Foo, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(Foo, self).__init__() def hybrid_forward(self, F, x): mask2 = x.astype(np.float32) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index a777546327ae..27d5f5e0ec3d 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -39,7 +39,7 @@ def check_rnn_layer(layer): - layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + layer.initialize(ctx=[mx.cpu(0), mx.gpu(0)]) with mx.gpu(0): x = mx.nd.ones((10, 16, 30)) states = layer.begin_state(16) @@ -58,7 +58,7 @@ def check_rnn_layer(layer): @with_seed() def check_rnn_layer_w_rand_inputs(layer): - layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + layer.initialize(ctx=[mx.cpu(0), mx.gpu(0)]) x = mx.nd.uniform(shape=(10, 16, 30)) with mx.gpu(0): x = x.copyto(mx.gpu(0)) @@ -92,18 +92,17 @@ def test_lstmp(): 'h2r_weight': (projection_size, hidden_size)} weights = {k: rand_ndarray(v) for k, v in shapes.items()} lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, - input_size=input_size, prefix='lstm0_') + input_size=input_size) lstm_cell = gluon.contrib.rnn.LSTMPCell(hidden_size=hidden_size, projection_size=projection_size, - input_size=input_size, - prefix='lstm0_l0_') + input_size=input_size) lstm_layer.initialize(ctx=ctx) lstm_cell.initialize(ctx=ctx) layer_params = lstm_layer.collect_params() cell_params = lstm_cell.collect_params() for k, v in weights.items(): - layer_params['lstm0_l0_' + k].set_data(v.copy()) - cell_params['lstm0_l0_' + k].set_data(v.copy()) + layer_params['l0_' + k].set_data(v.copy()) + cell_params[k].set_data(v.copy()) with autograd.record(): layer_output = lstm_layer(lstm_input.copy()) cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), layout='TNC', @@ -113,8 +112,8 @@ def test_lstmp(): layer_output.backward() cell_output.backward() for k, v in weights.items(): - layer_grad = layer_params['lstm0_l0_' + k].grad() - cell_grad = cell_params['lstm0_l0_' + k].grad() + layer_grad = layer_params['l0_' + k].grad() + cell_grad = cell_params[k].grad() print('checking gradient for {}'.format('lstm0_l0_' + k)) assert_almost_equal(layer_grad, cell_grad, rtol=rtol, atol=atol) check_rnn_layer_forward(gluon.rnn.LSTM( @@ -142,7 +141,7 @@ def test_lstm_clip(): lstm_states = [mx.nd.uniform(shape=(2, batch_size, projection_size), ctx=mx.gpu(0)), mx.nd.uniform(shape=(2, batch_size, hidden_size), ctx=mx.gpu(0))] lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, - input_size=input_size, prefix='lstm0_', + input_size=input_size, bidirectional=True, state_clip_min=clip_min, state_clip_max=clip_max, @@ -172,11 +171,10 @@ def check_layer_bidirectional(size, in_size, proj_size): class RefBiLSTM(gluon.Block): def __init__(self, size, proj_size, **kwargs): super(RefBiLSTM, self).__init__(**kwargs) - with self.name_scope(): - self._lstm_fwd = gluon.rnn.LSTM( - size, projection_size=proj_size, bidirectional=False, prefix='l0') - self._lstm_bwd = gluon.rnn.LSTM( - size, projection_size=proj_size, bidirectional=False, prefix='r0') + self._lstm_fwd = gluon.rnn.LSTM( + size, projection_size=proj_size, bidirectional=False) + self._lstm_bwd = gluon.rnn.LSTM( + size, projection_size=proj_size, bidirectional=False) def forward(self, inpt): fwd = self._lstm_fwd(inpt) @@ -186,32 +184,32 @@ def forward(self, inpt): return nd.concat(fwd, bwd, dim=2) weights = {} for d in ['l', 'r']: - weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform( + weights['{}0_i2h_weight'.format(d)] = mx.random.uniform( shape=(size * 4, in_size)) if proj_size: - weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform( + weights['{}0_h2h_weight'.format(d)] = mx.random.uniform( shape=(size * 4, proj_size)) - weights['lstm_{}0_h2r_weight'.format(d)] = mx.random.uniform( + weights['{}0_h2r_weight'.format(d)] = mx.random.uniform( shape=(proj_size, size)) else: - weights['lstm_{}0_h2h_weight'.format( + weights['{}0_h2h_weight'.format( d)] = mx.random.uniform(shape=(size * 4, size)) - weights['lstm_{}0_i2h_bias'.format( + weights['{}0_i2h_bias'.format( d)] = mx.random.uniform(shape=(size * 4,)) - weights['lstm_{}0_h2h_bias'.format( + weights['{}0_h2h_bias'.format( d)] = mx.random.uniform(shape=(size * 4,)) net = gluon.rnn.LSTM(size, projection_size=proj_size, - bidirectional=True, prefix='lstm_') - ref_net = RefBiLSTM(size, proj_size, prefix='lstm_') + bidirectional=True) + ref_net = RefBiLSTM(size, proj_size) net.initialize() ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() for k in weights: net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', 'l0l0').replace( - 'r0', 'r0l0')].set_data(weights[k]) + ref_net_params[k.replace('l0', '_lstm_fwd.l0').replace( + 'r0', '_lstm_bwd.l0')].set_data(weights[k]) data = mx.random.uniform(shape=(11, 10, in_size)) mx.test_utils.assert_allclose(net(data), ref_net(data), rtol=1e-6) @@ -221,20 +219,20 @@ def forward(self, inpt): def check_layer_bidirectional_varseqlen(size, in_size): weights = {} for d in ['l', 'r']: - weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) - weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) - weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + weights['{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) + weights['{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) + weights['{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + weights['{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_') - ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False, prefix='lstm_ref_') + net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True) + ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False) net.initialize() ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() for k in weights: net_params[k].set_data(weights[k]) - ref_net_params[k.replace("lstm_", "lstm_ref_")].set_data(weights[k]) + ref_net_params[k].set_data(weights[k]) batch_size = 10 num_timesteps = 11 @@ -276,7 +274,7 @@ def check_layer_bidirectional_varseqlen(size, in_size): for k in weights: net_grad = net_params[k].grad() - ref_net_grad = ref_net_params[k.replace('lstm_', 'lstm_ref_')].grad() + ref_net_grad = ref_net_params[k].grad() assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(), rtol=1e-2, atol=1e-6) @@ -452,13 +450,15 @@ def test_symbol_block_fp16(tmpdir): sm = mx.sym.load(tmpfile + '-symbol.json') inputs = mx.sym.var('data', dtype='float16') net_fp16 = mx.gluon.SymbolBlock(sm, inputs) - net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx) + net_fp16.load_parameters(tmpfile + '-0000.params', ctx=ctx) # 3. Get a conv layer's weight parameter name. Conv layer's weight param is # expected to be of dtype casted, fp16. - for param_name in net_fp16.params.keys(): + name = None + for param_name, param in net_fp32.collect_params().items(): if 'conv' in param_name and 'weight' in param_name: + name = param.name break - assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16) + assert np.dtype(net_fp16.params[name].dtype) == np.dtype(np.float16) @with_seed() @@ -469,8 +469,7 @@ def test_large_models(): net = gluon.nn.HybridSequential() largest_num_features = 256 - with net.name_scope(): - net.add(nn.Conv2D(largest_num_features, 3)) + net.add(nn.Conv2D(largest_num_features, 3)) net.hybridize() net.initialize(mx.init.Normal(sigma=0.01), ctx=ctx) @@ -515,9 +514,8 @@ def hybrid_forward(self, F, x): def get_net(num_ops): net = nn.HybridSequential() - with net.name_scope(): - for _ in range(num_ops): - net.add(Flip()) + for _ in range(num_ops): + net.add(Flip()) return net data_shape = (10,) diff --git a/tests/python/gpu/test_gluon_model_zoo_gpu.py b/tests/python/gpu/test_gluon_model_zoo_gpu.py index e505cd037621..427297173bae 100644 --- a/tests/python/gpu/test_gluon_model_zoo_gpu.py +++ b/tests/python/gpu/test_gluon_model_zoo_gpu.py @@ -65,17 +65,16 @@ def test_inference(model_name): # This is to create a model and run the model once to initialize # all parameters. cpu_model = get_model(model_name) - cpu_model.collect_params().initialize(ctx=mx.cpu()) + cpu_model.initialize(ctx=mx.cpu()) cpu_model(mx.nd.array(data, ctx=mx.cpu())) gpu_model = get_model(model_name) - gpu_model.collect_params().initialize(ctx=mx.gpu()) + gpu_model.initialize(ctx=mx.gpu()) gpu_model(mx.nd.array(data, ctx=mx.gpu())) # Force the two models have the same parameters. cpu_params = cpu_model.collect_params() gpu_params = gpu_model.collect_params() for k in cpu_params.keys(): - k = k.replace(cpu_params.prefix, '') cpu_param = cpu_params.get(k) gpu_param = gpu_params.get(k) gpu_param.set_data(cpu_param.data().as_in_context(mx.gpu())) @@ -135,17 +134,16 @@ def test_training(): # This is to create a model and run the model once to initialize # all parameters. cpu_model = get_nn_model(model_name) - cpu_model.collect_params().initialize(ctx=mx.cpu()) + cpu_model.initialize(ctx=mx.cpu()) cpu_model(mx.nd.array(data, ctx=mx.cpu())) gpu_model = get_nn_model(model_name) - gpu_model.collect_params().initialize(ctx=mx.gpu()) + gpu_model.initialize(ctx=mx.gpu()) gpu_model(mx.nd.array(data, ctx=mx.gpu())) # Force the two models have the same parameters. cpu_params = cpu_model.collect_params() gpu_params = gpu_model.collect_params() for k in cpu_params.keys(): - k = k.replace(cpu_params.prefix, '') cpu_param = cpu_params.get(k) gpu_param = gpu_params.get(k) gpu_param.set_data(cpu_param.data().as_in_context(mx.gpu())) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 90bd21f09892..934a10edc782 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1894,7 +1894,7 @@ def test_deformable_convolution_options(): def check_rnn_layer(layer): - layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + layer.initialize(ctx=[mx.cpu(0), mx.gpu(0)]) with mx.gpu(0): x = mx.nd.ones((10, 16, 30)) states = layer.begin_state(16) @@ -1911,7 +1911,7 @@ def check_rnn_layer(layer): assert_almost_equal(g, c, rtol=1e-2, atol=1e-6) def check_rnn_layer_w_rand_inputs(layer): - layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + layer.initialize(ctx=[mx.cpu(0), mx.gpu(0)]) x = mx.nd.uniform(shape=(10, 16, 30)) with mx.gpu(0): x = x.copyto(mx.gpu(0)) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 26ec0818cf4b..e9f9e0171834 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -35,9 +35,8 @@ def test_mkldnn_ndarray_slice(): ctx = mx.cpu() net = gluon.nn.HybridSequential() - with net.name_scope(): - net.add(gluon.nn.Conv2D(channels=32, kernel_size=3, activation=None)) - net.collect_params().initialize(ctx=ctx) + net.add(gluon.nn.Conv2D(channels=32, kernel_size=3, activation=None)) + net.initialize(ctx=ctx) x = mx.nd.array(np.ones([32, 3, 224, 224]), ctx) y = net(x) @@ -47,9 +46,8 @@ def test_mkldnn_ndarray_slice(): @with_seed(1234) def test_mkldnn_engine_threading(): net = gluon.nn.HybridSequential() - with net.name_scope(): - net.add(gluon.nn.Conv2D(channels=32, kernel_size=3, activation=None)) - net.collect_params().initialize(ctx=mx.cpu()) + net.add(gluon.nn.Conv2D(channels=32, kernel_size=3, activation=None)) + net.initialize(ctx=mx.cpu()) class Dummy(gluon.data.Dataset): def __len__(self): return 2 @@ -110,9 +108,8 @@ class Net(gluon.HybridBlock): """ def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(10, (3, 3)) - self.conv1 = nn.Conv2D(5, (3, 3)) + self.conv0 = nn.Conv2D(10, (3, 3)) + self.conv1 = nn.Conv2D(5, (3, 3)) def hybrid_forward(self, F, x, *args, **kwargs): x_reshape = x.reshape((0, 0, 20, 5)) @@ -123,7 +120,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): x = mx.nd.random.uniform(shape=(2, 4, 10, 10)) x.attach_grad() net = Net() - net.collect_params().initialize() + net.initialize() with mx.autograd.record(): out1 = net(x) out1.backward() @@ -144,9 +141,8 @@ class Net(gluon.HybridBlock): """ def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(4, (3, 3)) - self.conv1 = nn.Conv2D(4, (3, 3)) + self.conv0 = nn.Conv2D(4, (3, 3)) + self.conv1 = nn.Conv2D(4, (3, 3)) def hybrid_forward(self, F, x, *args, **kwargs): x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 10, 10)) @@ -157,7 +153,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): x = mx.nd.random.uniform(shape=(2, 10, 10, 10)) x.attach_grad() net = Net() - net.collect_params().initialize() + net.initialize() with mx.autograd.record(): out1 = net(x) out1.backward() @@ -178,9 +174,8 @@ class Net(gluon.HybridBlock): """ def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(4, (3, 3)) - self.conv1 = nn.Conv2D(4, (3, 3)) + self.conv0 = nn.Conv2D(4, (3, 3)) + self.conv1 = nn.Conv2D(4, (3, 3)) def hybrid_forward(self, F, x, *args, **kwargs): x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 8, 9)) @@ -191,7 +186,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): x = mx.nd.random.uniform(shape=(2, 10, 10, 10)) x.attach_grad() net = Net() - net.collect_params().initialize() + net.initialize() with mx.autograd.record(): out1 = net(x) out1.backward() @@ -298,12 +293,11 @@ class BNNet(gluon.HybridBlock): def __init__(self, fuse_relu): super(BNNet, self).__init__() self.fuse_relu = fuse_relu - with self.name_scope(): - if self.fuse_relu: - self.bn = gluon.nn.BatchNormReLU() - else: - self.bn = gluon.nn.BatchNorm() - self.relu = gluon.nn.Activation('relu') + if self.fuse_relu: + self.bn = gluon.nn.BatchNormReLU() + else: + self.bn = gluon.nn.BatchNorm() + self.relu = gluon.nn.Activation('relu') def forward(self, x): y = self.bn(x) @@ -312,8 +306,8 @@ def forward(self, x): return y fused_net = BNNet(fuse_relu=True) unfused_net = BNNet(fuse_relu=False) - fused_net.collect_params().initialize() - unfused_net.collect_params().initialize() + fused_net.initialize() + unfused_net.initialize() in_data = mx.nd.random.normal(shape=shape) no_fuse_outputs = unfused_net.forward(in_data) fuse_outputs = fused_net.forward(in_data) @@ -571,9 +565,8 @@ def hybrid_forward(self, F, x): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv1 = nn.Conv2D(8, kernel_size=5) - self.reshape2D = Reshape2D(2) + self.conv1 = nn.Conv2D(8, kernel_size=5) + self.reshape2D = Reshape2D(2) def hybrid_forward(self, F, x): x = self.conv1(x) diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index d884e911bd3e..e7efbca4e1b1 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -92,7 +92,7 @@ def check_qsym_gluon_forward(path, qsym, qarg_params, qaux_params, data_shape): mx.nd.save(params_path, save_dict) # load back with SymbolBlock net = mx.gluon.SymbolBlock.imports(json_path, ['data'], params_path) - net.collect_params().reset_ctx(ctx = mx.current_context()) + net.reset_ctx(ctx = mx.current_context()) net.hybridize() data = mx.random.uniform(-1.0, 1.0, shape=data_shape) diff --git a/tests/python/profiling/simple_forward.py b/tests/python/profiling/simple_forward.py index 0ad43c89a6f5..4202532cceb5 100644 --- a/tests/python/profiling/simple_forward.py +++ b/tests/python/profiling/simple_forward.py @@ -26,10 +26,9 @@ def simple_forward(): # define simple gluon network with random weights net = nn.Sequential() - with net.name_scope(): - net.add(nn.Dense(128, activation='relu')) - net.add(nn.Dense(64, activation='relu')) - net.add(nn.Dense(10)) + net.add(nn.Dense(128, activation='relu')) + net.add(nn.Dense(64, activation='relu')) + net.add(nn.Dense(10)) net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) input = mx.nd.zeros((128,), ctx=ctx) diff --git a/tests/python/train/test_autograd.py b/tests/python/train/test_autograd.py index f0fdc5ea2576..87548489a5b0 100644 --- a/tests/python/train/test_autograd.py +++ b/tests/python/train/test_autograd.py @@ -32,11 +32,10 @@ def test_autograd(tmpdir): # define network def get_net(): net = nn.Sequential() - net.add(nn.Dense(128, activation='relu', prefix='fc1_')) - net.add(nn.Dense(64, activation='relu', prefix='fc2_')) - net.add(nn.Dense(10, prefix='fc3_')) + net.add(nn.Dense(128, activation='relu')) + net.add(nn.Dense(64, activation='relu')) + net.add(nn.Dense(10)) return net - path = str(tmpdir) get_mnist_ubyte(path) @@ -67,7 +66,7 @@ def score(net, ctx_list): return metric.get()[1] def train(net, epoch, ctx_list): - net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx_list) + net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx_list) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5}) metric = gluon.metric.Accuracy() loss = gluon.loss.SoftmaxCrossEntropyLoss() @@ -97,9 +96,9 @@ def train(net, epoch, ctx_list): acc2 = score(net1, [mx.cpu(0), mx.cpu(1)]) assert acc1 > 0.95 assert abs(acc1 - acc2) < 0.01 - net1.collect_params().save('mnist.params') + net1.save_parameters('mnist.params') net2 = get_net() - net2.collect_params().load('mnist.params', ctx=[mx.cpu(0)]) + net2.load_parameters('mnist.params', ctx=[mx.cpu(0)]) acc3 = score(net2, [mx.cpu(0)]) assert abs(acc3 - acc1) < 0.0001 diff --git a/tests/python/unittest/onnx/mxnet_export_test.py b/tests/python/unittest/onnx/mxnet_export_test.py index 3b3f1c5ba1ee..61d00f2fca3c 100644 --- a/tests/python/unittest/onnx/mxnet_export_test.py +++ b/tests/python/unittest/onnx/mxnet_export_test.py @@ -58,7 +58,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params= data = nd.random.uniform(0, 1, (1, 1024)) output = _force_list(net(data)) # initialize weights net_sym = _optional_group(net(sym.Variable('data')), group_outputs) - net_params = {name: param._reduce() for name, param in net.collect_params().items()} + net_params = {param.name: param._reduce() for name, param in net.collect_params().items()} net_params.update(extra_params) with tempfile.TemporaryDirectory() as tmpdirname: onnx_file_path = os.path.join(tmpdirname, 'net.onnx') @@ -83,8 +83,8 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params= class SplitConcatBlock(HybridBlock): """Block which creates two splits and later concatenates them""" - def __init__(self, name): - super(SplitConcatBlock, self).__init__(name) + def __init__(self): + super(SplitConcatBlock, self).__init__() def hybrid_forward(self, F, x): splits = F.split(x, axis=1, num_outputs=2) @@ -100,9 +100,8 @@ def setUp(self): @with_seed() def test_onnx_export_single_output(self): - net = nn.HybridSequential(prefix='single_output_net') - with net.name_scope(): - net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) + net = nn.HybridSequential() + net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) _check_onnx_export(net) @with_seed() @@ -110,10 +109,9 @@ def test_onnx_export_multi_output(self): class MultiOutputBlock(nn.HybridBlock): def __init__(self): super(MultiOutputBlock, self).__init__() - with self.name_scope(): - self.net = nn.HybridSequential() - for i in range(10): - self.net.add(nn.Dense(100 + i * 10, activation='relu')) + self.net = nn.HybridSequential() + for i in range(10): + self.net.add(nn.Dense(100 + i * 10, activation='relu')) def hybrid_forward(self, F, x): out = tuple(block()(x) for block in self.net._children.values()) @@ -125,29 +123,25 @@ def hybrid_forward(self, F, x): @with_seed() def test_onnx_export_list_shape(self): - net = nn.HybridSequential(prefix='list_shape_net') - with net.name_scope(): - net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) + net = nn.HybridSequential() + net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) _check_onnx_export(net, shape_type=list) @with_seed() def test_onnx_export_extra_params(self): - net = nn.HybridSequential(prefix='extra_params_net') - with net.name_scope(): - net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) + net = nn.HybridSequential() + net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) _check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])}) @with_seed() def test_onnx_export_slice(self): - net = nn.HybridSequential(prefix='slice_net') - with net.name_scope(): - net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"), nn.Dense(10)) + net = nn.HybridSequential() + net.add(nn.Dense(100, activation='relu'), SplitConcatBlock(), nn.Dense(10)) _check_onnx_export(net) @with_seed() def test_onnx_export_slice_changing_shape(self): - net = nn.HybridSequential(prefix='slice_net_changing_shape') - with net.name_scope(): - net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"), - nn.Dense(50, activation='relu'), SplitConcatBlock("splitConcat2"), nn.Dense(10)) + net = nn.HybridSequential() + net.add(nn.Dense(100, activation='relu'), SplitConcatBlock(), + nn.Dense(50, activation='relu'), SplitConcatBlock(), nn.Dense(10)) _check_onnx_export(net) diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 3c8950d6de4e..4de075c8fad2 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1057,9 +1057,9 @@ def cond(inputs, free): ) class RNNLayer(gluon.HybridBlock): - def __init__(self, cell_type, hidden_size, prefix=None, params=None): - super(RNNLayer, self).__init__(prefix=prefix, params=params) - self.cell = cell_type(hidden_size, prefix='rnn_') + def __init__(self, cell_type, hidden_size): + super(RNNLayer, self).__init__() + self.cell = cell_type(hidden_size) def hybrid_forward(self, F, inputs, states): out, states = F.contrib.foreach(self.cell, inputs, states) @@ -1494,8 +1494,8 @@ def step_nd(in1, states): @with_seed() def test_cut_subgraph_foreach(): class TestLayer(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestLayer, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(TestLayer, self).__init__() def hybrid_forward(self, F, inputs, states): def step1(data, states): @@ -1529,8 +1529,8 @@ def step2(data, states): @with_seed() def test_uniq_name(): class ForeachLayer1(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(ForeachLayer1, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(ForeachLayer1, self).__init__() def hybrid_forward(self, F, inputs, states): def step1(data, states): @@ -1541,8 +1541,8 @@ def step1(data, states): return out class ForeachLayer2(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(ForeachLayer2, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(ForeachLayer2, self).__init__() def hybrid_forward(self, F, inputs, states): def step1(data, states): @@ -1556,8 +1556,8 @@ def step2(data, states): return out class WhileLayer1(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(WhileLayer1, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(WhileLayer1, self).__init__() def hybrid_forward(self, F, inputs, states): def cond(state1, state2): @@ -1572,8 +1572,8 @@ def step(state1, state2): return out class WhileLayer2(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(WhileLayer2, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(WhileLayer2, self).__init__() def hybrid_forward(self, F, inputs, states): def cond(state1, state2): @@ -1615,8 +1615,8 @@ def step2(state1, state2): @with_seed() def test_cut_subgraph_while_loop(): class TestLayer(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestLayer, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(TestLayer, self).__init__() def hybrid_forward(self, F, data): out1, data1 = F.contrib.while_loop( cond=lambda i: i <= 5, @@ -1649,8 +1649,8 @@ def hybrid_forward(self, F, data): @with_seed() def test_cut_subgraph_cond(): class TestLayer(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestLayer, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(TestLayer, self).__init__() def hybrid_forward(self, F, data): data1 = F.contrib.cond( data > 0.5, @@ -1680,8 +1680,8 @@ def hybrid_forward(self, F, data): def test_scope(): class TestBlock1(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestBlock1, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(TestBlock1, self).__init__() def hybrid_forward(self, F, data): (new_data, ) = F.contrib.cond( data > 0.5, @@ -1691,8 +1691,8 @@ def hybrid_forward(self, F, data): ) return new_data class TestBlock2(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestBlock2, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(TestBlock2, self).__init__() def hybrid_forward(self, F, data): (new_data, ) = F.contrib.cond( data > 0.5, @@ -1719,8 +1719,8 @@ def hybrid_forward(self, F, data): def test_output_format_foreach(): class TestLayer1(gluon.HybridBlock): - def __init__(self, step, prefix=None, params=None): - super(TestLayer1, self).__init__(prefix=prefix, params=params) + def __init__(self, step): + super(TestLayer1, self).__init__() self.step = step def hybrid_forward(self, F, ins, states): out, states = F.contrib.foreach(self.step, ins, states) @@ -1818,8 +1818,8 @@ def step4(data, state): def test_output_format_while(): class TestLayer1(gluon.HybridBlock): - def __init__(self, step, use_list, nested_list=False, prefix=None, params=None): - super(TestLayer1, self).__init__(prefix=prefix, params=params) + def __init__(self, step, use_list, nested_list=False): + super(TestLayer1, self).__init__() self.step = step self.use_list = use_list self.nested_list = nested_list @@ -1929,8 +1929,8 @@ def step6(state, state2): def test_output_format_cond(): class TestLayer1(gluon.HybridBlock): - def __init__(self, func, prefix=None, params=None): - super(TestLayer1, self).__init__(prefix=prefix, params=params) + def __init__(self, func): + super(TestLayer1, self).__init__() self.func = func def hybrid_forward(self, F, data): def then_func(): diff --git a/tests/python/unittest/test_contrib_stes_op.py b/tests/python/unittest/test_contrib_stes_op.py index 9b09033ed522..26ab6f9491e4 100644 --- a/tests/python/unittest/test_contrib_stes_op.py +++ b/tests/python/unittest/test_contrib_stes_op.py @@ -24,8 +24,7 @@ class RoundSTENET(gluon.HybridBlock): def __init__(self, w_init, **kwargs): super(RoundSTENET, self).__init__(**kwargs) - with self.name_scope(): - self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') + self.w = gluon.Parameter('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') @staticmethod def expected_grads(in_data, w_init): @@ -48,8 +47,7 @@ def hybrid_forward(self, F, x, w): class SignSTENET(gluon.HybridBlock): def __init__(self, w_init, **kwargs): super(SignSTENET, self).__init__(**kwargs) - with self.name_scope(): - self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') + self.w = gluon.Parameter('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') @staticmethod def expected_grads(in_data, w_init): @@ -76,7 +74,7 @@ def check_ste(net_type_str, w_init, hybridize, in_data, ctx=None): if hybridize: net.hybridize() # Init - net.collect_params().initialize(mx.init.Constant([w_init]), ctx=ctx) + net.initialize(mx.init.Constant([w_init]), ctx=ctx) # Test: in_data = in_data.as_in_context(ctx) diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py index 84df4d2ca55f..18ed0b4c5103 100644 --- a/tests/python/unittest/test_deferred_compute.py +++ b/tests/python/unittest/test_deferred_compute.py @@ -430,11 +430,10 @@ def _dc_gluon_simple_setup(shape=(8, 10), *, nd): def test_dc_hybridblock(): class MyBlock(mx.gluon.HybridBlock): - def __init__(self, *, prefix=None, params=None): - super().__init__(prefix, params) - with self.name_scope(): - self.dense = mx.gluon.nn.Dense(units=10, in_units=10) - self.weight = self.params.get('weight', shape=(10, )) + def __init__(self): + super().__init__() + self.dense = mx.gluon.nn.Dense(units=10, in_units=10) + self.weight = mx.gluon.Parameter('weight', shape=(10, )) def forward(self, x): assert x.shape[1] == 10 # due to in_units=10 above @@ -451,11 +450,10 @@ def forward(self, x): def test_dc_hybridblock_deferred_init_no_infer_shape_error(): class MyBlock(mx.gluon.HybridBlock): - def __init__(self, *, prefix=None, params=None): - super().__init__(prefix, params) - with self.name_scope(): - self.dense = mx.gluon.nn.Dense(units=10) - self.weight = self.params.get('weight', allow_deferred_init=True) + def __init__(self): + super().__init__() + self.dense = mx.gluon.nn.Dense(units=10) + self.weight = mx.gluon.Parameter('weight', allow_deferred_init=True) def forward(self, x): return self.dense(x) + self.weight.data(x.context) @@ -469,11 +467,10 @@ def forward(self, x): def test_dc_hybridblock_deferred_init(): class MyBlock(mx.gluon.HybridBlock): - def __init__(self, *, prefix=None, params=None): - super().__init__(prefix, params) - with self.name_scope(): - self.dense = mx.gluon.nn.Dense(units=10) - self.weight = self.params.get('weight', allow_deferred_init=True) + def __init__(self): + super().__init__() + self.dense = mx.gluon.nn.Dense(units=10) + self.weight = mx.gluon.Parameter('weight', allow_deferred_init=True) def infer_shape(self, x): self.weight.shape = (x.shape[1], ) @@ -496,10 +493,9 @@ def test_dc_hybridblock_dynamic_shape(): return class MyBlock(mx.gluon.HybridBlock): - def __init__(self, *, prefix=None, params=None): - super().__init__(prefix, params) - with self.name_scope(): - self.dense = mx.gluon.nn.Dense(units=10) + def __init__(self): + super().__init__() + self.dense = mx.gluon.nn.Dense(units=10) def forward(self, x, idx): return x[idx].reshape((2, 2)), mx.np.flatnonzero(self.dense(x)) diff --git a/tests/python/unittest/test_exc_handling.py b/tests/python/unittest/test_exc_handling.py index be4e643d6890..f544ab5d6510 100644 --- a/tests/python/unittest/test_exc_handling.py +++ b/tests/python/unittest/test_exc_handling.py @@ -88,9 +88,9 @@ def gluon(exec_wait=True, waitall=False): model.add(nn.Dropout(1)) model.add(nn.Dense(64, activation='tanh', in_units=256), nn.Dense(32, in_units=64)) + model.initialize(ctx=[default_context()]) x = mx.sym.var('data') y = model(x) - model.collect_params().initialize(ctx=[default_context()]) z = model(mx.nd.random.normal(10, -10, (32, 2, 10), ctx=default_context())) if waitall: mx.nd.waitall() @@ -178,11 +178,10 @@ def run_training_iteration(data): output = net(data) net = gluon.nn.HybridSequential() - with net.name_scope(): - net.add(gluon.nn.Dense(10)) + net.add(gluon.nn.Dense(10)) ctx = default_context() - net.collect_params().initialize(mx.init.Xavier(), ctx=ctx) + net.initialize(mx.init.Xavier(), ctx=ctx) data = mx.nd.ones((3, 4)) mx.profiler.set_state("run") run_training_iteration(data) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 5b8c7cf967e3..47ef86ff58ed 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -46,7 +46,6 @@ def test_parameter(): assert len(p.list_grad()) == 2 assert p.data(mx.cpu(1)).context == mx.cpu(1) assert p.data(mx.cpu(0)).shape == (10, 10) - assert p.var().name == 'weight' assert p.grad(mx.cpu(0)).stype == 'default' assert p.data(mx.cpu(0)).stype == 'default' @@ -77,7 +76,6 @@ def test_sparse_parameter(): assert weight.context == mx.cpu(1) assert weight.shape == (10, 10) assert weight.stype == 'row_sparse' - assert p.var().name == 'weight' assert p.var().attr('__storage_type__') == str(_STORAGE_TYPE_STR_TO_ID['row_sparse']) assert p.grad(mx.cpu(0)).stype == 'row_sparse' @@ -98,80 +96,6 @@ def test_parameter_invalid_access(): assertRaises(RuntimeError, p1.row_sparse_data, row_id.copyto(mx.cpu(0))) assertRaises(RuntimeError, p1.list_row_sparse_data, row_id) -@with_seed() -@pytest.mark.usefixtures("check_leak_ndarray") -def test_parameter_dict(): - ctx = mx.cpu(1) - params0 = gluon.ParameterDict('net_') - params0.get('w0', shape=(10, 10)) - params0.get('w1', shape=(10, 10), stype='row_sparse') - all_row_ids = mx.nd.arange(0, 10, ctx=ctx) - # check param names - assert list(params0.keys()) == ['net_w0', 'net_w1'] - params0.initialize(ctx=ctx) - trainer0 = mx.gluon.Trainer(params0, 'sgd') - prev_w0 = params0.get('w0').data(ctx) - prev_w1 = params0.get('w1').row_sparse_data(all_row_ids) - # save params - params0.save('test_parameter_dict.params') - - # load params - params1 = gluon.ParameterDict('net_') - params1.get('w0', shape=(10, 10)) - params1.get('w1', shape=(10, 10), stype='row_sparse') - params1.load('test_parameter_dict.params', ctx) - trainer1 = mx.gluon.Trainer(params1, 'sgd') - - # compare the values before and after save/load - cur_w0 = params1.get('w0').data(ctx) - cur_w1 = params1.get('w1').row_sparse_data(all_row_ids) - mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) - mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) - - # create a new param dict with dense params, and load from the checkpoint - # of sparse & dense params - params2 = gluon.ParameterDict('net_') - params2.get('w0', shape=(10, 10)) - params2.get('w1', shape=(10, 10)) - params2.load('test_parameter_dict.params', ctx) - - # compare the values before and after save/load - cur_w0 = params2.get('w0').data(ctx) - cur_w1 = params2.get('w1').data(ctx) - mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) - mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) - - # test reset_ctx - params3 = gluon.ParameterDict('net_') - params3.get('w0', shape=(10, 10)) - params3.get('w1', shape=(10, 10)) - params3.initialize(ctx=ctx) - list_contexts = [mx.cpu(42), mx.cpu(24)] - params3.reset_ctx(list_contexts) - for p in params3.values(): - assert set(p.list_ctx()) == set(list_contexts) - - # and test list_ctx - assert set(params3.list_ctx()) == set(list_contexts) - - - # test the dtype casting functionality - params0 = gluon.ParameterDict('') - params0.get('w0', shape=(10, 10), dtype='float32') - params0.get('w1', shape=(10, 10), dtype='int8') - params0.initialize(mx.init.One(), ctx=ctx) - params0.save('test_parameter_dict.params') - - params1 = gluon.ParameterDict('') - params1.get('w0', shape=(10, 10), dtype='float16') - params1.get('w1', shape=(10, 10), dtype='float64') - params1.load('test_parameter_dict.params', cast_dtype=True, dtype_source='current') - assert params1['w0'].data().dtype == np.float16 - assert params1['w1'].data().dtype == np.float64 - params1.load('test_parameter_dict.params', cast_dtype=True, dtype_source='saved') - assert params1['w0'].data().dtype == np.float32 - assert params1['w1'].data().dtype == np.int8 - @with_seed() def test_parameter_row_sparse_data(): @@ -205,7 +129,7 @@ class Test(gluon.HybridBlock): def __init__(self, **kwargs): super(Test, self).__init__(**kwargs) self.value = np.asarray([[1,2], [3,4]]) - self.const = self.params.get_constant('const', self.value) + self.const = gluon.Constant(self.value) def hybrid_forward(self, F, x, const): return x + const @@ -232,31 +156,30 @@ def test_parameter_sharing(): class Net(gluon.Block): def __init__(self, in_units=0, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.dense0 = nn.Dense(5, in_units=in_units) - self.dense1 = nn.Dense(5, in_units=in_units) + self.dense0 = nn.Dense(5, in_units=in_units) + self.dense1 = nn.Dense(5, in_units=in_units) def forward(self, x): return self.dense1(self.dense0(x)) - net1 = Net(prefix='net1_', in_units=5) - net2 = Net(prefix='net2_', params=net1.collect_params()) - net1.collect_params().initialize() + net1 = Net(in_units=5) + net2 = Net().share_parameters(net1.collect_params()) + net1.initialize() net2(mx.nd.zeros((3, 5))) net1.save_parameters('net1.params') - net3 = Net(prefix='net3_') + net3 = Net() net3.load_parameters('net1.params', mx.cpu()) - net4 = Net(prefix='net4_') - net5 = Net(prefix='net5_', in_units=5, params=net4.collect_params()) - net4.collect_params().initialize() + net4 = Net() + net5 = Net(in_units=5).share_parameters(net4.collect_params()) + net4.initialize() net5(mx.nd.zeros((3, 5))) net4.save_parameters('net4.params') - net6 = Net(prefix='net6_') + net6 = Net() net6.load_parameters('net4.params', mx.cpu()) @@ -265,31 +188,27 @@ def test_parameter_str(): class Net(gluon.Block): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.dense0 = nn.Dense(10, in_units=5, use_bias=False) + self.dense0 = nn.Dense(10, in_units=5, use_bias=False) - net = Net(prefix='net1_') + net = Net() lines = str(net.collect_params()).splitlines() - - assert lines[0] == 'net1_ (' - assert 'net1_dense0_weight' in lines[1] - assert '(10, 5)' in lines[1] - assert 'float32' in lines[1] - assert lines[2] == ')' - + + assert 'dense0.weight' in lines[0] + assert '(10, 5)' in lines[0] + assert 'float32' in lines[0] + @with_seed() def test_collect_parameters(): - net = nn.HybridSequential(prefix="test_") - with net.name_scope(): - net.add(nn.Conv2D(10, 3)) - net.add(nn.Dense(10, activation='relu')) + net = nn.HybridSequential() + net.add(nn.Conv2D(10, 3)) + net.add(nn.Dense(10, activation='relu')) assert set(net.collect_params().keys()) == \ - set(['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias']) + set(['0.weight', '0.bias','1.weight','1.bias']) assert set(net.collect_params('.*weight').keys()) == \ - set(['test_conv0_weight', 'test_dense0_weight']) - assert set(net.collect_params('test_conv0_bias|test_dense0_bias').keys()) == \ - set(['test_conv0_bias', 'test_dense0_bias']) + set(['0.weight', '1.weight']) + assert set(net.collect_params('0.bias|1.bias').keys()) == \ + set(['0.bias', '1.bias']) @with_seed() def test_basic(): @@ -299,39 +218,36 @@ def test_basic(): model.add(nn.Dense(64, activation='tanh', in_units=256), nn.Dense(32, in_units=64)) model.add(nn.Activation('relu')) - # symbol x = mx.sym.var('data') y = model(x) assert len(y.list_arguments()) == 7 # ndarray - model.collect_params().initialize(mx.init.Xavier(magnitude=2.24)) + model.initialize(mx.init.Xavier(magnitude=2.24)) x = model(mx.nd.zeros((32, 2, 10))) assert x.shape == (32, 32) x.wait_to_read() - model.collect_params().setattr('grad_req', 'null') + model.setattr('grad_req', 'null') assert list(model.collect_params().values())[0]._grad is None - model.collect_params().setattr('grad_req', 'write') + model.setattr('grad_req', 'write') assert list(model.collect_params().values())[0]._grad is not None @with_seed() def test_dense(): - model = nn.Dense(128, activation='tanh', in_units=10, flatten=False, prefix='test_') + model = nn.Dense(128, activation='tanh', in_units=10, flatten=False) inputs = mx.sym.Variable('data') outputs = model(inputs) - assert set(model.collect_params().keys()) == set(['test_weight', 'test_bias']) - assert outputs.list_outputs() == ['test_tanh_fwd_output'] + assert set(model.collect_params().keys()) == set(['weight', 'bias']) args, outs, auxs = outputs.infer_shape(data=(2, 3, 10)) assert outs == [(2, 3, 128)] - model = nn.Dense(128, activation='relu', in_units=30, flatten=True, prefix='test2_') + model = nn.Dense(128, activation='relu', in_units=30, flatten=True) inputs = mx.sym.Variable('data') outputs = model(inputs) - assert set(model.collect_params().keys()) == set(['test2_weight', 'test2_bias']) - assert outputs.list_outputs() == ['test2_relu_fwd_output'] + assert set(model.collect_params().keys()) == set(['weight', 'bias']) args, outs, auxs = outputs.infer_shape(data=(17, 2, 5, 3)) assert outs == [(17, 128)] @@ -399,7 +315,7 @@ def hybrid_forward(self, F, x): sm = mx.sym.load(sym_file) inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) - net_fp64.collect_params().load(params_file, ctx=ctx) + net_fp64.load_parameters(params_file, ctx=ctx) # Get a conv layer's weight parameter name. Conv layer's weight param is # expected to be of dtype casted, fp64. for param_name in net_fp64.params.keys(): @@ -436,10 +352,10 @@ def test_sparse_symbol_block(): @with_seed() def test_sparse_hybrid_block(): - params = gluon.ParameterDict('net_') - params.get('weight', shape=(5,5), stype='row_sparse', dtype='float32') - params.get('bias', shape=(5), dtype='float32') - net = gluon.nn.Dense(5, params=params) + params = {} + params['weight'] = gluon.Parameter('weight', shape=(5,5), stype='row_sparse', dtype='float32') + params['bias'] = gluon.Parameter('bias', shape=(5), dtype='float32') + net = gluon.nn.Dense(5).share_parameters(params) net.initialize() x = mx.nd.ones((2,5)) with pytest.raises(RuntimeError): @@ -472,11 +388,11 @@ def hybrid_forward(self, F, a, b=None): class FooNested(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(FooNested, self).__init__(prefix=prefix, params=params) - self.f1 = Foo(prefix='foo1') - self.f2 = Foo(prefix='foo2') - self.f3 = Foo(prefix='foo3') + def __init__(self): + super(FooNested, self).__init__() + self.f1 = Foo() + self.f2 = Foo() + self.f3 = Foo() def hybrid_forward(self, F, a, b): data = self.f1(a, b) @@ -487,9 +403,9 @@ def hybrid_forward(self, F, a, b): for arg_inputs in [(None, mx.nd.ones((10,))), (mx.nd.ones((10,)), mx.nd.ones((10,))), (mx.nd.ones((10,)), None)]: - foo1 = FooNested(prefix='foo_nested_hybridized') + foo1 = FooNested() foo1.hybridize() - foo2 = FooNested(prefix='foo_nested_nohybrid') + foo2 = FooNested() for _ in range(2): # Loop for 2 times to trigger forwarding of the cached version out1 = foo1(*arg_inputs) out2 = foo2(*arg_inputs) @@ -581,7 +497,7 @@ def forward(self, a, b): @with_seed() def check_layer_forward(layer, dshape): print("checking layer {}\nshape: {}.".format(layer, dshape)) - layer.collect_params().initialize() + layer.initialize() x = mx.nd.ones(shape=dshape) x.attach_grad() with mx.autograd.record(): @@ -741,11 +657,11 @@ def transpose(shape): x = mx.nd.zeros(xshape) layer = nn.MaxPool2D(3, ceil_mode=False, layout=layout) - layer.collect_params().initialize() + layer.initialize() assert (layer(x).shape==noceil_out_shape) layer = nn.MaxPool2D(3, ceil_mode=True, layout=layout) - layer.collect_params().initialize() + layer.initialize() assert (layer(x).shape==ceil_out_shape) @@ -914,7 +830,7 @@ def test_reflectionpad(): def test_reshape(): x = mx.nd.ones((2, 4, 10, 10)) layer = nn.Conv2D(10, 2, in_channels=4) - layer.collect_params().initialize() + layer.initialize() with mx.autograd.record(): x = layer(x) x = x.reshape((-1,)) @@ -926,7 +842,7 @@ def test_reshape(): def test_slice(): x = mx.nd.ones((5, 4, 10, 10)) layer = nn.Conv2D(10, 2, in_channels=4) - layer.collect_params().initialize() + layer.initialize() with mx.autograd.record(): x = layer(x) x = x[1:3] @@ -938,7 +854,7 @@ def test_slice(): def test_at(): x = mx.nd.ones((5, 4, 10, 10)) layer = nn.Conv2D(10, 2, in_channels=4) - layer.collect_params().initialize() + layer.initialize() with mx.autograd.record(): x = layer(x) x = x[1] @@ -950,7 +866,7 @@ def test_at(): def test_deferred_init(): x = mx.nd.ones((5, 4, 10, 10)) layer = nn.Conv2D(10, 2) - layer.collect_params().initialize() + layer.initialize() layer(x) @@ -1052,28 +968,24 @@ def test_block_attr_list_of_block(): class Model1(gluon.Block): def __init__(self, **kwargs): super(Model1, self).__init__(**kwargs) - with self.name_scope(): - self.layers = [nn.Dense(i * 10) for i in range(6)] + self.layers = [nn.Dense(i * 10) for i in range(6)] class Model2(gluon.Block): def __init__(self, **kwargs): super(Model2, self).__init__(**kwargs) - with self.name_scope(): - self.layers = dict() - self.layers['a'] = [nn.Dense(10), nn.Dense(10)] + self.layers = dict() + self.layers['a'] = [nn.Dense(10), nn.Dense(10)] class Model3(gluon.Block): def __init__(self, **kwargs): super(Model3, self).__init__(**kwargs) - with self.name_scope(): - self.layers = nn.Sequential() - self.layers.add(*[nn.Dense(i * 10) for i in range(6)]) + self.layers = nn.Sequential() + self.layers.add(*[nn.Dense(i * 10) for i in range(6)]) class Model4(gluon.Block): def __init__(self, **kwargs): super(Model4, self).__init__(**kwargs) - with self.name_scope(): - self.data = {'a': '4', 'b': 123} + self.data = {'a': '4', 'b': 123} with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') @@ -1179,7 +1091,7 @@ def check_embedding_large_input(sparse_grad): def test_export(): ctx = mx.context.current_context() model = gluon.model_zoo.vision.resnet18_v1( - prefix='resnet', ctx=ctx, pretrained=True) + ctx=ctx, pretrained=True) model.hybridize() data = mx.nd.random.normal(shape=(1, 3, 32, 32)) out = model(data) @@ -1188,17 +1100,11 @@ def test_export(): assert symbol_filename == 'gluon-symbol.json' assert params_filename == 'gluon-0000.params' - model2 = gluon.model_zoo.vision.resnet18_v1(prefix='resnet', ctx=ctx) - model2.collect_params().load('gluon-0000.params', ctx) - out2 = model2(data) - - assert_almost_equal(out.asnumpy(), out2.asnumpy()) - @with_seed() def test_import(): ctx = mx.context.current_context() net1 = gluon.model_zoo.vision.resnet18_v1( - prefix='resnet', ctx=ctx, pretrained=True) + ctx=ctx, pretrained=True) net1.hybridize() data = mx.nd.random.normal(shape=(1, 3, 32, 32)) out1 = net1(data) @@ -1219,8 +1125,7 @@ def test_import(): @with_seed() def test_hybrid_stale_cache(): net = mx.gluon.nn.HybridSequential() - with net.name_scope(): - net.add(mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', flatten=False)) + net.add(mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', flatten=False)) net.hybridize() net.initialize() @@ -1230,11 +1135,10 @@ def test_hybrid_stale_cache(): assert net(mx.nd.ones((2,3,5))).shape == (2, 30) net = mx.gluon.nn.HybridSequential() - with net.name_scope(): - net.fc1 = mx.gluon.nn.Dense(10, weight_initializer='zeros', - bias_initializer='ones', flatten=False) - net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', - bias_initializer='ones', flatten=False) + net.fc1 = mx.gluon.nn.Dense(10, weight_initializer='zeros', + bias_initializer='ones', flatten=False) + net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', + bias_initializer='ones', flatten=False) net.hybridize() net.initialize() net(mx.nd.ones((2,3,5))) @@ -1270,10 +1174,10 @@ def test_lambda(): @with_seed() def test_fill_shape_deferred(): net = nn.HybridSequential() - with net.name_scope(): - net.add(nn.Conv2D(64, kernel_size=2, padding=1), - nn.BatchNorm(), - nn.Dense(10)) + net.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net net.hybridize() net.initialize() net(mx.nd.ones((2,3,5,7))) @@ -1304,9 +1208,8 @@ def test_dtype(): class Net(gluon.Block): def __init__(self, in_dim, output_dim): super(Net, self).__init__() - with self.name_scope(): - self.embed = gluon.nn.Embedding(input_dim=in_dim, output_dim=output_dim,dtype=np.float64) - self.dense = gluon.nn.Dense(2, dtype=np.float64) + self.embed = gluon.nn.Embedding(input_dim=in_dim, output_dim=output_dim,dtype=np.float64) + self.dense = gluon.nn.Dense(2, dtype=np.float64) def forward(self, x): e = self.embed(x) @@ -1324,20 +1227,19 @@ def forward(self, x): def test_fill_shape_load(): ctx = mx.context.current_context() net1 = nn.HybridSequential() - with net1.name_scope(): - net1.add(nn.Conv2D(64, kernel_size=2, padding=1), - nn.BatchNorm(), - nn.Dense(10)) + net1.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net1 net1.hybridize() net1.initialize(ctx=ctx) net1(mx.nd.ones((2,3,5,7), ctx)) net1.save_parameters('net_fill.params') net2 = nn.HybridSequential() - with net2.name_scope(): - net2.add(nn.Conv2D(64, kernel_size=2, padding=1), - nn.BatchNorm(), - nn.Dense(10)) + net2.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) net2.hybridize() net2.initialize() net2.load_parameters('net_fill.params', ctx) @@ -1349,10 +1251,9 @@ def test_fill_shape_load(): @with_seed() def test_inline(): net = mx.gluon.nn.HybridSequential() - with net.name_scope(): - net.add(mx.gluon.nn.Dense(10)) - net.add(mx.gluon.nn.Dense(10)) - net.add(mx.gluon.nn.Dense(10)) + net.add(mx.gluon.nn.Dense(10)) + net.add(mx.gluon.nn.Dense(10)) + net.add(mx.gluon.nn.Dense(10)) net.initialize() net.hybridize(inline_limit=3) @@ -1491,7 +1392,7 @@ def test_req(): for v in net.collect_params().values(): v.grad_req = 'add' - net.collect_params().zero_grad() + net.zero_grad() with mx.autograd.record(): pred = net(data) l = loss(pred, label) @@ -1519,12 +1420,10 @@ def test_save_load(tmpdir): class Network(gluon.Block): def __init__(self, **kwargs): super(Network, self).__init__(**kwargs) - with self.name_scope(): - self.encoders = gluon.nn.Sequential() - with self.encoders.name_scope(): - for _ in range(2): - lstm = mx.gluon.rnn.LSTM(200, 1, bidirectional=True) - self.encoders.add(lstm) + self.encoders = gluon.nn.Sequential() + for _ in range(2): + lstm = mx.gluon.rnn.LSTM(200, 1, bidirectional=True) + self.encoders.add(lstm) def forward(self, x): for i in range(2): @@ -1544,11 +1443,9 @@ def forward(self, x): @with_seed() def test_save_load_deduplicate_with_shared_params(tmpdir): class B(mx.gluon.Block): - def __init__(self, params=None): - super(B, self).__init__(params=params) - - with self.name_scope(): - self.weight = self.params.get('weight', shape=(10, 10)) + def __init__(self): + super(B, self).__init__() + self.weight = gluon.Parameter('weight', shape=(10, 10)) class C(mx.gluon.Block): def __init__(self, b1, b2): @@ -1557,7 +1454,7 @@ def __init__(self, b1, b2): self.b2 = b2 b1 = B() - b2 = B(b1.collect_params()) + b2 = B().share_parameters(b1.collect_params()) c = C(b1, b2) c.initialize() _, param_path = tempfile.mkstemp(suffix='.params', dir=str(tmpdir)) @@ -1567,7 +1464,7 @@ def __init__(self, b1, b2): assert len(params) == 1 # Only a single copy of the shared parameter is saved b1 = B() - b2 = B(b1.collect_params()) + b2 = B().share_parameters(b1.collect_params()) c = C(b1, b2) c.load_parameters(param_path) @@ -1578,7 +1475,7 @@ def __init__(self, b1, b2): assert len(params) == 2 # Only a single copy of the shared parameter is saved b1 = B() - b2 = B(b1.collect_params()) + b2 = B().share_parameters(b1.collect_params()) c = C(b1, b2) c.load_parameters(param_path) @@ -1587,15 +1484,14 @@ def test_symbol_block_save_load(): class Net(gluon.HybridBlock): def __init__(self): super(Net, self).__init__() - with self.name_scope(): - backbone = gluon.model_zoo.vision.resnet18_v1() - data = mx.sym.var('data') - featnames = ['stage1_activation0', 'stage2_activation0', 'stage3_activation0'] - out_names = ['_'.join([backbone.name, featname, 'output']) for featname in featnames] - internals = backbone(data).get_internals() - outs = [internals[out_name] for out_name in out_names] - self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params()) - self.body = nn.Conv2D(3, 1) + backbone = gluon.model_zoo.vision.resnet18_v1() + data = mx.sym.var('data') + featnames = [backbone.features[i][1].name for i in range(4, 7)] + out_names = ['_'.join([featname, 'activation0_output']) for featname in featnames] + internals = backbone(data).get_internals() + outs = [internals[out_name] for out_name in out_names] + self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params()) + self.body = nn.Conv2D(3, 1) def hybrid_forward(self, F, x): x = self.body(x) @@ -1624,13 +1520,13 @@ def _test_grad_reset(ctx, dtype='float32', sparse=False, embeddingType=None): data = mx.nd.random.uniform(shape=(3,3), dtype=dtype, ctx=ctx) if embeddingType is None: embeddingType = dtype - net = nn.Embedding(3, 4, sparse_grad=sparse, prefix='test_zero_grad_', dtype=embeddingType) + net = nn.Embedding(3, 4, sparse_grad=sparse, dtype=embeddingType) net.initialize(ctx=ctx) with mx.autograd.record(): l = net(data) l.backward() - net.collect_params().zero_grad() - grad = net.collect_params()['test_zero_grad_weight'].grad() + net.zero_grad() + grad = net.collect_params()['weight'].grad() assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) def _test_multi_reset(nArrays, dtype, ctx): @@ -1682,9 +1578,9 @@ def check_hybrid_static_memory(**kwargs): x.attach_grad() net1 = gluon.model_zoo.vision.get_resnet( - 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) + 1, 18, pretrained=True, ctx=mx.context.current_context()) net2 = gluon.model_zoo.vision.get_resnet( - 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) + 1, 18, pretrained=True, ctx=mx.context.current_context()) net2.hybridize(**kwargs) net1(x) net2(x) @@ -1795,49 +1691,44 @@ def mon_callback(node_name, opr_name, arr): assert opr_name == expected_opr_name # Test with Dense layer - model = mx.gluon.nn.HybridSequential(prefix="dense_") - with model.name_scope(): - model.add(mx.gluon.nn.Dense(2)) + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.Dense(2)) model.initialize() model.hybridize() - check_name(model, ["dense_dense0_fwd_output"]) + check_name(model, [model[0].name + "_fwd_output"]) # Test with Activation, FListInputNames not registered, input name will have _input appended - model = mx.gluon.nn.HybridSequential(prefix="relu_") - with model.name_scope(): - model.add(mx.gluon.nn.Activation("relu")) + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.Activation("relu")) model.initialize() model.hybridize() - check_name(model, ["relu_relu0_fwd_output"]) + check_name(model, [model[0].name + "_fwd_output"]) # Test with Pooling, monitor_all is set to True - model = mx.gluon.nn.HybridSequential("pool_") - with model.name_scope(): - model.add(mx.gluon.nn.AvgPool1D()) + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.AvgPool1D()) model.initialize() model.hybridize() - check_name(model, ['pool_pool0_fwd_data', 'pool_pool0_fwd_output'], expected_opr_names=["Pooling"], - monitor_all=True) + check_name(model, [model[0].name + '_fwd_data', model[0].name + '_fwd_output'], + expected_opr_names=["Pooling"], monitor_all=True) # stack two layers and test - model = mx.gluon.nn.HybridSequential("dense_") - with model.name_scope(): - model.add(mx.gluon.nn.Dense(2)) - model.add(mx.gluon.nn.Activation("relu")) + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.Dense(2)) + model.add(mx.gluon.nn.Activation("relu")) model.initialize() model.hybridize() check_name(model, - ['dense_dense0_fwd_data', 'dense_dense0_fwd_weight', - 'dense_dense0_fwd_bias', 'dense_dense0_fwd_output', - 'dense_relu0_fwd_input0', 'dense_relu0_fwd_output'], monitor_all=True) + [model[0].name + '_fwd_data', model[0].name + '_fwd_weight', + model[0].name + '_fwd_bias', model[0].name + '_fwd_output', + model[1].name + '_fwd_input0', model[1].name + '_fwd_output'], monitor_all=True) # check with different hybridize modes model.hybridize(static_alloc=True) check_name(model, - ['dense_dense0_fwd_data', 'dense_dense0_fwd_weight', - 'dense_dense0_fwd_bias', 'dense_dense0_fwd_output', - 'dense_relu0_fwd_input0', 'dense_relu0_fwd_output'], monitor_all=True) - + [model[0].name + '_fwd_data', model[0].name + '_fwd_weight', + model[0].name + '_fwd_bias', model[0].name + '_fwd_output', + model[1].name + '_fwd_input0', model[1].name + '_fwd_output'], monitor_all=True) @with_seed() def test_apply(): @@ -1846,15 +1737,14 @@ def test_apply(): def record_name(block): global called_blocks - called_blocks.append(block.name) + called_blocks.append(type(block)) - block = nn.HybridSequential(prefix='test_') - with block.name_scope(): - block.add(nn.Dense(10)) - block.add(nn.Dropout(0.5)) + block = nn.HybridSequential() + block.add(nn.Dense(10)) + block.add(nn.Dropout(0.5)) block.apply(record_name) - assert called_blocks == ['test_dense0', 'test_dropout0', 'test'] + assert called_blocks == [type(block[0]), type(block[1]), type(block)] @with_seed() @@ -1865,10 +1755,9 @@ def test_summary(): net.summary(mx.nd.ones((32, 3, 224, 224))) net2 = nn.Sequential() - with net2.name_scope(): - net2.add(nn.Embedding(40, 30)) - net2.add(gluon.rnn.LSTM(30)) - net2.add(nn.Dense(40, flatten=False, params=net2[0].params)) + net2.add(nn.Embedding(40, 30)) + net2.add(gluon.rnn.LSTM(30)) + net2.add(nn.Dense(40, flatten=False).share_parameters(net2[0].params)) net2.initialize() net2.summary(mx.nd.ones((80, 32))) @@ -1880,23 +1769,6 @@ def test_summary(): net.hybridize() pytest.raises(AssertionError, net.summary, mx.nd.ones((32, 3, 224, 224))) - -@with_seed() -def test_legacy_save_params(): - net = gluon.nn.HybridSequential(prefix='') - with net.name_scope(): - net.add(gluon.nn.Conv2D(10, (3, 3))) - net.add(gluon.nn.Dense(50)) - net.initialize() - net(mx.nd.ones((1,1,50,50))) - a = net(mx.sym.var('data')) - a.save('test.json') - net.save_params('test.params') - model = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(open('test.json', 'r').read()), - inputs=mx.sym.var('data')) - model.load_params('test.params', ctx=mx.cpu()) - - @with_seed() def test_sparse_hybrid_block_grad(): class Embedding(mx.gluon.HybridBlock): @@ -1904,9 +1776,8 @@ def __init__(self, num_tokens, embedding_size): super(Embedding, self).__init__() self.num_tokens = num_tokens - with self.name_scope(): - self.embedding = mx.gluon.nn.Embedding( - num_tokens, embedding_size, sparse_grad=True) + self.embedding = mx.gluon.nn.Embedding( + num_tokens, embedding_size, sparse_grad=True) def hybrid_forward(self, F, words): emb = self.embedding(words) @@ -1930,8 +1801,7 @@ def test_sparse_hybrid_block(): class Linear(mx.gluon.HybridBlock): def __init__(self, units): super(Linear, self).__init__() - with self.name_scope(): - self.w = self.params.get('w', shape=(units, units)) + self.w = gluon.Parameter('w', shape=(units, units)) def hybrid_forward(self, F, x, w): return F.dot(x, w) @@ -1939,8 +1809,7 @@ def hybrid_forward(self, F, x, w): class SparseBlock(mx.gluon.HybridBlock): def __init__(self, units): super(SparseBlock, self).__init__() - with self.name_scope(): - self.net = Linear(units) + self.net = Linear(units) def hybrid_forward(self, F, x): return self.net(x) * x @@ -1967,15 +1836,15 @@ def test_hybrid_static_memory_recording(): def test_share_inputs_outputs(): class TestIOBackward(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestIOBackward, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(TestIOBackward, self).__init__() def hybrid_forward(self, F, in1, in2): return in1 + in2 class TestIOForward(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestIOForward, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(TestIOForward, self).__init__() def hybrid_forward(self, F, in1): return in1 @@ -2034,7 +1903,7 @@ def check_layer_forward_withinput(net, x): x_hybrid = x.copy() x.attach_grad() x_hybrid.attach_grad() - net.collect_params().initialize() + net.initialize() with mx.autograd.record(): out1 = net(x) out1.backward() @@ -2056,8 +1925,7 @@ def __init__(self, kernel, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = gluon.nn.Conv2D(chn_num, (kernel, kernel)) + self.conv0 = gluon.nn.Conv2D(chn_num, (kernel, kernel)) def hybrid_forward(self, F, x): out = self.conv0(x) @@ -2079,9 +1947,8 @@ def __init__(self, kernel, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = gluon.nn.Conv2D(chn_num, (1, 1)) - self.conv1 = gluon.nn.Conv2D(chn_num, (kernel, kernel), groups=chn_num) + self.conv0 = gluon.nn.Conv2D(chn_num, (1, 1)) + self.conv1 = gluon.nn.Conv2D(chn_num, (kernel, kernel), groups=chn_num) def hybrid_forward(self, F, x): y = self.conv0(x) @@ -2106,8 +1973,7 @@ def test_deconv2d_16c(): class Net(gluon.HybridBlock): def __init__(self, chn_num, kernel, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.deconv0 = gluon.nn.Conv2DTranspose(chn_num, (kernel, kernel)) + self.deconv0 = gluon.nn.Conv2DTranspose(chn_num, (kernel, kernel)) def hybrid_forward(self, F, x): out = self.deconv0(x) @@ -2135,9 +2001,8 @@ def __init__(self, axis, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = gluon.nn.Conv2D(chn_num, (kernel, kernel)) - self.bn0 = gluon.nn.BatchNorm(axis=axis) + self.conv0 = gluon.nn.Conv2D(chn_num, (kernel, kernel)) + self.bn0 = gluon.nn.BatchNorm(axis=axis) def hybrid_forward(self, F, x): conv = self.conv0(x) @@ -2169,11 +2034,10 @@ def __init__(self, kernel, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - from mxnet.gluon.contrib.nn import HybridConcurrent - self.concat = HybridConcurrent(axis=check_dim) - for i in range(input_num): - self.concat.add(gluon.nn.Conv2D(chn_num, (kernel, kernel))) + from mxnet.gluon.contrib.nn import HybridConcurrent + self.concat = HybridConcurrent(axis=check_dim) + for i in range(input_num): + self.concat.add(gluon.nn.Conv2D(chn_num, (kernel, kernel))) def hybrid_forward(self, F, x): return self.concat(x) @@ -2191,8 +2055,7 @@ def test_reshape_conv(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(64, (3, 3)) + self.conv0 = nn.Conv2D(64, (3, 3)) def hybrid_forward(self, F, x): x_reshape = x.reshape((0, 0, 128, 32)) @@ -2209,9 +2072,8 @@ def test_reshape_conv_reshape_conv(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(64, (3, 3)) - self.conv1 = nn.Conv2D(128, (3, 3)) + self.conv0 = nn.Conv2D(64, (3, 3)) + self.conv1 = nn.Conv2D(128, (3, 3)) def hybrid_forward(self, F, x): x_reshape = x.reshape((0, 0, 128, 32)) @@ -2229,8 +2091,7 @@ def test_slice_conv(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(16, (3, 3)) + self.conv0 = nn.Conv2D(16, (3, 3)) def hybrid_forward(self, F, x): x_slice = x.slice(begin=(0, 2, 0, 0), end=(4, 5, 32, 32)) @@ -2246,9 +2107,8 @@ def test_slice_conv_slice_conv(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(32, (3, 3)) - self.conv1 = nn.Conv2D(16, (1, 1)) + self.conv0 = nn.Conv2D(32, (3, 3)) + self.conv1 = nn.Conv2D(16, (1, 1)) def hybrid_forward(self, F, x): x_slice = x.slice(begin=(0, 0, 0, 0), end=(4, 16, 16, 16)) @@ -2268,9 +2128,8 @@ def test_slice_conv_reshape_conv(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(64, (3, 3)) - self.conv1 = nn.Conv2D(128, (3, 3)) + self.conv0 = nn.Conv2D(64, (3, 3)) + self.conv1 = nn.Conv2D(128, (3, 3)) def hybrid_forward(self, F, x): x_slice = x.slice(begin=(0, 0, 1, 1), end=(4, 16, 33, 33)) @@ -2292,9 +2151,8 @@ def test_reshape_conv_slice_conv(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(16, (3, 3)) - self.conv1 = nn.Conv2D(32, (3, 3)) + self.conv0 = nn.Conv2D(16, (3, 3)) + self.conv1 = nn.Conv2D(32, (3, 3)) def hybrid_forward(self, F, x): x_reshape = x.reshape((0, 0, 64, 16)) @@ -2312,9 +2170,8 @@ def test_reshape_dense(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - channel0 = np.random.randint(1, 17) - self.dense0 = nn.Dense(channel0) + channel0 = np.random.randint(1, 17) + self.dense0 = nn.Dense(channel0) def hybrid_forward(self, F, x): x_reshape = x.reshape((8, 64, 128, -1)) @@ -2331,10 +2188,9 @@ def test_slice_dense(): class Net(gluon.HybridBlock): def __init__(self, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - channel0 = np.random.randint(1, 17) - self.dense0 = nn.Dense(channel0) - self.slice = slice + channel0 = np.random.randint(1, 17) + self.dense0 = nn.Dense(channel0) + self.slice = slice def hybrid_forward(self, F, x): x_slice = x.slice(begin=tuple(self.slice[0]), @@ -2352,12 +2208,11 @@ def test_slice_dense_slice_dense(): class Net(gluon.HybridBlock): def __init__(self, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - channel0 = 32 - channel1 = np.random.randint(1, 17) - self.dense0 = nn.Dense(channel0) - self.dense1 = nn.Dense(channel1) - self.slice = slice + channel0 = 32 + channel1 = np.random.randint(1, 17) + self.dense0 = nn.Dense(channel0) + self.dense1 = nn.Dense(channel1) + self.slice = slice def hybrid_forward(self, F, x): x_slice = x.slice(begin=tuple(self.slice[0]), end=tuple(self.slice[1])) @@ -2376,11 +2231,10 @@ def test_reshape_dense_reshape_dense(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - channel0 = np.random.randint(1, 17) - channel1 = np.random.randint(1, 33) - self.dense0 = nn.Dense(channel0) - self.dense1 = nn.Dense(channel1) + channel0 = np.random.randint(1, 17) + channel1 = np.random.randint(1, 33) + self.dense0 = nn.Dense(channel0) + self.dense1 = nn.Dense(channel1) def hybrid_forward(self, F, x): x_reshape = x.reshape((4, 16, 128, 32)) @@ -2399,12 +2253,11 @@ def test_slice_dense_reshape_dense(): class Net(gluon.HybridBlock): def __init__(self, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - channel0 = np.random.randint(1, 17) - channel1 = np.random.randint(1, 17) - self.dense0 = nn.Dense(channel0) - self.dense1 = nn.Dense(channel1) - self.slice = slice + channel0 = np.random.randint(1, 17) + channel1 = np.random.randint(1, 17) + self.dense0 = nn.Dense(channel0) + self.dense1 = nn.Dense(channel1) + self.slice = slice def hybrid_forward(self, F, x): x_slice = x.slice(begin=tuple(self.slice[0]), end=tuple(self.slice[1])) @@ -2424,11 +2277,10 @@ def test_reshape_dense_slice_dense(): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - channel0 = 64 - channel1 = np.random.randint(1, 17) - self.dense0 = nn.Dense(channel0) - self.dense1 = nn.Dense(channel1) + channel0 = 64 + channel1 = np.random.randint(1, 17) + self.dense0 = nn.Dense(channel0) + self.dense1 = nn.Dense(channel1) def hybrid_forward(self, F, x): x_reshape = x.reshape((4, 16, 128, 32)) @@ -2448,10 +2300,9 @@ def test_reshape_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, shape, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(96, (1, 1)) - self.bn0 = nn.BatchNorm() - self.reshape = shape + self.conv0 = nn.Conv2D(96, (1, 1)) + self.bn0 = nn.BatchNorm() + self.reshape = shape def hybrid_forward(self, F, x): x_in = self.conv0(x) @@ -2471,10 +2322,9 @@ def test_slice_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(128, (1, 1)) - self.bn0 = nn.BatchNorm() - self.slice = slice + self.conv0 = nn.Conv2D(128, (1, 1)) + self.bn0 = nn.BatchNorm() + self.slice = slice def hybrid_forward(self, F, x): x_in = self.conv0(x) @@ -2496,11 +2346,10 @@ def test_slice_batchnorm_slice_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(128, (1, 1)) - self.bn0 = nn.BatchNorm() - self.bn1 = nn.BatchNorm() - self.slice = slice + self.conv0 = nn.Conv2D(128, (1, 1)) + self.bn0 = nn.BatchNorm() + self.bn1 = nn.BatchNorm() + self.slice = slice def hybrid_forward(self, F, x): x_in = self.conv0(x) @@ -2522,11 +2371,10 @@ def test_reshape_batchnorm_reshape_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, shape, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(128, (1, 1)) - self.bn0 = nn.BatchNorm() - self.bn1 = nn.BatchNorm() - self.reshape = shape + self.conv0 = nn.Conv2D(128, (1, 1)) + self.bn0 = nn.BatchNorm() + self.bn1 = nn.BatchNorm() + self.reshape = shape def hybrid_forward(self, F, x): x_in = self.conv0(x) @@ -2548,12 +2396,11 @@ def test_slice_batchnorm_reshape_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, shape, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(128, (1, 1)) - self.bn0 = nn.BatchNorm() - self.bn1 = nn.BatchNorm() - self.reshape = shape - self.slice = slice + self.conv0 = nn.Conv2D(128, (1, 1)) + self.bn0 = nn.BatchNorm() + self.bn1 = nn.BatchNorm() + self.reshape = shape + self.slice = slice def hybrid_forward(self, F, x): x_in = self.conv0(x) @@ -2576,12 +2423,11 @@ def test_reshape_batchnorm_slice_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, shape, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.conv0 = nn.Conv2D(128, (1, 1)) - self.bn0 = nn.BatchNorm() - self.bn1 = nn.BatchNorm() - self.reshape = shape - self.slice = slice + self.conv0 = nn.Conv2D(128, (1, 1)) + self.bn0 = nn.BatchNorm() + self.bn1 = nn.BatchNorm() + self.reshape = shape + self.slice = slice def hybrid_forward(self, F, x): x_in = self.conv0(x) @@ -2611,9 +2457,8 @@ def __init__(self, pooling_layer, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.pool0 = pooling_layer + self.reshape = shape + self.pool0 = pooling_layer def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape) @@ -2645,9 +2490,8 @@ def __init__(self, pooling_layer, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.slice = slice - self.pool0 = pooling_layer + self.slice = slice + self.pool0 = pooling_layer def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) @@ -2680,10 +2524,9 @@ def __init__(self, pooling_layer2, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.pool0 = pooling_layer1 - self.pool1 = pooling_layer2 + self.reshape = shape + self.pool0 = pooling_layer1 + self.pool1 = pooling_layer2 def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape[0]) @@ -2716,10 +2559,9 @@ def __init__(self, pooling_layer2, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.slice = slice - self.pool0 = pooling_layer1 - self.pool1 = pooling_layer2 + self.slice = slice + self.pool0 = pooling_layer1 + self.pool1 = pooling_layer2 def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0][0], end=self.slice[0][1]) @@ -2753,11 +2595,10 @@ def __init__(self, pooling_layer2, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.slice = slice - self.pool0 = pooling_layer1 - self.pool1 = pooling_layer2 + self.reshape = shape + self.slice = slice + self.pool0 = pooling_layer1 + self.pool1 = pooling_layer2 def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) @@ -2791,11 +2632,10 @@ def __init__(self, pooling_layer2, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.slice = slice - self.pool0 = pooling_layer1 - self.pool1 = pooling_layer2 + self.reshape = shape + self.slice = slice + self.pool0 = pooling_layer1 + self.pool1 = pooling_layer2 def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape) @@ -2821,9 +2661,8 @@ def test_reshape_deconv(): class Net(gluon.HybridBlock): def __init__(self, shape, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.conv0 = nn.Conv2DTranspose(64, (3, 3)) + self.reshape = shape + self.conv0 = nn.Conv2DTranspose(64, (3, 3)) def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape) @@ -2841,9 +2680,8 @@ def test_slice_deconv(): class Net(gluon.HybridBlock): def __init__(self, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.slice = slice - self.conv0 = nn.Conv2DTranspose(64, (3, 3)) + self.slice = slice + self.conv0 = nn.Conv2DTranspose(64, (3, 3)) def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) @@ -2861,10 +2699,9 @@ def test_reshape_deconv_reshape_deconv(): class Net(gluon.HybridBlock): def __init__(self, shape, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.conv0 = nn.Conv2DTranspose(32, (3, 3)) - self.conv1 = nn.Conv2DTranspose(64, (3, 3), strides=(2, 2)) + self.reshape = shape + self.conv0 = nn.Conv2DTranspose(32, (3, 3)) + self.conv1 = nn.Conv2DTranspose(64, (3, 3), strides=(2, 2)) def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape[0]) @@ -2885,10 +2722,9 @@ def test_slice_deconv_slice_deconv(): class Net(gluon.HybridBlock): def __init__(self, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.slice = slice - self.conv0 = nn.Conv2DTranspose(32, (3, 3)) - self.conv1 = nn.Conv2DTranspose(64, (3, 3), strides=(2, 2)) + self.slice = slice + self.conv0 = nn.Conv2DTranspose(32, (3, 3)) + self.conv1 = nn.Conv2DTranspose(64, (3, 3), strides=(2, 2)) def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0][0], end=self.slice[0][1]) @@ -2909,11 +2745,10 @@ def test_reshape_deconv_slice_deconv(): class Net(gluon.HybridBlock): def __init__(self, shape, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.slice = slice - self.conv0 = nn.Conv2DTranspose(32, (3, 3)) - self.conv1 = nn.Conv2DTranspose(64, (3, 3), strides=(2, 2)) + self.reshape = shape + self.slice = slice + self.conv0 = nn.Conv2DTranspose(32, (3, 3)) + self.conv1 = nn.Conv2DTranspose(64, (3, 3), strides=(2, 2)) def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape) @@ -2935,11 +2770,10 @@ def test_slice_deconv_reshape_deconv(): class Net(gluon.HybridBlock): def __init__(self, shape, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.slice = slice - self.conv0 = nn.Conv2DTranspose(32, (3, 3)) - self.conv1 = nn.Conv2DTranspose(96, (3, 3), strides=(2, 2)) + self.reshape = shape + self.slice = slice + self.conv0 = nn.Conv2DTranspose(32, (3, 3)) + self.conv1 = nn.Conv2DTranspose(96, (3, 3), strides=(2, 2)) def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) @@ -2960,9 +2794,8 @@ def test_reshape_activation(): class Net(gluon.HybridBlock): def __init__(self, act, shape, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.act = nn.Activation(act) + self.reshape = shape + self.act = nn.Activation(act) def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape) @@ -2982,9 +2815,8 @@ def test_slice_activation(): class Net(gluon.HybridBlock): def __init__(self, act, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.slice = slice - self.act = nn.Activation(act) + self.slice = slice + self.act = nn.Activation(act) def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) @@ -3005,10 +2837,9 @@ def test_reshape_activation_reshape_activation(): class Net(gluon.HybridBlock): def __init__(self, act0, act1, shape, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.act0 = nn.Activation(act0) - self.act1 = nn.Activation(act1) + self.reshape = shape + self.act0 = nn.Activation(act0) + self.act1 = nn.Activation(act1) def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape[0]) @@ -3033,10 +2864,9 @@ def test_slice_activation_slice_activation(): class Net(gluon.HybridBlock): def __init__(self, act0, act1, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.slice = slice - self.act0 = nn.Activation(act0) - self.act1 = nn.Activation(act1) + self.slice = slice + self.act0 = nn.Activation(act0) + self.act1 = nn.Activation(act1) def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0][0], end=self.slice[0][1]) @@ -3061,11 +2891,10 @@ def test_reshape_activation_slice_activation(): class Net(gluon.HybridBlock): def __init__(self, act0, act1, shape, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.slice = slice - self.act0 = nn.Activation(act0) - self.act1 = nn.Activation(act1) + self.reshape = shape + self.slice = slice + self.act0 = nn.Activation(act0) + self.act1 = nn.Activation(act1) def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape) @@ -3091,11 +2920,10 @@ def test_slice_activation_reshape_activation(): class Net(gluon.HybridBlock): def __init__(self, act0, act1, shape, slice, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.reshape = shape - self.slice = slice - self.act0 = nn.Activation(act0) - self.act1 = nn.Activation(act1) + self.reshape = shape + self.slice = slice + self.act0 = nn.Activation(act0) + self.act1 = nn.Activation(act1) def hybrid_forward(self, F, x): x_slice = x.slice(begin=self.slice[0], end=self.slice[1]) @@ -3172,11 +3000,10 @@ class MyBlock(gluon.HybridBlock): def __init__(self, **kwargs): super(MyBlock, self).__init__(**kwargs) - with self.name_scope(): - self.param = self.params.get("param", shape=(1, ), init=mx.init.Constant(-10.0)) + self.param = gluon.Parameter(shape=(1, ), init=mx.init.Constant(-10.0)) bl = MyBlock() - bl2 = MyBlock(params=bl.collect_params()) + bl2 = MyBlock().share_parameters(bl.collect_params()) assert bl.param is bl2.param bl3 = MyBlock() assert bl.param is not bl3.param diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index a69b0230c3c8..d7356575b927 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -29,23 +29,23 @@ import numpy as np -def check_rnn_cell(cell, prefix, in_shape=(10, 50), out_shape=(10, 100), begin_state=None): +def check_rnn_cell(cell, in_shape=(10, 50), out_shape=(10, 100), begin_state=None): inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs, begin_state=begin_state) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == [prefix+'h2h_bias', prefix+'h2h_weight', - prefix+'i2h_bias', prefix+'i2h_weight'] - assert outputs.list_outputs() == [prefix+'t0_out_output', prefix+'t1_out_output', prefix+'t2_out_output'] + assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', + 'i2h_bias', 'i2h_weight'] + assert outputs.list_outputs() == [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] args, outs, auxs = outputs.infer_shape(rnn_t0_data=in_shape, rnn_t1_data=in_shape, rnn_t2_data=in_shape) - assert outs == [out_shape]*3 + assert outs == [out_shape] * 3 def check_rnn_forward(layer, inputs): inputs.attach_grad() - layer.collect_params().initialize() + layer.initialize() with mx.autograd.record(): layer.unroll(3, inputs, merge_outputs=True)[0].backward() mx.autograd.backward(layer.unroll(3, inputs, merge_outputs=False)[0]) @@ -70,38 +70,38 @@ def test_rnn_cells(): @with_seed() def test_convrnn(): - cell = contrib.rnn.Conv1DRNNCell((10, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48)) + cell = contrib.rnn.Conv1DRNNCell((10, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 50), out_shape=(1, 100, 48)) - cell = contrib.rnn.Conv2DRNNCell((10, 20, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) + cell = contrib.rnn.Conv2DRNNCell((10, 20, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) - cell = contrib.rnn.Conv3DRNNCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) + cell = contrib.rnn.Conv3DRNNCell((10, 20, 30, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) @with_seed() def test_convlstm(): - cell = contrib.rnn.Conv1DLSTMCell((10, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48)) + cell = contrib.rnn.Conv1DLSTMCell((10, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 50), out_shape=(1, 100, 48)) - cell = contrib.rnn.Conv2DLSTMCell((10, 20, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) + cell = contrib.rnn.Conv2DLSTMCell((10, 20, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) - cell = contrib.rnn.Conv3DLSTMCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) + cell = contrib.rnn.Conv3DLSTMCell((10, 20, 30, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) @with_seed() def test_convgru(): - cell = contrib.rnn.Conv1DGRUCell((10, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48)) + cell = contrib.rnn.Conv1DGRUCell((10, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 50), out_shape=(1, 100, 48)) - cell = contrib.rnn.Conv2DGRUCell((10, 20, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) + cell = contrib.rnn.Conv2DGRUCell((10, 20, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) - cell = contrib.rnn.Conv3DGRUCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_') - check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) + cell = contrib.rnn.Conv3DGRUCell((10, 20, 30, 50), 100, 3, 3) + check_rnn_cell(cell, in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) @with_seed() @@ -116,12 +116,12 @@ def test_conv_fill_shape(): def test_lstmp(): nhid = 100 nproj = 64 - cell = contrib.rnn.LSTMPCell(nhid, nproj, prefix='rnn_') + cell = contrib.rnn.LSTMPCell(nhid, nproj) inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - expected_params = ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_h2r_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] - expected_outputs = ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] + expected_params = ['h2h_bias', 'h2h_weight', 'h2r_weight', 'i2h_bias', 'i2h_weight'] + expected_outputs = [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] assert sorted(cell.collect_params().keys()) == expected_params assert outputs.list_outputs() == expected_outputs, outputs.list_outputs() @@ -132,11 +132,11 @@ def test_lstmp(): @with_seed() def test_vardrop(): def check_vardrop(drop_inputs, drop_states, drop_outputs): - cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'), + cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100), drop_outputs=drop_outputs, drop_states=drop_states, drop_inputs=drop_inputs) - cell.collect_params().initialize(init='xavier') + cell.initialize(init='xavier') input_data = mx.nd.random_uniform(shape=(10, 3, 50), ctx=mx.context.current_context()) with mx.autograd.record(): outputs1, _ = cell.unroll(3, input_data, merge_outputs=True) @@ -315,9 +315,9 @@ def test_sampler(): class RNNLayer(gluon.HybridBlock): - def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None): - super(RNNLayer, self).__init__(prefix=prefix, params=params) - self.cell = cell_type(hidden_size, prefix='rnn_') + def __init__(self, cell_type, hidden_size, layout): + super(RNNLayer, self).__init__() + self.cell = cell_type(hidden_size) self.layout = layout def hybrid_forward(self, F, inputs, states, valid_length): @@ -343,7 +343,7 @@ def check_unroll(cell_type, num_states, layout): state_shape = (batch_size, hidden_size) states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] - cell = cell_type(hidden_size, prefix='rnn_') + cell = cell_type(hidden_size) cell.initialize(ctx=default_context()) if layout == 'TNC': cell(rnn_data[0], states) @@ -376,7 +376,7 @@ def check_unroll(cell_type, num_states, layout): res2, states2 = layer(rnn_data, states, valid_length) params2 = layer.collect_params() for key, val in orig_params1.items(): - params2[key].set_data(copy.deepcopy(val.data())) + params2['cell.' + key].set_data(copy.deepcopy(val.data())) trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) with mx.autograd.record(): @@ -390,7 +390,7 @@ def check_unroll(cell_type, num_states, layout): for key, val in params1.items(): weight1 = val.data() - weight2 = params2[key].data() + weight2 = params2['cell.' + key].data() assert_almost_equal(weight1, weight2, rtol=0.001, atol=0.0001) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 844c8b2b857f..a18dce054744 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -30,14 +30,9 @@ def _get_test_network(params=None): - net = nn.Sequential(params=params) + net = nn.Sequential() net.add(nn.Dense(4, activation='relu', flatten=False)) - return net - -def _get_test_network_with_namescope(params=None): - net = nn.Sequential(params=params) - with net.name_scope(): - net.add(nn.Dense(4, activation='relu', flatten=False)) + net.share_parameters(params) return net def _get_test_data(): @@ -375,7 +370,6 @@ def test_default_handlers(): def test_val_net(): ''' test estimator with different training and validation networks ''' - ''' test weight sharing of sequential networks without namescope ''' net = _get_test_network() val_net = _get_test_network(params=net.collect_params()) dataloader, dataiter = _get_test_data() @@ -394,24 +388,6 @@ def test_val_net(): val_loss=val_loss, val_net=val_net) - with pytest.raises(RuntimeError): - est.fit(train_data=dataloader, - val_data=dataloader, - epochs=num_epochs) - - ''' test weight sharing of sequential networks with namescope ''' - net = _get_test_network_with_namescope() - val_net = _get_test_network_with_namescope(params=net.collect_params()) - net.initialize(ctx=ctx) - trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) - est = Estimator(net=net, - loss=loss, - train_metrics=acc, - trainer=trainer, - context=ctx, - val_loss=val_loss, - val_net=val_net) - est.fit(train_data=dataloader, val_data=dataloader, epochs=num_epochs) @@ -420,7 +396,7 @@ def test_val_net(): net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx) net.output = gluon.nn.Dense(10) val_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx) - val_net.output = gluon.nn.Dense(10, params=net.output.collect_params()) + val_net.output = net.output dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10))) dataloader = gluon.data.DataLoader(dataset=dataset, batch_size=5) net.initialize(ctx=ctx) diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py index 2c7d73487d0e..191a070be287 100644 --- a/tests/python/unittest/test_gluon_model_zoo.py +++ b/tests/python/unittest/test_gluon_model_zoo.py @@ -41,7 +41,7 @@ def eprint(*args, **kwargs): 'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25' ]) def test_models(model_name): - pretrained_to_test = set(['squeezenet1.1']) + pretrained_to_test = set(['vgg19_bn']) test_pretrain = model_name in pretrained_to_test model = get_model(model_name, pretrained=test_pretrain, root='model/') @@ -49,7 +49,7 @@ def test_models(model_name): eprint('testing forward for %s' % model_name) print(model) if not test_pretrain: - model.collect_params().initialize() + model.initialize() model(mx.nd.random.uniform(shape=data_shape)).wait_to_read() def parallel_download(model_name): diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 73ac038c60c4..933c2c17d95f 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -46,27 +46,29 @@ def check_rnn_states(fused_states, stack_states, num_layers, bidirectional=False def test_rnn(): - cell = gluon.rnn.RNNCell(100, prefix='rnn_') - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] + cell = gluon.rnn.RNNCell(100) + inputs = [mx.sym.Variable('t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', - 'rnn_i2h_bias', 'rnn_i2h_weight'] - assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] + assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', + 'i2h_bias', 'i2h_weight'] + assert outputs.list_outputs() == \ + [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) + args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] def test_lstm(): - cell = gluon.rnn.LSTMCell(100, prefix='rnn_') - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] + cell = gluon.rnn.LSTMCell(100) + inputs = [mx.sym.Variable('t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] - assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] + assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight'] + assert outputs.list_outputs() == \ + [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) + args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @@ -83,15 +85,12 @@ def test_lstmp(): # ==== Unidirectional Layer ==== for num_layers in [1, 3]: fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, - num_layers=num_layers, layout='TNC', bidirectional=False, - prefix='lstm0_') - - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix='lstm0_') - with stack_layer.name_scope(): - for i in range(num_layers): - stack_layer.add(gluon.contrib.rnn.LSTMPCell(hidden_size, - projection_size=projection_size, - prefix='l%d_' % i)) + num_layers=num_layers, layout='TNC', bidirectional=False) + + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell() + for i in range(num_layers): + stack_layer.add(gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size)) fused_layer.initialize() stack_layer.initialize() @@ -104,7 +103,7 @@ def test_lstmp(): for name, value in fused_layer_params.items(): w = mx.nd.random.uniform(shape=value.shape) value.set_data(w.copy()) - stack_layer_params[name].set_data(w.copy()) + stack_layer_params[name[1:].replace('_', '.', 1)].set_data(w.copy()) fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, @@ -117,19 +116,15 @@ def test_lstmp(): # ==== Bidirectional Layer ==== for num_layers in [1, 3]: fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, - num_layers=num_layers, layout='TNC', bidirectional=True, - prefix='lstm0_') - - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix='lstm0_') - with stack_layer.name_scope(): - for i in range(num_layers): - stack_layer.add( - gluon.rnn.BidirectionalCell(gluon.contrib.rnn.LSTMPCell(hidden_size, - projection_size=projection_size, - prefix='l%d_' % i), - gluon.contrib.rnn.LSTMPCell(hidden_size, - projection_size=projection_size, - prefix='r%d_' % i))) + num_layers=num_layers, layout='TNC', bidirectional=True) + + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell() + for i in range(num_layers): + stack_layer.add( + gluon.rnn.BidirectionalCell(gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size), + gluon.contrib.rnn.LSTMPCell(hidden_size, + projection_size=projection_size))) fused_layer.initialize() stack_layer.initialize() @@ -142,7 +137,8 @@ def test_lstmp(): for name, value in fused_layer_params.items(): w = mx.nd.random.uniform(shape=value.shape) value.set_data(w.copy()) - stack_layer_params[name].set_data(w.copy()) + cur = name.split("_")[0] + stack_layer_params["{}.{}_cell.{}".format(cur[1:], name[0], name[len(cur)+1:])].set_data(w.copy()) fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, @@ -162,49 +158,46 @@ def test_lstm_cpu_inference(): [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) x = mx.nd.ones(shape=(2, 2, 2)) model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) - model_cell = model._unfuse() model.initialize(mx.init.One()) y = model(x).asnumpy() - y_cell = model_cell.unroll(2, x, layout='TNC', merge_outputs=True)[0].asnumpy() - - mx.test_utils.assert_almost_equal(y_cell, EXPECTED_LSTM_OUTPUT, - rtol=1e-3, atol=1e-5) mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, rtol=1e-3, atol=1e-5) def test_gru(): - cell = gluon.rnn.GRUCell(100, prefix='rnn_', activation='relu', recurrent_activation='tanh') - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] + cell = gluon.rnn.GRUCell(100, activation='relu', recurrent_activation='tanh') + inputs = [mx.sym.Variable('t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] - assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] + assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight'] + assert outputs.list_outputs() == \ + [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) + args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @pytest.mark.serial def test_residual(): - cell = gluon.rnn.ResidualCell(gluon.rnn.GRUCell(50, prefix='rnn_')) - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)] + cell = gluon.rnn.ResidualCell(gluon.rnn.GRUCell(50)) + inputs = [mx.sym.Variable('t%d_data'%i) for i in range(2)] outputs, _ = cell.unroll(2, inputs) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == \ - ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] + params = cell.collect_params() + assert sorted(params.keys()) == \ + ['base_cell.h2h_bias', 'base_cell.h2h_weight', 'base_cell.i2h_bias', 'base_cell.i2h_weight'] # assert outputs.list_outputs() == \ # ['rnn_t0_out_plus_residual_output', 'rnn_t1_out_plus_residual_output'] - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50)) + args, outs, auxs = outputs.infer_shape(t0_data=(10, 50), t1_data=(10, 50)) assert outs == [(10, 50), (10, 50)] - outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50)), - rnn_t1_data=mx.nd.ones((10, 50)), - rnn_i2h_weight=mx.nd.zeros((150, 50)), - rnn_i2h_bias=mx.nd.zeros((150,)), - rnn_h2h_weight=mx.nd.zeros((150, 50)), - rnn_h2h_bias=mx.nd.zeros((150,))) + outputs = outputs.eval(**{'t0_data':mx.nd.ones((10, 50)), + 't1_data':mx.nd.ones((10, 50)), + params['base_cell.i2h_weight'].name:mx.nd.zeros((150, 50)), + params['base_cell.i2h_bias'].name:mx.nd.zeros((150,)), + params['base_cell.h2h_weight'].name:mx.nd.zeros((150, 50)), + params['base_cell.h2h_bias'].name:mx.nd.zeros((150,))}) expected_outputs = np.ones((10, 50)) assert np.array_equal(outputs[0].asnumpy(), expected_outputs) assert np.array_equal(outputs[1].asnumpy(), expected_outputs) @@ -214,30 +207,32 @@ def test_residual(): def test_residual_bidirectional(): cell = gluon.rnn.ResidualCell( gluon.rnn.BidirectionalCell( - gluon.rnn.GRUCell(25, prefix='rnn_l_'), - gluon.rnn.GRUCell(25, prefix='rnn_r_'))) - + gluon.rnn.GRUCell(25), + gluon.rnn.GRUCell(25))) inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)] outputs, _ = cell.unroll(2, inputs, merge_outputs=False) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == \ - ['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight', - 'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight'] + params = cell.collect_params() + assert sorted(params.keys()) == \ + ['base_cell.l_cell.h2h_bias', 'base_cell.l_cell.h2h_weight', + 'base_cell.l_cell.i2h_bias', 'base_cell.l_cell.i2h_weight', + 'base_cell.r_cell.h2h_bias', 'base_cell.r_cell.h2h_weight', + 'base_cell.r_cell.i2h_bias', 'base_cell.r_cell.i2h_weight'] # assert outputs.list_outputs() == \ # ['bi_t0_plus_residual_output', 'bi_t1_plus_residual_output'] args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50)) assert outs == [(10, 50), (10, 50)] - outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50))+5, - rnn_t1_data=mx.nd.ones((10, 50))+5, - rnn_l_i2h_weight=mx.nd.zeros((75, 50)), - rnn_l_i2h_bias=mx.nd.zeros((75,)), - rnn_l_h2h_weight=mx.nd.zeros((75, 25)), - rnn_l_h2h_bias=mx.nd.zeros((75,)), - rnn_r_i2h_weight=mx.nd.zeros((75, 50)), - rnn_r_i2h_bias=mx.nd.zeros((75,)), - rnn_r_h2h_weight=mx.nd.zeros((75, 25)), - rnn_r_h2h_bias=mx.nd.zeros((75,))) + outputs = outputs.eval(**{'rnn_t0_data':mx.nd.ones((10, 50))+5, + 'rnn_t1_data':mx.nd.ones((10, 50))+5, + params['base_cell.l_cell.i2h_weight'].name:mx.nd.zeros((75, 50)), + params['base_cell.l_cell.i2h_bias'].name:mx.nd.zeros((75,)), + params['base_cell.l_cell.h2h_weight'].name:mx.nd.zeros((75, 25)), + params['base_cell.l_cell.h2h_bias'].name:mx.nd.zeros((75,)), + params['base_cell.r_cell.i2h_weight'].name:mx.nd.zeros((75, 50)), + params['base_cell.r_cell.i2h_bias'].name:mx.nd.zeros((75,)), + params['base_cell.r_cell.h2h_weight'].name:mx.nd.zeros((75, 25)), + params['base_cell.r_cell.h2h_bias'].name:mx.nd.zeros((75,))}) expected_outputs = np.ones((10, 50))+5 assert np.array_equal(outputs[0].asnumpy(), expected_outputs) assert np.array_equal(outputs[1].asnumpy(), expected_outputs) @@ -247,21 +242,28 @@ def test_stack(): cell = gluon.rnn.SequentialRNNCell() for i in range(5): if i == 1: - cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i))) + cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100))) else: - cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i)) - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] + cell.add(gluon.rnn.LSTMCell(100)) + inputs = [mx.sym.Variable('t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) keys = sorted(cell.collect_params().keys()) for i in range(5): - assert 'rnn_stack%d_h2h_weight'%i in keys - assert 'rnn_stack%d_h2h_bias'%i in keys - assert 'rnn_stack%d_i2h_weight'%i in keys - assert 'rnn_stack%d_i2h_bias'%i in keys - assert outputs.list_outputs() == ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output'] - - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) + if i==1: + continue + assert '%d.h2h_weight'%i in keys + assert '%d.h2h_bias'%i in keys + assert '%d.i2h_weight'%i in keys + assert '%d.i2h_bias'%i in keys + assert '1.base_cell.h2h_weight' in keys + assert '1.base_cell.h2h_bias' in keys + assert '1.base_cell.i2h_weight' in keys + assert '1.base_cell.i2h_bias' in keys + assert outputs.list_outputs() == \ + [cell[4].name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + + args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @@ -270,21 +272,28 @@ def test_hybridstack(): cell = gluon.rnn.HybridSequentialRNNCell() for i in range(5): if i == 1: - cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i))) + cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100))) else: - cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i)) - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] + cell.add(gluon.rnn.LSTMCell(100)) + inputs = [mx.sym.Variable('t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) keys = sorted(cell.collect_params().keys()) for i in range(5): - assert 'rnn_stack%d_h2h_weight'%i in keys - assert 'rnn_stack%d_h2h_bias'%i in keys - assert 'rnn_stack%d_i2h_weight'%i in keys - assert 'rnn_stack%d_i2h_bias'%i in keys - assert outputs.list_outputs() == ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output'] - - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) + if i==1: + continue + assert '%d.h2h_weight'%i in keys + assert '%d.h2h_bias'%i in keys + assert '%d.i2h_weight'%i in keys + assert '%d.i2h_bias'%i in keys + assert '1.base_cell.h2h_weight' in keys + assert '1.base_cell.h2h_bias' in keys + assert '1.base_cell.i2h_weight' in keys + assert '1.base_cell.i2h_bias' in keys + assert outputs.list_outputs() == \ + [cell[4].name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + + args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] # Test HybridSequentialRNNCell nested in nn.HybridBlock, SequentialRNNCell will fail in this case @@ -292,23 +301,22 @@ class BidirectionalOfSequential(gluon.HybridBlock): def __init__(self): super(BidirectionalOfSequential, self).__init__() - with self.name_scope(): - cell0 = gluon.rnn.HybridSequentialRNNCell() - cell0.add(gluon.rnn.LSTMCell(100)) - cell0.add(gluon.rnn.LSTMCell(100)) + cell0 = gluon.rnn.HybridSequentialRNNCell() + cell0.add(gluon.rnn.LSTMCell(100)) + cell0.add(gluon.rnn.LSTMCell(100)) - cell1 = gluon.rnn.HybridSequentialRNNCell() - cell1.add(gluon.rnn.LSTMCell(100)) - cell1.add(gluon.rnn.LSTMCell(100)) + cell1 = gluon.rnn.HybridSequentialRNNCell() + cell1.add(gluon.rnn.LSTMCell(100)) + cell1.add(gluon.rnn.LSTMCell(100)) - self.rnncell = gluon.rnn.BidirectionalCell(cell0, cell1) + self.rnncell = gluon.rnn.BidirectionalCell(cell0, cell1) def hybrid_forward(self, F, x): return self.rnncell.unroll(3, x, layout="NTC", merge_outputs=True) x = mx.nd.random.uniform(shape=(10, 3, 100)) net = BidirectionalOfSequential() - net.collect_params().initialize() + net.initialize() outs, _ = net(x) assert outs.shape == (10, 3, 200) @@ -316,15 +324,14 @@ def hybrid_forward(self, F, x): def test_bidirectional(): cell = gluon.rnn.BidirectionalCell( - gluon.rnn.LSTMCell(100, prefix='rnn_l0_'), - gluon.rnn.LSTMCell(100, prefix='rnn_r0_'), - output_prefix='rnn_bi_') - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] + gluon.rnn.LSTMCell(100), + gluon.rnn.LSTMCell(100)) + inputs = [mx.sym.Variable('t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - assert outputs.list_outputs() == ['rnn_bi_t0_output', 'rnn_bi_t1_output', 'rnn_bi_t2_output'] + assert outputs.list_outputs() == ['t0_output', 't1_output', 't2_output'] - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) + args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 200), (10, 200), (10, 200)] @@ -335,9 +342,8 @@ def test_layer_bidirectional(): class RefBiLSTM(gluon.Block): def __init__(self, size, **kwargs): super(RefBiLSTM, self).__init__(**kwargs) - with self.name_scope(): - self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0') - self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0') + self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False) + self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False) def forward(self, inpt): fwd = self._lstm_fwd(inpt) @@ -350,20 +356,20 @@ def forward(self, inpt): in_size = 5 weights = {} for d in ['l', 'r']: - weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) - weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) - weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + weights['{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) + weights['{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) + weights['{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + weights['{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - net = gluon.rnn.LSTM(size, bidirectional=True, prefix='lstm_') - ref_net = RefBiLSTM(size, prefix='lstm_') + net = gluon.rnn.LSTM(size, bidirectional=True) + ref_net = RefBiLSTM(size) net.initialize() ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() for k in weights: net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k]) + ref_net_params[k.replace('l0', '_lstm_fwd.l0').replace('r0', '_lstm_bwd.l0')].set_data(weights[k]) data = mx.random.uniform(shape=(11, 10, in_size)) assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy(), rtol=1e-04, atol=1e-02) @@ -371,8 +377,8 @@ def forward(self, inpt): def test_zoneout(): - cell = gluon.rnn.ZoneoutCell(gluon.rnn.RNNCell(100, prefix='rnn_'), zoneout_outputs=0.5, - zoneout_states=0.5) + cell = gluon.rnn.ZoneoutCell(gluon.rnn.RNNCell(100), zoneout_outputs=0.5, + zoneout_states=0.5) inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) @@ -386,10 +392,10 @@ def test_unroll_layout(): cell = gluon.rnn.HybridSequentialRNNCell() for i in range(5): if i == 1: - cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i))) + cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100))) else: - cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i)) - cell.collect_params().initialize() + cell.add(gluon.rnn.LSTMCell(100)) + cell.initialize() inputs = [mx.nd.random.uniform(shape=(10,50)) for _ in range(3)] outputs, _ = cell.unroll(3, inputs, layout='TNC') assert outputs[0].shape == (10, 100) @@ -493,8 +499,7 @@ def test_rnn_cells_export_import(): class RNNLayer(gluon.HybridBlock): def __init__(self): super(RNNLayer, self).__init__() - with self.name_scope(): - self.cell = gluon.rnn.RNNCell(hidden_size=1) + self.cell = gluon.rnn.RNNCell(hidden_size=1) def hybrid_forward(self, F, seq): outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True) @@ -503,8 +508,7 @@ def hybrid_forward(self, F, seq): class LSTMLayer(gluon.HybridBlock): def __init__(self): super(LSTMLayer, self).__init__() - with self.name_scope(): - self.cell = gluon.rnn.LSTMCell(hidden_size=1) + self.cell = gluon.rnn.LSTMCell(hidden_size=1) def hybrid_forward(self, F, seq): outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True) @@ -513,8 +517,7 @@ def hybrid_forward(self, F, seq): class GRULayer(gluon.HybridBlock): def __init__(self): super(GRULayer, self).__init__() - with self.name_scope(): - self.cell = gluon.rnn.GRUCell(hidden_size=1) + self.cell = gluon.rnn.GRUCell(hidden_size=1) def hybrid_forward(self, F, seq): outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True) @@ -537,7 +540,7 @@ def hybrid_forward(self, F, seq): def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.cpu()): - layer.collect_params().initialize(ctx=ctx) + layer.initialize(ctx=ctx) inputs = inputs.as_in_context(ctx) inputs.attach_grad() if states is not None: @@ -612,7 +615,7 @@ def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()): net.add(gluon.nn.BatchNorm(axis=2)) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(3, activation='relu')) - net.collect_params().initialize(ctx=ctx) + net.initialize(ctx=ctx) net.cast(dtype) with mx.autograd.record(): out = net(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx)) @@ -625,7 +628,7 @@ def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()): net2.add(gluon.nn.Flatten()) net2.add(gluon.nn.Dense(3, activation='relu')) net2.hybridize() - net2.collect_params().initialize(ctx=ctx) + net2.initialize(ctx=ctx) net2.cast(dtype) with mx.autograd.record(): out = net2(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx)) @@ -638,7 +641,7 @@ def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()): net3.add(gluon.nn.Flatten()) net3.add(gluon.nn.Dense(3, activation='relu')) net3.hybridize() - net3.collect_params().initialize(ctx=ctx) + net3.initialize(ctx=ctx) net3.cast(dtype2) with mx.autograd.record(): out = net3(mx.nd.ones((2, 3, 10), dtype=dtype2, ctx=ctx)) @@ -665,12 +668,15 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz stack_layer_params = stack_layer.collect_params() for name, value in fused_layer_params.items(): - if 'rnn' in fused_layer.prefix and 'weight' in name: + if 'weight' in name: w = mx.nd.zeros(shape=value.shape) else: w = mx.nd.random.normal(shape=value.shape) value.set_data(w.copy()) - stack_layer_params[name].set_data(w.copy()) + cur = name.split('_')[0] + num = cur[1:] + stack_name = ('{}.{}_cell.'.format(num, name[0]) if bidirectional else num + '.' ) + name[len(cur)+1:] + stack_layer_params[stack_name].set_data(w.copy()) fx = x.copy() sx = x.copy() @@ -694,8 +700,11 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, atol=atol) assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol) - for key, value in fused_grads.items(): - assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol) + for name, value in fused_grads.items(): + cur = name.split('_')[0] + num = cur[1:] + stack_name = ('{}.{}_cell.'.format(num, name[0]) if bidirectional else num + '.' ) + name[len(cur)+1:] + assert_allclose(value.asnumpy(), stack_grads[stack_name].asnumpy(), rtol=rtol, atol=atol) num_layers = fused_begin_state[0].shape[0] // (2 if bidirectional else 1) check_rnn_states(fused_states, stack_states, num_layers, bidirectional, len(fused_begin_state) == 2) @@ -725,13 +734,12 @@ def create_op_by_mode(mode): def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) - fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) + fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=False) fused_layer.initialize() - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) - with stack_layer.name_scope(): - for n in range(num_layers): - stack_layer.add(stack_op(hidden_size, prefix=f'l{n}_')) + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell() + for n in range(num_layers): + stack_layer.add(stack_op(hidden_size)) stack_layer.initialize() check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) @@ -739,15 +747,14 @@ def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) - fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) + fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=True) fused_layer.initialize() - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) - with stack_layer.name_scope(): - for n in range(num_layers): - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix=f'l{n}_'), - stack_op(hidden_size, prefix=f'r{n}_'))) - stack_layer.initialize() + stack_layer = mx.gluon.rnn.HybridSequentialRNNCell() + for n in range(num_layers): + stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size), + stack_op(hidden_size))) + stack_layer.initialize() check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) @@ -818,8 +825,9 @@ def test_rnn_unroll_variant_length(): valid_length = [3, 10, 5, 6] valid_length_nd = mx.nd.array(valid_length) for cell in cell_list: - cell.collect_params().initialize() + cell.initialize() cell.hybridize() + print(cell.collect_params()) # Test for NTC layout data_nd = mx.nd.random.normal(0, 1, shape=(batch_size, max_length, 20)) outs, states = cell.unroll(length=max_length, inputs=data_nd, @@ -882,11 +890,9 @@ class BiLSTM(gluon.nn.HybridBlock): def __init__(self, rnn_size, time_step, **kwargs): super(BiLSTM, self).__init__(**kwargs) self.time_step = time_step - with self.name_scope(): - self.bi_lstm = gluon.rnn.BidirectionalCell( - gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'), - gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'), - output_prefix='lstm_bi_') + self.bi_lstm = gluon.rnn.BidirectionalCell( + gluon.rnn.LSTMCell(rnn_size), + gluon.rnn.LSTMCell(rnn_size)) def hybrid_forward(self, F, inputs, valid_len): outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len, diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 874ab8c9468a..8cf78042411e 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -177,12 +177,11 @@ def test_trainer_multi_layer_init(): class Net(gluon.Block): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - # sparse param - self.embed_weight = self.params.get('embed_weight', stype='row_sparse', - shape=(4,3), grad_stype='row_sparse') - # dense param from a hybrid block - self.dense0 = nn.Dense(2) + # sparse param + self.embed_weight = gluon.Parameter('embed_weight', stype='row_sparse', + shape=(4,3), grad_stype='row_sparse') + # dense param from a hybrid block + self.dense0 = nn.Dense(2) def forward(self, x): embed_weight = self.embed_weight.row_sparse_data(x) @@ -191,7 +190,7 @@ def forward(self, x): return self.dense0(embed) def check_init(ctxes): - net = Net(prefix='net_') + net = Net() net.initialize(mx.init.One(), ctx=ctxes) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1}) data = mx.nd.array([[0,2], [1,2]]) @@ -221,11 +220,11 @@ def check_init(ctxes): @with_seed() def test_trainer_reset_kv(): def check_trainer_reset_kv(kv): - params = gluon.ParameterDict() - x = params.get('x', shape=(10,), lr_mult=1.0) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + x = gluon.Parameter('x', shape=(10,), lr_mult=1.0) + params = {'x': x} + x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) - params.save('test_trainer_reset_kv.params') + mx.nd.save('test_trainer_reset_kv.params', {k: v._reduce() for k, v in params.items()}) with mx.autograd.record(): for w in x.list_data(): y = w + 1 @@ -234,7 +233,8 @@ def check_trainer_reset_kv(kv): assert trainer._kvstore.type == kv # load would reset kvstore mx.nd.waitall() - params.load('test_trainer_reset_kv.params') + params = mx.nd.load('test_trainer_reset_kv.params') + x._load_init(params['x'], None) if trainer._update_on_kvstore: # drop kvstore state if new parameters are loaded assert trainer._kvstore is None @@ -255,10 +255,9 @@ def check_trainer_reset_kv(kv): @with_seed() def test_trainer_sparse_kv(): def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected): - params = gluon.ParameterDict() - x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + x = mx.gluon.Parameter('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) + x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0)) try: @@ -337,11 +336,11 @@ def test_gluon_trainer_param_order(): layers = {'ones_': 1, 'zeros_': 0} for name, init in layers.items(): net.add(mx.gluon.nn.Dense(10, in_units=10, weight_initializer=mx.init.Constant(init), - use_bias=False, prefix=name)) - params = net.collect_params() + use_bias=False)) net.initialize() + params = net.collect_params() trainer = gluon.Trainer(params, 'sgd') for name, init in layers.items(): expected_idx = 0 if name == 'ones_' else 1 - expected_name = name + 'weight' - assert trainer._params[expected_idx].name == expected_name + expected_name = '{}.weight'.format(expected_idx) + assert trainer._params[expected_idx].name == params[expected_name].name diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index 1c465d43539d..ae3c33a4d9b7 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -628,8 +628,7 @@ def test_dense_backward_flatten(): for x in NDArrayGenerator(4,2): hidden = random.randrange(1, 4) net = gluon.nn.Sequential() - with net.name_scope(): - net.add(gluon.nn.Dense(hidden, flatten=True)) + net.add(gluon.nn.Dense(hidden, flatten=True)) net.initialize(mxnet.initializer.Constant(.5)) x.attach_grad() with autograd.record(): @@ -673,8 +672,7 @@ def test_dense_backward_no_flatten(): for x in NDArrayGenerator(5,3): hidden = random.randrange(1, 4) net = gluon.nn.Sequential() - with net.name_scope(): - net.add(gluon.nn.Dense(hidden, flatten=False)) + net.add(gluon.nn.Dense(hidden, flatten=False)) net.initialize(mxnet.initializer.Constant(.5)) x.attach_grad() with autograd.record(): diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index 5e6c0f798d9e..9c1d496d715e 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -130,7 +130,7 @@ def test_sdml_loss(): # Init model and trainer sdml_loss = gluon.loss.SDMLLoss() model = gluon.nn.Dense(DIM, activation='tanh') # Simple NN encoder - model.collect_params().initialize(mx.init.Xavier(), ctx=mx.current_context()) + model.initialize(mx.init.Xavier(), ctx=mx.current_context()) trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate' : 0.1}) for i in range(EPOCHS): # Training loop diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index d2021666a908..a3adad68f985 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -47,8 +47,7 @@ def check_block_params(x, TestBlock, hybridize, expected_type, initializer): class TestBlock1(gluon.HybridBlock): def __init__(self): super(TestBlock1, self).__init__() - with self.name_scope(): - self.w = self.params.get('w', shape=(K, N), allow_deferred_init=True) + self.w = gluon.Parameter('w', shape=(K, N), allow_deferred_init=True) def hybrid_forward(self, F, x, w): return F.dot(x, w) @@ -57,8 +56,7 @@ def hybrid_forward(self, F, x, w): class TestBlock2(gluon.HybridBlock): def __init__(self): super(TestBlock2, self).__init__() - with self.name_scope(): - self.w = self.params.get('w', shape=(K, N), allow_deferred_init=True) + self.w = gluon.Parameter('w', shape=(K, N), allow_deferred_init=True) def hybrid_forward(self, F, x, w): return F.np.dot(x, w) @@ -77,11 +75,10 @@ def test_optimizer_with_np_ndarrays(): class LinearRegression(gluon.HybridBlock): def __init__(self, num_input_dim=0, num_hidden_dim=100, num_output_dim=10): super(LinearRegression, self).__init__() - with self.name_scope(): - self.w1 = self.params.get('w1', shape=(num_input_dim, num_hidden_dim), - allow_deferred_init=True) - self.w2 = self.params.get('w2', shape=(num_hidden_dim, num_output_dim), - allow_deferred_init=True) + self.w1 = gluon.Parameter('w1', shape=(num_input_dim, num_hidden_dim), + allow_deferred_init=True) + self.w2 = gluon.Parameter('w2', shape=(num_hidden_dim, num_output_dim), + allow_deferred_init=True) def hybrid_forward(self, F, x, w1, w2): h = x.dot(w1) # equivalent to F.np.dot(x, w1) @@ -166,9 +163,9 @@ def test_np_get_constant(): const_arr = _np.random.uniform(0, 100, size=(10, 10)).astype(_np.float32) class Foo(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(Foo, self).__init__(prefix=prefix, params=params) - self.weight = self.params.get_constant('const', const_arr) + def __init__(self): + super(Foo, self).__init__() + self.weight = gluon.Constant(const_arr) def hybrid_forward(self, F, x, weight): return x + weight.astype(np.float32) @@ -195,7 +192,7 @@ def test_parameters_zero_grad(): out = net(mx.np.ones((32, 8))) for v in net.collect_params().values(): v.grad()[()] = 1 - net.collect_params().zero_grad() + net.zero_grad() for v in net.collect_params().values(): assert_almost_equal(v.grad().asnumpy(), mx.np.zeros_like(v.grad()).asnumpy()) @@ -328,10 +325,9 @@ def hybrid_forward(self, F, x): return x class TestSlicingWithSplit2(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(TestSlicingWithSplit2, self).__init__(prefix=prefix, params=params) - with self.name_scope(): - self.layer = gluon.nn.Dense(16, flatten=False, params=params) + def __init__(self): + super(TestSlicingWithSplit2, self).__init__() + self.layer = gluon.nn.Dense(16, flatten=False) def hybrid_forward(self, F, x, y): x = F.np.split(x, 1) @@ -384,10 +380,9 @@ def hybrid_forward(self, F, x): @use_np def test_net_symbol_save_load(): class Case1(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(Case1, self).__init__(prefix=prefix, params=params) - with self.name_scope(): - self.layer = gluon.nn.Dense(64, flatten=False, params=params) + def __init__(self): + super(Case1, self).__init__() + self.layer = gluon.nn.Dense(64, flatten=False) def hybrid_forward(self, F, x, y): x = F.np.split(x, 1) @@ -397,11 +392,10 @@ def hybrid_forward(self, F, x, y): mx.np.random.normal(0, 1, (10, 5, 8, 6))]) class Case2(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(Case2, self).__init__(prefix=prefix, params=params) - with self.name_scope(): - self.layer1 = gluon.nn.Dense(64, flatten=False, params=params) - self.layer2 = gluon.nn.Dense(64, flatten=False, params=params) + def __init__(self): + super(Case2, self).__init__() + self.layer1 = gluon.nn.Dense(64, flatten=False) + self.layer2 = gluon.nn.Dense(64, flatten=False) def hybrid_forward(self, F, x, y): x = F.np.split(x, 1) @@ -415,8 +409,8 @@ def hybrid_forward(self, F, x, y): @use_np def test_hybridize_boolean_dtype(): class Foo(gluon.HybridBlock): - def __init__(self, prefix=None, params=None): - super(Foo, self).__init__(prefix=prefix, params=params) + def __init__(self): + super(Foo, self).__init__() def hybrid_forward(self, F, valid_length): mask = ((F.np.ones((10,)) / 2) < valid_length) diff --git a/tests/python/unittest/test_profiler.py b/tests/python/unittest/test_profiler.py index e026d57a4402..2bdfb0c7f9d2 100644 --- a/tests/python/unittest/test_profiler.py +++ b/tests/python/unittest/test_profiler.py @@ -539,13 +539,12 @@ def test_gpu_memory_profiler_gluon(): run=True, continuous_dump=True) profiler.set_state('run') - model = nn.HybridSequential(prefix='net_') - with model.name_scope(): - model.add(nn.Dense(128, activation='tanh')) - model.add(nn.Dropout(0.5)) - model.add(nn.Dense(64, activation='tanh'), - nn.Dense(32, in_units=64)) - model.add(nn.Activation('relu')) + model = nn.HybridSequential() + model.add(nn.Dense(128, activation='tanh')) + model.add(nn.Dropout(0.5)) + model.add(nn.Dense(64, activation='tanh'), + nn.Dense(32, in_units=64)) + model.add(nn.Activation('relu')) model.initialize(ctx=mx.gpu()) model.hybridize() @@ -558,42 +557,15 @@ def test_gpu_memory_profiler_gluon(): profiler.set_state('stop') profiler.dump(True) - # Sample gpu_memory_profiler.csv - # "Attribute Name","Requested Size","Device","Actual Size","Reuse?" - # ":in_arg:data","640","0","4096","0" - # "net:arg_grad:net_dense0_bias","512","0","4096","0" - # "net:arg_grad:net_dense0_weight","5120","0","8192","0" - # "net:arg_grad:net_dense1_bias","256","0","4096","0" - # "net:arg_grad:net_dense1_weight","32768","0","32768","0" - # "net:arg_grad:net_dense2_bias","128","0","4096","0" - # "net:arg_grad:net_dense2_weight","8192","0","8192","0" - # "net:dense0:net_dense0_fwd","8192","0","8192","0" - # "net:dense0:tanh:net_dense0_tanh_fwd","8192","0","8192","0" - # "net:dense1:net_dense1_fwd","4096","0","4096","0" - # "net:dense1:tanh:net_dense1_tanh_fwd","4096","0","4096","0" - # "net:dense2:net_dense2_fwd","2048","0","4096","0" - # "net:dense2:net_dense2_fwd_backward","4096","0","4096","0" - # "net:dropout0:net_dropout0_fwd","8192","0","8192","0" - # "net:dropout0:net_dropout0_fwd","8192","0","8192","0" - # "net:in_arg:net_dense0_bias","512","0","4096","0" - # "net:in_arg:net_dense0_weight","5120","0","8192","0" - # "net:in_arg:net_dense1_bias","256","0","4096","0" - # "net:in_arg:net_dense1_weight","32768","0","32768","0" - # "net:in_arg:net_dense2_bias","128","0","4096","0" - # "net:in_arg:net_dense2_weight","8192","0","8192","0" - # "net:relu0:net_relu0_fwd","2048","0","4096","0" - # "net:relu0:net_relu0_fwd_backward","8192","0","8192","0" - # "net:relu0:net_relu0_fwd_head_grad","2048","0","4096","0" - # "resource:cudnn_dropout_state (dropout-inl.h +258)","1671168","0","1671168","0" - # "resource:temp_space (fully_connected-inl.h +316)","34816","0","36864","0" - # We are only checking for weight parameters here, also making sure that # there is no unknown entries in the memory profile. with open('gpu_memory_profile-pid_%d.csv' % (os.getpid()), mode='r') as csv_file: csv_reader = csv.DictReader(csv_file) + for row in csv_reader: + print(",".join(list(row.values()))) for scope in ['in_arg', 'arg_grad']: for key, nd in model.collect_params().items(): - expected_arg_name = "net:%s:" % scope + key + expected_arg_name = "%s:%s:" % (model.name, scope) + nd.name expected_arg_size = str(4 * np.prod(nd.shape)) csv_file.seek(0) entry_found = False diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index d27295e584e0..bb693f57f415 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -423,7 +423,7 @@ def get_net(): # regular inference x = nd.random.normal(shape=(1, 512),ctx=mx.current_context()) net = get_net() - net.collect_params().initialize(ctx=mx.current_context()) + net.initialize(ctx=mx.current_context()) outputs1 = net(x) param_path = os.path.join(str(tmpdir), 'test_subgraph_backend_gluon_ext1.params') net.save_parameters(param_path) @@ -450,10 +450,9 @@ def test_subgraph_backend_gluon_ext2(tmpdir): class Net(gluon.HybridBlock): def __init__(self, **kwargs): super(Net, self).__init__(**kwargs) - with self.name_scope(): - self.fc1 = nn.Dense(256) - self.fc2 = nn.Dense(128) - self.fc3 = nn.Dense(2) + self.fc1 = nn.Dense(256) + self.fc2 = nn.Dense(128) + self.fc3 = nn.Dense(2) def hybrid_forward(self, F, x): x = F.relu(self.fc1(x)) @@ -462,7 +461,7 @@ def hybrid_forward(self, F, x): # regular inference x = nd.random.normal(shape=(1, 512),ctx=mx.current_context()) net = Net() - net.collect_params().initialize(ctx=mx.current_context()) + net.initialize(ctx=mx.current_context()) outputs1 = net(x) param_path = os.path.join(str(tmpdir), 'test_subgraph_backend_gluon_ext2.params') net.save_parameters(param_path) diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index 975ad2a34873..41c7ff5f49c6 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -118,14 +118,14 @@ def g(): def test_blockscope(): class dummy_block(object): def __init__(self, prefix): - self.prefix = prefix + self.name = prefix self._empty_prefix = False self._profiler_scope_name = ':' blockscope_list = [] status = [False] event = threading.Event() def f(): - net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block + net = dummy_block("spawned") # BlockScope only keeps a weakref to the Block with block._BlockScope(net): x = NameManager.current.get(None, "hello") event.wait() @@ -133,7 +133,6 @@ def f(): status[0] = True thread = threading.Thread(target=f) thread.start() - block._BlockScope.create("main_thread", None, "hi") event.set() thread.join() event.clear()