Skip to content

Commit

Permalink
Implement type checking for Any
Browse files Browse the repository at this point in the history
Remove code generation related changes

Remove compile changes

Remove more

Remove unification hack

Add some code back that was needed, and clean up test

Refactor test cases

WIP

Implement TypeHint AST

Add test case which should fail

Remove unification changes, and fix bug with let rec

Restore unification for shapes

Improve error reporting while debugging

All examples type check

All examples type check

WIP

First version that works with hints, needs clean up

Remove dead code

Tweaks

Remove type hint

Remove unecessary type hint stuff

Remove more type hints

Clean up

Expose Any expression node

Address CR

Fix

Fix solver

Kill unecessary code

Fix

PyLint

Fix

Relocate loops

Fix license and test

Lint again

Lint again

Fix loops

Fix docstring

Fix template error

Fix compiler issue

Fix compile err

Remove more runtime changes

Restore buffer

Fix segfault

Fix

Fix arange
  • Loading branch information
jroesch committed Jul 3, 2019
1 parent 988ea2a commit d2b9d87
Show file tree
Hide file tree
Showing 33 changed files with 757 additions and 166 deletions.
9 changes: 9 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* _type_key = "Reduce";
};

/*! \brief Any shape. */
struct Any : public ExprNode<Any> {
TVM_DLL static Expr make();

void VisitAttrs(AttrVisitor* v) final {}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Any";
};

