diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fe6d957e79ed..9f28fbebccfc 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -485,6 +485,36 @@ inline ValueType OpMap::get(const Op& op, return map_.get(op, def_value); } +/*! + * \brief Check that an expression is a "primtive operator". + * + * Will return true if the expression is an operator which + * matches the form of primtive operators registered directly + * by the Relay codebase. + * + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + */ +inline bool IsPrimitiveOp(const Expr& expr) { + const auto* op = expr.as(); + + if (!op) { + return false; + } + + const auto& fn_ty = op->op_type; + if (fn_ty->type_constraints.size() != 1) return false; + + const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); + if (rel == nullptr) return false; + // validate if the type parameter matches up + for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { + if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; + } + + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_H_ diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 0ebe111ab6b2..3cbe1e00b9ca 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -278,10 +278,7 @@ class TextPrinter : } TextValue VisitExpr_(const CallNode* op) final { - // TODO(tqchen, M.K.): support generic call // possibly through meta-data - CHECK_EQ(op->type_args.size(), 0U) - << "generic call not yet supported"; TextValue call_op = GetValue(op->op); std::vector args; for (Expr arg : op->args) { @@ -289,7 +286,23 @@ class TextPrinter : } TextValue id = this->AllocTempVar(); this->PrintIndent(); - stream_ << id << " = " << call_op << "("; + + stream_ << id << " = " << call_op; + + auto type_args = op->type_args; + + if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) { + stream_ << "<"; + for (size_t i = 0; i < op->type_args.size(); ++i) { + this->PrintType(type_args[i], stream_); + if (i + 1 != type_args.size()) { + stream_ << ", "; + } + } + stream_ << ">"; + } + + stream_ << "("; for (size_t i = 0; i < args.size(); ++i) { stream_ << args[i]; if (i + 1 != args.size()) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 0cbce833aed9..87fdb1c0ffba 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -61,6 +61,17 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") .set_body_typed&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); +struct ResolvedTypeInfo { + explicit ResolvedTypeInfo(Type checked_type, Array type_args) + : checked_type(checked_type), type_args(type_args) {} + ResolvedTypeInfo() {} + + Type checked_type; + // Only allocated when the expression is a call. + + Array type_args = Array(NodePtr(nullptr)); +}; + // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -87,7 +98,8 @@ class TypeInferencer : private ExprFunctor { Environment env_; // map from expression to checked type // type inferencer will populate it up - std::unordered_map type_map_; + std::unordered_map type_map_; + // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -111,11 +123,12 @@ class TypeInferencer : private ExprFunctor { // will call visit to deduce it if it is not in the type_map_ Type GetType(const Expr &expr) { auto it = type_map_.find(expr); - if (it != type_map_.end()) { - return it->second; + if (it != type_map_.end() && it->second.checked_type.defined()) { + return it->second.checked_type; } Type ret = this->VisitExpr(expr); - type_map_[expr] = ret; + ResolvedTypeInfo& rti = type_map_[expr]; + rti.checked_type = ret; return ret; } @@ -176,7 +189,7 @@ class TypeInferencer : private ExprFunctor { } CHECK(!type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var] = vtype; + type_map_[op->var].checked_type = vtype; return GetType(op->body); } @@ -224,6 +237,7 @@ class TypeInferencer : private ExprFunctor { subst_map.Set(ty_param, fresh); ty_args->push_back(fresh); } + Type ret_type = fn_ty->ret_type; // If the function type is incomplete, place a new IncompleteType @@ -234,6 +248,7 @@ class TypeInferencer : private ExprFunctor { if (!ret_type.defined()) { ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); @@ -241,49 +256,74 @@ class TypeInferencer : private ExprFunctor { return Downcast(inst_ty); } + void AddTypeArgs(const Expr& expr, Array type_args) { + auto type_info = type_map_.find(expr); + if (type_info == type_map_.end()) { + type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)}); + } else { + CHECK(!type_info->second.type_args.defined()); + type_info->second.type_args = type_args; + } + } + // Handle general call node. - Type GeneralCall(const CallNode* op, Array arg_types) { - Type ftype = GetType(op->op); + Type GeneralCall(const CallNode* call, Array arg_types) { + Type ftype = GetType(call->op); auto* fn_ty_node = ftype.as(); + CHECK(fn_ty_node != nullptr) << "only expressions with function types can be called, at " - << op->span; + << call->span; Array type_args; FuncType fn_ty = Instantiate(fn_ty_node, &type_args); + + AddTypeArgs(GetRef(call), type_args); + size_t type_arity = fn_ty->arg_types.size(); size_t number_of_args = arg_types.size(); if (type_arity != number_of_args) { if (type_arity < number_of_args) { - LOG(FATAL) << "the function is provided too many arguments " << op->span; + LOG(FATAL) << "the function is provided too many arguments " << call->span; } else { - LOG(FATAL) << "the function is provided too few arguments" << op->span; + LOG(FATAL) << "the function is provided too few arguments" << call->span; } } + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span); + this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span); } for (auto cs : fn_ty->type_constraints) { - solver_.AddConstraint(cs); + if (auto tr = cs.as()) { + solver_.AddConstraint( + TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs)); + } else { + solver_.AddConstraint(cs); + } } + return fn_ty->ret_type; } - Type VisitExpr_(const CallNode* op) final { - // Fast path: well-formed primitive op + Type VisitExpr_(const CallNode* call) final { Array arg_types; - for (Expr arg : op->args) { + for (Expr arg : call->args) { arg_types.push_back(GetType(arg)); } - if (const OpNode* opnode = op->op.as()) { + + if (const OpNode* opnode = call->op.as()) { Type rtype = PrimitiveCall(opnode->op_type.as(), arg_types, - op->attrs); - if (rtype.defined()) return rtype; + call->attrs); + if (rtype.defined()) { + AddTypeArgs(GetRef(call), arg_types); + return rtype; + } } - return GeneralCall(op, arg_types); + + return GeneralCall(call, arg_types); } Type VisitExpr_(const FunctionNode* f) final { @@ -312,7 +352,7 @@ class TypeInferencer : private ExprFunctor { class TypeInferencer::Resolver : public ExprMutator { public: - Resolver(const std::unordered_map& tmap, + Resolver(const std::unordered_map& tmap, TypeSolver* solver) : tmap_(tmap), solver_(solver) { } @@ -362,7 +402,7 @@ class TypeInferencer::Resolver : public ExprMutator { Expr AttachCheckedType(const T* op) { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); - Type checked_type = solver_->Resolve(it->second); + Type checked_type = solver_->Resolve(it->second.checked_type); CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; @@ -376,25 +416,37 @@ class TypeInferencer::Resolver : public ExprMutator { } new_e->checked_type_ = checked_type; } + + if (it->second.type_args.defined()) { + Call call = Downcast(new_e); + const CallNode* const_call_ref = call.operator->(); + CallNode* call_ref = const_cast(const_call_ref); + call_ref->type_args = it->second.type_args; + + for (size_t i = 0; i < call->type_args.size(); i++) { + call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i])); + } + } + return new_e; } - Type VisitType(const Type& t) final { + Type VisitType(const Type &t) final { return solver_->Resolve(t); } private: - const std::unordered_map& tmap_; + const std::unordered_map& tmap_; TypeSolver* solver_; }; Expr TypeInferencer::Infer(Expr expr) { - // step 0: populate the constraints + // Step 0: Populate the constraints. GetType(expr); - // step 1: solve the constraints + // Step 1: Solve the constraints. solver_.Solve(); - // step 2: attach resolved types to checked_type field + // Step 2: Attach resolved types to checked_type field. return Resolver(type_map_, &solver_).VisitExpr(expr); } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 2d8f98974639..e1d749e75863 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -91,6 +91,21 @@ def test_free_expr(): yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.scalar_type("float32") +def test_type_args(): + x = relay.var("x", shape=(10, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + ty_z = relay.ir_pass.infer_type(z) + ty_args = ty_z.type_args + assert len(ty_args) == 2 + assert ty_args[0].dtype == "float32" + assert ty_args[1].dtype == "float32" + sh1 = ty_args[0].shape + sh2 = ty_args[1].shape + assert sh1[0].value == 10 + assert sh1[1].value == 10 + assert sh2[0].value == 1 + assert sh2[1].value == 10 if __name__ == "__main__": test_free_expr() @@ -100,3 +115,5 @@ def test_free_expr(): test_decl() test_recursion() test_tuple() + test_free_expr() + test_type_args()