From c40d53effac8c5c42fbe0efe7d29c41061d00154 Mon Sep 17 00:00:00 2001 From: MORITA Kazutaka Date: Sun, 6 May 2018 05:56:52 +0900 Subject: [PATCH] add sanity check to input shape type --- python/nnvm/compiler/build_module.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/nnvm/compiler/build_module.py b/python/nnvm/compiler/build_module.py index 3080baf1d..460cc1c75 100644 --- a/python/nnvm/compiler/build_module.py +++ b/python/nnvm/compiler/build_module.py @@ -123,6 +123,18 @@ def _build(funcs, target, target_host): return tvm.build(funcs, target=target, target_host=target_host) +def _check_shape_type(shape): + """Check whether the input shape has an appropriate type.""" + if not isinstance(shape, dict): + raise TypeError("require shape to be dict") + + for key, value in shape.items(): + if not isinstance(key, basestring): + raise TypeError("shape key must be str") + if not all(isinstance(x, int) for x in value): + raise TypeError("shape value must be int iterator") + + def _update_shape_dtype(shape, dtype, params): """Update shape dtype given params information""" if not params: @@ -239,8 +251,8 @@ def build(graph, target=None, shape=None, dtype="float32", target = tvm.target.create(target) shape = shape if shape else {} - if not isinstance(shape, dict): - raise TypeError("require shape to be dict") + _check_shape_type(shape) + cfg = BuildConfig.current graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) shape, dtype = _update_shape_dtype(shape, dtype, params)