Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
lint

update

address comment
  • Loading branch information
MarisaKirisame committed Jul 3, 2019
1 parent 882ae12 commit 761b764
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 14 deletions.
33 changes: 25 additions & 8 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -18,13 +18,14 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/expr_mutator.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include "type_functor.h"

Expand Down Expand Up @@ -400,11 +401,27 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params.size() == func->params.size()) {
return expr;
}
return FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
auto ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
std::unordered_set<Var, NodeHash, NodeEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
}
for (const auto& v : FreeVars(ret)) {
if (set.count(v) == 0) {
new_params.push_back(v);
}
}
ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return ret;
} else {
return ExprBinder(args_map).Mutate(expr);
}
Expand Down
36 changes: 35 additions & 1 deletion src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,46 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second;
}

template<typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
tvm::Array<T> ret(l);
for (const T& t : r) {
ret.push_back(t);
}
return ret;
}

void ModuleNode::Add(const GlobalVar& var,
const Function& f,
bool update) {
Function func = Downcast<Function>(DeDup(f));
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
auto fv = FreeVars(func);
auto ftv = FreeTypeVars(func, mod);
if (fv.size() != 0) {
LOG(WARNING)
<< "There are free variables: "
<< fv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
if (ftv.size() != 0) {
LOG(WARNING)
<< "There are free type variables: "
<< ftv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
func =
FunctionNode::make(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
// Type check the item before we add it to the module.
Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
Expand Down Expand Up @@ -174,7 +208,7 @@ Module ModuleNode::FromExpr(
if (func_node) {
func = GetRef<Function>(func_node);
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
}
mod->Add(mod->entry_func, func);
return mod;
Expand Down
12 changes: 10 additions & 2 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,16 @@ Pass QuantizeAnnotate() {

runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return FunctionNode::make(new_params,
func->body,
func->ret_type,
func->type_params,
func->attrs);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_typecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_id_type():
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
t = relay.scalar_type("float32")
b = relay.Var("b", t)
mod[mod.entry_func] = relay.Function([], make_id(b))
mod[mod.entry_func] = relay.Function([make_id, b], make_id(b))
mod = transform.InferType()(mod)
assert mod[mod.entry_func].body.checked_type == id_type(t)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_graph_tuner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_get_direct_ancestor():
visited_dict = {}
input_names = ["data"]
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)


def test_get_in_nodes():
Expand All @@ -125,7 +125,7 @@ def test_get_in_nodes():
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
out = get_in_nodes(node_list, target_ops, input_names)
expected_out = {7: [3], 3: [2, 0], 2: [0]}
expected_out = {3: [0], 4: [3, 0], 7: [4]}
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
Expand Down

0 comments on commit 761b764

Please sign in to comment.