Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][RFC] Implement type checking for Any #3221

Merged
merged 6 commits into from
Jul 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* _type_key = "Reduce";
};

/*! \brief Any shape. */
struct Any : public ExprNode<Any> {
TVM_DLL static Expr make();

void VisitAttrs(AttrVisitor* v) final {}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Any";
};

/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {

/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start;
tvm::Expr stop;
tvm::Expr step;
Expr start;
Expr stop;
Expr step;
DataType dtype;

TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing the defaults?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't easily make Relay constant values in C++.

TVM_ATTR_FIELD(start)
.describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
TVM_ATTR_FIELD(step)
.describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
TVM_ATTR_FIELD(dtype)
.describe("Target data type.");
}
}; // struct ArangeAttrs
Expand Down
10 changes: 6 additions & 4 deletions include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ struct RelayErrorStream {

struct Error : public dmlc::Error {
Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
Error() : dmlc::Error(""), sp(nullptr) {}
};

/*! \brief An abstraction around how errors are stored and reported.
Expand Down Expand Up @@ -118,7 +119,8 @@ class ErrorReporter {
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err));
std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg));
}

/*! \brief Report an error against a program, using the full program
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const {
return node;
}

/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node);
jroesch marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc<
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
const Expr& output_grad)>;

/*!
* \brief The codegeneration strategy for dynamic dimensions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code generation

*/
enum AnyCodegenStrategy {
/*! \brief The default strategy of using completely variable dimensions. */
kVariableDimensions
jroesch marked this conversation as resolved.
Show resolved Hide resolved
};

/* \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>;

using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<Shape>& out_shapes)>;

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
3 changes: 3 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
namespace tvm {
namespace relay {

using Any = tvm::ir::Any;

/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
Expand Down Expand Up @@ -384,6 +386,7 @@ class TypeReporterNode : public Node {
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;

/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ class NDArray {
TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);

TVM_DLL std::vector<int64_t> Shape() const;

// internal namespace
struct Internal;
protected:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def get_last_ffi_error():
"""
c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type.startswith("tvm.error."):
if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:]
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ def extern(shape,
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
TensorType = ty.TensorType
Kind = ty.Kind
TypeVar = ty.TypeVar
ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
Expand All @@ -71,6 +72,7 @@
RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall
Any = ty.Any

# Expr
Expr = expr.Expr
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def const(value, dtype=None):
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)

if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = {
Expand All @@ -578,6 +579,7 @@ def const(value, dtype=None):
}.get(value.dtype, None)
if map_dtype:
value = value.astype(map_dtype)

if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,9 @@ def _mx_arange(inputs, attrs):
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)

Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,9 +1059,9 @@ def _impl(inputs, attr, params):
return AttrCvt(
op_name="arange",
ignores=['Tidx'],
extras={'start': start,
"stop": limit,
'step': delta,
extras={'start': _expr.const(start),
"stop": _expr.const(limit),
'step': _expr.const(delta),
'dtype': dtype})([], attr)
return _impl

Expand Down Expand Up @@ -1269,8 +1269,8 @@ def _impl(inputs, attr, params):
crop = crops[axis - 1]
if crop != [0, 0]:
indices = tvm.relay.arange(
crop[0],
reshaped_permuted_shape[axis] - crop[1],
_expr.const(crop[0]),
_expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype='int32'
)
cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
Expand Down
65 changes: 65 additions & 0 deletions python/tvm/relay/loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
Utilities for building Relay loops.
"""
from .scope_builder import ScopeBuilder
from . import expr as _expr

def while_loop(cond, loop_vars, loop_bodies):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
"""
Construct a while loop.

Parameters
----------

cond: Callable[Tuple[relay.Expr], relay.Expr]
The condition of the loop.

loop_vars: Tuple[relay.Expr]
The variables being looped over.
The initial values of the loop, will be used to
construct the loop variables.

loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]]
The body of the loop, should be a function which
given loop variables produces the output result
also as a tuple

Returns
-------
loop: relay.Expr
The loop expression.
"""
sb = ScopeBuilder()
loop = _expr.Var("while_loop")
fresh_vars = []

for i, loop_var in enumerate(loop_vars):
name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i)
new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
fresh_vars.append(new_var)

with sb.if_scope(cond(*fresh_vars)):
sb.ret(loop(*loop_bodies(*fresh_vars)))
with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars))

func = _expr.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop)
return let
10 changes: 7 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Transform operators."""

from . import _make
from ..expr import TupleWrapper
from ..expr import TupleWrapper, const


def cast(data, dtype):
Expand Down Expand Up @@ -272,7 +272,7 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value)


def arange(start, stop=None, step=1, dtype="float32"):
def arange(start, stop=None, step=None, dtype="float32"):
"""Return evenly spaced values within a given interval.

.. note::
Expand Down Expand Up @@ -310,9 +310,13 @@ def arange(start, stop=None, step=1, dtype="float32"):
relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4]
"""
if step is None:
step = const(1, dtype)

if stop is None:
stop = start
start = 0
start = const(0, dtype=dtype)

return _make.arange(start, stop, step, dtype)


Expand Down
19 changes: 18 additions & 1 deletion python/tvm/relay/scope_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __exit__(self, ptype, value, trace):
else:
self._exit_cb()


def _make_lets(bindings, ret_value):
"""Make a nested let expressions.

Expand Down Expand Up @@ -176,6 +175,24 @@ def _on_exit():
false_branch)
return WithScope(None, _on_exit)


def type_of(self, expr):
"""
Compute the type of an expression.

Parameters
----------
expr: relay.Expr
The expression to compute the type of.
"""
if isinstance(expr, _expr.Var):
return expr.type_annotation

ity = _ty.IncompleteType()
var = _expr.var("unify", ity)
self.let(var, expr)
return ity

def ret(self, value):
"""Set the return value of this scope.

Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .base import RelayNode, register_relay_node
from . import _make

Any = _make.Any

class Type(RelayNode):
"""The base type for all Relay types."""
Expand Down Expand Up @@ -137,6 +138,19 @@ def __init__(self, var, kind=Kind.Type):
"""
self.__init_handle_by_constructor__(_make.TypeVar, var, kind)

def ShapeVar(name):
"""A helper which constructs a type var of which the shape kind.

Parameters
----------
name : str

Returns
-------
type_var : tvm.relay.TypeVar
The shape variable.
"""
return TypeVar(name, kind=Kind.ShapeVar)

@register_relay_node
class GlobalTypeVar(Type):
Expand Down
4 changes: 3 additions & 1 deletion src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
op->call_type == Call::PureExtern) {
return CreateCallExtern(op);
} else {
LOG(FATAL) << "Unknown call type ";
LOG(FATAL) << "Unknown call type " <<
"name= " << op->name <<
" call_type= " << op->call_type;
return nullptr;
}
}
Expand Down
Loading