Skip to content
This repository has been archived by the owner on Feb 1, 2020. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add sanity check to input shape type
Browse files Browse the repository at this point in the history
kazum committed May 8, 2018
1 parent 8c0a103 commit c40d53e
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
@@ -123,6 +123,18 @@ def _build(funcs, target, target_host):
return tvm.build(funcs, target=target, target_host=target_host)


def _check_shape_type(shape):
"""Check whether the input shape has an appropriate type."""
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")

for key, value in shape.items():
if not isinstance(key, basestring):
raise TypeError("shape key must be str")
if not all(isinstance(x, int) for x in value):
raise TypeError("shape value must be int iterator")


def _update_shape_dtype(shape, dtype, params):
"""Update shape dtype given params information"""
if not params:
@@ -239,8 +251,8 @@ def build(graph, target=None, shape=None, dtype="float32",
target = tvm.target.create(target)

shape = shape if shape else {}
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")
_check_shape_type(shape)

cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params)

0 comments on commit c40d53e

Please sign in to comment.