Skip to content

Commit

Permalink
Relax AST (apache#2)
Browse files Browse the repository at this point in the history
Co-authored-by: ZihengJiang <ziheng@apache.org>
2 people authored and YuchenJin committed Mar 2, 2022
1 parent 27f4927 commit 25f8d83
Showing 19 changed files with 1,185 additions and 7 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -290,6 +290,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS
tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS
src/relay/qnn/*.cc
)

list(APPEND COMPILER_SRCS ${RELAY_OP_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_PASS_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS})
18 changes: 18 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
@@ -133,6 +133,7 @@ class PrimExpr : public BaseExpr {
TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
};

class RelayExpr;
/*!
* \brief Base node of all non-primitive expressions.
*
@@ -151,10 +152,27 @@ class RelayExprNode : public BaseExprNode {
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);

/*!
* \brief Stores the result of static shape analysis.
*
* \note The value will be optional if a static shape can not be inferred.
* use .shape() instead to acesss an always defined shape expression.
*/
Optional<Array<PrimExpr>> shape_ = Optional<Array<PrimExpr>>();

/*!
* \return The checked_type
*/
inline const Type& checked_type() const;

/*!
* \return An expression which corresponds to the shape of the expression.
*
* Only valid when the expression's type is a Tensor.
*/
inline RelayExpr shape() const;

/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
4 changes: 2 additions & 2 deletions include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
* under the License.
*/
/*!
* \file tvm/node/structural_equal.h
* \file tvm/node/structural_hash.h
* \brief Structural hash class.
*/
#ifndef TVM_NODE_STRUCTURAL_HASH_H_
@@ -174,7 +174,7 @@ class SHashReducer {
/*!
* \brief Push hash of key to the current sequence of hash values.
* \param key The key to be hashed.
* \note This function indicate key could contain var defintions.
* \note This function indicates key could contain variable defintions.
*/
void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); }
/*!
404 changes: 404 additions & 0 deletions include/tvm/relax/expr.h

Large diffs are not rendered by default.

117 changes: 117 additions & 0 deletions include/tvm/relax/type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/type.h
* \brief Relax typed AST nodes.
*/
#ifndef TVM_RELAX_TYPE_H_
#define TVM_RELAX_TYPE_H_

#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
#include <tvm/ir/tensor_type.h>
#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>

#include <string>

namespace tvm {
namespace relax {

class ShapeTypeNode : public TypeNode {
public:

void VisitAttrs(tvm::AttrVisitor* v) {
}

bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
return true;
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }

static constexpr const char* _type_key = "relax.ShapeType";
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode);
};

class ShapeType : public Type {
public:
explicit ShapeType();
explicit ShapeType(runtime::ObjectPtr<runtime::Object> n) : Type(n) {}
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType);
const ShapeTypeNode* operator->() const {
return static_cast<const ShapeTypeNode*>(data_.get());
}
const ShapeTypeNode* get() const {
return operator->();
}
using ContainerType = ShapeTypeNode;
};


class DynTensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The rank of the tensor, use -1 to denote dynamic rank tensor.
*/
int rank;
/*! \brief The content data type, use void to denote the dtype is unknown. */
DataType dtype;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("rank", &rank);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}

bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const {
return equal(rank, other->rank) && equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(rank);
hash_reduce(dtype);
}

static constexpr const char* _type_key = "relax.DynTensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode);
};

