Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Simplifying mxnet.gluon.block APIs #18413

Merged
merged 42 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a0a0862
modify gluon (not tested yet)
acphile May 20, 2020
7464668
add prefix
acphile May 20, 2020
4d3277a
add set_prefix(); fix some naming problems
acphile May 22, 2020
f441b4f
simple fix; remove output_prefix in birnn
acphile May 22, 2020
25baf9b
add name_scope for block
acphile May 25, 2020
af2deb7
replace '.' with '_' in the prefix format
acphile May 25, 2020
1641fde
modify params in Symbolblock; simple fixes according to prefix
acphile May 26, 2020
bacac4b
rough modifications for tests/ (not tested yet)
acphile May 26, 2020
a4a1946
fix wrong changes for Symbolblock
acphile May 27, 2020
2c2ca88
fix load, collect_params; partially fix tests
acphile May 28, 2020
b8f5b0d
partially fix tests
acphile May 29, 2020
2b78f6e
fix some parts
acphile Jun 1, 2020
70368ad
fix symbolblock
acphile Jun 2, 2020
0a35834
fix some tests
acphile Jun 2, 2020
a192a27
fix dataloader and some small issues
acphile Jun 2, 2020
9ca3d6c
fix SyncBatchNorm issue
acphile Jun 3, 2020
72cb878
merge upstream
acphile Jun 3, 2020
491e06f
fix sanity
acphile Jun 4, 2020
04a40cf
Optimize the APIs
acphile Jun 4, 2020
1c2ed49
try to remove set_prefix()
acphile Jun 8, 2020
951c63e
remove parameter name
acphile Jun 8, 2020
3d2c224
fix Constant and Symbolblock
acphile Jun 8, 2020
983b37a
optimize save, load, export, imports
acphile Jun 9, 2020
71fccaf
improve Symbolblock
acphile Jun 9, 2020
567e4ca
use '.' as the linking character
acphile Jun 9, 2020
3c9ce5f
revert 'remove parameter name'
acphile Jun 9, 2020
c69bfcb
fix Parameter
acphile Jun 9, 2020
1899641
fix some bugs (Predictor has not been fixed)
acphile Jun 10, 2020
d2473eb
revert some bad changes
acphile Jun 10, 2020
c6ae5ec
fix some tests
acphile Jun 10, 2020
da91b40
resolve conflicts
acphile Jun 11, 2020
4e95940
fix sanity
acphile Jun 11, 2020
b0dabbf
resolve conflicts
acphile Jun 11, 2020
dae9f70
fix test_profiler
acphile Jun 11, 2020
8794676
revert unexpected typechange
acphile Jun 12, 2020
ace7d9e
resolve conflicts
acphile Jun 18, 2020
2946f31
remove ParameterDict
acphile Jun 18, 2020
892e55a
a simple fix
acphile Jun 18, 2020
558c004
remove filename parameter in load_dict
acphile Jun 18, 2020
a14cc1f
fix sanity
acphile Jun 18, 2020
be6b646
Merge remote-tracking branch 'upstream/master' into phile_block
acphile Jun 19, 2020
49302bb
remove prefix in mxnet_export_test
acphile Jun 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/gluon/style_transfer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from PIL import Image

from mxnet import autograd, gluon
from mxnet.gluon import nn, Block, HybridBlock, Parameter, ParameterDict
from mxnet.gluon import nn, Block, HybridBlock, Parameter
import mxnet.ndarray as F

import net
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,8 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
# If dtype for the param was set in the json, cast the
# param to this dtype
attr_dict = converted_sym.attr_dict()
for name, param in block.collect_params().items():
for param in block.collect_params().values():
name = param.name
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
if name in attr_dict and "__dtype__" in attr_dict[name]:
Expand Down Expand Up @@ -719,7 +720,7 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
if aux_param_name in arg_dict and param.dtype != arg_dict[aux_param_name].dtype:
param.cast(arg_dict[aux_param_name].dtype)

ret.collect_params().load_dict(arg_dict, ctx=ctx)
ret.load_dict(arg_dict, ctx=ctx)
return ret

def list_lp16_ops(target_dtype):
Expand Down
455 changes: 253 additions & 202 deletions python/mxnet/gluon/block.py

Large diffs are not rendered by default.

297 changes: 148 additions & 149 deletions python/mxnet/gluon/contrib/cnn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .... import symbol
from ...block import HybridBlock
from ...parameter import Parameter
from ....base import numeric_types
from ...nn import Activation

