Skip to content

Commit

Permalink
[Relay] Add pass for getting calibration data from a relay module (ap…
Browse files Browse the repository at this point in the history
…ache#5997)

* add simple pass to extract outputs

* complete pass that collects all function inputs/outputs

* add analysis pass for collecting outputs

* reorganize the files

* add the first test

* update test with tuples

* clean up Python code

* merge with upstream

* clean up transform.py

* add comments for cpp files

* fix lint issues

* update submodules

* modify files according to the review

* fix style and typo

* fix lint error

* add checks for repeated function calls

* fix lint error

* merge review comments

* small simplification

* revise the code according to the review comments

* add username in TODO

* use IRModule directly

* use better APIs according to the review

* apply comments from the reviewer

* retrigger ci
  • Loading branch information
seanlatias authored and trevor-m committed Jul 14, 2020
1 parent a5a88d0 commit 343232d
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,24 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
*/
TVM_DLL std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body);

/*!
* \brief Get the updated module for collecting calibration data.
*
* \param mod The module to be updated.
*
* \return The updated module.
*/
TVM_DLL IRModule GetCalibrateModule(IRModule mod);

/*!
* \brief Get the output map between subgrpahs and its inputs/output.
*
* \param mod The module for running calibration.
*
* \return The mapping between a subgraph name and its postition in the output tuple.
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

} // namespace relay
} // namespace tvm

Expand Down
48 changes: 48 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
configuring the passes and scripting them in Python.
"""
from tvm.ir import IRModule
from tvm.relay import transform, build_module
from tvm.runtime.ndarray import cpu

from . import _ffi_api
from .feature import Feature
Expand Down Expand Up @@ -351,3 +353,49 @@ def search_fc_transpose(expr):
"""
ret = _ffi_api.search_fc_transpose(expr)
return ret


def get_calibration_data(mod, data):
"""Get the calibration data of a given relay graph
This pass uses the graph runtime to get the calibration data of a module, which
includes the input and output values of each function. The returned data uses
the GlobalVar of each function as a key. Users can further access the inputs and
outputs by using `inputs` or `outputs` as the key.
Following are some limitations:
1. The input module (graph) cannot have control flows.
2. The input arguments of each function cannot be tuples (outputs can be tuples).
3. We only handle top-level functions (i.e., nested function is not handled).
4. We only handle functions with `Compiler` attribute being set.
Parameters
----------
mod : tvm.IRModule
The input module for collecting the calibration data
data : Dict[str, NDArray]
The input data for running the module
Returns
-------
data : Dict[tvm.relay.GlobalVar, Dict[str, NDArray]]
"""
output_map = _ffi_api.get_calibrate_output_map(mod)

mod = _ffi_api.get_calibrate_module(mod)
mod = transform.Inline()(mod)

ref_ex = build_module.create_executor("graph", mod=mod, ctx=cpu(0))
ref_res = ref_ex.evaluate()(**data)

calib_data = {}
for gvar, indices in output_map.items():
offset = int(indices[0])
in_len = int(indices[1])
out_len = int(indices[2])
value = {"inputs": ref_res[offset:offset + in_len],
"outputs": ref_res[offset + in_len:offset + in_len + out_len]}
calib_data[gvar] = value

return calib_data
202 changes: 202 additions & 0 deletions src/relay/analysis/get_calibration_data.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/analysis/get_calibration_data.cc
*
* \brief To get the calibration data, we need to perform two
* steps. First, we need to prepare the module that generates
* the tensor values (GetCalibrateModule). Second, we need to
* generate the mapping between the values and the functions
* (GetCalibrateOutputMap).
*/

#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>