/*!
* \brief Managed reference to DynTensorTypeNode.
* \sa DynTensorTypeNode.
*/
class DynTensorType : public Type {
public:
/*!
* \brief Constructor.
* \param shape The shape of the tensor.
* \param dtype The runtime dtype of the tensor's elements.
*/
TVM_DLL DynTensorType(int rank, DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode);
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_TYPE_H_
3 changes: 2 additions & 1 deletion python/tvm/ir/module.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@

from .base import Node
from . import expr as _expr
from ..ir.function import BaseFunc
from . import type as _ty
from . import _ffi_api

@@ -75,7 +76,7 @@ def __setitem__(self, var, val):
return self._add(var, val, True)

def _add(self, var, val, update=True):
if isinstance(val, _expr.RelayExpr):
if isinstance(val, (_expr.RelayExpr, BaseFunc)):
if isinstance(var, string_types):
if _ffi_api.Module_ContainGlobalVar(self, var):
var = _ffi_api.Module_GetGlobalVar(self, var)
53 changes: 51 additions & 2 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,51 @@
from .vm import VirtualMachine, load_exec_from_file
from .builder import ExecBuilder
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from . import exec_builder
from . import expr
from . import ty
from . import vm


# Expr
Expr = expr.Expr
Span = expr.Span
SourceName = expr.SourceName
Id = expr.Id
GlobalVar = expr.GlobalVar
Var = expr.Var
DataflowVar = expr.DataflowVar
Binding = expr.Binding
MatchShape = expr.MatchShape
VarBinding = expr.VarBinding
BindingBlock = expr.BindingBlock
DataflowBlock = expr.DataflowBlock
SeqExpr = expr.SeqExpr
ShapeExpr = expr.ShapeExpr
Function = expr.Function

# helper functions
const = expr.const

# Type
ShapeType = ty.ShapeType
DynTensorType = ty.DynTensorType

# VM
ExecBuilder = exec_builder.ExecBuilder
VirtualMachine = vm.VirtualMachine
load_exec_from_file = vm.load_exec_from_file
17 changes: 17 additions & 0 deletions python/tvm/relax/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI API for Relax."""
import tvm._ffi

tvm._ffi._init_api("relax", __name__)
File renamed without changes.
118 changes: 118 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List, Optional, Union, Dict
import tvm._ffi
from ..ir.base import Node, Span, SourceName
from ..relay.base import Id
from ..tir import PrimExpr
from . import _ffi_api
from .. import relay

GlobalVar = relay.GlobalVar
Expr = relay.Expr
Type = relay.Type
const = relay.const


@tvm._ffi.register_object("relax.expr.ShapeExpr")
class ShapeExpr(Expr):
values: List[PrimExpr]

def __init__(self, values: List[PrimExpr]) -> None:
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values)


@tvm._ffi.register_object("relax.expr.Var")
class Var(Expr):
id: Id
type_annotation: Optional[Type]

def __init__(self, name_hint: str,
shape_annotation: Optional[List[Type]] = None,
type_annotation: Optional[Type] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint,
shape_annotation,
type_annotation)

@property
def name_hint(self):
"""Get name hint of the current var."""
name = str(self.vid.name_hint)
return name


@tvm._ffi.register_object("relax.expr.DataflowVar")
class DataflowVar(Var):
pass


@tvm._ffi.register_object("relax.expr.Binding")
class Binding(Node):
pass


@tvm._ffi.register_object("relax.expr.MatchShape")
class MatchShape(Binding):
pattern: List[PrimExpr]
value: Expr

def __init__(self, pattern: List[PrimExpr], value: Expr) -> None:
self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value)


@tvm._ffi.register_object("relax.expr.VarBinding")
class VarBinding(Binding):
var: Var
value: Expr

def __init__(self, var: Var, value: Expr) -> None:
self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value)


@tvm._ffi.register_object("relax.expr.BindingBlock")
class BindingBlock(Node):
bindings: List[Binding]

def __init__(self, bindings: List[Binding]) -> None:
self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings)


@tvm._ffi.register_object("relax.expr.DataflowBlock")
class DataflowBlock(BindingBlock):
pass


@tvm._ffi.register_object("relax.expr.SeqExpr")
class SeqExpr(Expr):
blocks: List[BindingBlock]
body: Expr

def __init__(self, blocks: List[BindingBlock], body: Expr) -> None:
self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body)


@tvm._ffi.register_object("relax.expr.Function")
class Function(Expr):
name: Optional[GlobalVar]
params: List[Var]
body: Expr
ret_type: Type

def __init__(self, params: List[Var], body: Expr,
ret_type: Type, name: Optional[GlobalVar] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Function, name, params,
body, ret_type)
47 changes: 47 additions & 0 deletions python/tvm/relax/ty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""The type nodes of the Relax language."""
import tvm._ffi
from tvm.ir import Type, TensorType

from . import _ffi_api


@tvm._ffi.register_object("relax.ShapeType")
class ShapeType(Type):
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.ShapeType)