/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {

/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start;
tvm::Expr stop;
tvm::Expr step;
Expr start;
Expr stop;
Expr step;
DataType dtype;

TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
TVM_ATTR_FIELD(start)
.describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
TVM_ATTR_FIELD(step)
.describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
TVM_ATTR_FIELD(dtype)
.describe("Target data type.");
}
}; // struct ArangeAttrs
Expand Down
10 changes: 6 additions & 4 deletions include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ struct RelayErrorStream {

struct Error : public dmlc::Error {
Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
Error() : dmlc::Error(""), sp(nullptr) {}
};

/*! \brief An abstraction around how errors are stored and reported.
Expand Down Expand Up @@ -118,7 +119,8 @@ class ErrorReporter {
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err));
std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg));
}

/*! \brief Report an error against a program, using the full program
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const {
return node;
}

/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node);

/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ using FForwardRewrite = runtime::TypedPackedFunc<
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
const Expr& output_grad)>;

enum AnyCodegenStrategy {
kVariableDimensions
};

using Shape = Array<IndexExpr>;

using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<Shape>& out_shapes)>;

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
3 changes: 3 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
namespace tvm {
namespace relay {

using Any = tvm::ir::Any;

/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
Expand Down Expand Up @@ -384,6 +386,7 @@ class TypeReporterNode : public Node {
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;

/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ class NDArray {
TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);

TVM_DLL std::vector<int64_t> Shape() const;

// internal namespace
struct Internal;
protected:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def get_last_ffi_error():
"""
c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type.startswith("tvm.error."):
if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:]
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ def extern(shape,
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
TensorType = ty.TensorType
Kind = ty.Kind
TypeVar = ty.TypeVar
ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
Expand All @@ -71,6 +72,7 @@
RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall
Any = ty.Any

# Expr
Expr = expr.Expr
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def const(value, dtype=None):
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)

if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = {
Expand All @@ -578,6 +579,7 @@ def const(value, dtype=None):
}.get(value.dtype, None)
if map_dtype:
value = value.astype(map_dtype)

if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)

Expand Down
57 changes: 57 additions & 0 deletions python/tvm/relay/loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
Utilities for building Relay loops.
"""
from .scope_builder import ScopeBuilder
from . import expr as _expr

def while_loop(cond, loop_vars, loop_bodies):
"""
Construct a while loop.
cond: Callable[Tuple[relay.Expr], relay.Expr]
The condition of the loop.
loop_vars: Tuple[relay.Expr]
The variables being looped over.
The initial values of the loop, will be used to
construct the loop variables.
loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]]
The body of the loop, should be a function which
given loop variables produces the output result
also as a tuple
"""
sb = ScopeBuilder()
loop = _expr.Var("while_loop")
fresh_vars = []

for i, loop_var in enumerate(loop_vars):
name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i)
new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
fresh_vars.append(new_var)

with sb.if_scope(cond(*fresh_vars)):
sb.ret(loop(*loop_bodies(*fresh_vars)))
with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars))

func = _expr.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop)
return let
6 changes: 3 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Transform operators."""

from . import _make
from ..expr import TupleWrapper
from ..expr import TupleWrapper, const


def cast(data, dtype):
Expand Down Expand Up @@ -272,7 +272,7 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value)


def arange(start, stop=None, step=1, dtype="float32"):
def arange(start, stop=None, step=const(1, dtype="float32"), dtype="float32"):
"""Return evenly spaced values within a given interval.
.. note::
Expand Down Expand Up @@ -312,7 +312,7 @@ def arange(start, stop=None, step=1, dtype="float32"):
"""
if stop is None:
stop = start
start = 0
start = const(0, dtype=dtype)
return _make.arange(start, stop, step, dtype)


Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/scope_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __exit__(self, ptype, value, trace):
else:
self._exit_cb()


def _make_lets(bindings, ret_value):
"""Make a nested let expressions.
Expand Down Expand Up @@ -176,6 +175,16 @@ def _on_exit():
false_branch)
return WithScope(None, _on_exit)


def type_of(self, expr):
if isinstance(expr, _expr.Var):
return expr.type_annotation

ity = _ty.IncompleteType()
var = _expr.var("unify", ity)
self.let(var, expr)
return ity

def ret(self, value):
"""Set the return value of this scope.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .base import RelayNode, register_relay_node
from . import _make

Any = _make.Any

class Type(RelayNode):
"""The base type for all Relay types."""
Expand Down Expand Up @@ -137,6 +138,19 @@ def __init__(self, var, kind=Kind.Type):
"""
self.__init_handle_by_constructor__(_make.TypeVar, var, kind)

def ShapeVar(name):
"""A helper which constructs a type var of which the shape kind.
Parameters
----------
name : str
Returns
-------
type_var : tvm.relay.TypeVar
The shape variable.
"""
return TypeVar(name, kind=Kind.ShapeVar)

@register_relay_node
class GlobalTypeVar(Type):
Expand Down
4 changes: 3 additions & 1 deletion src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
op->call_type == Call::PureExtern) {
return CreateCallExtern(op);
} else {
LOG(FATAL) << "Unknown call type ";
LOG(FATAL) << "Unknown call type " <<
"name= " << op->name <<
" call_type= " << op->call_type;
return nullptr;
}
}
Expand Down
19 changes: 13 additions & 6 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,20 @@ inline Expr MergeMulMod(const Expr &base) {
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size());
if (index.size() > 0) {
Expr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImm>();
CHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
CHECK_EQ(n->shape.size(), index.size());
if (index.size() > 0) {
Expr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
}
base = base + offset;
}
base = base + offset;
}
} else {
CHECK_EQ(n->strides.size(), index.size());
Expand Down
23 changes: 20 additions & 3 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -35,13 +35,24 @@ namespace Internal {

using tvm::ir::CommReducerNode;
using tvm::ir::Reduce;
using tvm::ir::Any;
using tvm::ir::AttrStmt;

template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
}

template<>
void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
p->stream << "?";
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
Expand Down Expand Up @@ -116,8 +127,14 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
return Expr(n);
}

Expr Any::make() {
auto n = make_node<Any>();
return Expr(n);
}

TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt);

TVM_REGISTER_NODE_TYPE(FloatImm);
Expand Down
Loading

0 comments on commit d2b9d87

Please sign in to comment.