diff --git a/python/nnvm/compiler/build_module.py b/python/nnvm/compiler/build_module.py index 3080baf1d..460cc1c75 100644 --- a/python/nnvm/compiler/build_module.py +++ b/python/nnvm/compiler/build_module.py @@ -123,6 +123,18 @@ def _build(funcs, target, target_host): return tvm.build(funcs, target=target, target_host=target_host) +def _check_shape_type(shape): + """Check whether the input shape has an appropriate type.""" + if not isinstance(shape, dict): + raise TypeError("require shape to be dict") + + for key, value in shape.items(): + if not isinstance(key, basestring): + raise TypeError("shape key must be str") + if not all(isinstance(x, int) for x in value): + raise TypeError("shape value must be int iterator") + + def _update_shape_dtype(shape, dtype, params): """Update shape dtype given params information""" if not params: @@ -239,8 +251,8 @@ def build(graph, target=None, shape=None, dtype="float32", target = tvm.target.create(target) shape = shape if shape else {} - if not isinstance(shape, dict): - raise TypeError("require shape to be dict") + _check_shape_type(shape) + cfg = BuildConfig.current graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) shape, dtype = _update_shape_dtype(shape, dtype, params)