Skip to content

Commit

Permalink
[RELAY][TypeSystem] Add support for populating type args (#1962)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and tqchen committed Oct 24, 2018
1 parent 3a1bb8c commit 3bfa5fc
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 30 deletions.
30 changes: 30 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
return map_.get<ValueType>(op, def_value);
}

/*!
* \brief Check that an expression is a "primtive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>();

if (!op) {
return false;
}

const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;

const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}

return true;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_H_
21 changes: 17 additions & 4 deletions src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,31 @@ class TextPrinter :
}

TextValue VisitExpr_(const CallNode* op) final {
// TODO(tqchen, M.K.): support generic call
// possibly through meta-data
CHECK_EQ(op->type_args.size(), 0U)
<< "generic call not yet supported";
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
for (Expr arg : op->args) {
args.emplace_back(GetValue(arg));
}
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << call_op << "(";

stream_ << id << " = " << call_op;

auto type_args = op->type_args;

if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) {
stream_ << "<";
for (size_t i = 0; i < op->type_args.size(); ++i) {
this->PrintType(type_args[i], stream_);
if (i + 1 != type_args.size()) {
stream_ << ", ";
}
}
stream_ << ">";
}

stream_ << "(";
for (size_t i = 0; i < args.size(); ++i) {
stream_ << args[i];
if (i + 1 != args.size()) {
Expand Down
104 changes: 78 additions & 26 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel);

struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {}
ResolvedTypeInfo() {}

Type checked_type;
// Only allocated when the expression is a call.

Array<Type> type_args = Array<Type>(NodePtr<Node>(nullptr));
};

//
// The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
Expand All @@ -87,7 +98,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Environment env_;
// map from expression to checked type
// type inferencer will populate it up
std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_;
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;

// The solver used by the inferencer.
TypeSolver solver_;
// relation function
Expand All @@ -111,11 +123,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// will call visit to deduce it if it is not in the type_map_
Type GetType(const Expr &expr) {
auto it = type_map_.find(expr);
if (it != type_map_.end()) {
return it->second;
if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second.checked_type;
}
Type ret = this->VisitExpr(expr);
type_map_[expr] = ret;
ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret;
return ret;
}

Expand Down Expand Up @@ -176,7 +189,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
CHECK(!type_map_.count(op->var));
// NOTE: no scoping is necessary because var are unique in program
type_map_[op->var] = vtype;
type_map_[op->var].checked_type = vtype;
return GetType(op->body);
}

Expand Down Expand Up @@ -224,6 +237,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
subst_map.Set(ty_param, fresh);
ty_args->push_back(fresh);
}

Type ret_type = fn_ty->ret_type;

// If the function type is incomplete, place a new IncompleteType
Expand All @@ -234,56 +248,82 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
}

Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {},
fn_ty->type_constraints);
inst_ty = TypeSubst(inst_ty, subst_map);
return Downcast<FuncType>(inst_ty);
}

void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) {
type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)});
} else {
CHECK(!type_info->second.type_args.defined());
type_info->second.type_args = type_args;
}
}

// Handle general call node.
Type GeneralCall(const CallNode* op, Array<Type> arg_types) {
Type ftype = GetType(op->op);
Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
Type ftype = GetType(call->op);
auto* fn_ty_node = ftype.as<FuncTypeNode>();

CHECK(fn_ty_node != nullptr)
<< "only expressions with function types can be called, at "
<< op->span;
<< call->span;

Array<Type> type_args;
FuncType fn_ty = Instantiate(fn_ty_node, &type_args);

AddTypeArgs(GetRef<Call>(call), type_args);

size_t type_arity = fn_ty->arg_types.size();
size_t number_of_args = arg_types.size();

if (type_arity != number_of_args) {
if (type_arity < number_of_args) {
LOG(FATAL) << "the function is provided too many arguments " << op->span;
LOG(FATAL) << "the function is provided too many arguments " << call->span;
} else {
LOG(FATAL) << "the function is provided too few arguments" << op->span;
LOG(FATAL) << "the function is provided too few arguments" << call->span;
}
}

for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span);
this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span);
}

for (auto cs : fn_ty->type_constraints) {
solver_.AddConstraint(cs);
if (auto tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs));
} else {
solver_.AddConstraint(cs);
}
}

return fn_ty->ret_type;
}

Type VisitExpr_(const CallNode* op) final {
// Fast path: well-formed primitive op
Type VisitExpr_(const CallNode* call) final {
Array<Type> arg_types;
for (Expr arg : op->args) {
for (Expr arg : call->args) {
arg_types.push_back(GetType(arg));
}
if (const OpNode* opnode = op->op.as<OpNode>()) {

if (const OpNode* opnode = call->op.as<OpNode>()) {
Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
arg_types,
op->attrs);
if (rtype.defined()) return rtype;
call->attrs);
if (rtype.defined()) {
AddTypeArgs(GetRef<Call>(call), arg_types);
return rtype;
}
}
return GeneralCall(op, arg_types);

return GeneralCall(call, arg_types);
}

Type VisitExpr_(const FunctionNode* f) final {
Expand Down Expand Up @@ -312,7 +352,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {

class TypeInferencer::Resolver : public ExprMutator {
public:
Resolver(const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap,
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap,
TypeSolver* solver)
: tmap_(tmap), solver_(solver) {
}
Expand Down Expand Up @@ -362,7 +402,7 @@ class TypeInferencer::Resolver : public ExprMutator {
Expr AttachCheckedType(const T* op) {
auto it = tmap_.find(GetRef<Expr>(op));
CHECK(it != tmap_.end());
Type checked_type = solver_->Resolve(it->second);
Type checked_type = solver_->Resolve(it->second.checked_type);
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << GetRef<Expr>(op)
<< " at " << op->span;
Expand All @@ -376,25 +416,37 @@ class TypeInferencer::Resolver : public ExprMutator {
}
new_e->checked_type_ = checked_type;
}

if (it->second.type_args.defined()) {
Call call = Downcast<Call>(new_e);
const CallNode* const_call_ref = call.operator->();
CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
call_ref->type_args = it->second.type_args;

for (size_t i = 0; i < call->type_args.size(); i++) {
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
}
}

return new_e;
}

Type VisitType(const Type& t) final {
Type VisitType(const Type &t) final {
return solver_->Resolve(t);
}

private:
const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap_;
const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_;
};


Expr TypeInferencer::Infer(Expr expr) {
// step 0: populate the constraints
// Step 0: Populate the constraints.
GetType(expr);
// step 1: solve the constraints
// Step 1: Solve the constraints.
solver_.Solve();
// step 2: attach resolved types to checked_type field
// Step 2: Attach resolved types to checked_type field.
return Resolver(type_map_, &solver_).VisitExpr(expr);
}

Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ def test_free_expr():
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32")

def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
ty_z = relay.ir_pass.infer_type(z)
ty_args = ty_z.type_args
assert len(ty_args) == 2
assert ty_args[0].dtype == "float32"
assert ty_args[1].dtype == "float32"
sh1 = ty_args[0].shape
sh2 = ty_args[1].shape
assert sh1[0].value == 10
assert sh1[1].value == 10
assert sh2[0].value == 1
assert sh2[1].value == 10

if __name__ == "__main__":
test_free_expr()
Expand All @@ -100,3 +115,5 @@ def test_free_expr():
test_decl()
test_recursion()
test_tuple()
test_free_expr()
test_type_args()

0 comments on commit 3bfa5fc

Please sign in to comment.