Skip to content

Commit

Permalink
Infer dtype in SymbolBlock import from input symbol (apache#12412)
Browse files Browse the repository at this point in the history
* Infer dtype in SymbolBlock import from input symbol

* Fix lint issues and make existing tests pass

* Add tests for importing a fp64 model into symbol block

* Fixing failing test for test symbol block

* Set context in unit tests

* Add tests for fp16, add default dtype in infer_param_types

* Use tmp directory as root for loading from model zoo to avoid race condition

* Fixing naming and parameter selection in test case

* Fixing failing GPU tests

* Make unit test more deterministic to get param name

* Override cast in symbol block, handle grouped symbol

* Handle multiple symbolic input usecase

* Add tests to verify behavior of SymbolBlock.cast
  • Loading branch information
sandeep-krishnamurthy authored Sep 18, 2018
1 parent 1744b0c commit acf309e
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 6 deletions.
86 changes: 80 additions & 6 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import re
from collections import OrderedDict

from ..base import mx_real_t
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
from ..ndarray import NDArray
Expand Down Expand Up @@ -1053,13 +1054,20 @@ def __init__(self, outputs, inputs, params=None):
"SymbolBlock doesn't support Parameter '%s' because its storage " \
"type is 'row_sparse'." % j.name

for i in out.list_arguments():
if i not in input_names:
self.params.get(i, allow_deferred_init=True)
# Infer type of parameters. Without this, every parameter will be created with
# default type i.e., fp32
arg_params = out.list_arguments()
aux_params = out.list_auxiliary_states()

for i in out.list_auxiliary_states():
if i not in input_names:
self.params.get(i, grad_req='null', allow_deferred_init=True)
arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params)

for i, arg in enumerate(arg_params):
if arg not in input_names:
self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i])

for i, aux in enumerate(aux_params):
if aux not in input_names:
self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i])

self._cached_graph = syms, out
len_prefix = len(_common_prefix(list(self._params.keys())))
Expand All @@ -1084,5 +1092,71 @@ def _clear_cached_op(self):
super(SymbolBlock, self)._clear_cached_op()
self._cached_graph = tmp

def cast(self, dtype):
self._clear_cached_op()
super(SymbolBlock, self).cast(dtype)

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError

def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t):
"""Utility function that helps in inferring DType of args and auxs params
from given input param.
Parameters
----------
in_params: List of Symbol
List of input symbol variables.
out_params: Symbol
Output symbol variable.
arg_params: List of Str
List of names of argument parametrs.
aux_params: List of Str
List of names of auxiliary parameters.
default_dtype: numpy.dtype or str, default 'float32'
Default data type for arg_params and aux_params, if unable to infer the type.
Returns
-------
arg_types: List of numpy.dtype
List of arg_params type. Order is same as arg_params.
Defaults to 'float32', if unable to infer type.
aux_types: List of numpy.dtype
List of aux_params type. Order is same as aux_params.
Defaults to 'float32', if unable to infer type.
"""
arg_types = None
aux_types = None

# Get Input symbol details. This will be used to infer types of
# other parameters.
input_sym_names = [in_param.name for in_param in in_params]

# Try to infer input types. If not successful, we will set default dtype.
# If successful, we will try to infer other params in the graph.
input_sym_arg_types = []
can_infer_input_type = True
for in_param in in_params:
input_sym_arg_type = in_param.infer_type()[0]
if not input_sym_arg_type or len(input_sym_arg_type) < 1:
can_infer_input_type = False
break
else:
input_sym_arg_types.append(in_param.infer_type()[0][0])

# Try to infer types of other parameters.
if can_infer_input_type:
params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)}
arg_types, _, aux_types = out_params.infer_type(**params)

if arg_types is None or len(arg_types) != len(arg_params):
arg_types = []
for _ in arg_params:
arg_types.append(default_dtype)

if aux_types is None or len(aux_types) != len(aux_params):
aux_types = []
for _ in aux_params:
aux_types.append(default_dtype)

return (arg_types, aux_types)
2 changes: 2 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,8 @@ def get(self, name, **kwargs):
if matched:
param._shape = tuple(inferred_shape)
continue
elif k == 'dtype' and np.dtype(v) == np.dtype(existing):
continue

assert v is None or v == existing, \
"Cannot retrieve Parameter '%s' because desired attribute " \
Expand Down
31 changes: 31 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function
import sys
import os
import tempfile
import time
import multiprocessing as mp
import unittest
Expand Down Expand Up @@ -202,6 +203,36 @@ def get_num_devices():
_check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
num_devices=ndev, cuda=True)

@with_seed()
def test_symbol_block_fp16():
# Test case to verify if initializing the SymbolBlock from a model with params
# other than fp32 param dtype.

# 1. Load a resnet model, cast it to fp16 and export
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'resnet34_fp16')
ctx = mx.gpu(0)

net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
net_fp32.cast('float16')
net_fp32.hybridize()
data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx)
net_fp32.forward(data)
net_fp32.export(tmpfile, 0)

# 2. Load the saved model and verify if all the params are loaded correctly.
# and choose one of the param to verify the type if fp16.
sm = mx.sym.load(tmpfile + '-symbol.json')
inputs = mx.sym.var('data', dtype='float16')
net_fp16 = mx.gluon.SymbolBlock(sm, inputs)
net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
# 3. Get a conv layer's weight parameter name. Conv layer's weight param is
# expected to be of dtype casted, fp16.
for param_name in net_fp16.params.keys():
if 'conv' in param_name and 'weight' in param_name:
break
assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16)

if __name__ == '__main__':
import nose
nose.runmodule()
38 changes: 38 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import os
import tempfile

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
Expand Down Expand Up @@ -336,6 +339,41 @@ def hybrid_forward(self, F, x):
net.hybridize()
assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)

# Test case to verify if initializing the SymbolBlock from a model with params
# other than fp32 param dtype.

# 1. Load a resnet model, cast it to fp64 and export
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'resnet34_fp64')
ctx = mx.cpu(0)

net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
net_fp32.cast('float64')
net_fp32.hybridize()
data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)
net_fp32.forward(data)
net_fp32.export(tmpfile, 0)

# 2. Load the saved model and verify if all the params are loaded correctly.
# and choose one of the param to verify the type if fp64.
sm = mx.sym.load(tmpfile + '-symbol.json')
inputs = mx.sym.var('data', dtype='float64')
net_fp64 = mx.gluon.SymbolBlock(sm, inputs)
net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
# 3. Get a conv layer's weight parameter name. Conv layer's weight param is
# expected to be of dtype casted, fp64.
for param_name in net_fp64.params.keys():
if 'conv' in param_name and 'weight' in param_name:
break
assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)

# Cast the symbol block to FP32 and try to forward a FP32 data.
# This will verify SymbolBlock.cast() functionality.
net_fp64.cast('float32')
fp32_data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx)
prediction = net_fp64.forward(fp32_data)
assert np.dtype(prediction.dtype) == np.dtype(np.float32)

@with_seed()
@raises(AssertionError)
def test_sparse_symbol_block():
Expand Down

0 comments on commit acf309e

Please sign in to comment.