namespace tvm {
namespace relay {

/*!
* \brief This function returns a module that will be used by
* the relay graph runtime for collecting the calibration data.
* To do that, we first make all inputs and outputs of each
* function into the final output (i.e., the final output is a
* tuple of tensors). Then, we change the compiler attribute of
* each function. Finally, we mark all function to be inlined.
*/

class Collector : public ExprRewriter {
public:
explicit Collector(const IRModule& module) : module_(module) {}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
// check if the function implementation is available
// intrinsic functions are excluded for now
if (call->op->IsInstance<GlobalVarNode>()) {
auto var = Downcast<GlobalVar>(call->op);
CHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
// we only handle functions with Compiler attribute set
auto func = Downcast<Function>(module_->Lookup(var));
if (func->GetAttr<String>(attr::kCompiler)) {
// collect all the inputs and outputs
for (const auto& it : call->args) new_outputs_.push_back(it);
new_outputs_.push_back(post);
}
}
return post;
}

Array<Expr> GetNewOutputs() { return new_outputs_; }

private:
const IRModule& module_;
Array<Expr> new_outputs_;
};

Expr FlattenOutputTuple(const Array<Expr>& exprs) {
Array<Expr> fields;
for (const auto& it : exprs) {
CHECK(it->checked_type_.defined());
if (auto* tn = it->checked_type_.as<TupleTypeNode>()) {
// TODO(seanlatias): for now input argument cannot be a tuple
CHECK(it->IsInstance<CallNode>());
for (size_t i = 0; i < tn->fields.size(); i++) {
fields.push_back(TupleGetItem(it, i));
}
} else {
fields.push_back(it);
}
}
return Tuple(fields);
}

IRModule GetCalibrateModule(IRModule module) {
auto glob_funcs = module->functions;
// module is mutable, hence, we make a copy of it.
module.CopyOnWrite();
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
// we only collect the outputs for main function
if (pair.first->name_hint == "main") {
Collector collector(module);
PostOrderRewrite(func->body, &collector);
auto new_outputs = collector.GetNewOutputs();
Expr tuple = FlattenOutputTuple(new_outputs);
func = Function(func->params, tuple, tuple->checked_type_, func->type_params, func->attrs);
module->Update(pair.first, func);
}
}
}
// reset the attribute of functions for running graph runtime
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
if (func->GetAttr<String>(attr::kCompiler)) {
// we need to inline the functions in order to run grpah runtime
func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1));
// reset the compiler attribute to null for llvm execution
func = WithAttr(std::move(func), attr::kCompiler, NullValue<ObjectRef>());
module->Update(pair.first, func);
}
}
}
return module;
}

/*!
* \brief This function generates the output mapping between
* the calibration data and each function. The key is a
* GlobalVar that corresponds to each function and the value
* is an array of integers. The size of the array is always
* three. The first value is the offset the points to the start.
* The second value is the number of inputs. The third value
* is the number of outputs.
*/

class OutputMapper : public ExprRewriter {
public:
OutputMapper(Map<GlobalVar, Array<Integer>>* output_map, const IRModule& module, size_t* offset)
: output_map_(output_map), module_(module), offset_(offset) {}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (call->op->IsInstance<GlobalVarNode>()) {
auto var = Downcast<GlobalVar>(call->op);
CHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
CHECK_EQ(output_map_->count(var), 0)
<< "Repeated function call " << var << " is not supported.";
auto func = Downcast<Function>(module_->Lookup(var));
// we only handle functions with Compiler attribute set
if (func->GetAttr<String>(attr::kCompiler)) {
Array<Integer> info;
// the first value is the offset
info.push_back(Integer(*offset_));
// the second value is the number of inputs
info.push_back(Integer(call->args.size()));
// the third value is the number of outputs
// we need to check if the output is a tuple
size_t out_size = 1;
if (auto* tn = func->body.as<TupleNode>()) {
info.push_back(Integer(tn->fields.size()));
out_size = tn->fields.size();
} else {
info.push_back(Integer(1));
}
output_map_->Set(var, info);
// calculate the offset for the next function
*offset_ = *offset_ + call->args.size() + out_size;
}
}
return post;
}

