diff --git a/python/tvm/script/builder/tir/block_frame.py b/python/tvm/script/builder/tir/block_frame.py index 890b3c44eaa1..488d896147ea 100644 --- a/python/tvm/script/builder/tir/block_frame.py +++ b/python/tvm/script/builder/tir/block_frame.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR Block Frame""" +from typing import Any, Dict, List, Union + from tvm._ffi import register_object as _register_object +from tvm.tir import Buffer, BufferLoad, BufferRegion from . import _ffi_api from .base import TIRFrame -from typing import List, Dict, Any, Union -from tvm.tir import Buffer, BufferLoad, BufferRegion - @_register_object("script.builder.tir.BlockFrame") class BlockFrame(TIRFrame): @@ -73,7 +73,6 @@ def alloc_buffer( offset_factor=0, buffer_type="default", axis_separators=None, - span=None, ) -> Buffer: return _ffi_api.AllocBuffer( shape, @@ -86,5 +85,4 @@ def alloc_buffer( offset_factor, buffer_type, axis_separators, - span, ) diff --git a/python/tvm/script/builder/tir/op.py b/python/tvm/script/builder/tir/op.py index d75e242ba71a..c0fce23bacd1 100644 --- a/python/tvm/script/builder/tir/op.py +++ b/python/tvm/script/builder/tir/op.py @@ -16,45 +16,65 @@ # under the License. """TVM Script TIR Op""" -from . import _ffi_api - - -from tvm.tir.op import abs, popcount, nextafter, copysign, fmod +from tvm.tir.expr import Broadcast, Ramp, Select, Shuffle +from tvm.tir.generic import cast from tvm.tir.op import ( + abs, + acos, + acosh, + asin, + asinh, + atan, + atan2, + atanh, + call_extern, + call_packed, + ceil, + clz, + comm_reducer, + copysign, + cos, + cosh, + erf, + exp, + exp2, + exp10, floor, floordiv, floormod, - ceil, - round, - trunc, - truncdiv, - truncmod, - nearbyint, -) -from tvm.tir.op import ( + fmod, hypot, + if_then_else, + infinity, + isfinite, + isinf, + isnan, ldexp, - power, - exp, - exp2, - exp10, - erf, - sqrt, - rsqrt, log, + log1p, log2, log10, - log1p, + max_value, + min_value, + nearbyint, + nextafter, + popcount, + power, + reinterpret, + round, + rsqrt, sigmoid, + sin, + sinh, + sqrt, + tan, + tanh, + trunc, + truncdiv, + truncmod, ) -from tvm.tir.op import isnan, isfinite, isinf -from tvm.tir.op import cos, cosh, sin, sinh, tan, tanh -from tvm.tir.op import acos, acosh, asin, asinh, atan, atanh -from tvm.tir.op import atan2, clz, comm_reducer, infinity, reinterpret -from tvm.tir.op import min_value, max_value, if_then_else -from tvm.tir.op import call_packed, call_extern -from tvm.tir.expr import Select, Ramp, Broadcast, Shuffle -from tvm.tir.generic import cast + +from . import _ffi_api def boolean(expr): @@ -113,7 +133,7 @@ def handle(): return _ffi_api.Handle() -def min(a, b, span=None): +def min(a, b): """Compute the minimum value of two expressions. Parameters @@ -124,9 +144,6 @@ def min(a, b, span=None): b : PrimExpr The right hand operand - span : Optional[Span] - The location of this operator in the source. - Returns ------- res : PrimExpr @@ -136,10 +153,10 @@ def min(a, b, span=None): ---- This is the default integer division behavior in C. """ - return _ffi_api.min(a, b, span) # type: ignore + return _ffi_api.min(a, b) # type: ignore -def max(a, b, span=None): +def max(a, b): """Compute the maximum value of two expressions. Parameters @@ -150,9 +167,6 @@ def max(a, b, span=None): b : PrimExpr The right hand operand - span : Optional[Span] - The location of this operator in the source. - Returns ------- res : PrimExpr @@ -162,4 +176,4 @@ def max(a, b, span=None): ---- This is the default integer division behavior in C. """ - return _ffi_api.max(a, b, span) # type: ignore + return _ffi_api.max(a, b) # type: ignore diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py index 386dc974d2cd..53ac43eb1ab3 100644 --- a/python/tvm/script/builder/tir/prim_func_frame.py +++ b/python/tvm/script/builder/tir/prim_func_frame.py @@ -70,7 +70,6 @@ def match_buffer( offset_factor=0, buffer_type="default", axis_separators=None, - span=None, ) -> Buffer: return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore param, @@ -84,7 +83,6 @@ def match_buffer( offset_factor, buffer_type, axis_separators, - span, ) @@ -100,7 +98,6 @@ def preflattened_buffer( offset_factor=0, buffer_type="default", axis_separators=None, - span=None, ) -> None: _ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore postflattened, @@ -114,5 +111,4 @@ def preflattened_buffer( offset_factor, buffer_type, axis_separators, - span, ) diff --git a/src/script/builder/frame.cc b/src/script/builder/frame.cc index ab2bf7774e27..fc1991fa3fb8 100644 --- a/src/script/builder/frame.cc +++ b/src/script/builder/frame.cc @@ -34,27 +34,7 @@ void FrameNode::ExitWithScope() { Builder::Current()->frames.pop_back(); } -IRModuleFrame::IRModuleFrame() { - ObjectPtr n = make_object(); - n->global_vars.clear(); - n->functions.clear(); - data_ = std::move(n); -} - -void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); - Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); - } - Builder builder = Builder::Current(); - ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); -} - TVM_REGISTER_NODE_TYPE(FrameNode); -TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); TVM_REGISTER_GLOBAL("script.builder.FrameEnter").set_body_method(&FrameNode::EnterWithScope); TVM_REGISTER_GLOBAL("script.builder.FrameExit").set_body_method(&FrameNode::ExitWithScope); diff --git a/src/script/builder/frame.h b/src/script/builder/frame.h index e3465a9e30b8..b8ee1c487642 100644 --- a/src/script/builder/frame.h +++ b/src/script/builder/frame.h @@ -56,30 +56,6 @@ class Frame : public runtime::ObjectRef { inline void ExitWithScope(); }; -class IRModuleFrameNode : public FrameNode { - public: - Array global_vars; - Array functions; - - void VisitAttrs(tvm::AttrVisitor* v) { - FrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); - v->Visit("functions", &functions); - } - - static constexpr const char* _type_key = "script.builder.IRModuleFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, FrameNode); - - public: - void ExitWithScope() final; -}; - -class IRModuleFrame : public Frame { - public: - IRModuleFrame(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, Frame, IRModuleFrameNode); -}; - inline void Frame::EnterWithScope() { ICHECK(data_ != nullptr); static_cast(data_.get())->EnterWithScope(); diff --git a/src/script/builder/ir/ir.cc b/src/script/builder/ir/ir.cc new file mode 100644 index 000000000000..38d8eb1098bc --- /dev/null +++ b/src/script/builder/ir/ir.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./ir.h" + +#include "../builder.h" + +namespace tvm { +namespace script { +namespace builder { +namespace ir { + +IRModuleFrame::IRModuleFrame() { + ObjectPtr n = make_object(); + n->global_vars.clear(); + n->functions.clear(); + data_ = std::move(n); +} + +void IRModuleFrameNode::ExitWithScope() { + ICHECK_EQ(functions.size(), global_vars.size()); + int n = functions.size(); + Map func_map; + for (int i = 0; i < n; ++i) { + func_map.Set(global_vars[i], functions[i]); + } + Builder builder = Builder::Current(); + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = tvm::IRModule(func_map); +} + +TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); + +} // namespace ir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/ir/ir.h b/src/script/builder/ir/ir.h new file mode 100644 index 000000000000..890a06e3c76a --- /dev/null +++ b/src/script/builder/ir/ir.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_IR_IR_H_ +#define TVM_SCRIPT_BUILDER_IR_IR_H_ + +#include "../frame.h" + +namespace tvm { +namespace script { +namespace builder { +namespace ir { + +class IRModuleFrameNode : public FrameNode { + public: + Array global_vars; + Array functions; + + void VisitAttrs(tvm::AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("global_vars", &global_vars); + v->Visit("functions", &functions); + } + + static constexpr const char* _type_key = "script.builder.ir.IRModuleFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, FrameNode); + + public: + void ExitWithScope() final; +}; + +class IRModuleFrame : public Frame { + public: + IRModuleFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, Frame, IRModuleFrameNode); +}; + +IRModuleFrame ir_module(); + +} // namespace ir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_IR_IR_H_ diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc index 6566809d6eba..9fa0cce6fa0c 100644 --- a/src/script/builder/tir/block_frame.cc +++ b/src/script/builder/tir/block_frame.cc @@ -30,7 +30,7 @@ namespace script { namespace builder { namespace tir { -BlockFrame Block_(String name, bool no_realize) { +BlockFrame Block(String name, bool no_realize) { ObjectPtr n = make_object(); n->name = name; n->iter_vars.clear(); @@ -49,8 +49,8 @@ BlockFrame Block_(String name, bool no_realize) { void BlockFrameNode::ExitWithScope() { using namespace tvm::tir; TIRFrameNode::ExitWithScope(); - Block block = Block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers, - match_buffers, annotations); + tvm::tir::Block block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers, + match_buffers, annotations); if (no_realize) { CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`"; @@ -145,10 +145,10 @@ void BlockAttrs(Map attrs) { tvm::tir::Buffer AllocBuffer(Array shape, DataType dtype, Optional data, Array strides, PrimExpr elem_offset, String storage_scope, int align, int offset_factor, String buffer_type_str, - Array axis_separators, Span span) { + Array axis_separators) { using namespace tvm::tir; - Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope, align, - offset_factor, buffer_type_str, axis_separators, span); + tvm::tir::Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope, + align, offset_factor, buffer_type_str, axis_separators); BlockFrame frame = FindBlockFrame("T.alloc_buffer"); frame->alloc_buffers.push_back(buffer); return buffer; @@ -236,7 +236,7 @@ Array Remap(String kinds, Array bindings, DataType TVM_REGISTER_NODE_TYPE(BlockFrameNode); TVM_REGISTER_NODE_TYPE(BlockInitFrameNode); -TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block_); +TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block); TVM_REGISTER_GLOBAL("script.builder.tir.BlockInitFrame").set_body_typed(Init); TVM_REGISTER_GLOBAL("script.builder.tir.Where").set_body_typed(Where); TVM_REGISTER_GLOBAL("script.builder.tir.Reads").set_body_typed(Reads); diff --git a/src/script/builder/tir/block_frame.h b/src/script/builder/tir/block_frame.h index 7137a9a2bfca..c69cc45f5328 100644 --- a/src/script/builder/tir/block_frame.h +++ b/src/script/builder/tir/block_frame.h @@ -68,7 +68,7 @@ class BlockFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); }; -BlockFrame Block_(String name, bool no_realize = false); +BlockFrame Block(String name, bool no_realize = false); class BlockInitFrameNode : public TIRFrameNode { public: @@ -97,8 +97,8 @@ tvm::tir::Buffer AllocBuffer(Array shape, DataType dtype = DataType::F Optional data = NullOpt, Array strides = {}, PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1, int offset_factor = 0, - String buffer_type_str = "default", Array axis_separators = {}, - Span span = Span()); + String buffer_type_str = "default", + Array axis_separators = {}); namespace axis { tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); diff --git a/src/script/builder/tir/op.cc b/src/script/builder/tir/op.cc index 6bd5a45f7bbf..d10bbcad8126 100644 --- a/src/script/builder/tir/op.cc +++ b/src/script/builder/tir/op.cc @@ -23,17 +23,13 @@ namespace script { namespace builder { namespace tir { -PrimExpr prim_type(String type_name, PrimExpr expr) { - return cast(DataType(runtime::String2DLDataType(type_name)), expr); -} - -TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(prim_type); -TVM_REGISTER_GLOBAL("script.builder.tir.Handle").set_body_typed(handle); -TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return tvm::min(a, b, span); +TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(PrimType); +TVM_REGISTER_GLOBAL("script.builder.tir.Handle").set_body_typed(Handle); +TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b) { + return tvm::min(a, b); }); -TVM_REGISTER_GLOBAL("script.builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return tvm::max(a, b, span); +TVM_REGISTER_GLOBAL("script.builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b) { + return tvm::max(a, b); }); } // namespace tir diff --git a/src/script/builder/tir/op.h b/src/script/builder/tir/op.h index 9eb8629f73ab..b3d41d79c6b5 100644 --- a/src/script/builder/tir/op.h +++ b/src/script/builder/tir/op.h @@ -29,26 +29,21 @@ namespace script { namespace builder { namespace tir { -PrimExpr int8(PrimExpr expr) { return cast(DataType::Int(8), expr); } -PrimExpr int16(PrimExpr expr) { return cast(DataType::Int(16), expr); } -PrimExpr int32(PrimExpr expr) { return cast(DataType::Int(32), expr); } -PrimExpr int64(PrimExpr expr) { return cast(DataType::Int(64), expr); } - -PrimExpr uint8(PrimExpr expr) { return cast(DataType::UInt(8), expr); } -PrimExpr uint16(PrimExpr expr) { return cast(DataType::UInt(16), expr); } -PrimExpr uint32(PrimExpr expr) { return cast(DataType::UInt(32), expr); } -PrimExpr uint64(PrimExpr expr) { return cast(DataType::UInt(64), expr); } - -PrimExpr float8(PrimExpr expr) { return cast(DataType::Float(8), expr); } -PrimExpr float16(PrimExpr expr) { return cast(DataType::Float(16), expr); } -PrimExpr float32(PrimExpr expr) { return cast(DataType::Float(32), expr); } -PrimExpr float64(PrimExpr expr) { return cast(DataType::Float(64), expr); } - -PrimExpr bool_(PrimExpr expr) { return cast(DataType::Bool(), expr); } - -PrimExpr prim_type(String type_name, PrimExpr expr); - -tvm::tir::Var handle() { return tvm::tir::Var("", DataType::Handle()); } +inline PrimExpr Int8(PrimExpr expr) { return tvm::cast(DataType::Int(8), expr); } +inline PrimExpr Int16(PrimExpr expr) { return tvm::cast(DataType::Int(16), expr); } +inline PrimExpr Int32(PrimExpr expr) { return tvm::cast(DataType::Int(32), expr); } +inline PrimExpr Int64(PrimExpr expr) { return tvm::cast(DataType::Int(64), expr); } +inline PrimExpr Uint8(PrimExpr expr) { return tvm::cast(DataType::UInt(8), expr); } +inline PrimExpr Uint16(PrimExpr expr) { return tvm::cast(DataType::UInt(16), expr); } +inline PrimExpr Uint32(PrimExpr expr) { return tvm::cast(DataType::UInt(32), expr); } +inline PrimExpr Uint64(PrimExpr expr) { return tvm::cast(DataType::UInt(64), expr); } +inline PrimExpr Float8(PrimExpr expr) { return tvm::cast(DataType::Float(8), expr); } +inline PrimExpr Float16(PrimExpr expr) { return tvm::cast(DataType::Float(16), expr); } +inline PrimExpr Float32(PrimExpr expr) { return tvm::cast(DataType::Float(32), expr); } +inline PrimExpr Float64(PrimExpr expr) { return tvm::cast(DataType::Float(64), expr); } +inline PrimExpr Bool(PrimExpr expr) { return tvm::cast(DataType::Bool(), expr); } +inline tvm::tir::Var Handle() { return tvm::tir::Var("", DataType::Handle()); } +inline PrimExpr PrimType(DataType dtype, PrimExpr expr) { return tvm::cast(dtype, expr); } using tvm::cast; using tvm::if_then_else; diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index beebf23f6a98..595b037765c2 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -21,6 +21,7 @@ #include +#include "../ir/ir.h" #include "./block_frame.h" #include "./var.h" @@ -29,6 +30,22 @@ namespace script { namespace builder { namespace tir { +PrimFuncFrame FindPrimFuncFrame(const String& method) { + Builder builder = Builder::Current(); + if (Optional prim_func_frame = builder->FindFrame()) { + if (Optional block_frame = builder->GetLastFrame()) { + if (prim_func_frame.value()->root_block_frame.get() == block_frame.get()) { + return prim_func_frame.value(); + } + } + } else { + LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method + << "' is called under T.prim_func()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under T.prim_func()"; + throw; +} + void PrimFuncFrameNode::EnterWithScope() { TIRFrameNode::EnterWithScope(); // add implicit root block @@ -37,6 +54,7 @@ void PrimFuncFrameNode::EnterWithScope() { void PrimFuncFrameNode::ExitWithScope() { using namespace tvm::tir; + using ir::IRModuleFrame; root_block_frame->ExitWithScope(); TIRFrameNode::ExitWithScope(); Builder builder = Builder::Current(); @@ -44,18 +62,18 @@ void PrimFuncFrameNode::ExitWithScope() { LOG(FATAL) << "ValueError: PrimFuncFrame shoulde have one and only one root block."; } BlockRealize root_block_realize = Downcast(stmts[0]); - Block root_block = root_block_realize->block; + tvm::tir::Block root_block = root_block_realize->block; // remove redundant implicit root block if (root_block->alloc_buffers.empty() && root_block->body->IsInstance() && root_block->annotations.empty() && root_block->reads.empty() && root_block->writes.empty()) { stmts.Set(0, root_block->body); } - PrimFunc func(/*params=*/args, - /*body=*/AsStmt(stmts), - /*ret_type=*/ret_type.value_or(TupleType::Empty()), - /*buffer_map=*/buffer_map, - /*preflattened_buffer_map=*/preflattened_buffer_map, - /*attrs=*/DictAttrs(attrs)); + tvm::tir::PrimFunc func(/*params=*/args, + /*body=*/AsStmt(stmts), + /*ret_type=*/ret_type.value_or(TupleType::Empty()), + /*buffer_map=*/buffer_map, + /*preflattened_buffer_map=*/preflattened_buffer_map, + /*attrs=*/DictAttrs(attrs)); if (builder->frames.empty()) { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; @@ -68,7 +86,7 @@ void PrimFuncFrameNode::ExitWithScope() { } } -PrimFuncFrame PrimFunc_() { +PrimFuncFrame PrimFunc() { ObjectPtr n = make_object(); n->name = NullOpt; n->args.clear(); @@ -76,26 +94,10 @@ PrimFuncFrame PrimFunc_() { n->buffer_map.clear(); n->preflattened_buffer_map.clear(); n->attrs.clear(); - n->root_block_frame = Block_("root"); + n->root_block_frame = Block("root"); return PrimFuncFrame(n); } -PrimFuncFrame FindPrimFuncFrame(const String& method) { - Builder builder = Builder::Current(); - if (Optional prim_func_frame = builder->FindFrame()) { - if (Optional block_frame = builder->GetLastFrame()) { - if (prim_func_frame.value()->root_block_frame.get() == block_frame.get()) { - return prim_func_frame.value(); - } - } - } else { - LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method - << "' is called under T.prim_func()"; - } - LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under T.prim_func()"; - throw; -} - tvm::tir::Var Arg(String name, tvm::tir::Var var) { PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); Namer::Name(var, name); @@ -144,10 +146,10 @@ tvm::tir::Buffer MatchBuffer(ObjectRef param, Array shape, DataType dt Optional data, Array strides, PrimExpr elem_offset, String storage_scope, int align, int offset_factor, String buffer_type_str, - Array axis_separators, Span span) { + Array axis_separators) { using namespace tvm::tir; - Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope, align, - offset_factor, buffer_type_str, axis_separators, span); + tvm::tir::Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope, + align, offset_factor, buffer_type_str, axis_separators); if (const auto* var = param.as()) { PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); Var v = GetRef(var); @@ -170,15 +172,15 @@ tvm::tir::Buffer MatchBuffer(ObjectRef param, Array shape, DataType dt void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array shape, DataType dtype, Optional data, Array strides, PrimExpr elem_offset, String storage_scope, int align, int offset_factor, - String buffer_type_str, Array axis_separators, Span span) { + String buffer_type_str, Array axis_separators) { using namespace tvm::tir; PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer"); for (auto const& p : frame->buffer_map) { if (p.second.same_as(postflattened_buffer)) { String buffer_name(postflattened_buffer->name + "_preflatten"); - Buffer buffer = + tvm::tir::Buffer buffer = DeclBuffer(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, align, - offset_factor, buffer_type_str, axis_separators, span); + offset_factor, buffer_type_str, axis_separators); Namer::Name(buffer, buffer_name); frame->preflattened_buffer_map.Set(p.first, buffer); return; @@ -189,15 +191,15 @@ void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array s }; TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); -TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc_); +TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc); TVM_REGISTER_GLOBAL("script.builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; if (const auto* var = obj.as()) { - return Arg(name, GetRef(var)); + return Arg(name, GetRef(var)); } if (const auto* buffer = obj.as()) { - return Arg(name, GetRef(buffer)); + return Arg(name, GetRef(buffer)); } LOG(FATAL) << "ValueError: Unexpected type for TIR Arg."; throw; diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h index a696b4ef84d1..de5b84cd152e 100644 --- a/src/script/builder/tir/prim_func_frame.h +++ b/src/script/builder/tir/prim_func_frame.h @@ -61,8 +61,7 @@ class PrimFuncFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; -PrimFuncFrame PrimFunc_(); -PrimFuncFrame FindPrimFuncFrame(const String& method); +PrimFuncFrame PrimFunc(); tvm::tir::Var Arg(String name, tvm::tir::Var var); tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer); void FuncName(String name); @@ -74,15 +73,15 @@ tvm::tir::Buffer MatchBuffer(ObjectRef param, Array shape, Optional data = NullOpt, Array strides = {}, PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1, int offset_factor = 0, - String buffer_type_str = "default", Array axis_separators = {}, - Span span = Span()); + String buffer_type_str = "default", + Array axis_separators = {}); void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array shape, DataType dtype = DataType::Float(32), Optional data = NullOpt, Array strides = {}, PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1, int offset_factor = 0, String buffer_type_str = "default", - Array axis_separators = {}, Span span = Span()); + Array axis_separators = {}); } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/stmt.cc b/src/script/builder/tir/stmt.cc index 60e788109e46..dd9ac3faa948 100644 --- a/src/script/builder/tir/stmt.cc +++ b/src/script/builder/tir/stmt.cc @@ -154,8 +154,8 @@ AllocateFrame Allocate(Array extents, DataType dtype, String storage_s n->condition = tvm::cast(DataType::Bool(), n->condition); } n->annotations = annotations.value_or(Map()); - n->buffer = DeclBuffer(extents, dtype, "", NullOpt, {}, PrimExpr(), storage_scope, 0, 0, - "default", {}, Span()); + n->buffer = + DeclBuffer(extents, dtype, "", NullOpt, {}, PrimExpr(), storage_scope, 0, 0, "default", {}); return AllocateFrame(n); } @@ -165,8 +165,7 @@ AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, n->dtype = dtype; n->extents = extents; n->data = data; - n->buffer = - DeclBuffer(extents, dtype, "", NullOpt, {}, PrimExpr(), "", 0, 0, "default", {}, Span()); + n->buffer = DeclBuffer(extents, dtype, "", NullOpt, {}, PrimExpr(), "", 0, 0, "default", {}); return AllocateConstFrame(n); } diff --git a/src/script/builder/tir/var.cc b/src/script/builder/tir/var.cc index e91bc20d37fe..1bee419615e5 100644 --- a/src/script/builder/tir/var.cc +++ b/src/script/builder/tir/var.cc @@ -23,7 +23,7 @@ namespace script { namespace builder { namespace tir { -tvm::tir::Buffer Buffer_(Array shape, DataType dtype, String name, String storage_scope) { +tvm::tir::Buffer Buffer(Array shape, DataType dtype, String name, String storage_scope) { return tvm::tir::decl_buffer(shape, dtype, name, storage_scope); } @@ -31,7 +31,7 @@ tvm::tir::Buffer DeclBuffer(Array shape, DataType dtype, String buffer Optional data, Array strides, PrimExpr elem_offset, String storage_scope, int align, int offset_factor, String buffer_type_str, - Array axis_separators, Span span) { + Array axis_separators) { using namespace tvm::tir; Var buffer_data; if (!data.defined()) { @@ -39,13 +39,13 @@ tvm::tir::Buffer DeclBuffer(Array shape, DataType dtype, String buffer if (storage_dtype == DataType::Bool()) { storage_dtype = DataType::Int(8); } - buffer_data = Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope), span); + buffer_data = Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope)); } else { buffer_data = data.value(); } BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault; - return Buffer(buffer_data, dtype, shape, strides, elem_offset, buffer_name, align, offset_factor, - buffer_type, axis_separators, span); + return tvm::tir::Buffer(buffer_data, dtype, shape, strides, elem_offset, buffer_name, align, + offset_factor, buffer_type, axis_separators); } TVM_STATIC_IR_FUNCTOR(Namer, vtable) @@ -84,7 +84,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) Namer::Name(var->var, name); }); -TVM_REGISTER_GLOBAL("script.builder.tir.Buffer").set_body_typed(Buffer_); +TVM_REGISTER_GLOBAL("script.builder.tir.Buffer").set_body_typed(Buffer); } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/var.h b/src/script/builder/tir/var.h index 433018c0037d..d312d3992554 100644 --- a/src/script/builder/tir/var.h +++ b/src/script/builder/tir/var.h @@ -26,16 +26,16 @@ namespace script { namespace builder { namespace tir { -tvm::tir::Buffer Buffer_(Array shape, // - DataType dtype = DataType::Float(32), // - String name = "buffer", // - String storage_scope = ""); +tvm::tir::Buffer Buffer(Array shape, // + DataType dtype = DataType::Float(32), // + String name = "buffer", // + String storage_scope = ""); tvm::tir::Buffer DeclBuffer(Array shape, DataType dtype, String buffer_name, Optional data, Array strides, PrimExpr elem_offset, String storage_scope, int align, int offset_factor, String buffer_type_str, - Array axis_separators, Span span); + Array axis_separators); } // namespace tir } // namespace builder