@tvm._ffi.register_object("relax.DynTensorType")
class DynTensorType(TensorType):
"""A dynamic TensorType in Relax.
This is the type assigned to tensors with a known dtype and unknown shape.
Parameters
----------
rank : Optional[int]
The rank of the Tensor
dtype : Optional[str]
The content data type.
"""

def __init__(self, rank=-1, dtype="float32"):
self.__init_handle_by_constructor__(_ffi_api.DynTensorType, rank, dtype)
5 changes: 3 additions & 2 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@

from tvm.runtime import Object
from tvm.ir import SourceName, Span, Node as RelayNode
from . import _ffi_api


__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
@@ -37,5 +38,5 @@ class Id(Object):
Guaranteed to be stable across all passes.
"""

def __init__(self):
raise RuntimeError("Cannot directly construct Id")
def __init__(self, string):
self.__init_handle_by_constructor__(_ffi_api.Id, string)
2 changes: 2 additions & 0 deletions python/tvm/script/utils.py
Original file line number Diff line number Diff line change
@@ -67,6 +67,7 @@ def get_param_list(

def tvm_span_from_synr(span: synr.ast.Span) -> Span:
"""Convert a synr span to a TVM span"""
assert isinstance(span, synr.ast.Span), "Expected span to be synr.ast.Span, but got " + str(type(span))
return Span(
SourceName(span.filename),
span.start_line,
@@ -78,6 +79,7 @@ def tvm_span_from_synr(span: synr.ast.Span) -> Span:

def synr_span_from_tvm(span: Span) -> synr.ast.Span:
"""Convert a TVM span to a synr span"""
assert isinstance(span, synr.ast.Span), "Expected span to be tvm.ir.Span, but got " + str(type(span))
return synr.ast.Span(
span.source_name.name,
span.line,
181 changes: 181 additions & 0 deletions src/relax/expr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

using tvm::runtime::Optional;


TVM_REGISTER_NODE_TYPE(ShapeExprNode);

ShapeExpr::ShapeExpr(Array<PrimExpr> values) {
ObjectPtr<ShapeExprNode> n = make_object<ShapeExprNode>();
n->values = std::move(values);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.ShapeExpr")
.set_body_typed([](Array<PrimExpr> values) {
return ShapeExpr(values);
});


TVM_REGISTER_NODE_TYPE(VarNode);

Var::Var(Id vid,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Type> type_annotation,
Span span) {
ObjectPtr<VarNode> n = make_object<VarNode>();
n->vid = std::move(vid);
n->shape_ = std::move(shape_annotation);
n->type_annotation = std::move(type_annotation);
n->span = std::move(span);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.Var")
.set_body_typed([](String name_hint,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Type> type_annotation) {
return Var(name_hint, shape_annotation, type_annotation);
});


TVM_REGISTER_NODE_TYPE(DataflowVarNode);

TVM_REGISTER_GLOBAL("relax.DataflowVar")
.set_body_typed([](String name_hint,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Type> type_annotation) {
return DataflowVar(name_hint, shape_annotation, type_annotation);
});


TVM_REGISTER_NODE_TYPE(BindingNode);

TVM_REGISTER_GLOBAL("relax.Binding")
.set_body_typed([]() {
return Binding();
});


TVM_REGISTER_NODE_TYPE(MatchShapeNode);

MatchShape::MatchShape(Array<PrimExpr> pattern,
Expr value) {
ObjectPtr<MatchShapeNode> n = make_object<MatchShapeNode>();
n->pattern = std::move(pattern);
n->value = std::move(value);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.MatchShape")
.set_body_typed([](Array<PrimExpr> pattern, Expr value) {
return MatchShape(pattern, value);
});


TVM_REGISTER_NODE_TYPE(VarBindingNode);

VarBinding::VarBinding(Var var,
Expr value) {
ObjectPtr<VarBindingNode> n = make_object<VarBindingNode>();
n->var = std::move(var);
n->value = std::move(value);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.VarBinding")
.set_body_typed([](Var var,Expr value) {
return VarBinding(var,value);
});


TVM_REGISTER_NODE_TYPE(BindingBlockNode);

BindingBlock::BindingBlock(Array<Binding> bindings) {
ObjectPtr<BindingBlockNode> n = make_object<BindingBlockNode>();
n->bindings = std::move(bindings);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.BindingBlock")
.set_body_typed([](Array<Binding> bindings) {
return BindingBlock(bindings);
});


TVM_REGISTER_NODE_TYPE(DataflowBlockNode);

DataflowBlock::DataflowBlock(Array<Binding> bindings) {
ObjectPtr<DataflowBlockNode> n = make_object<DataflowBlockNode>();
n->bindings = std::move(bindings);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.DataflowBlock")
.set_body_typed([](Array<Binding> bindings) {
return DataflowBlock(bindings);
});


TVM_REGISTER_NODE_TYPE(SeqExprNode);

SeqExpr::SeqExpr(Array<BindingBlock> blocks,
Expr body) {
ObjectPtr<SeqExprNode> n = make_object<SeqExprNode>();
n->blocks = std::move(blocks);
n->body = std::move(body);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.SeqExpr")
.set_body_typed([](Array<BindingBlock> blocks, Expr body) {
return SeqExpr(blocks, body);
});


Function::Function(runtime::Optional<GlobalVar> name,
Array<Var> params,
Expr body,
Type ret_type) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
n->name = std::move(name);
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(FunctionNode);

TVM_REGISTER_GLOBAL("relax.Function")
.set_body_typed([](runtime::Optional<GlobalVar> name,
Array<Var> params,
Expr body,
Type ret_type) {
return Function(name, params, body, ret_type);
});

} // namespace relax
} // namespace tvm
58 changes: 58 additions & 0 deletions src/relax/type.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relax/type.cc
* \brief Relax's type system AST nodes throughout the IR.
*/
#include <tvm/relax/type.h>
#include <tvm/runtime/registry.h>

namespace tvm {
namespace relax {

TVM_REGISTER_NODE_TYPE(ShapeTypeNode);

ShapeType::ShapeType() {
ObjectPtr<ShapeTypeNode> n = make_object<ShapeTypeNode>();
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.ShapeType")
.set_body_typed([]() {
return ShapeType();
});

DynTensorType::DynTensorType(int rank, DataType dtype) {
ObjectPtr<DynTensorTypeNode> n = make_object<DynTensorTypeNode>();
n->rank = std::move(rank);
n->dtype = std::move(dtype);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DynTensorTypeNode);

TVM_REGISTER_GLOBAL("relax.DynTensorType")
.set_body_typed([](int rank, DataType dtype) {
return DynTensorType(rank, dtype);
});


} // namespace relax
} // namespace tvm
4 changes: 4 additions & 0 deletions src/relay/ir/base.cc
Original file line number Diff line number Diff line change
@@ -39,6 +39,10 @@ Id::Id(String name_hint) {
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relay.ir.Id").set_body_typed([](String name_hint) {
return Id(name_hint);
});

TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) {
rn->span = sp;
125 changes: 125 additions & 0 deletions tests/python/relax/test_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import tvm
from tvm import tir
from tvm import relax as rx
from tvm.ir import TensorType
import numpy as np


def test_var() -> None:
v0 = rx.Var("v0")
assert v0.name_hint == "v0"
assert v0.shape_ is None
assert v0.type_annotation is None
shape_anno = [54, 96]
type_anno = TensorType(shape_anno, "float32")
v1 = rx.Var("v1", shape_anno, type_anno)
assert v1.name_hint == "v1"
for s0, s1 in zip(v1.shape_, shape_anno):
assert s0 == s1
assert v1.type_annotation == type_anno


def test_dataflow_var() -> None:
v0 = rx.DataflowVar("v0")
assert v0.name_hint == "v0"
assert v0.shape_ is None
assert v0.type_annotation is None
shape_anno = [54, 96]
type_anno = TensorType(shape_anno, "float16")
v1 = rx.DataflowVar("v1", shape_anno, type_anno)
assert v1.name_hint == "v1"
for s0, s1 in zip(v1.shape_, shape_anno):
assert s0 == s1
assert v1.type_annotation == type_anno
assert isinstance(v1, rx.DataflowVar)


def test_match_shape() -> None:
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchShape([m, n], shape)
assert b0.pattern[0] == m
assert b0.pattern[1] == n
assert b0.value == shape


def test_var_binding() -> None:
v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
b0 = rx.VarBinding(v0, val)
assert b0.var.name_hint == "v0"
assert b0.value == val


def test_binding_block() -> None:
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchShape([m, n], shape)

v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
b1 = rx.VarBinding(v0, val)

block0 = rx.BindingBlock([b0, b1])
assert block0.bindings[0] == b0
assert block0.bindings[1] == b1


def test_dataflow_block() -> None:
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
shape = rx.const([16, 8], "int32")
b0 = rx.MatchShape([m, n], shape)

v0 = rx.Var("v0")
val = rx.const(np.random.rand(24, 56))
b1 = rx.VarBinding(v0, val)

block0 = rx.DataflowBlock([b0, b1])
assert block0.bindings[0] == b0
assert block0.bindings[1] == b1
assert isinstance(block0, rx.DataflowBlock)


def test_seq_expr() -> None:
x = rx.Var("foo")
bindings = [rx.VarBinding(x, rx.const(1))]
blocks = [rx.BindingBlock(bindings)]
seqe = rx.SeqExpr(blocks, x)
assert seqe.blocks[0] == blocks[0]
assert seqe.body == x


def test_shape_expr() -> None:
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
s = rx.ShapeExpr([m, n])
assert s.values[0] == m
assert s.values[1] == n


def test_func():
x = rx.Var("foo")
bindings = [rx.VarBinding(x, rx.const(1))]
blocks = [rx.BindingBlock(bindings)]
seqe = rx.SeqExpr(blocks, x)
ret_type = TensorType(None, "float32")
func = rx.Function([x], seqe, ret_type, rx.GlobalVar("func"))
assert func.params[0] == x
assert func.body == seqe
assert func.ret_type == ret_type
assert func.name.name_hint == "func"


if __name__ == "__main__":
test_var()
test_dataflow_var()
test_match_shape()
test_var_binding()
test_binding_block()
test_dataflow_block()
test_seq_expr()
test_shape_expr()
test_func()
35 changes: 35 additions & 0 deletions tests/python/relax/test_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relax as rx

def test_shape_type():
t0 = rx.ShapeType()
t1 = rx.ShapeType()
assert t0 == t1

def test_dyn_tensor_type():
t0 = rx.DynTensorType()
assert t0.rank == -1
t1 = rx.DynTensorType(3, "int32")
assert t1.rank == 3
assert t1.dtype == "int32"

if __name__ == "__main__":
test_shape_type()
test_dyn_tensor_type()
File renamed without changes.

0 comments on commit 25f8d83

Please sign in to comment.