Skip to content

Commit

Permalink
[Relay] Fix format (#1957)
Browse files Browse the repository at this point in the history
* save

* fix format
  • Loading branch information
MarisaKirisame authored and yzhliu committed Oct 23, 2018
1 parent 390acc5 commit 09efdc9
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 56 deletions.
6 changes: 3 additions & 3 deletions src/relay/ir/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) {
}

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) {
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
p->stream << "SourceName(" << node->name << ", " << node << ")";
});

Expand All @@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode);

TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SpanNode::make(args[0], args[1], args[2]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode *node, tvm::IRPrinter *p) {
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
});
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) {
return (*it).second;
}

Function EnvironmentNode::Lookup(const std::string &name) {
Function EnvironmentNode::Lookup(const std::string& name) {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}

void EnvironmentNode::Update(const Environment &env) {
void EnvironmentNode::Update(const Environment& env) {
for (auto pair : env->functions) {
this->Update(pair.first, pair.second);
}
Expand Down
34 changes: 17 additions & 17 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode);

TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstantNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode *node, tvm::IRPrinter *p) {
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
p->stream << "Constant(TODO)";
});

Expand All @@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode);

TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) {
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
p->stream << "Tuple(" << node->fields << ")";
});

Expand All @@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode);

TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VarNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
p->stream << "Var(" << node->name_hint;
if (node->type_annotation.defined()) {
p->stream << ", ty=";
Expand All @@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode);

TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GlobalVarNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, tvm::IRPrinter *p) {
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
p->stream << "GlobalVar(" << node->name_hint << ")";
});

Expand Down Expand Up @@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const {
TVM_REGISTER_NODE_TYPE(FunctionNode);

TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode *node,
tvm::IRPrinter *p) {
.set_dispatch<FunctionNode>([](const FunctionNode* node,
tvm::IRPrinter* p) {
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ")";
});
Expand All @@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode);

TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode *node, tvm::IRPrinter *p) {
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});
Expand All @@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode);

TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LetNode::make(args[0], args[1], args[2]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) {
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")";
});
Expand All @@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {

TVM_REGISTER_NODE_TYPE(IfNode);

TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IfNode::make(args[0], args[1], args[2]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) {
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")";
});
Expand Down
28 changes: 14 additions & 14 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
TVM_REGISTER_NODE_TYPE(TensorTypeNode);

TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});

Expand All @@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);

TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret =
TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const TypeVarNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TypeVarNode>([](const TypeVarNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
});
Expand Down Expand Up @@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode);

TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "FuncTypeNode(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")";
Expand All @@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode);

TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, tvm::IRPrinter *p) {
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
p->stream << "TypeRelationNode("
<< node->func->name
<< ", " << node->args << ")";
Expand All @@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
TVM_REGISTER_NODE_TYPE(TupleTypeNode);

TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleTypeNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "TupleTypeNode(" << node->fields << ")";
});

Expand Down
11 changes: 8 additions & 3 deletions src/relay/pass/alpha_eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
};

bool AlphaEqual(const Type& t1, const Type& t2) {
if (t1.defined() != t2.defined())
if (t1.defined() != t2.defined()) {
return false;
}

if (!t1.defined())
if (!t1.defined()) {
return true;
}

TypeAlphaEq aeq;
aeq.VisitType(t1, t2);
Expand Down Expand Up @@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
for (size_t i = 0; i < func1->params.size(); ++i) {
MergeVarDecl(func1->params[i], func2->params[i]);
}
if (!equal) return;

if (!equal) {
return;
}

for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
Expand Down
6 changes: 3 additions & 3 deletions src/relay/pass/kind_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> {

// checks if t is an incomplete node of kind k or a type param of kind k
bool MatchKind(const Type& t, Kind k) {
if (const IncompleteTypeNode *tv = t.as<IncompleteTypeNode>()) {
if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) {
return tv->kind == k;
}

if (const TypeVarNode *tp = t.as<TypeVarNode>()) {
if (const TypeVarNode* tp = t.as<TypeVarNode>()) {
return tp->kind == k;
}

Expand Down Expand Up @@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> {
}
}

bool Check(const Type &t) {
bool Check(const Type& t) {
this->VisitType(t);
return valid;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator {
return new_e;
}

Type VisitType(const Type &t) final {
Type VisitType(const Type& t) final {
return solver_->Resolve(t);
}

Expand Down
14 changes: 7 additions & 7 deletions src/relay/pass/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ namespace relay {

class FreeVar;
class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars) :
std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { }

void VisitType_(const TypeVarNode* tp) final {
Expand Down Expand Up @@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> {
};

class FreeVar : public ExprVisitor {
void VisitExpr_(const VarNode *v) final {
void VisitExpr_(const VarNode* v) final {
auto var = GetRef<Var>(v);
if (bound_vars.count(var) == 0) {
free_vars.insert(var);
Expand All @@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor {
}
}

void VisitExpr_(const FunctionNode *f) final {
void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) {
bound_types.insert(tp);
}
Expand All @@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor {
VisitType(f->ret_type);
}

void VisitExpr_(const LetNode *l) final {
void VisitExpr_(const LetNode* l) final {
bound_vars.insert(l->var);
VisitExpr(l->value);
VisitExpr(l->body);
Expand Down
12 changes: 6 additions & 6 deletions src/relay/pass/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,36 @@ class WellFormedChecker : private ExprVisitor {

std::unordered_set<Var, NodeHash, NodeEqual> s;

void Check(const Var & v) {
void Check(const Var& v) {
if (s.count(v) != 0) {
well_formed = false;
}
s.insert(v);
}

void VisitExpr_(const LetNode * l) final {
void VisitExpr_(const LetNode* l) final {
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var);
CheckWellFormed(l->value);
CheckWellFormed(l->body);
}

void VisitExpr_(const FunctionNode * f) final {
for (const Var & param : f->params) {
void VisitExpr_(const FunctionNode* f) final {
for (const Var& param : f->params) {
Check(param);
}
CheckWellFormed(f->body);
}

public:
bool CheckWellFormed(const Expr & e) {
bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e);
return well_formed;
}
};

bool WellFormed(const Expr & e) {
bool WellFormed(const Expr& e) {
return WellFormedChecker().CheckWellFormed(e);
}

Expand Down

0 comments on commit 09efdc9

Please sign in to comment.