Expand Down Expand Up @@ -103,80 +104,79 @@ def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0),
num_deformable_group=1, layout='NCHW', use_bias=True, in_channels=0, activation=None,
weight_initializer=None, bias_initializer='zeros',
offset_weight_initializer='zeros', offset_bias_initializer='zeros', offset_use_bias=True,
op_name='DeformableConvolution', adj=None, prefix=None, params=None):
super(DeformableConvolution, self).__init__(prefix=prefix, params=params)
with self.name_scope():
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = 2 * kernel_size[0] * kernel_size[1] * num_deformable_group
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = self.params.get('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = self.params.get('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = self.params.get('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = self.params.get('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

if activation:
self.act = Activation(activation, prefix=activation + '_')
else:
self.act = None
op_name='DeformableConvolution', adj=None):
super(DeformableConvolution, self).__init__()
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = 2 * kernel_size[0] * kernel_size[1] * num_deformable_group
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = Parameter('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = Parameter('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = Parameter('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = Parameter('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

if activation:
self.act = Activation(activation)
else:
self.act = None

def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bias=None, deformable_conv_bias=None):
if offset_bias is None:
Expand Down Expand Up @@ -296,81 +296,80 @@ def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0),
num_deformable_group=1, layout='NCHW', use_bias=True, in_channels=0, activation=None,
weight_initializer=None, bias_initializer='zeros',
offset_weight_initializer='zeros', offset_bias_initializer='zeros', offset_use_bias=True,
op_name='ModulatedDeformableConvolution', adj=None, prefix=None, params=None):
super(ModulatedDeformableConvolution, self).__init__(prefix=prefix, params=params)
with self.name_scope():
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = num_deformable_group * 3 * kernel_size[0] * kernel_size[1]
self.offset_split_index = num_deformable_group * 2 * kernel_size[0] * kernel_size[1]
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = self.params.get('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = self.params.get('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = self.params.get('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = self.params.get('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

if activation:
self.act = Activation(activation, prefix=activation + '_')
else:
self.act = None
op_name='ModulatedDeformableConvolution', adj=None):
super(ModulatedDeformableConvolution, self).__init__()
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = num_deformable_group * 3 * kernel_size[0] * kernel_size[1]
self.offset_split_index = num_deformable_group * 2 * kernel_size[0] * kernel_size[1]
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = Parameter('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = Parameter('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = Parameter('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = Parameter('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

if activation:
self.act = Activation(activation)
else:
self.act = None

def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bias=None, deformable_conv_bias=None):
if offset_bias is None:
Expand Down
14 changes: 7 additions & 7 deletions python/mxnet/gluon/contrib/data/vision/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_image_augment(data_shape, resize=0, rand_crop=False, rand_resize=Fals
"""
if inter_method == 10:
inter_method = np.random.randint(0, 5)
augmenter = HybridSequential('default_img_augment_')
augmenter = HybridSequential()
if resize > 0:
augmenter.add(transforms.image.Resize(resize, interpolation=inter_method))
crop_size = (data_shape[2], data_shape[1])
Expand Down Expand Up @@ -220,9 +220,9 @@ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None,
augmenter = create_image_augment(data_shape, **kwargs)
elif isinstance(aug_list, list):
if all([isinstance(a, HybridBlock) for a in aug_list]):
augmenter = HybridSequential('user_img_augment_')
augmenter = HybridSequential()
else:
augmenter = Sequential('user_img_augment_')
augmenter = Sequential()
for aug in aug_list:
augmenter.add(aug)
elif isinstance(aug_list, Block):
Expand Down Expand Up @@ -316,7 +316,7 @@ def create_bbox_augment(data_shape, rand_crop=0, rand_pad=0, rand_gray=0,
"""
if inter_method == 10:
inter_method = np.random.randint(0, 5)
augmenter = Sequential('default_bbox_aug_')
augmenter = Sequential()
if rand_crop > 0:
augmenter.add(bbox.ImageBboxRandomCropWithConstraints(
p=rand_crop, min_scale=area_range[0], max_scale=1.0,
Expand Down Expand Up @@ -439,17 +439,17 @@ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None,
augmenter = create_bbox_augment(data_shape, **kwargs)
elif isinstance(aug_list, list):
if all([isinstance(a, HybridBlock) for a in aug_list]):
augmenter = HybridSequential('user_bbox_augment_')
augmenter = HybridSequential()
else:
augmenter = Sequential('user_bbox_augment_')
augmenter = Sequential()
for aug in aug_list:
augmenter.add(aug)
elif isinstance(aug_list, Block):
augmenter = aug_list
else:
raise ValueError('aug_list must be a list of Blocks')
augmenter.hybridize()
wrapper_aug = Sequential('wrapper_bbox_aug_')
wrapper_aug = Sequential()
wrapper_aug.add(BboxLabelTransform(coord_normalized))
wrapper_aug.add(augmenter)

Expand Down
Loading