forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial version of evaluator and tests WIP Work towards simple examples in the evaluator Requires implementation of lowering ops and monomorph Evaluator now works on simple cases Restore Function case in Evaluator WIP Fix rebase issues working towards working version RTS is now working again RTS can add numbers now Fix some rebase issues Fix up tests post rebase WIP Issue type checking MLP Remove dead file Clean up evaluator Remove accidental change Reset changes from apache#1962 Rename Evaluator A little clean up WIP Clean up tests WIP WIP Repair from rebase and refactor to not use MM Remove testing code which is now in apache#1969 WIP
- Loading branch information
Showing
21 changed files
with
1,671 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file tvm/relay/interpreter.h | ||
* \brief An interpreter for Relay. | ||
* | ||
* This file implements a simple reference interpreter for Relay programs. | ||
* Given a Relay environment, an a Relay expression it produces a value. | ||
* | ||
* This is intended as an implementation of the reference semantics for | ||
* the Relay IR, as well as for debugging and testing. | ||
*/ | ||
#ifndef TVM_RELAY_INTERPRETER_H_ | ||
#define TVM_RELAY_INTERPRETER_H_ | ||
|
||
#include <tvm/relay/environment.h> | ||
#include <tvm/relay/expr.h> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! | ||
* \brief A Relay value. | ||
*/ | ||
class Value; | ||
|
||
/*! \brief Evaluate an expression using the interpreter producing a value. | ||
* | ||
* This implements the reference semantics of Relay, giving us a tool | ||
* for debugging and testing, especially in the development of alternative | ||
* backends/runtimes. | ||
* | ||
* The resulting value can be passed to Python, making it easy to use | ||
* for testing. | ||
* | ||
* The interpreter interprets the program pieces between TVM operators | ||
* using TVM to back all Relay operator's evaluation. | ||
* | ||
* This is not intended to be an efficient implementation of Relay's | ||
* semantics, eventually the TVM runtime will grow to support Relay's | ||
* features. | ||
*/ | ||
Value Evaluate(Environment env, Expr e); | ||
|
||
/*! \brief The base container type of Relay values. */ | ||
class ValueNode : public RelayNode { | ||
public: | ||
static constexpr const char* _type_key = "relay.Value"; | ||
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); | ||
}; | ||
|
||
class Value : public NodeRef { | ||
public: | ||
Value() {} | ||
explicit Value(NodePtr<Node> n) : NodeRef(n) {} | ||
const ValueNode* operator->() const { | ||
return static_cast<const ValueNode*>(node_.get()); | ||
} | ||
|
||
using ContainerType = ValueNode; | ||
}; | ||
|
||
/*! \brief A Relay closure, i.e a scope and a function. */ | ||
class Closure; | ||
|
||
/*! \brief The container type of Closures. */ | ||
class ClosureNode : public ValueNode { | ||
public: | ||
/*! \brief The set of free variables in the closure. | ||
* | ||
* These are the captured variables which are required for | ||
* evaluation when we call the closure. | ||
*/ | ||
tvm::Map<Var, Value> env; | ||
/*! \brief The function which implements the closure. | ||
* | ||
* \note May reference the variables contained in the env. | ||
*/ | ||
Function func; | ||
|
||
ClosureNode() {} | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("env", &env); | ||
v->Visit("func", &func); | ||
} | ||
|
||
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func); | ||
|
||
static constexpr const char* _type_key = "relay.Closure"; | ||
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); | ||
|
||
/*! \brief A tuple value. */ | ||
class TupleValue; | ||
|
||
/*! \brief Tuple (x, ... y). */ | ||
struct TupleValueNode : ValueNode { | ||
tvm::Array<Value> fields; | ||
|
||
TupleValueNode() {} | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } | ||
|
||
TVM_DLL static TupleValue make(tvm::Array<Value> value); | ||
|
||
static constexpr const char* _type_key = "relay.TupleValue"; | ||
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value); | ||
|
||
/*! \brief A tensor value. */ | ||
class TensorValue; | ||
|
||
/*! \brief The tensor value container, wrapping an NDArray. */ | ||
struct TensorValueNode : ValueNode { | ||
runtime::NDArray data; | ||
|
||
TensorValueNode() {} | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } | ||
|
||
/*! \brief Build a value from an NDArray. */ | ||
TVM_DLL static TensorValue make(runtime::NDArray data); | ||
|
||
/*! \brief Construct an empty tensor value from t. */ | ||
TVM_DLL static TensorValue FromType(const Type& t); | ||
|
||
static constexpr const char* _type_key = "relay.TensorValue"; | ||
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); | ||
|
||
|
||
} // namespace relay | ||
} // namespace tvm | ||
#endif // TVM_RELAY_INTERPRETER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""The interface to the Evaluator exposed from C++.""" | ||
from tvm._ffi.function import _init_api | ||
|
||
_init_api("relay._eval", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from __future__ import absolute_import | ||
import numpy as np | ||
from .. import register_func, nd | ||
from .base import NodeBase, register_relay_node | ||
from . import _make | ||
from . import _eval | ||
from . import ir_pass | ||
from .expr import Call, Constant | ||
from . import const | ||
|
||
class Value(NodeBase): | ||
"""Base class of all values. | ||
""" | ||
pass | ||
|
||
@staticmethod | ||
@register_func("relay.from_scalar") | ||
def from_scalar(i, dtype=None): | ||
if dtype is None: | ||
if isinstance(i, int): | ||
dtype = 'int32' | ||
elif isinstance(i, float): | ||
dtype = 'float32' | ||
elif isinstance(i, bool): | ||
dtype = 'uint8' | ||
else: | ||
raise Exception("unable to infer dtype {0}".format(type(i))) | ||
|
||
return TensorValue(nd.array(np.array(i, dtype=dtype))) | ||
|
||
|
||
@register_relay_node | ||
class TupleValue(Value): | ||
def __init__(self, *fields): | ||
self.__init_handle_by_constructor__( | ||
_make.TupleValue, fields) | ||
|
||
def __getitem__(self, field_no): | ||
return self.fields[field_no] | ||
|
||
|
||
@register_relay_node | ||
class Closure(Value): | ||
pass | ||
|
||
|
||
@register_relay_node | ||
class TensorValue(Value): | ||
"""A Tensor value produced by the evaluator.""" | ||
|
||
def __init__(self, data): | ||
"""Allocate a new TensorValue and copy the data from `array` into | ||
the new array. | ||
""" | ||
if isinstance(data, np.ndarray): | ||
data = nd.array(data) | ||
|
||
self.__init_handle_by_constructor__( | ||
_make.TensorValue, data) | ||
|
||
def as_ndarray(self): | ||
"""Convert a Relay TensorValue into a tvm.ndarray.""" | ||
return self.data | ||
def asnumpy(self): | ||
"""Convert a Relay TensorValue into a numpy.ndarray.""" | ||
return self.data.asnumpy() | ||
|
||
def __eq__(self, other): | ||
return self.data == other.data | ||
|
||
def _arg_to_ast(arg): | ||
if isinstance(arg, TensorValue): | ||
return Constant(arg.data) | ||
elif isinstance(arg, np.ndarray): | ||
return Constant(nd.array(arg)) | ||
elif isinstance(arg, Constant): | ||
return arg | ||
else: | ||
return const(arg) | ||
|
||
def apply_passes(expr, env=None): | ||
ck_expr = ir_pass.infer_type(expr, env) | ||
fused_expr = ir_pass.fuse_ops(ck_expr, env) | ||
return fused_expr | ||
|
||
def evaluate(env, expr, *args): | ||
# assert len(args) == 0 | ||
relay_args = [] | ||
for arg in args: | ||
relay_args.append(_arg_to_ast(arg)) | ||
|
||
expr = Call(expr, relay_args) | ||
opt_expr = apply_passes(expr, env) | ||
return _eval.evaluate(env, opt_expr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,16 @@ | ||
#pylint: disable=invalid-name | ||
"""Backend compiler related feature registration""" | ||
import tvm | ||
import topi | ||
from . import register | ||
|
||
def add_compiler(attrs, inputs, output_type): | ||
assert len(inputs) == 2 | ||
return [topi.add(inputs[0], inputs[1])] | ||
|
||
def add_schedule(outputs, target): | ||
assert len(outputs) == 1 | ||
return tvm.create_schedule(outputs[0].op) | ||
|
||
register("add", "FTVMCompute", add_compiler) | ||
register("add", "FTVMSchedule", add_schedule) |
Oops, something went wrong.