Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][TypeSystem] Add support for populating type args #1962

Merged
merged 7 commits into from
Oct 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
return map_.get<ValueType>(op, def_value);
}

/*!
* \brief Check that an expression is a "primtive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
inline bool IsPrimitiveOp(const Expr& expr) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
const auto* op = expr.as<OpNode>();

if (!op) {
return false;
}

const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;

const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}

return true;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_H_
21 changes: 17 additions & 4 deletions src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,31 @@ class TextPrinter :
}

TextValue VisitExpr_(const CallNode* op) final {
// TODO(tqchen, M.K.): support generic call
// possibly through meta-data
CHECK_EQ(op->type_args.size(), 0U)
<< "generic call not yet supported";
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
for (Expr arg : op->args) {
args.emplace_back(GetValue(arg));
}
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << call_op << "(";

stream_ << id << " = " << call_op;

auto type_args = op->type_args;
jroesch marked this conversation as resolved.
Show resolved Hide resolved

if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) {
stream_ << "<";
for (size_t i = 0; i < op->type_args.size(); ++i) {
this->PrintType(type_args[i], stream_);
if (i + 1 != type_args.size()) {
stream_ << ", ";
}
}
stream_ << ">";
}

stream_ << "(";
for (size_t i = 0; i < args.size(); ++i) {
stream_ << args[i];
if (i + 1 != args.size()) {
Expand Down
104 changes: 78 additions & 26 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel);

struct ResolvedTypeInfo {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {}
ResolvedTypeInfo() {}

Type checked_type;
// Only allocated when the expression is a call.

Array<Type> type_args = Array<Type>(NodePtr<Node>(nullptr));
};

//
// The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
Expand All @@ -87,7 +98,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Environment env_;
// map from expression to checked type
// type inferencer will populate it up
std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_;
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;

// The solver used by the inferencer.
TypeSolver solver_;
// relation function
Expand All @@ -111,11 +123,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// will call visit to deduce it if it is not in the type_map_
Type GetType(const Expr &expr) {
auto it = type_map_.find(expr);
if (it != type_map_.end()) {
return it->second;
if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second.checked_type;
}
Type ret = this->VisitExpr(expr);
type_map_[expr] = ret;
ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret;
return ret;
}

Expand Down Expand Up @@ -176,7 +189,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
CHECK(!type_map_.count(op->var));
// NOTE: no scoping is necessary because var are unique in program
type_map_[op->var] = vtype;
type_map_[op->var].checked_type = vtype;
return GetType(op->body);
}

Expand Down Expand Up @@ -224,6 +237,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
subst_map.Set(ty_param, fresh);
ty_args->push_back(fresh);
}

Type ret_type = fn_ty->ret_type;

// If the function type is incomplete, place a new IncompleteType
Expand All @@ -234,56 +248,82 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
}

Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {},
fn_ty->type_constraints);
inst_ty = TypeSubst(inst_ty, subst_map);
return Downcast<FuncType>(inst_ty);
}

void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) {
type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)});
} else {
CHECK(!type_info->second.type_args.defined());
type_info->second.type_args = type_args;
}
}

// Handle general call node.
Type GeneralCall(const CallNode* op, Array<Type> arg_types) {
Type ftype = GetType(op->op);
Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
Type ftype = GetType(call->op);
auto* fn_ty_node = ftype.as<FuncTypeNode>();

CHECK(fn_ty_node != nullptr)
<< "only expressions with function types can be called, at "
<< op->span;
<< call->span;

Array<Type> type_args;
FuncType fn_ty = Instantiate(fn_ty_node, &type_args);

AddTypeArgs(GetRef<Call>(call), type_args);

size_t type_arity = fn_ty->arg_types.size();
size_t number_of_args = arg_types.size();

if (type_arity != number_of_args) {
if (type_arity < number_of_args) {
LOG(FATAL) << "the function is provided too many arguments " << op->span;
LOG(FATAL) << "the function is provided too many arguments " << call->span;
} else {
LOG(FATAL) << "the function is provided too few arguments" << op->span;
LOG(FATAL) << "the function is provided too few arguments" << call->span;
}
}

for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span);
this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span);
}

for (auto cs : fn_ty->type_constraints) {
solver_.AddConstraint(cs);
if (auto tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs));
} else {
solver_.AddConstraint(cs);
}
}

return fn_ty->ret_type;
}

Type VisitExpr_(const CallNode* op) final {
// Fast path: well-formed primitive op
Type VisitExpr_(const CallNode* call) final {
Array<Type> arg_types;
for (Expr arg : op->args) {
for (Expr arg : call->args) {
arg_types.push_back(GetType(arg));
}
if (const OpNode* opnode = op->op.as<OpNode>()) {

if (const OpNode* opnode = call->op.as<OpNode>()) {
Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
arg_types,
op->attrs);
if (rtype.defined()) return rtype;
call->attrs);
if (rtype.defined()) {
AddTypeArgs(GetRef<Call>(call), arg_types);
return rtype;
}
}
return GeneralCall(op, arg_types);

return GeneralCall(call, arg_types);
}

Type VisitExpr_(const FunctionNode* f) final {
Expand Down Expand Up @@ -312,7 +352,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {

class TypeInferencer::Resolver : public ExprMutator {
public:
Resolver(const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap,
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap,
TypeSolver* solver)
: tmap_(tmap), solver_(solver) {
}
Expand Down Expand Up @@ -362,7 +402,7 @@ class TypeInferencer::Resolver : public ExprMutator {
Expr AttachCheckedType(const T* op) {
auto it = tmap_.find(GetRef<Expr>(op));
CHECK(it != tmap_.end());
Type checked_type = solver_->Resolve(it->second);
Type checked_type = solver_->Resolve(it->second.checked_type);
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << GetRef<Expr>(op)
<< " at " << op->span;
Expand All @@ -376,25 +416,37 @@ class TypeInferencer::Resolver : public ExprMutator {
}
new_e->checked_type_ = checked_type;
}

if (it->second.type_args.defined()) {
Call call = Downcast<Call>(new_e);
const CallNode* const_call_ref = call.operator->();
CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
call_ref->type_args = it->second.type_args;

for (size_t i = 0; i < call->type_args.size(); i++) {
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
}
}

return new_e;
}

Type VisitType(const Type& t) final {
Type VisitType(const Type &t) final {
return solver_->Resolve(t);
}

private:
const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap_;
const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_;
};


Expr TypeInferencer::Infer(Expr expr) {
// step 0: populate the constraints
// Step 0: Populate the constraints.
GetType(expr);
// step 1: solve the constraints
// Step 1: Solve the constraints.
solver_.Solve();
// step 2: attach resolved types to checked_type field
// Step 2: Attach resolved types to checked_type field.
return Resolver(type_map_, &solver_).VisitExpr(expr);
}

Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ def test_free_expr():
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32")

def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
ty_z = relay.ir_pass.infer_type(z)
ty_args = ty_z.type_args
assert len(ty_args) == 2
assert ty_args[0].dtype == "float32"
assert ty_args[1].dtype == "float32"
sh1 = ty_args[0].shape
sh2 = ty_args[1].shape
assert sh1[0].value == 10
assert sh1[1].value == 10
assert sh2[0].value == 1
assert sh2[1].value == 10

if __name__ == "__main__":
test_free_expr()
Expand All @@ -100,3 +115,5 @@ def test_free_expr():
test_decl()
test_recursion()
test_tuple()
test_free_expr()
test_type_args()