Skip to content

Commit

Permalink
Merge pull request #1 from anijain2305/tf
Browse files Browse the repository at this point in the history
Fixes for Tuple and index0 error
  • Loading branch information
rohanmukh authored Sep 30, 2020
2 parents 6dab226 + e89094c commit b2c5a04
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
34 changes: 25 additions & 9 deletions python/tvm/relay/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import tvm.ir
import tvm.relay.transform as transform
from tvm import relay
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.transform import _ffi_api
from tvm.relay.expr_functor import ExprMutator
Expand Down Expand Up @@ -93,16 +93,25 @@ def IsTrtRuntimeAvailable():
return False
return GetTrtVersion() != ()

def check_dynamism(args):
def check_dynamism(args, op_name):
for arg in args:
# TODO: Tuple Type
# if isinstance(arg, relay.TupleType):
# print("TupleType inputs are not supported for TensorRT.")
# return False
for arg_shape in arg.checked_type.shape:
if isinstance(arg_shape, tvm.tir.expr.Any):
print("Dynamic inputs are not supported for TensorRT.")
return False
try:
if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
for dim_shape in arg.checked_type.shape:
if isinstance(dim_shape, tvm.tir.expr.Any):
print("Dynamic inputs are not supported for TensorRT for ", op_name, arg.checked_type.shape)
return False
elif isinstance(arg, Tuple):
return check_dynamism(arg.fields, op_name)
else:
raise NotImplementedError(type(arg))
except:
print(args[0])
assert False, "failed for the {}".format(op_name)
return True


Expand All @@ -111,7 +120,7 @@ def _register_external_op_helper(op_name, supported=True):
def _func_wrapper(attrs, args):
# TODO Rohan: Code Repetition for dynamic checks in multiple wrappers
print("Working with op {}".format(op_name))
t = check_dynamism(args)
t = check_dynamism(args, op_name)
if not t:
return t
if any([x.checked_type.dtype != "float32" for x in args]):
Expand All @@ -125,7 +134,9 @@ def _func_wrapper(attrs, args):
def _register_external_op_helper_func(op_name, func, trt_version):
@tvm.ir.register_op_attr(op_name, "target.tensorrt")
def _func_wrapper(attrs, args):
t = check_dynamism(args)
print("Working with op {}".format(op_name))
t = check_dynamism(args, op_name)

if not t:
return t
if any([x.checked_type.dtype != "float32" for x in args]):
Expand All @@ -138,7 +149,8 @@ def _func_wrapper(attrs, args):
def _register_external_dynamic_check_func(op_name, func):
@tvm.ir.register_op_attr(op_name, "target.tensorrt")
def _func_wrapper(attrs, args):
t = check_dynamism(args)
print("Working with op {}".format(op_name))
t = check_dynamism(args, op_name)
if not t:
return t
return func(attrs, args)
Expand Down Expand Up @@ -804,8 +816,12 @@ def EnableTrt(mod, params=None, trt_version=None, use_implicit_batch=True,
'nn.conv3d': ['NCDHW', 'default']}),
transform.FoldConstant(),
LegalizeLayoutTranformPass(),
transform.InferType(),
# tvm.transform.PrintIR("A1"),
transform.AnnotateTarget('tensorrt'),
# tvm.transform.PrintIR("A2"),
transform.MergeCompilerRegions(),
# tvm.transform.PrintIR("A3"),
transform.PartitionGraph(),
transform.InferType()])
with tvm.transform.PassContext(opt_level=3):
Expand Down
12 changes: 7 additions & 5 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,13 @@ class AnnotateTargetRewriter : public ExprRewriter {

// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
if (pre->args.size() > 0) {
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type vtype = GetType(let->value);
let_type = Unify(let_type, vtype, GetRef<Let>(let));

CHECK(is_functional_literal || !type_map_.count(let->var));
CHECK(is_functional_literal || !type_map_.count(let->var)) << AsText(GetRef<Expr>(let), false);
// NOTE: no scoping is necessary because var are unique in program
type_map_[let->var].checked_type = let_type;
return GetType(let->body);
Expand Down

0 comments on commit b2c5a04

Please sign in to comment.