forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: ZihengJiang <ziheng@apache.org>
Showing
19 changed files
with
1,185 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relax/type.h | ||
* \brief Relax typed AST nodes. | ||
*/ | ||
#ifndef TVM_RELAX_TYPE_H_ | ||
#define TVM_RELAX_TYPE_H_ | ||
|
||
#include <tvm/ir/attrs.h> | ||
#include <tvm/ir/env_func.h> | ||
#include <tvm/ir/tensor_type.h> | ||
#include <tvm/ir/type.h> | ||
#include <tvm/ir/type_relation.h> | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/tir/expr.h> | ||
|
||
#include <string> | ||
|
||
namespace tvm { | ||
namespace relax { | ||
|
||
class ShapeTypeNode : public TypeNode { | ||
public: | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
} | ||
|
||
bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { | ||
return true; | ||
} | ||
|
||
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } | ||
|
||
static constexpr const char* _type_key = "relax.ShapeType"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); | ||
}; | ||
|
||
class ShapeType : public Type { | ||
public: | ||
explicit ShapeType(); | ||
explicit ShapeType(runtime::ObjectPtr<runtime::Object> n) : Type(n) {} | ||
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType); | ||
const ShapeTypeNode* operator->() const { | ||
return static_cast<const ShapeTypeNode*>(data_.get()); | ||
} | ||
const ShapeTypeNode* get() const { | ||
return operator->(); | ||
} | ||
using ContainerType = ShapeTypeNode; | ||
}; | ||
|
||
|
||
class DynTensorTypeNode : public BaseTensorTypeNode { | ||
public: | ||
/*! | ||
* \brief The rank of the tensor, use -1 to denote dynamic rank tensor. | ||
*/ | ||
int rank; | ||
/*! \brief The content data type, use void to denote the dtype is unknown. */ | ||
DataType dtype; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("rank", &rank); | ||
v->Visit("dtype", &dtype); | ||
v->Visit("span", &span); | ||
} | ||
|
||
bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const { | ||
return equal(rank, other->rank) && equal(dtype, other->dtype); | ||
} | ||
|
||
void SHashReduce(SHashReducer hash_reduce) const { | ||
hash_reduce(rank); | ||
hash_reduce(dtype); | ||
} | ||
|
||
static constexpr const char* _type_key = "relax.DynTensorType"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to DynTensorTypeNode. | ||
* \sa DynTensorTypeNode. | ||
*/ | ||
class DynTensorType : public Type { | ||
public: | ||
/*! | ||
* \brief Constructor. | ||
* \param shape The shape of the tensor. | ||
* \param dtype The runtime dtype of the tensor's elements. | ||
*/ | ||
TVM_DLL DynTensorType(int rank, DataType dtype); | ||
|
||
TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); | ||
}; | ||
|
||
} // namespace relax | ||
} // namespace tvm | ||
#endif // TVM_RELAX_TYPE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,51 @@ | ||
from .vm import VirtualMachine, load_exec_from_file | ||
from .builder import ExecBuilder | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from . import exec_builder | ||
from . import expr | ||
from . import ty | ||
from . import vm | ||
|
||
|
||
# Expr | ||
Expr = expr.Expr | ||
Span = expr.Span | ||
SourceName = expr.SourceName | ||
Id = expr.Id | ||
GlobalVar = expr.GlobalVar | ||
Var = expr.Var | ||
DataflowVar = expr.DataflowVar | ||
Binding = expr.Binding | ||
MatchShape = expr.MatchShape | ||
VarBinding = expr.VarBinding | ||
BindingBlock = expr.BindingBlock | ||
DataflowBlock = expr.DataflowBlock | ||
SeqExpr = expr.SeqExpr | ||
ShapeExpr = expr.ShapeExpr | ||
Function = expr.Function | ||
|
||
# helper functions | ||
const = expr.const | ||
|
||
# Type | ||
ShapeType = ty.ShapeType | ||
DynTensorType = ty.DynTensorType | ||
|
||
# VM | ||
ExecBuilder = exec_builder.ExecBuilder | ||
VirtualMachine = vm.VirtualMachine | ||
load_exec_from_file = vm.load_exec_from_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI API for Relax.""" | ||
import tvm._ffi | ||
|
||
tvm._ffi._init_api("relax", __name__) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from typing import List, Optional, Union, Dict | ||
import tvm._ffi | ||
from ..ir.base import Node, Span, SourceName | ||
from ..relay.base import Id | ||
from ..tir import PrimExpr | ||
from . import _ffi_api | ||
from .. import relay | ||
|
||
GlobalVar = relay.GlobalVar | ||
Expr = relay.Expr | ||
Type = relay.Type | ||
const = relay.const | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.ShapeExpr") | ||
class ShapeExpr(Expr): | ||
values: List[PrimExpr] | ||
|
||
def __init__(self, values: List[PrimExpr]) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values) | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.Var") | ||
class Var(Expr): | ||
id: Id | ||
type_annotation: Optional[Type] | ||
|
||
def __init__(self, name_hint: str, | ||
shape_annotation: Optional[List[Type]] = None, | ||
type_annotation: Optional[Type] = None) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, | ||
shape_annotation, | ||
type_annotation) | ||
|
||
@property | ||
def name_hint(self): | ||
"""Get name hint of the current var.""" | ||
name = str(self.vid.name_hint) | ||
return name | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.DataflowVar") | ||
class DataflowVar(Var): | ||
pass | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.Binding") | ||
class Binding(Node): | ||
pass | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.MatchShape") | ||
class MatchShape(Binding): | ||
pattern: List[PrimExpr] | ||
value: Expr | ||
|
||
def __init__(self, pattern: List[PrimExpr], value: Expr) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value) | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.VarBinding") | ||
class VarBinding(Binding): | ||
var: Var | ||
value: Expr | ||
|
||
def __init__(self, var: Var, value: Expr) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value) | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.BindingBlock") | ||
class BindingBlock(Node): | ||
bindings: List[Binding] | ||
|
||
def __init__(self, bindings: List[Binding]) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings) | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.DataflowBlock") | ||
class DataflowBlock(BindingBlock): | ||
pass | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.SeqExpr") | ||
class SeqExpr(Expr): | ||
blocks: List[BindingBlock] | ||
body: Expr | ||
|
||
def __init__(self, blocks: List[BindingBlock], body: Expr) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body) | ||
|
||
|
||
@tvm._ffi.register_object("relax.expr.Function") | ||
class Function(Expr): | ||
name: Optional[GlobalVar] | ||
params: List[Var] | ||
body: Expr | ||
ret_type: Type | ||
|
||
def __init__(self, params: List[Var], body: Expr, | ||
ret_type: Type, name: Optional[GlobalVar] = None) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.Function, name, params, | ||
body, ret_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name, unused-import | ||
"""The type nodes of the Relax language.""" | ||
import tvm._ffi | ||
from tvm.ir import Type, TensorType | ||
|
||
from . import _ffi_api | ||
|
||
|
||
@tvm._ffi.register_object("relax.ShapeType") | ||
class ShapeType(Type): | ||
def __init__(self): | ||
self.__init_handle_by_constructor__(_ffi_api.ShapeType) | ||
|
||
|
||
@tvm._ffi.register_object("relax.DynTensorType") | ||
class DynTensorType(TensorType): | ||
"""A dynamic TensorType in Relax. | ||
This is the type assigned to tensors with a known dtype and unknown shape. | ||
Parameters | ||
---------- | ||
rank : Optional[int] | ||
The rank of the Tensor | ||
dtype : Optional[str] | ||
The content data type. | ||
""" | ||
|
||
def __init__(self, rank=-1, dtype="float32"): | ||
self.__init_handle_by_constructor__(_ffi_api.DynTensorType, rank, dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#include <tvm/relax/expr.h> | ||
|
||
namespace tvm { | ||
namespace relax { | ||
|
||
using tvm::runtime::Optional; | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(ShapeExprNode); | ||
|
||
ShapeExpr::ShapeExpr(Array<PrimExpr> values) { | ||
ObjectPtr<ShapeExprNode> n = make_object<ShapeExprNode>(); | ||
n->values = std::move(values); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.ShapeExpr") | ||
.set_body_typed([](Array<PrimExpr> values) { | ||
return ShapeExpr(values); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(VarNode); | ||
|
||
Var::Var(Id vid, | ||
Optional<Array<PrimExpr>> shape_annotation, | ||
Optional<Type> type_annotation, | ||
Span span) { | ||
ObjectPtr<VarNode> n = make_object<VarNode>(); | ||
n->vid = std::move(vid); | ||
n->shape_ = std::move(shape_annotation); | ||
n->type_annotation = std::move(type_annotation); | ||
n->span = std::move(span); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.Var") | ||
.set_body_typed([](String name_hint, | ||
Optional<Array<PrimExpr>> shape_annotation, | ||
Optional<Type> type_annotation) { | ||
return Var(name_hint, shape_annotation, type_annotation); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(DataflowVarNode); | ||
|
||
TVM_REGISTER_GLOBAL("relax.DataflowVar") | ||
.set_body_typed([](String name_hint, | ||
Optional<Array<PrimExpr>> shape_annotation, | ||
Optional<Type> type_annotation) { | ||
return DataflowVar(name_hint, shape_annotation, type_annotation); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(BindingNode); | ||
|
||
TVM_REGISTER_GLOBAL("relax.Binding") | ||
.set_body_typed([]() { | ||
return Binding(); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(MatchShapeNode); | ||
|
||
MatchShape::MatchShape(Array<PrimExpr> pattern, | ||
Expr value) { | ||
ObjectPtr<MatchShapeNode> n = make_object<MatchShapeNode>(); | ||
n->pattern = std::move(pattern); | ||
n->value = std::move(value); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.MatchShape") | ||
.set_body_typed([](Array<PrimExpr> pattern, Expr value) { | ||
return MatchShape(pattern, value); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(VarBindingNode); | ||
|
||
VarBinding::VarBinding(Var var, | ||
Expr value) { | ||
ObjectPtr<VarBindingNode> n = make_object<VarBindingNode>(); | ||
n->var = std::move(var); | ||
n->value = std::move(value); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.VarBinding") | ||
.set_body_typed([](Var var,Expr value) { | ||
return VarBinding(var,value); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(BindingBlockNode); | ||
|
||
BindingBlock::BindingBlock(Array<Binding> bindings) { | ||
ObjectPtr<BindingBlockNode> n = make_object<BindingBlockNode>(); | ||
n->bindings = std::move(bindings); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.BindingBlock") | ||
.set_body_typed([](Array<Binding> bindings) { | ||
return BindingBlock(bindings); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(DataflowBlockNode); | ||
|
||
DataflowBlock::DataflowBlock(Array<Binding> bindings) { | ||
ObjectPtr<DataflowBlockNode> n = make_object<DataflowBlockNode>(); | ||
n->bindings = std::move(bindings); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.DataflowBlock") | ||
.set_body_typed([](Array<Binding> bindings) { | ||
return DataflowBlock(bindings); | ||
}); | ||
|
||
|
||
TVM_REGISTER_NODE_TYPE(SeqExprNode); | ||
|
||
SeqExpr::SeqExpr(Array<BindingBlock> blocks, | ||
Expr body) { | ||
ObjectPtr<SeqExprNode> n = make_object<SeqExprNode>(); | ||
n->blocks = std::move(blocks); | ||
n->body = std::move(body); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.SeqExpr") | ||
.set_body_typed([](Array<BindingBlock> blocks, Expr body) { | ||
return SeqExpr(blocks, body); | ||
}); | ||
|
||
|
||
Function::Function(runtime::Optional<GlobalVar> name, | ||
Array<Var> params, | ||
Expr body, | ||
Type ret_type) { | ||
ObjectPtr<FunctionNode> n = make_object<FunctionNode>(); | ||
n->name = std::move(name); | ||
n->params = std::move(params); | ||
n->body = std::move(body); | ||
n->ret_type = std::move(ret_type); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(FunctionNode); | ||
|
||
TVM_REGISTER_GLOBAL("relax.Function") | ||
.set_body_typed([](runtime::Optional<GlobalVar> name, | ||
Array<Var> params, | ||
Expr body, | ||
Type ret_type) { | ||
return Function(name, params, body, ret_type); | ||
}); | ||
|
||
} // namespace relax | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file src/relax/type.cc | ||
* \brief Relax's type system AST nodes throughout the IR. | ||
*/ | ||
#include <tvm/relax/type.h> | ||
#include <tvm/runtime/registry.h> | ||
|
||
namespace tvm { | ||
namespace relax { | ||
|
||
TVM_REGISTER_NODE_TYPE(ShapeTypeNode); | ||
|
||
ShapeType::ShapeType() { | ||
ObjectPtr<ShapeTypeNode> n = make_object<ShapeTypeNode>(); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.ShapeType") | ||
.set_body_typed([]() { | ||
return ShapeType(); | ||
}); | ||
|
||
DynTensorType::DynTensorType(int rank, DataType dtype) { | ||
ObjectPtr<DynTensorTypeNode> n = make_object<DynTensorTypeNode>(); | ||
n->rank = std::move(rank); | ||
n->dtype = std::move(dtype); | ||
data_ = std::move(n); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); | ||
|
||
TVM_REGISTER_GLOBAL("relax.DynTensorType") | ||
.set_body_typed([](int rank, DataType dtype) { | ||
return DynTensorType(rank, dtype); | ||
}); | ||
|
||
|
||
} // namespace relax | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import tvm | ||
from tvm import tir | ||
from tvm import relax as rx | ||
from tvm.ir import TensorType | ||
import numpy as np | ||
|
||
|
||
def test_var() -> None: | ||
v0 = rx.Var("v0") | ||
assert v0.name_hint == "v0" | ||
assert v0.shape_ is None | ||
assert v0.type_annotation is None | ||
shape_anno = [54, 96] | ||
type_anno = TensorType(shape_anno, "float32") | ||
v1 = rx.Var("v1", shape_anno, type_anno) | ||
assert v1.name_hint == "v1" | ||
for s0, s1 in zip(v1.shape_, shape_anno): | ||
assert s0 == s1 | ||
assert v1.type_annotation == type_anno | ||
|
||
|
||
def test_dataflow_var() -> None: | ||
v0 = rx.DataflowVar("v0") | ||
assert v0.name_hint == "v0" | ||
assert v0.shape_ is None | ||
assert v0.type_annotation is None | ||
shape_anno = [54, 96] | ||
type_anno = TensorType(shape_anno, "float16") | ||
v1 = rx.DataflowVar("v1", shape_anno, type_anno) | ||
assert v1.name_hint == "v1" | ||
for s0, s1 in zip(v1.shape_, shape_anno): | ||
assert s0 == s1 | ||
assert v1.type_annotation == type_anno | ||
assert isinstance(v1, rx.DataflowVar) | ||
|
||
|
||
def test_match_shape() -> None: | ||
m = tir.Var("m", dtype="int32") | ||
n = tir.Var("n", dtype="int32") | ||
shape = rx.const([16, 8], "int32") | ||
b0 = rx.MatchShape([m, n], shape) | ||
assert b0.pattern[0] == m | ||
assert b0.pattern[1] == n | ||
assert b0.value == shape | ||
|
||
|
||
def test_var_binding() -> None: | ||
v0 = rx.Var("v0") | ||
val = rx.const(np.random.rand(24, 56)) | ||
b0 = rx.VarBinding(v0, val) | ||
assert b0.var.name_hint == "v0" | ||
assert b0.value == val | ||
|
||
|
||
def test_binding_block() -> None: | ||
m = tir.Var("m", dtype="int32") | ||
n = tir.Var("n", dtype="int32") | ||
shape = rx.const([16, 8], "int32") | ||
b0 = rx.MatchShape([m, n], shape) | ||
|
||
v0 = rx.Var("v0") | ||
val = rx.const(np.random.rand(24, 56)) | ||
b1 = rx.VarBinding(v0, val) | ||
|
||
block0 = rx.BindingBlock([b0, b1]) | ||
assert block0.bindings[0] == b0 | ||
assert block0.bindings[1] == b1 | ||
|
||
|
||
def test_dataflow_block() -> None: | ||
m = tir.Var("m", dtype="int32") | ||
n = tir.Var("n", dtype="int32") | ||
shape = rx.const([16, 8], "int32") | ||
b0 = rx.MatchShape([m, n], shape) | ||
|
||
v0 = rx.Var("v0") | ||
val = rx.const(np.random.rand(24, 56)) | ||
b1 = rx.VarBinding(v0, val) | ||
|
||
block0 = rx.DataflowBlock([b0, b1]) | ||
assert block0.bindings[0] == b0 | ||
assert block0.bindings[1] == b1 | ||
assert isinstance(block0, rx.DataflowBlock) | ||
|
||
|
||
def test_seq_expr() -> None: | ||
x = rx.Var("foo") | ||
bindings = [rx.VarBinding(x, rx.const(1))] | ||
blocks = [rx.BindingBlock(bindings)] | ||
seqe = rx.SeqExpr(blocks, x) | ||
assert seqe.blocks[0] == blocks[0] | ||
assert seqe.body == x | ||
|
||
|
||
def test_shape_expr() -> None: | ||
m = tir.Var("m", dtype="int32") | ||
n = tir.Var("n", dtype="int32") | ||
s = rx.ShapeExpr([m, n]) | ||
assert s.values[0] == m | ||
assert s.values[1] == n | ||
|
||
|
||
def test_func(): | ||
x = rx.Var("foo") | ||
bindings = [rx.VarBinding(x, rx.const(1))] | ||
blocks = [rx.BindingBlock(bindings)] | ||
seqe = rx.SeqExpr(blocks, x) | ||
ret_type = TensorType(None, "float32") | ||
func = rx.Function([x], seqe, ret_type, rx.GlobalVar("func")) | ||
assert func.params[0] == x | ||
assert func.body == seqe | ||
assert func.ret_type == ret_type | ||
assert func.name.name_hint == "func" | ||
|
||
|
||
if __name__ == "__main__": | ||
test_var() | ||
test_dataflow_var() | ||
test_match_shape() | ||
test_var_binding() | ||
test_binding_block() | ||
test_dataflow_block() | ||
test_seq_expr() | ||
test_shape_expr() | ||
test_func() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
import numpy as np | ||
import tvm | ||
from tvm import relax as rx | ||
|
||
def test_shape_type(): | ||
t0 = rx.ShapeType() | ||
t1 = rx.ShapeType() | ||
assert t0 == t1 | ||
|
||
def test_dyn_tensor_type(): | ||
t0 = rx.DynTensorType() | ||
assert t0.rank == -1 | ||
t1 = rx.DynTensorType(3, "int32") | ||
assert t1.rank == 3 | ||
assert t1.dtype == "int32" | ||
|
||
if __name__ == "__main__": | ||
test_shape_type() | ||
test_dyn_tensor_type() |
File renamed without changes.