forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init onnx finish onnx frontend add onnx tests fix various backup use transformer [Frontend] graph passed add test forward test forward fix doc and lint fix test graph tuple from_onnx now take 2 args, output (sym, params) fix rename fix input names fix multiple fix lint fix lint check * better doc
- Loading branch information
Showing
20 changed files
with
1,105 additions
and
223 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""NNVM frontends.""" | ||
from __future__ import absolute_import | ||
from .mxnet import from_mxnet | ||
from .onnx import from_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
"""Shared functions and classes for frontends.""" | ||
from __future__ import absolute_import as _abs | ||
import warnings | ||
from .._base import string_types | ||
|
||
class Renamer(object): | ||
"""A simply renamer for operators. | ||
Parameters | ||
---------- | ||
new_name : str | ||
The new name for the operator | ||
""" | ||
def __init__(self, new_name): | ||
self._new_name = new_name | ||
|
||
def __call__(self, attrs): | ||
return self._new_name, attrs | ||
|
||
|
||
class AttrConverter(object): | ||
"""Common attribute conveter. An AttrConverter instance is a callable: | ||
``` | ||
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) | ||
new_op_name, new_attr = attr_converter(attrs) | ||
``` | ||
Parameters | ||
---------- | ||
op_name : str or callable | ||
If set as str, returned operator name is the str. | ||
If set as callable, returned operator is the str returned by calling: | ||
`op_name = func(attr)` | ||
transforms : dict of `new_name, or (new_name, default_value, transform function)` | ||
If only a new_name is provided, it's like renaming the attribute name. | ||
If default_value if provded, then the attribute is considered as optional. | ||
If transform function is provided, the original attribute value is handled | ||
by transform function. | ||
excludes : list | ||
A list of excluded attributes that should `NOT` appear. | ||
Raise NotImplementedError if occured. | ||
disables : list | ||
A list of attributes that is disabled in nnvm. Raise warnings. | ||
ignores : list | ||
A list of attributes that is ignored in nnvm. Silent. | ||
extras : dict | ||
A series of additional attributes should be added anyway to the returned | ||
attribute dict. | ||
custom_check : callable | ||
A custom function takes attribute, and return True/False. | ||
Raise RuntimeError if not bool(True) returned. | ||
""" | ||
def __init__(self, op_name, transforms=None, | ||
excludes=None, disables=None, ignores=None, | ||
extras=None, custom_check=None): | ||
self._op_name = op_name | ||
self._transforms = transforms if transforms else {} | ||
self._excludes = excludes if excludes else [] | ||
self._disables = disables if disables else [] | ||
self._ignores = ignores if ignores else [] | ||
self._extras = extras if extras else {} | ||
self._custom_check = custom_check | ||
|
||
def __call__(self, attrs): | ||
# apply custom check | ||
if self._custom_check: | ||
func, msg = self._custom_check | ||
if not func(attrs): | ||
raise RuntimeError("Check failed: {}".format(msg)) | ||
# get new op_name | ||
if isinstance(self._op_name, string_types): | ||
op_name = self._op_name | ||
else: | ||
assert callable(self._op_name), "op_name can either be string or callable" | ||
op_name = self._op_name(attrs) | ||
# convert attributes | ||
new_attrs = {} | ||
for k in attrs.keys(): | ||
if k in self._excludes: | ||
raise NotImplementedError("Attribute {} not supported yet.".format(k)) | ||
elif k in self._disables: | ||
warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name)) | ||
elif k in self._ignores: | ||
pass | ||
elif k in self._transforms: | ||
new_name, defaults, transform = self._parse_default(self._transforms[k]) | ||
if defaults is None: | ||
new_attr = self._required_attr(attrs, k) | ||
else: | ||
new_attr = attrs.get(k, None) | ||
if new_attr is None: | ||
new_attrs[new_name] = defaults | ||
else: | ||
new_attrs[new_name] = transform(new_attr) | ||
else: | ||
# copy | ||
new_attrs[k] = attrs[k] | ||
# add extras | ||
new_attrs.update(self._extras) | ||
return op_name, new_attrs | ||
|
||
def _parse_default(self, target): | ||
"""Helper function to parse default values.""" | ||
if not isinstance(target, (list, tuple)): | ||
k, v, t = target, None, lambda x: x | ||
elif len(target) == 1: | ||
k, v, t = target[0], None, lambda x: x | ||
elif len(target) == 2: | ||
k, v, t = target[0], target[1], lambda x: x | ||
elif len(target) > 2: | ||
k, v, t = target[0], target[1], target[2] | ||
else: | ||
k = None # should raise | ||
if not isinstance(k, string_types): | ||
msg = "{} is not a valid target, (name, default) expected.".format(target) | ||
raise ValueError(msg) | ||
return k, v, t | ||
|
||
def _parse_bool(self, value): | ||
"""Helper function to parse default boolean values.""" | ||
if isinstance(value, string_types): | ||
return value.strip().lower() in ['true', '1', 't', 'y', 'yes'] | ||
return bool(value) | ||
|
||
def _required_attr(self, attr, key): | ||
"""Wrapper for getting required attributes.""" | ||
assert isinstance(attr, dict) | ||
if key not in attr: | ||
raise AttributeError("Required attribute {} not found.".format(key)) | ||
return attr[key] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.