Skip to content

Commit

Permalink
[IRBuilder] Minor tweaks (apache#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jun 24, 2022
1 parent 4bc1e67 commit f96e789
Show file tree
Hide file tree
Showing 16 changed files with 252 additions and 185 deletions.
8 changes: 3 additions & 5 deletions python/tvm/script/builder/tir/block_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Block Frame"""
from typing import Any, Dict, List, Union

from tvm._ffi import register_object as _register_object
from tvm.tir import Buffer, BufferLoad, BufferRegion

from . import _ffi_api
from .base import TIRFrame

from typing import List, Dict, Any, Union
from tvm.tir import Buffer, BufferLoad, BufferRegion


@_register_object("script.builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
Expand Down Expand Up @@ -73,7 +73,6 @@ def alloc_buffer(
offset_factor=0,
buffer_type="default",
axis_separators=None,
span=None,
) -> Buffer:
return _ffi_api.AllocBuffer(
shape,
Expand All @@ -86,5 +85,4 @@ def alloc_buffer(
offset_factor,
buffer_type,
axis_separators,
span,
)
90 changes: 52 additions & 38 deletions python/tvm/script/builder/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,65 @@
# under the License.
"""TVM Script TIR Op"""

from . import _ffi_api


from tvm.tir.op import abs, popcount, nextafter, copysign, fmod
from tvm.tir.expr import Broadcast, Ramp, Select, Shuffle
from tvm.tir.generic import cast
from tvm.tir.op import (
abs,
acos,
acosh,
asin,
asinh,
atan,
atan2,
atanh,
call_extern,
call_packed,
ceil,
clz,
comm_reducer,
copysign,
cos,
cosh,
erf,
exp,
exp2,
exp10,
floor,
floordiv,
floormod,
ceil,
round,
trunc,
truncdiv,
truncmod,
nearbyint,
)
from tvm.tir.op import (
fmod,
hypot,
if_then_else,
infinity,
isfinite,
isinf,
isnan,
ldexp,
power,
exp,
exp2,
exp10,
erf,
sqrt,
rsqrt,
log,
log1p,
log2,
log10,
log1p,
max_value,
min_value,
nearbyint,
nextafter,
popcount,
power,
reinterpret,
round,
rsqrt,
sigmoid,
sin,
sinh,
sqrt,
tan,
tanh,
trunc,
truncdiv,
truncmod,
)
from tvm.tir.op import isnan, isfinite, isinf
from tvm.tir.op import cos, cosh, sin, sinh, tan, tanh
from tvm.tir.op import acos, acosh, asin, asinh, atan, atanh
from tvm.tir.op import atan2, clz, comm_reducer, infinity, reinterpret
from tvm.tir.op import min_value, max_value, if_then_else
from tvm.tir.op import call_packed, call_extern
from tvm.tir.expr import Select, Ramp, Broadcast, Shuffle
from tvm.tir.generic import cast

from . import _ffi_api


def boolean(expr):
Expand Down Expand Up @@ -113,7 +133,7 @@ def handle():
return _ffi_api.Handle()


def min(a, b, span=None):
def min(a, b):
"""Compute the minimum value of two expressions.
Parameters
Expand All @@ -124,9 +144,6 @@ def min(a, b, span=None):
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
Expand All @@ -136,10 +153,10 @@ def min(a, b, span=None):
----
This is the default integer division behavior in C.
"""
return _ffi_api.min(a, b, span) # type: ignore
return _ffi_api.min(a, b) # type: ignore


def max(a, b, span=None):
def max(a, b):
"""Compute the maximum value of two expressions.
Parameters
Expand All @@ -150,9 +167,6 @@ def max(a, b, span=None):
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
Expand All @@ -162,4 +176,4 @@ def max(a, b, span=None):
----
This is the default integer division behavior in C.
"""
return _ffi_api.max(a, b, span) # type: ignore
return _ffi_api.max(a, b) # type: ignore
4 changes: 0 additions & 4 deletions python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def match_buffer(
offset_factor=0,
buffer_type="default",
axis_separators=None,
span=None,
) -> Buffer:
return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore
param,
Expand All @@ -84,7 +83,6 @@ def match_buffer(
offset_factor,
buffer_type,
axis_separators,
span,
)


Expand All @@ -100,7 +98,6 @@ def preflattened_buffer(
offset_factor=0,
buffer_type="default",
axis_separators=None,
span=None,
) -> None:
_ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore
postflattened,
Expand All @@ -114,5 +111,4 @@ def preflattened_buffer(
offset_factor,
buffer_type,
axis_separators,
span,
)
20 changes: 0 additions & 20 deletions src/script/builder/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,7 @@ void FrameNode::ExitWithScope() {
Builder::Current()->frames.pop_back();
}

IRModuleFrame::IRModuleFrame() {
ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
n->global_vars.clear();
n->functions.clear();
data_ = std::move(n);
}

void IRModuleFrameNode::ExitWithScope() {
ICHECK_EQ(functions.size(), global_vars.size());
int n = functions.size();
Map<GlobalVar, BaseFunc> func_map;
for (int i = 0; i < n; ++i) {
func_map.Set(global_vars[i], functions[i]);
}
Builder builder = Builder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = tvm::IRModule(func_map);
}

TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
TVM_REGISTER_GLOBAL("script.builder.FrameEnter").set_body_method<Frame>(&FrameNode::EnterWithScope);
TVM_REGISTER_GLOBAL("script.builder.FrameExit").set_body_method<Frame>(&FrameNode::ExitWithScope);

Expand Down
24 changes: 0 additions & 24 deletions src/script/builder/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,6 @@ class Frame : public runtime::ObjectRef {
inline void ExitWithScope();
};

class IRModuleFrameNode : public FrameNode {
public:
Array<GlobalVar> global_vars;
Array<BaseFunc> functions;

void VisitAttrs(tvm::AttrVisitor* v) {
FrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_vars);
v->Visit("functions", &functions);
}

static constexpr const char* _type_key = "script.builder.IRModuleFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, FrameNode);

public:
void ExitWithScope() final;
};

class IRModuleFrame : public Frame {
public:
IRModuleFrame();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, Frame, IRModuleFrameNode);
};

inline void Frame::EnterWithScope() {
ICHECK(data_ != nullptr);
static_cast<FrameNode*>(data_.get())->EnterWithScope();
Expand Down
52 changes: 52 additions & 0 deletions src/script/builder/ir/ir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./ir.h"

#include "../builder.h"

namespace tvm {
namespace script {
namespace builder {
namespace ir {

IRModuleFrame::IRModuleFrame() {
ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
n->global_vars.clear();
n->functions.clear();
data_ = std::move(n);
}

void IRModuleFrameNode::ExitWithScope() {
ICHECK_EQ(functions.size(), global_vars.size());
int n = functions.size();
Map<GlobalVar, BaseFunc> func_map;
for (int i = 0; i < n; ++i) {
func_map.Set(global_vars[i], functions[i]);
}
Builder builder = Builder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = tvm::IRModule(func_map);
}

TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);

} // namespace ir
} // namespace builder
} // namespace script
} // namespace tvm
60 changes: 60 additions & 0 deletions src/script/builder/ir/ir.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_BUILDER_IR_IR_H_
#define TVM_SCRIPT_BUILDER_IR_IR_H_

#include "../frame.h"

namespace tvm {
namespace script {
namespace builder {
namespace ir {

class IRModuleFrameNode : public FrameNode {
public:
Array<GlobalVar> global_vars;
Array<BaseFunc> functions;

void VisitAttrs(tvm::AttrVisitor* v) {
FrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_vars);
v->Visit("functions", &functions);
}

static constexpr const char* _type_key = "script.builder.ir.IRModuleFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, FrameNode);

public:
void ExitWithScope() final;
};

class IRModuleFrame : public Frame {
public:
IRModuleFrame();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, Frame, IRModuleFrameNode);
};

IRModuleFrame ir_module();

} // namespace ir
} // namespace builder
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_BUILDER_IR_IR_H_
Loading

0 comments on commit f96e789

Please sign in to comment.