private:
Map<GlobalVar, Array<Integer>>* output_map_;
const IRModule& module_;
size_t* offset_;
};

Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& module) {
Map<GlobalVar, Array<Integer>> output_map;
size_t offset = 0;
auto glob_funcs = module->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
if (pair.first->name_hint == "main") {
OutputMapper output_mapper(&output_map, module, &offset);
auto func = GetRef<Function>(fn);
PostOrderRewrite(func->body, &output_mapper);
}
}
}

return output_map;
}

TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_module").set_body_typed([](IRModule mod) {
return GetCalibrateModule(mod);
});

TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_output_map")
.set_body_typed([](const IRModule& mod) { return GetCalibrateOutputMap(mod); });

} // namespace relay
} // namespace tvm
105 changes: 105 additions & 0 deletions tests/python/relay/test_analysis_get_calibration_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np

import tvm
import tvm.relay.testing
from tvm import relay
from tvm.relay import transform
from tvm.relay.analysis import get_calibration_data


def check_data_size(mod, data):
assert len(data) == len(mod.functions) - 1
for key, value in mod.functions.items():
if key.name_hint != "main":
assert len(data[key]["inputs"]) == len(value.params)
if isinstance(value.body, relay.Tuple):
assert len(data[key]["outputs"]) == len(value.body.fields)
else:
assert len(data[key]["outputs"]) == 1

def test_simple_graph():
# A module with two subgraphs
mod = tvm.IRModule()

x0 = relay.var('x0', shape=(8, 8))
y0 = relay.var('y0', shape=(8, 8))
z0 = x0 + y0
z1 = x0 - y0
z2 = relay.Tuple((z0, z1))
f0 = relay.Function([x0, y0], z2)
f0 = f0.with_attr("Compiler", "test_graph")
g0 = relay.GlobalVar("g0")
mod[g0] = f0

x1 = relay.var('x1', shape=(8, 8))
y1 = relay.var('y1', shape=(8, 8))
z1 = x1 - y1
f1 = relay.Function([x1, y1], z1)
f1 = f1.with_attr("Compiler", "test_graph")
g1 = relay.GlobalVar("g1")
mod[g1] = f1


x = relay.var('x', shape=(8, 8))
y = relay.var('y', shape=(8, 8))
z = relay.var('z', shape=(8, 8))
c0 = relay.Call(g0, [x, y])
c1 = relay.Call(g1, [relay.TupleGetItem(c0, 0), z])
fm = relay.Function([x, y, z], c1)
mod["main"] = fm

x_data = np.random.rand(8, 8).astype('float32')
y_data = np.random.rand(8, 8).astype('float32')
z_data = np.random.rand(8, 8).astype('float32')
data = get_calibration_data(mod, {"x": x_data, "y": y_data, "z": z_data})

# Check the number and orders
check_data_size(mod, data)
tvm.testing.assert_allclose(data[g0]["inputs"][0].asnumpy(), x_data)
tvm.testing.assert_allclose(data[g0]["inputs"][1].asnumpy(), y_data)
tvm.testing.assert_allclose(data[g0]["outputs"][0].asnumpy(), x_data + y_data)
tvm.testing.assert_allclose(data[g0]["outputs"][1].asnumpy(), x_data - y_data)
tvm.testing.assert_allclose(data[g1]["inputs"][0].asnumpy(), x_data + y_data)
tvm.testing.assert_allclose(data[g1]["inputs"][1].asnumpy(), z_data)
tvm.testing.assert_allclose(data[g1]["outputs"][0].asnumpy(), x_data + y_data - z_data)

def test_mobilenet_dnnl():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return

dtype = 'float32'
ishape = (1, 3, 224, 224)
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')

mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)

i_data = np.random.uniform(0, 1, ishape).astype(dtype)
data = get_calibration_data(mod, {"data": i_data, **params})

# Check the number and orders
check_data_size(mod, data)

if __name__ == "__main__":
test_simple_graph()
test_mobilenet_dnnl()

0 comments on commit 343232d

Please sign in to comment.