diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index bd2b2d8fb145f..26f8ef91969a6 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -50,8 +50,8 @@ COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh RUN bash /install/ubuntu_install_sccache.sh ENV PATH /opt/sccache:$PATH -COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh -RUN bash /install/ubuntu_install_llvm.sh +COPY install/ubuntu_install_llvm_from_source.sh /install/ubuntu_install_llvm_from_source.sh +RUN bash /install/ubuntu_install_llvm_from_source.sh 15.0.7 8b5fcb24b4128cf04df1b0b9410ce8b1a729cb3c544e6da885d234280dedeac6 ENV TVM_VENV /venv/apache-tvm-py3.7 COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles diff --git a/docker/install/ubuntu_install_llvm_from_source.sh b/docker/install/ubuntu_install_llvm_from_source.sh new file mode 100644 index 0000000000000..854e74a4d8245 --- /dev/null +++ b/docker/install/ubuntu_install_llvm_from_source.sh @@ -0,0 +1,100 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# This script builds LLVM and clang from the llvm-project tarball +# using CMake. It is tested with LLVM from version 15. + +set -e + +LLVM_VERSION=$1 +LLVM_FILE_SHA=$2 + +echo ${LLVM_VERSION} + +tmpdir=$(mktemp -d) + +cleanup() +{ + rm -rf "$tmpdir" +} + +trap cleanup 0 + +pushd "$tmpdir" + +curl -sL \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-${LLVM_VERSION}/llvm-project-${LLVM_VERSION}.src.tar.xz \ + -o llvm-project-${LLVM_VERSION}.src.tar.xz +echo "$LLVM_FILE_SHA llvm-project-${LLVM_VERSION}.src.tar.xz" | sha256sum --check +tar xf llvm-project-${LLVM_VERSION}.src.tar.xz +pushd llvm-project-${LLVM_VERSION}.src + +pushd llvm +mkdir build +pushd build +cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_MODULE_PATH="/llvm-project-${LLVM_VERSION}.src/cmake/Modules" \ + -DCMAKE_INSTALL_PREFIX=/usr \ + -DLLVM_TARGETS_TO_BUILD="AArch64;ARM;X86" \ + -DLLVM_INCLUDE_DOCS=OFF \ + -DLLVM_INCLUDE_EXAMPLES=OFF \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_UTILS=OFF \ + -DLLVM_ENABLE_TERMINFO=OFF \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_OCAMLDOC=OFF \ + -DLLVM_USE_INTEL_JITEVENTS=ON \ + -DLLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN=ON \ + -DPYTHON_EXECUTABLE="$(cpython_path 3.7)/bin/python" \ + -GNinja \ + .. +ninja install +popd +popd + +# clang is only used to precompile Gandiva bitcode +if [ ${LLVM_VERSION_MAJOR} -lt 9 ]; then + clang_package_name=cfe +else + clang_package_name=clang +fi + +pushd ${clang_package_name} +mkdir build +pushd build +cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/usr \ + -DCMAKE_MODULE_PATH="/llvm-project-${LLVM_VERSION}.src/cmake/Modules" \ + -DCLANG_INCLUDE_TESTS=OFF \ + -DCLANG_INCLUDE_DOCS=OFF \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_DOCS=OFF \ + -Wno-dev \ + -GNinja \ + .. + +ninja -w dupbuild=warn install # both clang and llvm builds generate llvm-config file +popd +popd + +# out of llvm-project-${LLVM_VERSION}.src +popd +popd diff --git a/docker/install/ubuntu_install_paddle.sh b/docker/install/ubuntu_install_paddle.sh index 386d0fa6e797e..6cbd6289a16b2 100755 --- a/docker/install/ubuntu_install_paddle.sh +++ b/docker/install/ubuntu_install_paddle.sh @@ -20,4 +20,4 @@ set -e set -u set -o pipefail -pip install paddlepaddle==2.4.1 +pip install paddlepaddle==2.4.2 diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index f547b5a707394..2cfb50e1a2839 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -244,7 +244,7 @@ class ReflectionVTable::Registry { * static constexpr const std::nullptr_t VisitAttrs = nullptr; * * static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { - * hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); + * hash_reduce->SHashReduceHashedValue(runtime::String::StableHashBytes(key->data, key->size)); * } * * static bool SEqualReduce(const runtime::StringObj* lhs, diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 8b8a403326c4c..4d822e68d3d03 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -36,52 +36,60 @@ namespace tvm { * \brief Hash definition of base value classes. */ class BaseValueHash { - public: - size_t operator()(const double& key) const { return std::hash()(key); } - - size_t operator()(const int64_t& key) const { return std::hash()(key); } - - size_t operator()(const uint64_t& key) const { return std::hash()(key); } - - size_t operator()(const int& key) const { return std::hash()(key); } - - size_t operator()(const bool& key) const { return std::hash()(key); } - - size_t operator()(const std::string& key) const { return std::hash()(key); } - - size_t operator()(const runtime::DataType& key) const { - return std::hash()(static_cast(key.code()) | - (static_cast(key.bits()) << 8) | - (static_cast(key.lanes()) << 16)); + protected: + template + uint64_t Reinterpret(T value) const { + union Union { + T a; + U b; + } u; + static_assert(sizeof(Union) == sizeof(T), "sizeof(Union) != sizeof(T)"); + static_assert(sizeof(Union) == sizeof(U), "sizeof(Union) != sizeof(U)"); + u.b = 0; + u.a = value; + return u.b; } + public: + uint64_t operator()(const float& key) const { return Reinterpret(key); } + uint64_t operator()(const double& key) const { return Reinterpret(key); } + uint64_t operator()(const int64_t& key) const { return Reinterpret(key); } + uint64_t operator()(const uint64_t& key) const { return key; } + uint64_t operator()(const int& key) const { return Reinterpret(key); } + uint64_t operator()(const bool& key) const { return key; } + uint64_t operator()(const runtime::DataType& key) const { + return Reinterpret(key); + } template ::value>::type> - bool operator()(const ENum& key) const { - return std::hash()(static_cast(key)); + uint64_t operator()(const ENum& key) const { + return Reinterpret(static_cast(key)); + } + uint64_t operator()(const std::string& key) const { + return runtime::String::StableHashBytes(key.data(), key.length()); } }; /*! - * \brief Content-aware structural hasing. + * \brief Content-aware structural hashing. * * The structural hash value is recursively defined in the DAG of IRNodes. * There are two kinds of nodes: * * - Normal node: the hash value is defined by its content and type only. * - Graph node: each graph node will be assigned a unique index ordered by the - * first occurence during the visit. The hash value of a graph node is + * first occurrence during the visit. The hash value of a graph node is * combined from the hash values of its contents and the index. */ class StructuralHash : public BaseValueHash { public: - // inheritate operator() + // inherit operator() using BaseValueHash::operator(); /*! * \brief Compute structural hashing value for an object. * \param key The left operand. * \return The hash value. */ - TVM_DLL size_t operator()(const ObjectRef& key) const; + TVM_DLL uint64_t operator()(const ObjectRef& key) const; }; /*! @@ -109,23 +117,23 @@ class SHashReducer { * * \param hashed_value The hashed value */ - virtual void SHashReduceHashedValue(size_t hashed_value) = 0; + virtual void SHashReduceHashedValue(uint64_t hashed_value) = 0; /*! * \brief Append hash value of key to the current sequence of hashes. * * \param key The object to compute hash from. - * \param map_free_vars Whether to map free variables by their occurence number. + * \param map_free_vars Whether to map free variables by their occurrence number. */ virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0; /*! - * \brief Apppend a hash value of free variable to the current sequence of hashes. + * \brief Append a hash value of free variable to the current sequence of hashes. * * \param var The var of interest. - * \param map_free_vars Whether to map free variables by their occurence number. + * \param map_free_vars Whether to map free variables by their occurrence number. * * \note If map_free_vars is set to be true, * internally the handler can maintain a counter to encode free variables - * by their order of occurence. This helps to resolve variable + * by their order of occurrence. This helps to resolve variable * mapping of function parameters and let binding variables. * * If map_free_vars is set to be false, the address of the variable will be used. @@ -139,7 +147,7 @@ class SHashReducer { * * \return Whether there is already a pre-computed hash value. */ - virtual bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) = 0; + virtual bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) = 0; /*! * \brief Mark current comparison as graph node in hashing. * Graph node hash will depends on the graph structure. @@ -193,7 +201,7 @@ class SHashReducer { /*! \brief Internal class pointer. */ Handler* handler_; /*! - * \brief Whether or not to map free variables by their occurence + * \brief Whether or not to map free variables by their occurrence * If the flag is false, then free variables will be mapped * by their in-memory address. */ @@ -210,10 +218,10 @@ class SHashHandlerDefault : public SHashReducer::Handler { SHashHandlerDefault(); virtual ~SHashHandlerDefault(); - void SHashReduceHashedValue(size_t hashed_value) override; + void SHashReduceHashedValue(uint64_t hashed_value) override; void SHashReduce(const ObjectRef& key, bool map_free_vars) override; void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) override; - bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) override; + bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) override; void MarkGraphNode() override; /*! @@ -222,7 +230,7 @@ class SHashHandlerDefault : public SHashReducer::Handler { * \param map_free_vars Whether or not to remap variables if possible. * \return The hash result. */ - virtual size_t Hash(const ObjectRef& object, bool map_free_vars); + virtual uint64_t Hash(const ObjectRef& object, bool map_free_vars); protected: /*! diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h index 5ecd89e9f56d0..c6382506b355e 100644 --- a/include/tvm/runtime/container/string.h +++ b/include/tvm/runtime/container/string.h @@ -24,6 +24,7 @@ #ifndef TVM_RUNTIME_CONTAINER_STRING_H_ #define TVM_RUNTIME_CONTAINER_STRING_H_ +#include #include #include #include @@ -247,10 +248,70 @@ class String : public ObjectRef { * \param size The size of the bytes. * \return the hash value. */ - static size_t HashBytes(const char* data, size_t size) { - // This function falls back to string copy with c++11 compiler and is - // recommended to be compiled with c++14 - return std::hash()(std::string_view(data, size)); + static uint64_t StableHashBytes(const char* data, size_t size) { + const constexpr uint64_t kMultiplier = 1099511628211ULL; + const constexpr uint64_t kMod = 2147483647ULL; + union Union { + uint8_t a[8]; + uint64_t b; + } u; + static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); + const char* it = data; + const char* end = it + size; + uint64_t result = 0; + for (; it + 8 <= end; it += 8) { + if (DMLC_IO_NO_ENDIAN_SWAP) { + u.a[0] = it[0]; + u.a[1] = it[1]; + u.a[2] = it[2]; + u.a[3] = it[3]; + u.a[4] = it[4]; + u.a[5] = it[5]; + u.a[6] = it[6]; + u.a[7] = it[7]; + } else { + u.a[0] = it[7]; + u.a[1] = it[6]; + u.a[2] = it[5]; + u.a[3] = it[4]; + u.a[4] = it[3]; + u.a[5] = it[2]; + u.a[6] = it[1]; + u.a[7] = it[0]; + } + result = (result * kMultiplier + u.b) % kMod; + } + if (it < end) { + u.b = 0; + uint8_t* a = u.a; + if (it + 4 <= end) { + a[0] = it[0]; + a[1] = it[1]; + a[2] = it[2]; + a[3] = it[3]; + it += 4; + a += 4; + } + if (it + 2 <= end) { + a[0] = it[0]; + a[1] = it[1]; + it += 2; + a += 2; + } + if (it + 1 <= end) { + a[0] = it[0]; + it += 1; + a += 1; + } + if (!DMLC_IO_NO_ENDIAN_SWAP) { + std::swap(u.a[0], u.a[7]); + std::swap(u.a[1], u.a[6]); + std::swap(u.a[2], u.a[5]); + std::swap(u.a[3], u.a[4]); + } + result = (result * kMultiplier + u.b) % kMod; + } + return result; } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); @@ -448,7 +509,7 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, s inline size_t ObjectHash::operator()(const ObjectRef& a) const { if (const auto* str = a.as()) { - return String::HashBytes(str->data, str->size); + return String::StableHashBytes(str->data, str->size); } return ObjectPtrHash()(a); } @@ -476,7 +537,7 @@ namespace std { template <> struct hash<::tvm::runtime::String> { std::size_t operator()(const ::tvm::runtime::String& str) const { - return ::tvm::runtime::String::HashBytes(str.data(), str.size()); + return ::tvm::runtime::String::StableHashBytes(str.data(), str.size()); } }; } // namespace std diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index d5cc1de5c675d..a0343b03955b4 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -283,11 +283,15 @@ AssertFrame Assert(PrimExpr condition, String message); /*! * \brief The let binding. - * \param var The variable to bind. * \param value The value to be bound. + * \param type_annotation The type annotation of the let binding. + * Usually it is used for fine-grained var typing, + * particularly, PointerType. + * \param var The variable to be bound. If not specified, a new variable will be created. * \return The created LetFrame. */ -LetFrame Let(Var var, PrimExpr value); +LetFrame LetStmt(PrimExpr value, Optional type_annotation = NullOpt, + Optional var = NullOpt); /*! * \brief The realization. @@ -386,6 +390,14 @@ DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_ */ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); +/*! + * \brief Launch a new thread. + * \param thread_tag The thread type tag. + * \param extent The extent of environment thread. + * \return The result LaunchThreadFrame. + */ +LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent); + /*! * \brief Bind a var to thread env. * \param thread_tag The thread type tag. @@ -418,14 +430,27 @@ void Evaluate(PrimExpr value); * \brief Create a TIR var that represents a pointer * \param dtype The data type of the pointer. * \param storage_scope The storage scope of the pointer. + * \param is_size_var Whether the pointer is a size var. * \return The pointer. */ -Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String storage_scope = "global"); - -#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ - inline PrimExpr FuncName(Optional expr = NullOpt) { \ - DataType dtype = DType; \ - return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \ +inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), // + String storage_scope = "global", // + bool is_size_var = false) { + Type type_annotation{nullptr}; + if (dtype.is_void() && storage_scope == "global") { + type_annotation = PrimType(runtime::DataType::Handle()); + } else { + type_annotation = PointerType(PrimType(dtype), storage_scope); + } + return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); +} + +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(Optional expr = NullOpt, bool is_size_var = false) { \ + DataType dtype = DType; \ + return expr.defined() \ + ? tvm::cast(dtype, expr.value()) \ + : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ } #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index ec8e32526abb5..5bac25faa5fb5 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -94,7 +94,7 @@ TVM_DLL double EstimateTIRFlops(const IRModule& mod); /*! * \brief Find undefined vars in the statement. - * \param stmt The function to be checked. + * \param stmt The statement to be checked. * \param defs The vars that is defined. * \return Array of undefined vars. */ @@ -107,6 +107,14 @@ TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); */ TVM_DLL Array UndefinedVars(const PrimExpr& expr); +/*! + * \brief Find undefined vars in the expression. + * \param expr The expression to be checked. + * \param defs The vars that is defined. + * \return Array of undefined vars. + */ +TVM_DLL Array UndefinedVars(const PrimExpr& expr, const Array& defs); + /*! * \brief Analyze the side effect * \param expr The expression to be checked. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 7f2bdf6b4ebbc..22febfdfedecc 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -405,6 +405,34 @@ class ScheduleNode : public runtime::Object { virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) = 0; + /*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the indices to access allocated cache buffer is customized by user. + * \param block_rv The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ + virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) = 0; + /*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block who writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the indices to access allocated cache buffer is customized by user. + * \param block_rv The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ + virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the the target block both read & write the target buffer. diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 0dadd3dc712ef..52827f706a568 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -152,6 +152,13 @@ class SizeVar : public Var { */ TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32), Span span = Span()); + /*! + * \brief Constructor which provides a more detailed type annotation. + * \param name_hint variable name. + * \param type_annotation The type annotation. + * \param span The location of this object in the source code. + */ + TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span()); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index b7766efb47969..bed829ef6b290 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """ Provides support to auto-tuning networks using AutoTVM. """ @@ -39,7 +40,7 @@ from .model import TVMCModel from .target import target_from_cli, generate_target_args, reconstruct_target_args from .shape_parser import parse_shape_string -from .transform import convert_graph_layout +from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms # pylint: disable=invalid-name @@ -127,12 +128,7 @@ def add_tune_parser(subparsers, _, json_params): metavar="PATH", help="path to an auto-tuning log file by AutoTVM.", ) - parser.add_argument( - "--desired-layout", - choices=["NCHW", "NHWC"], - default=None, - help="change the data layout of the whole graph", - ) + generate_transform_args(parser) parser.add_argument( "--enable-autoscheduler", help="enable tuning the graph through the AutoScheduler tuner", @@ -269,6 +265,8 @@ def drive_tune(args): rpc_hostname = None rpc_port = None + transform_args = parse_graph_transform_args(args) + tune_model( tvmc_model, args.target, @@ -283,7 +281,6 @@ def drive_tune(args): tuner=args.tuner, min_repeat_ms=args.min_repeat_ms, early_stopping=args.early_stopping, - desired_layout=args.desired_layout, timeout=args.timeout, repeat=args.repeat, number=args.number, @@ -292,6 +289,7 @@ def drive_tune(args): include_simple_tasks=args.include_simple_tasks, log_estimated_latency=args.log_estimated_latency, additional_target_options=reconstruct_target_args(args), + **transform_args, ) @@ -309,7 +307,6 @@ def tune_model( tuner: str = "xgb", min_repeat_ms: Optional[int] = None, early_stopping: Optional[int] = None, - desired_layout: Optional[str] = None, timeout: int = 10, repeat: int = 1, number: int = 10, @@ -318,6 +315,12 @@ def tune_model( include_simple_tasks: bool = False, log_estimated_latency: bool = False, additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, + desired_layout: Optional[str] = None, + desired_layout_ops: Optional[List[str]] = None, + mixed_precision: bool = False, + mixed_precision_ops: Optional[List[str]] = None, + mixed_precision_calculation_type: Optional[str] = None, + mixed_precision_acc_type: Optional[str] = None, ): """Use tuning to automatically optimize the functions in a model. @@ -354,10 +357,6 @@ def tune_model( Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other targets. early_stopping : int, optional When specified, stop tuning after this number of trials if results aren't improving. - desired_layout : str, optional - Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph - will have their layout set to this format. Tasks will then be tuned using this - specified layout. timeout : int, optional, If a kernel trial lasts longer than this duration in seconds, it will be considered a failure. @@ -376,12 +375,28 @@ def tune_model( If using the autoscheduler, write the estimated latency at each step of tuning to file. additional_target_options: Optional[Dict[str, Dict[str, Any]]] Additional target options in a dictionary to combine with initial Target arguments + desired_layout: str, optional + Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph + will have their layout set to this format. Tasks will then be tuned using this + specified layout. + desired_layout_ops: list[str], optional + The list of operators to be transformed with desired layout. + mixed_precision: bool + To enable mixed precision transformation. + mixed_precision_ops: list[str], optional + The list of operators to be converted to mixed precision. + mixed_precision_calculation_type: str + The calculation dtype to be used while mixed precision. + mixed_precision_acc_type: str + The accumulation data type to be used while mixed precision. + Returns ------- tuning_records : str The path to the produced tuning log file. """ + transform_args = parse_graph_transform_args(locals()) target, extra_targets = target_from_cli(target, additional_target_options) target, target_host = Target.canon_target_and_host(target, target_host) # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source @@ -453,7 +468,7 @@ def tune_model( mod=mod, params=params, target=target, - alter_layout=desired_layout, + transform_args=transform_args, hardware_params=hardware_params, include_simple_tasks=include_simple_tasks, ) @@ -475,7 +490,7 @@ def tune_model( mod=mod, params=params, target=target, - alter_layout=desired_layout, + transform_args=transform_args, ) # In autotvm, trials is specified per task. We can convert the per-model input @@ -504,7 +519,7 @@ def autotvm_get_tuning_tasks( params: Dict[str, tvm.nd.NDArray], target: str, target_host: Optional[str] = None, - alter_layout: Optional[str] = None, + transform_args: Optional[Dict[str, Any]] = None, ): """Get the autotvm tuning tasks for a given relay module. @@ -518,10 +533,8 @@ def autotvm_get_tuning_tasks( The compilation target. target_host : str, optional The compilation target for the host. - alter_layout : str, optional - The layout to convert the graph to. Note, the convert layout - pass doesn't currently guarantee the whole of the graph will - be converted to the chosen layout. + transform_args: dict, optional + Graph transformation arguments that are applied to the relay module. Returns ------- @@ -530,8 +543,7 @@ def autotvm_get_tuning_tasks( """ target, target_host = Target.canon_target_and_host(target, target_host) - if alter_layout: - mod = convert_graph_layout(mod, alter_layout) + mod = apply_graph_transforms(mod, transform_args) tasks = autotvm.task.extract_from_program( mod["main"], @@ -547,7 +559,7 @@ def autoscheduler_get_tuning_tasks( params: Dict[str, tvm.nd.NDArray], target: str, target_host: Optional[str] = None, - alter_layout: Optional[str] = None, + transform_args: Optional[Dict[str, Any]] = None, hardware_params: Optional[HardwareParams] = None, include_simple_tasks: bool = False, ): @@ -563,10 +575,8 @@ def autoscheduler_get_tuning_tasks( The compilation target. target_host : str, optional The compilation target for the host. - alter_layout : str, optional - The layout to convert the graph to. Note, the convert layout - pass doesn't currently guarantee the whole of the graph will - be converted to the chosen layout. + transform_args: dict, optional + Graph transformation arguments that are applied to the relay module. hardware_params : Optional[HardwareParams] Hardware parameters used for the search tasks @@ -579,8 +589,7 @@ def autoscheduler_get_tuning_tasks( """ target, target_host = Target.canon_target_and_host(target, target_host) - if alter_layout: - mod = convert_graph_layout(mod, alter_layout) + mod = apply_graph_transforms(mod, transform_args) # Extract the tasks tasks, task_weights = auto_scheduler.extract_tasks( diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index eec80820cdb13..6e61e762ee212 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """ Provides support to compile networks both AOT and JIT. """ @@ -37,7 +38,7 @@ from .target import target_from_cli, generate_target_args, reconstruct_target_args from .pass_config import parse_configs from .pass_list import parse_pass_list_str -from .transform import convert_graph_layout +from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms from .shape_parser import parse_shape_string from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate @@ -61,17 +62,12 @@ def add_compile_parser(subparsers, _, json_params): default="", help="the cross compiler options to generate target libraries, e.g. '-mfpu=neon-vfpv4'.", ) - parser.add_argument( - "--desired-layout", - choices=["NCHW", "NHWC"], - default=None, - help="change the data layout of the whole graph.", - ) + generate_transform_args(parser) parser.add_argument( "--dump-code", metavar="FORMAT", default="", - help="comma separated list of formats to export the input model, e.g. 'asm,ll,relay'.", + help="comma separated list of formats to export the input model, e.g. 'asm,ll,tir,relay'.", ) parser.add_argument( "--model-format", @@ -177,6 +173,7 @@ def drive_compile(args): additional_targets = reconstruct_target_args(args) workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets) + transform_args = parse_graph_transform_args(args) compile_model( tvmc_model, @@ -191,7 +188,6 @@ def drive_compile(args): output_format=args.output_format, dump_code=dump_code, target_host=None, - desired_layout=args.desired_layout, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, mod_name=args.module_name, @@ -199,6 +195,7 @@ def drive_compile(args): workspace_pools=( workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets) ), + **transform_args, ) return 0 @@ -217,7 +214,6 @@ def compile_model( output_format: str = "so", dump_code: Optional[List[str]] = None, target_host: Optional[str] = None, - desired_layout: Optional[str] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, @@ -225,6 +221,12 @@ def compile_model( mod_name: Optional[str] = "default", workspace_pools: Optional[WorkspaceMemoryPools] = None, instruments: Optional[Sequence[PassInstrument]] = None, + desired_layout: Optional[str] = None, + desired_layout_ops: Optional[List[str]] = None, + mixed_precision: bool = False, + mixed_precision_ops: Optional[List[str]] = None, + mixed_precision_calculation_type: Optional[str] = None, + mixed_precision_acc_type: Optional[str] = None, ): """Compile a model from a supported framework into a TVM module. @@ -254,16 +256,12 @@ def compile_model( output_format : str What format to use when saving the function library. Must be one of "so" or "tar". When compiling for a remote device without a cross compiler, "tar" will likely work better. - dump_code : list, optional + dump_code : list[str], optional Dump the generated code for the specified source types, on - the requested target. + the requested target. Choose from: ["asm", "ll", "tir", "relay"]. target_host : str, optional The target of the host machine if host-side code needs to be generated. - desired_layout: str, optional - The layout to convert the graph to. Note, the convert layout - pass doesn't currently guarantee the whole of the graph will - be converted to the chosen layout. disabled_pass: str, optional Comma-separated list of passes which needs to be disabled during compilation @@ -281,6 +279,21 @@ def compile_model( compilation. instruments: Optional[Sequence[PassInstrument]] The list of pass instrument implementations. + desired_layout: str, optional + Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph + will have their layout set to this format. Tasks will then be tuned using this + specified layout. + desired_layout_ops: list[str], optional + The list of operators to be transformed with desired layout. + mixed_precision: bool + To enable mixed precision transformation. Disabled by default. + mixed_precision_ops: list[str], optional + The list of operators to be converted to mixed precision. + Set to ["nn.conv2d", "nn.dense"] by default + mixed_precision_calculation_type: str + The calculation dtype to be used while mixed precision. Set to "float16" by default. + mixed_precision_acc_type: str + The accumulation data type to be used while mixed precision. Set to "float16" by default. Returns ------- @@ -290,7 +303,15 @@ def compile_model( """ mod, params = tvmc_model.mod, tvmc_model.params + if dump_code is None: + dump_code = [] + if not isinstance(dump_code, list): + dump_code = [dump_code] + dumps = {} + config = parse_configs(pass_context_configs) + if "tir" in dump_code: + config, dumps = add_tir_to_dumps(config, dumps) tvm_target, extra_targets = target_from_cli(target, additional_target_options) tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host) @@ -310,8 +331,8 @@ def compile_model( disabled_pass=disabled_pass, instruments=instruments, ): - if desired_layout: - mod = convert_graph_layout(mod, desired_layout) + transform_args = parse_graph_transform_args(locals()) + mod = apply_graph_transforms(mod, transform_args) for partition_function, opts in zip(partition_functions, partition_opts): mod = partition_function(mod, params, mod_name=mod_name, **opts) @@ -366,20 +387,16 @@ def compile_model( ) # Generate output dump files with sources - if dump_code is None: - dump_code = [] - if not isinstance(dump_code, list): - dump_code = [dump_code] - dumps = {} for source_type in dump_code: - if use_vm: - lib = graph_module.lib + if source_type == "relay": + dumps[source_type] = str(mod) + elif source_type == "tir": + dumps[source_type] = "\n".join(dumps[source_type]) else: - lib = graph_module.get_lib() - # TODO lib.get_source call have inconsistent behavior for unsupported - # formats (@leandron). - source = str(mod) if source_type == "relay" else lib.get_source(source_type) - dumps[source_type] = source + lib = graph_module.lib if use_vm else graph_module.get_lib() + # TODO lib.get_source call have inconsistent behavior for unsupported + # formats (@leandron). + dumps[source_type] = lib.get_source(source_type) # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( @@ -440,6 +457,26 @@ def build( ) +def add_tir_to_dumps(config, dumps): + """ + Creates a debug pass that dumps TIR functions as a list of strings. + """ + key = "tir" + phase = 3 # final TIR phase before codegen + dumps[key] = [] + + @tvm.tir.transform.prim_func_pass(opt_level=0) + def _dump_tir_pass(tir_func, _, __): + dumps[key].append(str(tir_func)) + return tir_func + + tir_lower_passes = config.get("tir.add_lower_pass", []) + tir_lower_passes.append((phase, _dump_tir_pass)) + config["tir.add_lower_pass"] = tir_lower_passes + + return config, dumps + + def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): """ Serialize dump files to the disk. diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py index 8527c48b6b04d..2b34ba11b49fe 100644 --- a/python/tvm/driver/tvmc/transform.py +++ b/python/tvm/driver/tvmc/transform.py @@ -13,6 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language +# pylint: disable=unused-argument """ TVMC Graph Transforms """ @@ -21,7 +22,88 @@ from tvm.driver.tvmc import TVMCException -def convert_graph_layout(mod, desired_layout): +def generate_mixed_precision_rule(acc_dtype): + def _mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str): + return [ + relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS, + acc_dtype, + mixed_precision_type, + ] + + return _mixed_precision_rule + + +class MixedPrecision(object): + """Temporarily changes attr of ops to enable required precision.""" + + def __init__(self, ops, acc_type): + """Saves the required info for RAII pattern usage. + + Parameters + ---------- + ops : list + list of operators + acc_type: str + Output or accumulation precision to be used. + """ + self.older_attr = {} + self.ops = ops + self.acc_type = acc_type + self.attr_key = "FTVMMixedPrecisionConversionType" + + def __enter__(self): + for op_name in self.ops: + op = relay.op.get(op_name) + self.older_attr[op_name] = op.get_attr(self.attr_key) + op.reset_attr(self.attr_key) + op.set_attr(self.attr_key, generate_mixed_precision_rule(self.acc_type)) + return self + + def __exit__(self, ptype, value, trace): + for op_name in self.ops: + op = relay.op.get(op_name) + op.reset_attr(self.attr_key) + if self.older_attr[op_name]: + op.set_attr(self.attr_key, self.older_attr[op_name]) + + +def convert_to_mixed_precision(mod, ops=None, calculation_type="float16", acc_type="float16"): + """Converts the operator datatypes + + Parameters + ---------- + mod : tvm.IRModule + The relay module to convert. + ops : list + List of operators to be precision converted. + calculation_type: str + Input precision to be used. + acc_type: str + Output or accumulation precision to be used. + + Returns + ------- + mod : tvm.IRModule + The converted module. + """ + + if ops is None: + ops = ["nn.conv2d", "nn.dense"] + + with MixedPrecision(ops, acc_type): + seq = transform.Sequential( + [relay.transform.InferType(), relay.transform.ToMixedPrecision(calculation_type)] + ) + with transform.PassContext( + config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}, opt_level=3 + ): + try: + return seq(mod) + except Exception as err: + raise TVMCException("Error converting mixed precision : {0}".format(str(err))) + + +def convert_graph_layout(mod, desired_layout, ops=None): """Alter the layout of the input graph. Parameters @@ -30,20 +112,18 @@ def convert_graph_layout(mod, desired_layout): The relay module to convert. desired_layout : str The layout to convert to. + ops : list + List of operators to be layout converted. Returns ------- mod : tvm.IRModule The converted module. """ + if ops is None: + ops = ["nn.conv2d", "nn.conv2d_transpose", "qnn.conv2d"] - # Assume for the time being that graphs only have - # conv2d as heavily-sensitive operators. - desired_layouts = { - "nn.conv2d": [desired_layout, "default"], - "nn.conv2d_transpose": [desired_layout, "default"], - "qnn.conv2d": [desired_layout, "default"], - } + desired_layouts = {op: [desired_layout, "default"] for op in ops} # Convert the layout of the graph where possible. seq = transform.Sequential( @@ -58,3 +138,110 @@ def convert_graph_layout(mod, desired_layout): return seq(mod) except Exception as err: raise TVMCException("Error converting layout to {0}: {1}".format(desired_layout, str(err))) + + +def apply_graph_transforms(mod, args): + """Alter the layout of the input graph. + + Parameters + ---------- + mod : tvm.IRModule + The relay module to convert. + args : dict + The transform arguments. + + Returns + ------- + mod : tvm.IRModule + The converted module. + """ + if not args: + return mod + + # AlterLayout + if args.get("desired_layout", False): + mod = convert_graph_layout( + mod, args["desired_layout"], args.get("desired_layout_ops", None) + ) + + # ToMixedPrecision + if args.get("mixed_precision", False): + mod = convert_to_mixed_precision( + mod, + args.get("mixed_precision_ops"), + args.get("mixed_precision_calculation_type"), + args.get("mixed_precision_acc_type"), + ) + return mod + + +def parse_graph_transform_args(args): + """Parse incoming options for graph transform arguments. + + Parameters + ---------- + args: argparse.Namespace or dict + Arguments. + + Returns + ------- + transform_args : dict + Graph transform arguments + """ + + if not isinstance(args, dict): + args = vars(args) + + transform_args = [ + "desired_layout", + "desired_layout_ops", + "mixed_precision", + "mixed_precision_ops", + "mixed_precision_calculation_type", + "mixed_precision_acc_type", + ] + transform_args = {key: args.get(key, None) for key in transform_args} + return transform_args + + +def generate_transform_args(parser): + """Add graph transform related args""" + + # AlterLayout + parser.add_argument( + "--desired-layout", + choices=["NCHW", "NHWC"], + default=None, + help="Change the data layout of the whole graph.", + ) + parser.add_argument( + "--desired-layout-ops", + default=["nn.conv2d", "nn.conv2d_transpose", "qnn.conv2d"], + nargs="+", + help="List of operators to be layout converted.", + ) + + # ToMixedPrecision + parser.add_argument( + "--mixed-precision", + help="Enable mixed precision conversion", + action="store_true", + ) + parser.add_argument( + "--mixed-precision-ops", + default=["nn.conv2d", "nn.dense"], + nargs="+", + help="List of operators to be converted to mixed precision", + ) + parser.add_argument( + "--mixed-precision-calculation-type", + choices=["float16", "float32"], + default="float16", + help="Calculation precision type", + ) + parser.add_argument( + "--mixed-precision-acc-type", + choices=["float16", "float32"], + default="float16", + help="Accumulator precision type", + ) diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 0b7072b65afe8..45cd6659b6e07 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -88,7 +88,7 @@ def _find_match_sketch_id( decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) if structural_equal(sch.mod, expected_mod): - verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask) + verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, text_format="json") return sketch_id return None diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 38a46ebe757ef..f15976b1cc476 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -44,7 +44,8 @@ def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): - mod = mod.with_attr("global_symbol", "main") + if not (mod.attrs and "global_symbol" in mod.attrs): + mod = mod.with_attr("global_symbol", "main") mod = mod.with_attr("tir.noalias", True) mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 78895e4b49e0b..4b849987ed814 100755 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1084,6 +1084,19 @@ def convert_meshgrid(g, op, block): g.add_node(op.output("Out")[i], out) +def convert_mish(g, op, block): + """Operator converter for mish.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + exp = _op.exp(x) + add = _op.add(exp, _expr.const(1.0, dtype)) + log = _op.log(add) + tanh = _op.tanh(log) + out = _op.multiply(x, tanh) + g.add_node(op.output("Out")[0], out) + + def convert_mul(g, op, block): """Operator converter for mul.""" @@ -1785,6 +1798,14 @@ def convert_shape(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_silu(g, op, block): + """Operator converter for silu.""" + + x = g.get_node(op.input("X")[0]) + out = _op.multiply(x, _op.sigmoid(x)) + g.add_node(op.output("Out")[0], out) + + def convert_size(g, op, block): """Operator converter for size.""" @@ -1950,6 +1971,19 @@ def convert_softsign(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_softshrink(g, op, block): + """Operator converter for softshrink.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + threshold = _expr.const(op.attr("lambda"), dtype=dtype) + zeros = _op.zeros_like(x) + out = _op.where(x < -threshold, x + threshold, zeros) + _op.where( + x > threshold, x - threshold, zeros + ) + g.add_node(op.output("Out")[0], out) + + def convert_split(g, op, block): """Operator converter for split.""" @@ -1994,6 +2028,18 @@ def convert_split(g, op, block): g.add_node(op.output("Out")[i], out_i) +def convert_stack(g, op, blcok): + """Operator converter for stack.""" + + x = op.input("X") + all_inputs = [] + for inp in x: + all_inputs.append(g.get_node(inp)) + axis = op.attr("axis") + out = _op.stack(all_inputs, axis) + g.add_node(op.output("Y")[0], out) + + def convert_square(g, op, block): """Operator converter for square.""" @@ -2025,6 +2071,37 @@ def convert_swish(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_tile(g, op, block): + """Operator converter for tile.""" + + x = g.get_node(op.input("X")[0]) + if op.input("RepeatTimes"): + reps = g.get_node(op.input("RepeatTimes")[0]) + reps, infered = try_infer_value(reps, g.get_params()) + if infered: + reps = reps.tolist() + elif op.input("repeat_times_tensor"): + reps = [] + for rep_value in op.input("repeat_times_tensor"): + rep_value = g.get_node(rep_value).astype("int32") + reps.append(rep_value) + reps = _op.concatenate(reps, axis=0) + reps, infered = try_infer_value(reps, g.get_params()) + if infered: + reps = reps.tolist() + else: + reps = op.attr("repeat_times") + infered = True + + if not infered: + msg = 'Value {} in attribute "repeat_times" of operator Tile is not "valid."' + raise tvm.error.OpAttributeInvalid(msg.format(reps)) + + op_func = get_relay_op(op.type) + out = op_func(x, reps=reps) + g.add_node(op.output("Out")[0], out) + + def convert_topk(g, op, block): """Operator converter for topk.""" @@ -2074,6 +2151,28 @@ def convert_unsqueeze(g, op, block): g.add_node(op.output("Out")[0], x) +def convert_unstack(g, op, block): + """Operator converter for unstack.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + indices_or_sections = len(op.output("Y")) + outs = _op.split(x, indices_or_sections=indices_or_sections, axis=axis) + for i, out in enumerate(outs): + out = _op.squeeze(out, axis=axis) + g.add_node(op.output("Y")[i], out) + + +def convert_where(g, op, block): + """Operator converter for where.""" + + condition = g.get_node(op.input("Condition")[0]) + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + out = _op.where(condition, x, y) + g.add_node(op.output("Out")[0], out) + + def convert_where_index(g, op, block): """Operator converter for where_index.""" @@ -2166,6 +2265,7 @@ def convert_where_index(g, op, block): "matmul": convert_matmul, "matmul_v2": convert_matmul, "meshgrid": convert_meshgrid, + "mish": convert_mish, "mul": convert_mul, "mv": convert_mv, "nearest_interp_v2": convert_interpolate, @@ -2201,6 +2301,7 @@ def convert_where_index(g, op, block): "shape": convert_shape, "sigmoid": convert_unary_op, "sign": convert_unary_op, + "silu": convert_silu, "sin": convert_unary_op, "sinh": convert_unary_op, "size": convert_size, @@ -2208,7 +2309,9 @@ def convert_where_index(g, op, block): "softmax": convert_softmax, "softplus": convert_softplus, "softsign": convert_softsign, + "softshrink": convert_softshrink, "split": convert_split, + "stack": convert_stack, "strided_slice": convert_slice, "sqrt": convert_unary_op, "square": convert_square, @@ -2216,9 +2319,12 @@ def convert_where_index(g, op, block): "swish": convert_swish, "tan": convert_unary_op, "tanh": convert_unary_op, + "tile": convert_tile, "top_k_v2": convert_topk, "transpose2": convert_transpose, "unsqueeze2": convert_unsqueeze, + "unstack": convert_unstack, + "where": convert_where, "where_index": convert_where_index, } diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index a57c878bd9295..b2229d503bfbd 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -57,7 +57,9 @@ class AssertFrame(TIRFrame): @_register_object("script.ir_builder.tir.LetFrame") class LetFrame(TIRFrame): - ... + def __enter__(self) -> Var: + super().__enter__() + return self.var @_register_object("script.ir_builder.tir.RealizeFrame") @@ -113,4 +115,6 @@ def __enter__(self) -> Buffer: @_register_object("script.ir_builder.tir.LaunchThreadFrame") class LaunchThreadFrame(TIRFrame): - ... + def __enter__(self) -> Var: + super().__enter__() + return self.iter_var.var diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5f4e9d4f2cf0b..d65f9adea86fb 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -28,9 +28,10 @@ import numpy as np # type: ignore +from tvm import tir from tvm.ir import Range, Type from tvm.ir.base import deprecated -from tvm.runtime import convert, ndarray +from tvm.runtime import String, convert, ndarray from tvm.target import Target # pylint: disable=unused-import @@ -61,7 +62,6 @@ FloorMod, IntImm, IterVar, - Let, Load, Max, Min, @@ -138,6 +138,10 @@ def buffer( The declared buffer. """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is not None: + strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + else: + strides = [] return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member shape, dtype, @@ -304,7 +308,9 @@ def match_buffer( else: raise ValueError("Shape must be specified when binding input param") shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape - if strides is None: + if strides is not None: + strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + else: strides = [] return _ffi_api.MatchBuffer( # type: ignore[attr-defined] # pylint: disable=no-member param, @@ -472,7 +478,9 @@ def alloc_buffer( The allocated buffer. """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape - if strides is None: + if strides is not None: + strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + else: strides = [] return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint: disable=no-member shape, @@ -857,6 +865,47 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member +def LetStmt( # pylint: disable=invalid-name + value: PrimExpr, + type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name + *, + var: Optional[Var] = None, # pylint: disable=redefined-outer-name +) -> frame.LetFrame: + """Create a LetStmt binding + + Parameters + ---------- + value : PrimExpr + The value to be bound. + type_annotation : Optional[Type] = None + The type annotation of the let binding. Usually it is used for fine-grained var typing, + particularly, PointerType. + var : Optional[Var] = None + The variable to bind. If not specified, a new variable will be created. + + Returns + ------- + let_frame : frame.LetFrame + The result LetFrame. + """ + if type_annotation is not None: + if callable(type_annotation): + type_annotation = type_annotation() + if isinstance(type_annotation, Var): + type_annotation = type_annotation.type_annotation + return _ffi_api.LetStmt(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Let( # pylint: disable=invalid-name + expr: PrimExpr, + where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name +) -> PrimExpr: + """Create a Let expression binding""" + assert len(where) == 1, "T.Let only allows `where` to have exactly one element" + var, value = list(where.items())[0] # pylint: disable=redefined-outer-name + return tir.Let(var, value, expr) + + def let( v: Var, value: PrimExpr, @@ -880,9 +929,19 @@ def let( res : frame.LetFrame The result LetFrame. """ + + @deprecated("T.let", "T.Let") + def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: + return tir.Let(v, value, body) + + @deprecated("T.let", "T.LetStmt") + def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame: + return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member + if body is None: - return _ffi_api.Let(v, value) # type: ignore[attr-defined] # pylint: disable=no-member - return Let(v, value, body) + return let_stmt(v, value) + else: + return let_expr(v, value, body) def realize( @@ -1118,6 +1177,10 @@ def decl_buffer( The result DeclBufferFrame. """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is not None: + strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + else: + strides = [] return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member shape, dtype, @@ -1134,14 +1197,14 @@ def decl_buffer( def launch_thread( - iter_var: IterVar, # pylint: disable=redefined-outer-name + thread: Union[IterVar, str], # pylint: disable=redefined-outer-name extent: PrimExpr, ) -> frame.LaunchThreadFrame: """Launch a thread. Parameters ---------- - iter_var : IterVar + thread : Union[IterVar, str] The iteration variable. extent : PrimExpr @@ -1162,11 +1225,14 @@ def launch_thread( T.launch_thread(brow, 1) """ - return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member + + if isinstance(thread, str): + thread = String(thread) + return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member def env_thread(thread_tag: str) -> IterVar: - """Bind a var to thread env" + """Bind a var to thread env Parameters ---------- @@ -1267,11 +1333,13 @@ def func( Literal["inf", "-inf", "nan"], int, float, - ] = None + ] = None, + *, + is_size_var: bool = False, ) -> PrimExpr: if isinstance(expr, str): expr = float(expr) - return getattr(_ffi_api, name)(expr) + return getattr(_ffi_api, name)(expr, is_size_var) return func @@ -1354,7 +1422,7 @@ def func( # pylint: enable=invalid-name -def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: +def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimExpr: """Construct a new tir.Var with type boolean or cast expression to type boolean. Parameters @@ -1362,15 +1430,18 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: expr: PrimExpr The expression to be cast. + is_size_var: bool + Whether or not to return a SizeVar instead of Var. + Returns ------- res : PrimExpr The new tir.Var with type boolean or casted expression with type boolean. """ - return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(dtype: str = "void", storage_scope: str = "global") -> Var: +def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: bool = False) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -1381,15 +1452,18 @@ def handle(dtype: str = "void", storage_scope: str = "global") -> Var: storage_scope: str The storage scope of the pointer. + is_size_var: bool + Whether or not to return a SizeVar instead of Var. + Returns ------- res : PrimExpr The new tir.Var with type handle or casted expression with type handle. """ - return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Handle(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member -def void(expr: Optional[PrimExpr] = None) -> PrimExpr: +def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr: """Construct a new tir.Var with type void or cast expression to type void. Parameters @@ -1402,7 +1476,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr: res : PrimExpr The new tir.Var with type void or casted expression with type void. """ - return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Void(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.var", "T.{dtype}") @@ -1425,7 +1499,7 @@ def var(dtype: str, name: str = "") -> Var: return Var(name, dtype) # pylint: disable=no-member -def ptr(dtype: str, storage_scope: str = "global") -> Var: +def ptr(dtype: str, storage_scope: str = "global", is_size_var: bool = False) -> Var: """The pointer declaration function. Parameters @@ -1436,12 +1510,15 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var: storage_scope : str The storage scope of the pointer. + is_size_var: bool + Whether or not to return a SizeVar instead of Var. + Returns ------- res : Var The pointer. """ - return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Ptr(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.buffer_var", "T.handle") @@ -1850,7 +1927,6 @@ def wrapped(*args, **kwargs): "thread_binding", "grid", "Assert", - "let", "realize", "allocate", "allocate_const", @@ -2028,6 +2104,8 @@ def wrapped(*args, **kwargs): "Shuffle", "Call", "CallEffectKind", + "let", + "LetStmt", "Let", "IterVar", "CommReducer", diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index f0c04f47cdf6c..3e120339a6e4d 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -46,17 +46,17 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name for i in [0, 1]: # Case 1. binop - r(doc.Add, i, tir.Add) - r(doc.Sub, i, tir.Sub) - r(doc.Mult, i, tir.Mul) - r(doc.Div, i, tir.Div) - r(doc.FloorDiv, i, tir.FloorDiv) - r(doc.Mod, i, tir.FloorMod) - r(doc.LShift, i, lambda a, b: a << b) - r(doc.RShift, i, lambda a, b: a >> b) - r(doc.BitOr, i, lambda a, b: a | b) - r(doc.BitXor, i, lambda a, b: a ^ b) - r(doc.BitAnd, i, lambda a, b: a & b) + # doc.Add <-- is overloaded + # doc.Sub <-- is overloaded + # doc.Mult <-- is overloaded + # doc.Div <-- is overloaded + # doc.FloorDiv <-- is overloaded + # doc.Mod <-- is overloaded + # doc.LShift <-- is overloaded + # doc.RShift <-- is overloaded + # doc.BitOr <-- is overloaded + # doc.BitXor <-- is overloaded + # doc.BitAnd <-- is overloaded # doc.MatMult <-- not implemented # doc.Pow <-- not implemented # Case 2. cmpop @@ -75,10 +75,10 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name r(doc.Or, i, _or) for i in [0]: # Case 4. unaryop - r(doc.Invert, i, lambda a: ~a) + # doc.Invert <-- is overloaded r(doc.Not, i, tir.Not) - r(doc.UAdd, i, lambda a: +a) - r(doc.USub, i, lambda a: -a) + # doc.UAdd <-- is overloaded + # doc.USub <-- is overloaded _register_expr_op(tir.PrimExpr) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index fbef1a969179f..5796db40ec065 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -143,9 +143,8 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - IRBuilder.name(var_name, value) return value elif isinstance(value, PrimExpr): - var = Var("", value.dtype) - IRBuilder.name(var_name, var) - frame = T.let(var, value) + frame = T.LetStmt(value) + var = frame.var frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() return var @@ -294,7 +293,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: if not isinstance(ann_var, Var): self.report_error(node.annotation, "Annotation should be Var") self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) - frame = T.let(ann_var, rhs) + frame = T.LetStmt(rhs, var=ann_var) frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b63353bcb382a..896e2fc48e72f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1115,7 +1115,7 @@ def cache_write( block: Union[BlockRV, str], write_buffer_index: Union[int, str, Buffer], storage_scope: str, - consumer_blocks=None, + consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, ) -> BlockRV: """Create a block that reads a buffer region into a write cache. It requires: @@ -1203,6 +1203,197 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope, consumer_blocks ) + @type_checked + def reindex_cache_read( + self, + block: Union[BlockRV, str], + read_buffer_index: int, + storage_scope: str, + index_map: Union[IndexMap, Callable], + ) -> BlockRV: + """Create a block that reads a buffer region into a read cache using customized + indices specified by index map. The read region of the buffer must be a single point. + + The cache stage block follows the original order of loops and block itervars in the block. + If a block itervar does not appear in the buffer access region, it and its corresponding + loop variables will be omitted. User can then use `transform_block_layout` primitive to + reorder the block itervars and surrounding loops of the cache read/write block. + + Unlike `cache_read`, `reindex_cache_read` only supports single consumer, please use + `cache_read` when there are multiple consumers. + + Parameters + ---------- + block : BlockRV + The consumer block of the target buffer. + read_buffer_index: int + The index of the buffer in block's read region. + storage_scope: str + The target storage scope. + index_map: Union[IndexMap, Callable] + User defined indices to access allocated cache buffer, maps from block iter vars. + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before reindex_cache_read, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and reindex_cache_read: + + .. code-block:: python + + sch = tir.Schedule(before_cache_read) + block_b = sch.get_block("B") + sch.reindex_cache_read(block_b, 0, "local", lambda vi, vj: (vj, vi)) + print(sch.mod["main"].script()) + + After applying reindex_cache_read, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + A_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) + A_local[vj, vi] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_local[vj, vi] * 2.0 + + See Also + -------- + reindex_cache_write + transform_block_layout + transform_layout + cache_read + reindex + """ + # Convert any string block names into Block RVs. + block = self._normalize_block_arg(block) + + if callable(index_map): + index_map = IndexMap.from_func(index_map) + return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope, index_map + ) + + @type_checked + def reindex_cache_write( + self, + block: Union[BlockRV, str], + write_buffer_index: int, + storage_scope: str, + index_map: Union[Callable, IndexMap], + ) -> BlockRV: + r"""Create a block that reads a buffer region into a write cache using customized + indices specified by index map. The write region of the buffer must be a single point. + + The cache stage block follows the original order of loops and block itervars in the block. + If a block itervar does not appear in the buffer access region, it and its corresponding + loop variables will be omitted. User can then use `transform_block_layout` primitive to + reorder the block itervars and surrounding loops of the cache read/write block. + + Unlike `cache_write`, `reindex_cache_write` only supports single consumer, please use + `cache_write` when there are multiple consumers. + + Parameters + ---------- + block : Union[BlockRV, str] + The consumer block of the target buffer. + write_buffer_index: int + The index of the buffer in block's write region. + storage_scope: str + The target storage scope. + index_map: Union[Callable, IndexMap] + User defined indices to access allocated cache buffer, maps from block iter vars. + consumer_blocks: Optional[List[Union[BlockRV, str]]] + An optional list of consumers that should read directly from the cache. + If not specified, all consumers will read from the original buffer. + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before reindex_cache_write, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and reindex_cache_write: + + .. code-block:: python + + sch = tir.Schedule(before_cache_write) + block_b = sch.get_block("B") + sch.reindex_cache_write(block_b, 0, "local", lambda vi, vj: (vi // 2, vi % 2, vj)) + print(sch.mod["main"].script()) + + After applying reindex_cache_write, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (64, 2, 128)) + B_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi % 2, vi // 2, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_local[vi % 2, vi // 2, vj] + + See Also + -------- + reindex_cache_read + transform_block_layout + transform_layout + cache_write + reindex + """ + # Convert any string block names into Block RVs. + block = self._normalize_block_arg(block) + + if callable(index_map): + index_map = IndexMap.from_func(index_map) + return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member + self, block, write_buffer_index, storage_scope, index_map + ) + @type_checked def cache_inplace( self, @@ -1439,7 +1630,7 @@ def before_reindex( vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] * 2.0 - Create the schedule and do transform_layout: + Create the schedule and do reindex: .. code-block:: python diff --git a/src/ir/module.cc b/src/ir/module.cc index 22c6faf3d69d6..42ced96120457 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -65,6 +65,27 @@ IRModule::IRModule(tvm::Map functions, bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { if (functions.size() != other->functions.size()) return false; if (!equal(this->attrs, other->attrs)) return false; + if (equal.IsPathTracingEnabled()) { + const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); + for (const auto& kv : this->functions) { + if (!other->ContainGlobalVar(kv.first->name_hint)) return false; + ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), + obj_path_pair->rhs_path->Attr("functions") + ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; + if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; + } + if (type_definitions.size() != other->type_definitions.size()) return false; + for (const auto& kv : this->type_definitions) { + if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; + ObjectPathPair type_def_paths = { + obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), + obj_path_pair->rhs_path->Attr("type_definitions") + ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))}; + if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_def_paths)) + return false; + } + return true; + } for (const auto& kv : this->functions) { if (!other->ContainGlobalVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 99ffc1bfcdf7f..6f9b46a0f7342 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -162,6 +162,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 779114e9cfeab..0312c100b51b1 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -186,15 +186,15 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } -Array MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, - int n_tiles) const { +std::pair, Array> MultiLevelTilingNode::SplitLoop( + const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const { Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, /*max_innermost_factor=*/max_innermost_factor); Array splits = sch->Split(/*loop=*/loop, /*factors=*/{factors.begin(), factors.end()}); - return splits; + return {factors, splits}; } std::vector MultiLevelTilingNode::TileLoopNest(State state) const { @@ -207,6 +207,9 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; std::vector> tiles(s_indices_.size() + r_indices_.size()); + state->tile_factors.resize(tiles.size()); + std::vector> tile_factors; + tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; const std::vector* idx = nullptr; @@ -231,14 +234,16 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { if (n_tiles == 1) { tiles[idx->at(0)].push_back(loop); } else { - auto splits = SplitLoop(sch, block_rv, loop, n_tiles); + auto [factors, splits] = SplitLoop(sch, block_rv, loop, n_tiles); // Put every tile to its slot for (int j = 0; j < n_tiles; ++j) { tiles[idx->at(j)].push_back(splits[j]); + tile_factors[idx->at(j)].push_back(factors[j]); } } } + state->tile_factors = std::move(tile_factors); // Step 3. Reorder to organize the tiles sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); // Step 4. Bind the tiles to threads diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index ff38756ff06be..41b3ca9f26f3b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -94,6 +94,8 @@ class StateNode : public Object { tir::BlockRV block_rv; /*! \brief The loop tiles */ Array> tiles; + /*! \brief The factors of the loop tiles. */ + Array> tile_factors; /*! \brief The mapping from buffer index to read cache block. */ std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ @@ -163,8 +165,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual Array SplitLoop(const tir::Schedule& sch, tir::BlockRV block, - tir::LoopRV loop, int n_tiles) const; + virtual std::pair, Array> SplitLoop(const tir::Schedule& sch, + tir::BlockRV block, + tir::LoopRV loop, + int n_tiles) const; // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index d5cca52d41f93..1f9945022b66b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -124,6 +125,9 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { private: // SubRule: Add tensorization-related transformations inline std::vector TransformForTensorization(TensorCoreState state) const; + // Subrule: Transform the layout of the output. This is necessary for efficient cache write the + // output in the shared memory. + std::vector TransformIntermediateOutputLayout(TensorCoreState state); // Subrule: Add tensorized load inline std::vector AddReadReuseTensorCore(TensorCoreState state) const; // Subrule: Add tensorized store @@ -225,6 +229,9 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector(state)); }); states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { + return TransformIntermediateOutputLayout(Downcast(state)); + }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuseTensorCore(Downcast(state)); @@ -248,25 +255,162 @@ void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch, (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); } +std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout( + TensorCoreState state) { + // Transform the intermediate output to packed layout + // [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m, accum_elem_n] + // where warp_m, warp_n are thread indices bound to the warp id, accum_frag_m, accum_frag_n are + // the index of the fragments in each warp, accum_elem_m, accum_elem_n are the index of the + // elements in each accumulator fragment. + + // Get the shape of the wmma accumulator + auto [frag_shape_m, frag_shape_n] = [&]() { + tir::Block intrin_block = + Downcast( + tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) + ->block; + tir::For loop_m = Downcast(intrin_block->body); + tir::For loop_n = Downcast(loop_m->body); + return std::make_tuple(loop_m->extent, loop_n->extent); + }(); + + // Get the tile index of the warp id (i.e. threadIdx.y) + auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y"); + ICHECK(it != tile_binds.end()); + auto tile_index_warp_id = std::distance(tile_binds.begin(), it); + + // Get the extent of loop indicated by `loop_idx` inside the warp scope. + // For example, after spatial loops i, j are tiled, we will have + // tile_factors = ((i0, j0), (i1, j1), ..., (in, jn)) + // This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id. + // `loop_idx` can be negative, in which case it is counted from the end. + auto f_get_inner_tile_product = [&](int loop_idx) { + Array factors; + for (int i = tile_index_warp_id + 1; i < static_cast(s_indices_.size()); ++i) { + auto s_factors = state->tile_factors[s_indices_[i]]; + if (loop_idx < 0) { + loop_idx += s_factors.size(); + } + factors.push_back(s_factors[loop_idx]); + } + ICHECK(!factors.empty()); + if (factors.size() == 1) { + return factors[0]; + } + auto result = factors[0]; + for (int i = 1; i < static_cast(factors.size()); ++i) { + result = result * factors[i]; + } + return result; + }; + + // Compute the number of output fragment of each warp + auto warp_num_frag_m = f_get_inner_tile_product(-2); + auto warp_num_frag_n = f_get_inner_tile_product(-1); + + Schedule& sch = state->sch; + int buffer_ndim = static_cast(sch->Get(state->block_rv)->writes[0]->buffer->shape.size()); + // The dimension of the buffer should be larger or same as that of the tensor intrin. + ICHECK_GE(buffer_ndim, 2); + int num_higher_dims = buffer_ndim - 2; + + auto index_map = + tir::IndexMap::FromFunc(buffer_ndim, + // frag_shape_m and frag_shape_n are structural bindings that cannot + // not be automatically captured until c++20 + [&, frag_shape_m = frag_shape_m, + frag_shape_n = frag_shape_n](const Array& indices) { + Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + }); + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); + + return {state}; +} + std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( TensorCoreState state) const { // Add the cache write stage for Tensor Core - int level = r_indices_.front() - 1; - const LoopRV& loop = state->tiles[level].back(); Schedule& sch = state->sch; auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator"); - sch->ReverseComputeAt(cache_write, loop, true); - - if (state->write_reuse.count(0)) { - // Fuse the iterators of the cache_write - Array buffer_loops = sch->GetLoops(state->write_reuse[0]); - ICHECK_GT(buffer_loops.size(), 2); - sch->Fuse(Array{buffer_loops.end() - 2, // The src shmem is always 2D - buffer_loops.end()}); - AnnotateCooperativeFetching(&sch, state->write_reuse[0]); + + // The compute block has been tiled by the warp shape and the fragment shape. + // We need to bind the cache write block (from the accumulator to the shared memory) to the warp + // id. The schedule is as follows: + // + // After adding cache write for wmma.accumulator, we will have + // for i0, j0, i1, j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', i1', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // where i0' and j0' are already bound to the block id and warp id. + // + // To reduce the shared memory usage and allow efficient data movement, we will apply + // transformations to generate the following schedule: + // + // for i1': + // for i0_j0 (fused and bound to threadIdx.y): + // for j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // + // i1' is reordered to the outermost. This effectively allows only a row (i.e. loop i1') of the + // fragments are moved to the shared memory and then to the global memory each time. + // As a result, shared memory for the output will only have shape of [j1, accum_m, accum_n] + // instead of [i0 * i1 * accum_m, j0 * j1 * accum_n]. + + // Get the loops other than the innermost two loops (accum_m and accum_n). + auto f_get_loops = [&](const BlockRV& block_rv) -> std::array { + Array buffer_loops = sch->GetLoops(block_rv); + ICHECK_GT(buffer_loops.size(), 6); + return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], + buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; + }; + { + const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]); + sch->Reorder({i1, i0, j0, j1}); + sch->ComputeAt(cache_write, i1, true); + } + { + auto loops = f_get_loops(cache_write); + const auto& i0 = loops[0]; + const auto& j0 = loops[1]; + auto fused = sch->Fuse({i0, j0}); + sch->Bind(fused, "threadIdx.y"); } + sch->ReverseComputeInline(state->tensor_core_reindex_store); - TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin); + auto loops = sch->GetLoops(cache_write); + auto blockized_store = sch->Blockize(loops[loops.size() - 2]); + sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize, + state->intrin_group.store_intrin); + + Array buffer_loops = sch->GetLoops(state->write_reuse[0]); + ICHECK_GT(buffer_loops.size(), 5); + sch->Fuse(Array{buffer_loops.end() - 5, // The src shmem is always 2D + buffer_loops.end()}); + AnnotateCooperativeFetching(&sch, state->write_reuse[0]); return {state}; } @@ -508,7 +652,8 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( state->sch->state(), GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); - state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt); + state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); }; for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { @@ -569,6 +714,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + CHECK(node->reuse_write_.req == ReuseType::kMustReuse && + runtime::StorageScope::Create(node->reuse_write_.scope).rank == + runtime::StorageRank::kShared) + << "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore."; + node->intrin_groups.reserve(intrin_groups.size()); for (const auto& intrin_group_config : intrin_groups) { node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config)); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index d4c4a10fdd722..e68b64ea2d3aa 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -48,11 +48,12 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { return ScheduleRule(n); } - Array SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; + std::pair, Array> SplitLoop(const Schedule& sch, BlockRV block, + LoopRV loop, int n_tiles) const; }; -Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, - LoopRV loop_rv, int n_tiles) const { +std::pair, Array> MultiLevelTilingWideVectorNode::SplitLoop( + const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv); const tir::BlockNode* block_node = block_sref->StmtAs(); @@ -99,12 +100,14 @@ Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch Array outer_splits = sch->Split( /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); outer_splits.push_back(inner_splits[1]); - return outer_splits; + outer_factors.push_back(PrimExpr(vec_len)); + return {outer_factors, outer_splits}; } else { Array factors(n_tiles - 1, PrimExpr(1)); factors.push_back(loop->extent); - return sch->Split(/*loop=*/loop_rv, - /*factors=*/{factors.begin(), factors.end()}); + Array splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); + return {factors, splits}; } } } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index fa77b47bd2847..6cf796d34447c 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -39,7 +39,7 @@ namespace tvm { -// Define the dispatch functio here since primary user is in this file. +// Define the dispatch function here since primary user is in this file. void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const { uint32_t tindex = self->type_index(); if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) { @@ -50,7 +50,7 @@ void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) con } // Hash handler that handles free vars -// by assigning an unique counter in the order of their ocurrence. +// by assigning an unique counter in the order of their occurrence. // // This algorithm depends on the determinism of the traversal of SHash function. // In particular, when we traverse unordered_map, we should first sort @@ -69,9 +69,9 @@ class SHashHandlerDefault::Impl { */ ObjectRef object; /*! \brief The partially reduce hash value.*/ - size_t reduced_hash; + uint64_t reduced_hash; /*! \brief The expected location in the result stack. */ - size_t result_stack_index = std::numeric_limits::max(); + uint64_t result_stack_index = std::numeric_limits::max(); /*! \brief Whether the children has been expanded via SEqualReduce */ bool children_expanded{false}; /*! \brief Whether the node is graph node. */ @@ -80,7 +80,7 @@ class SHashHandlerDefault::Impl { bool map_free_vars; Task() = default; - explicit Task(ObjectRef object, size_t reduced_hash, bool map_free_vars) + explicit Task(ObjectRef object, uint64_t reduced_hash, bool map_free_vars) : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {} }; @@ -90,7 +90,7 @@ class SHashHandlerDefault::Impl { task_stack_.back().graph_node_hash = true; } - bool LookupHashedValue(const ObjectRef& key, size_t* hash_value) { + bool LookupHashedValue(const ObjectRef& key, uint64_t* hash_value) { auto it = hash_memo_.find(key); if (it != hash_memo_.end()) { hash_value[0] = it->second; @@ -99,7 +99,7 @@ class SHashHandlerDefault::Impl { return false; } - void SHashReduceHashedValue(size_t hashed_value) { + void SHashReduceHashedValue(uint64_t hashed_value) { pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false)); } @@ -107,18 +107,18 @@ class SHashHandlerDefault::Impl { ICHECK(!hash_memo_.count(GetRef(var))); if (map_free_vars) { // use counter value. - size_t value = std::hash()(free_var_counter_++); + uint64_t value = std::hash()(free_var_counter_++); pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } else { // use pointer hash - size_t value = std::hash()(var); + uint64_t value = std::hash()(var); pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } } void SHashReduce(const ObjectRef& object, bool map_free_vars) { // Directly push the result - // Note: it is still important to push the result to pendng tasks + // Note: it is still important to push the result to pending tasks // so that the reduction order of hash values stays the same. if (!object.defined()) { pending_tasks_.emplace_back(Task(ObjectRef(nullptr), 0, false)); @@ -133,7 +133,7 @@ class SHashHandlerDefault::Impl { } } - size_t Hash(const ObjectRef& object, bool map_free_vars) { + uint64_t Hash(const ObjectRef& object, bool map_free_vars) { ICHECK_EQ(task_stack_.size(), 0U); ICHECK_EQ(pending_tasks_.size(), 0U); ICHECK_EQ(result_stack_.size(), 0U); @@ -147,7 +147,7 @@ class SHashHandlerDefault::Impl { this->RunTasks(); ICHECK_EQ(result_stack_.size(), 1U); - size_t ret = result_stack_.back(); + uint64_t ret = result_stack_.back(); result_stack_.pop_back(); return ret; } @@ -170,13 +170,13 @@ class SHashHandlerDefault::Impl { * \brief Compute the reduced hash value for the task. * \param task The indicated task. */ - size_t ReduceHash(const Task& task) { - size_t stack_begin = task.result_stack_index; + uint64_t ReduceHash(const Task& task) { + uint64_t stack_begin = task.result_stack_index; ICHECK_LE(stack_begin, result_stack_.size()); // combine in the reverse order of the stack. - size_t reduced_hash = task.reduced_hash; - for (size_t i = result_stack_.size(); i != stack_begin; --i) { + uint64_t reduced_hash = task.reduced_hash; + for (uint32_t i = result_stack_.size(); i != stack_begin; --i) { reduced_hash = support::HashCombine(reduced_hash, result_stack_[i - 1]); } result_stack_.resize(stack_begin); @@ -201,7 +201,7 @@ class SHashHandlerDefault::Impl { // so that we can distinguish DAG from trees. if (entry.graph_node_hash) { entry.reduced_hash = support::HashCombine(entry.reduced_hash, - std::hash()(graph_node_counter_++)); + std::hash()(graph_node_counter_++)); } hash_memo_[entry.object] = entry.reduced_hash; } @@ -241,27 +241,27 @@ class SHashHandlerDefault::Impl { // The owner of this impl SHashHandlerDefault* parent_; // free var counter. - size_t free_var_counter_{0}; + uint32_t free_var_counter_{0}; // graph node counter. - size_t graph_node_counter_{0}; + uint32_t graph_node_counter_{0}; // record current stack top bool allow_push_to_stack_{true}; // list of pending tasks to be pushed to the stack. std::vector pending_tasks_; // Internal task stack to executed the task std::vector task_stack_; - // Internal stack to store the result poped from the task stack. - std::vector result_stack_; + // Internal stack to store the result popped from the task stack. + std::vector result_stack_; // reflection vtable ReflectionVTable* vtable_ = ReflectionVTable::Global(); // map from lhs to rhs - std::unordered_map hash_memo_; + std::unordered_map hash_memo_; }; SHashHandlerDefault::SHashHandlerDefault() { impl = new Impl(this); } SHashHandlerDefault::~SHashHandlerDefault() { delete impl; } -void SHashHandlerDefault::SHashReduceHashedValue(size_t hashed_value) { +void SHashHandlerDefault::SHashReduceHashedValue(uint64_t hashed_value) { return impl->SHashReduceHashedValue(hashed_value); } @@ -273,13 +273,13 @@ void SHashHandlerDefault::SHashReduceFreeVar(const runtime::Object* var, bool ma impl->SHashReduceFreeVar(var, map_free_vars); } -bool SHashHandlerDefault::LookupHashedValue(const ObjectRef& key, size_t* hashed_value) { +bool SHashHandlerDefault::LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) { return impl->LookupHashedValue(key, hashed_value); } void SHashHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); } -size_t SHashHandlerDefault::Hash(const ObjectRef& object, bool map_free_vars) { +uint64_t SHashHandlerDefault::Hash(const ObjectRef& object, bool map_free_vars) { return impl->Hash(object, map_free_vars); } @@ -289,11 +289,11 @@ void SHashHandlerDefault::DispatchSHash(const ObjectRef& key, bool map_free_vars TVM_REGISTER_GLOBAL("node.StructuralHash") .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { - size_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); + uint64_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); return static_cast(hashed_value); }); -size_t StructuralHash::operator()(const ObjectRef& object) const { +uint64_t StructuralHash::operator()(const ObjectRef& object) const { return SHashHandlerDefault().Hash(object, false); } @@ -302,7 +302,7 @@ struct StringObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); + hash_reduce->SHashReduceHashedValue(runtime::String::StableHashBytes(key->data, key->size)); } static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, @@ -371,7 +371,7 @@ void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_redu } if (hash_data) { (*hash_reduce) - ->SHashReduceHashedValue(runtime::String::HashBytes( + ->SHashReduceHashedValue(runtime::String::StableHashBytes( static_cast(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor))); } } @@ -405,7 +405,7 @@ struct ArrayNodeTrait { static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { hash_reduce(static_cast(key->size())); - for (size_t i = 0; i < key->size(); ++i) { + for (uint32_t i = 0; i < key->size(); ++i) { hash_reduce(key->at(i)); } } @@ -416,7 +416,7 @@ struct ArrayNodeTrait { } if (lhs->size() != rhs->size()) return false; - for (size_t i = 0; i < lhs->size(); ++i) { + for (uint32_t i = 0; i < lhs->size(); ++i) { if (!equal(lhs->at(i), rhs->at(i))) return false; } return true; @@ -425,10 +425,10 @@ struct ArrayNodeTrait { private: static bool SEqualReduceTraced(const ArrayNode* lhs, const ArrayNode* rhs, const SEqualReducer& equal) { - size_t min_size = std::min(lhs->size(), rhs->size()); + uint32_t min_size = std::min(lhs->size(), rhs->size()); const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths(); - for (size_t index = 0; index < min_size; ++index) { + for (uint32_t index = 0; index < min_size; ++index) { ObjectPathPair element_paths = {array_paths->lhs_path->ArrayIndex(index), array_paths->rhs_path->ArrayIndex(index)}; if (!equal(lhs->at(index), rhs->at(index), element_paths)) { @@ -491,7 +491,7 @@ struct ShapeTupleObjTrait { static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) { hash_reduce(self->size); - for (size_t i = 0; i < self->size; ++i) { + for (uint32_t i = 0; i < self->size; ++i) { hash_reduce(self->data[i]); } } @@ -499,7 +499,7 @@ struct ShapeTupleObjTrait { static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs, SEqualReducer equal) { if (lhs->size != rhs->size) return false; - for (size_t i = 0; i < lhs->size; ++i) { + for (uint32_t i = 0; i < lhs->size; ++i) { if (!equal(lhs->data[i], rhs->data[i])) return false; } return true; @@ -539,10 +539,10 @@ struct MapNodeTrait { // This resolves common use cases where we want to store // Map where Var is defined in the function // parameters. - using KV = std::pair; + using KV = std::pair; std::vector temp; for (const auto& kv : *key) { - size_t hashed_value; + uint64_t hashed_value; if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) { temp.emplace_back(hashed_value, kv.second); } @@ -553,11 +553,11 @@ struct MapNodeTrait { // add size to the hash hash_reduce(static_cast(key->size())); // hash the content - for (size_t i = 0; i < temp.size();) { - size_t k = i + 1; + for (uint32_t i = 0; i < temp.size();) { + uint32_t k = i + 1; for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { } - // ties are rare, but we need to skip them to make the hash determinsitic + // ties are rare, but we need to skip them to make the hash deterministic if (k == i + 1) { hash_reduce->SHashReduceHashedValue(temp[i].first); hash_reduce(temp[i].second); @@ -583,7 +583,7 @@ struct MapNodeTrait { // add size to the hash after sorting. hash_reduce(static_cast(key->size())); // hash the content - for (size_t i = 0; i < temp.size(); ++i) { + for (uint32_t i = 0; i < temp.size(); ++i) { hash_reduce(temp[i].first); hash_reduce(temp[i].second); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index d5cf4baf7243d..acaea425d1789 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -696,7 +696,7 @@ struct TargetStrHash { */ size_t operator()(const Target& target) const { std::string s(target->kind->name); - return String::HashBytes(s.c_str(), s.size()); + return String::StableHashBytes(s.c_str(), s.size()); } }; diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 30102b6877223..aee0b4bb6253d 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -32,6 +32,8 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt Optional> strides, Optional elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional> axis_separators) { + CHECK(buffer_type == "auto" || buffer_type == "default" || buffer_type.empty()) + << "ValueError: `buffer_type` must be `auto` or `default` or empty"; Var buffer_data; if (!data.defined()) { DataType storage_dtype = dtype; @@ -48,7 +50,7 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt } return Buffer(buffer_data, dtype, shape, strides.value_or(Array()), elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor, - (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault, + (buffer_type == "auto" ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault), axis_separators.value_or(Array())); } @@ -381,7 +383,20 @@ AssertFrame Assert(PrimExpr condition, String message) { return AssertFrame(n); } -LetFrame Let(Var var, PrimExpr value) { +LetFrame LetStmt(PrimExpr value, Optional type_annotation, Optional var) { + ObjectPtr n = make_object(); + if (var.defined()) { + n->var = var.value(); + } else if (type_annotation.defined()) { + n->var = Var("v", type_annotation.value()); + } else { + n->var = Var("v", value.dtype()); + } + n->value = value; + return LetFrame(n); +} + +LetFrame LegacyLetStmt(Var var, PrimExpr value) { ObjectPtr n = make_object(); n->var = var; n->value = value; @@ -414,6 +429,10 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { return LaunchThreadFrame(n); } +LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) { + return LaunchThread(EnvThread(thread_tag), extent); +} + RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition) { ObjectPtr n = make_object(); @@ -541,18 +560,9 @@ DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_ void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } -PrimExpr Ptr(runtime::DataType dtype, String storage_scope) { - return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope)); -} - -Var Handle(runtime::DataType dtype, String storage_scope) { - Type type_annotation{nullptr}; - if (dtype.is_void() && storage_scope == "global") { - type_annotation = PrimType(runtime::DataType::Handle()); - } else { - type_annotation = PointerType(PrimType(dtype), storage_scope); - } - return tvm::tir::Var("", type_annotation); +PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global", bool is_size_var = false) { + PointerType type_annotation(PrimType(dtype), storage_scope); + return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); } using tvm::script::ir_builder::details::Namer; @@ -634,7 +644,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(Thread TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); @@ -644,7 +655,18 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") + .set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) { + if (const auto* var = thread_tag_or_var.as()) { + return LaunchThread(GetRef(var), extent); + } else if (const auto* str = thread_tag_or_var.as()) { + return LaunchThread(GetRef(str), extent); + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " + << thread_tag_or_var->GetTypeKey(); + throw; + } + }); TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index e9a3b3567ec07..e726cd42a241d 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -548,7 +548,11 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) { } if (doc->rhs) { output_ << " = "; - PrintDoc(doc->rhs.value()); + if (const auto* tuple_doc = doc->rhs.as()) { + PrintJoinedDocs(tuple_doc->elements, ", "); + } else { + PrintDoc(doc->rhs.value()); + } } MaybePrintCommentInline(doc); } @@ -670,7 +674,6 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { PrintBlockComment(doc->comment.value()); } PrintIndentedBlock(doc->body); - NewLineWithoutIndent(); } void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) { diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index e6f4a1eaee2c2..065cfe5168ad6 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -72,6 +72,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) d->cfg->binding_names.pop_back(); if (const auto* stmt_block = doc.as()) { (*f)->stmts.push_back(stmt_block->stmts.back()); + (*f)->stmts.back()->source_paths = std::move(doc->source_paths); } else if (const auto* stmt = doc.as()) { (*f)->stmts.push_back(GetRef(stmt)); } else { diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 19f3dc7ef577c..ed8b7071765e3 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -24,55 +24,120 @@ namespace tvm { namespace script { namespace printer { -Map BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, +Map BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, const Frame& frame, const IRDocsifier& d) { + using tvm::tir::Var; + using tvm::tir::VarNode; Map kwargs; - auto implicit_var_def = [&](const PrimExpr& e, const ObjectPath& p, const String& key) { - if (Optional doc = d->GetVarDoc(e)) { - kwargs.Set(key, doc.value()); - return false; - } - if (e->IsInstance()) { - d->Define(e, frame, [=]() { return d->AsDoc(buffer, p)->Attr(key); }); + Array var_def_lhs; + Array var_def_rhs; + + // Step 0. Set up statistics + std::unordered_map use_count; + auto update_use_count = [&](const PrimExpr& e) { + tir::PostOrderVisit(e, [&](const ObjectRef& n) { + if (const VarNode* var = n.as()) { + ++use_count[var]; + } + }); + }; + update_use_count(buffer->elem_offset); + update_use_count(buffer->data); + for (const PrimExpr& e : buffer->strides) { + update_use_count(e); + } + for (const PrimExpr& e : buffer->shape) { + update_use_count(e); + } + auto is_new_var = [&](const PrimExpr& e) { + return e->IsInstance() && !d->IsVarDefined(e); + }; + auto add_out_of_line_var_def = [&](const Var& var, const ObjectPath& var_p) { + ICHECK(!d->IsVarDefined(var)); + ExprDoc lhs = DefineVar(var, frame, d); + lhs->source_paths.push_back(var_p); + var_def_lhs.push_back(lhs); + var_def_rhs.push_back(PrintVarCreation(var, var_p, d)); + }; + auto try_inline_def = [&](const PrimExpr& e, const ObjectPath& e_p, + std::function inline_f) { + ICHECK(is_new_var(e)); + Var var = Downcast(e); + if (use_count[var.get()] == 1) { + d->Define(e, frame, inline_f); return true; + } else { + add_out_of_line_var_def(var, e_p); + return false; } - kwargs.Set(key, d->AsDoc(e, p)); - return false; }; - auto array_out_line_var_def = [&](const Array& array, const ObjectPath& p, - const String& key) { - int n = array.size(); + // Step 1. Handle `buffer.shape` + { + const Array& shape = buffer->shape; + ObjectPath shape_p = buffer_p->Attr("shape"); + int n = shape.size(); Array results; results.reserve(n); for (int i = 0; i < n; ++i) { - PrimExpr s = array[i]; - ObjectPath s_path = p->ArrayIndex(i); - // Add out-of-line definition for a new Var in shape - results.push_back(d->AsDoc(s, s_path)); + PrimExpr e = shape[i]; + ObjectPath e_p = shape_p->ArrayIndex(i); + if (is_new_var(e)) { + add_out_of_line_var_def(Downcast(e), e_p); + } + results.push_back(d->AsDoc(e, e_p)); } - kwargs.Set(key, TupleDoc(results)); - }; - // Step 1. Handle `buffer.shape` - array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape"); + kwargs.Set("shape", TupleDoc(results)); + } // Step 2. Handle `buffer.dtype` if (buffer->dtype != d->cfg->buffer_dtype) { - kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); + kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, buffer_p->Attr("dtype"))); } // Step 3. Handle `buffer.data` - implicit_var_def(buffer->data, p->Attr("data"), "data"); + if (!is_new_var(buffer->data)) { + kwargs.Set("data", d->AsDoc(buffer->data, buffer_p->Attr("data"))); + } else { + try_inline_def(buffer->data, buffer_p->Attr("data"), + [=]() { return d->AsDoc(buffer, buffer_p)->Attr("data"); }); + } // Step 4. Handle `buffer.strides` if (!buffer->strides.empty()) { - array_out_line_var_def(buffer->strides, p->Attr("strides"), "strides"); + const Array& strides = buffer->strides; + ObjectPath strides_p = buffer_p->Attr("strides"); + int n = strides.size(); + Array results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + PrimExpr e = strides[i]; + ObjectPath e_p = strides_p->ArrayIndex(i); + if (is_new_var(e)) { + if (try_inline_def(e, e_p, [=]() { + return d->AsDoc(buffer, buffer_p) + ->Attr("strides")[{LiteralDoc::Int(i, NullOpt)}]; + })) { + results.push_back(LiteralDoc::Str(Downcast(e)->name_hint, e_p)); + continue; + } + } + results.push_back(d->AsDoc(e, e_p)); + } + kwargs.Set("strides", TupleDoc(results)); } // Step 5. Handle `buffer.elem_offset` bool needs_print_factor = false; if (const auto* int_imm = buffer->elem_offset.as()) { if (int_imm->value != 0) { - kwargs.Set("elem_offset", d->AsDoc(buffer->elem_offset, p->Attr("elem_offset"))); + kwargs.Set("elem_offset", + d->AsDoc(buffer->elem_offset, // + buffer_p->Attr("elem_offset"))); } + } else if (is_new_var(buffer->elem_offset)) { + try_inline_def(buffer->elem_offset, buffer_p->Attr("elem_offset"), + [=]() { return d->AsDoc(buffer, buffer_p)->Attr("elem_offset"); }); + needs_print_factor = true; } else { - needs_print_factor = - implicit_var_def(buffer->elem_offset, p->Attr("elem_offset"), "elem_offset"); + kwargs.Set("elem_offset", + d->AsDoc(buffer->elem_offset, // + buffer_p->Attr("elem_offset"))); } // Step 6. Handle `buffer.scope` { @@ -80,25 +145,32 @@ Map BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, if (scope != "global") { kwargs.Set( "scope", - LiteralDoc::Str(scope, p->Attr("data")->Attr("type_annotation")->Attr("storage_scope"))); + LiteralDoc::Str(scope, + buffer_p->Attr("data")->Attr("type_annotation")->Attr("storage_scope"))); } } // Step 7. Handle `buffer.data_alignment` if (buffer->data_alignment != runtime::kAllocAlignment) { - kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, p->Attr("data_alignment"))); + kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, buffer_p->Attr("data_alignment"))); } // Step 8. Handle `buffer.offset_factor` if (needs_print_factor || buffer->offset_factor != 1) { - kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor, p->Attr("offset_factor"))); + kwargs.Set("offset_factor", + LiteralDoc::Int(buffer->offset_factor, buffer_p->Attr("offset_factor"))); } // Step 9. Handle `buffer.buffer_type` if (buffer->buffer_type != tir::BufferType::kDefault) { - kwargs.Set("type", LiteralDoc::Str("auto", p->Attr("buffer_type"))); + kwargs.Set("buffer_type", LiteralDoc::Str("auto", buffer_p->Attr("buffer_type"))); } // Step 10. Handle `buffer.axis_separator` if (!buffer->axis_separators.empty()) { kwargs.Set("axis_separators", - d->AsDoc(buffer->axis_separators, p->Attr("axis_separators"))); + d->AsDoc(buffer->axis_separators, buffer_p->Attr("axis_separators"))); + } + if (var_def_lhs.size() == 1) { + frame->stmts.push_back(AssignDoc(var_def_lhs[0], var_def_rhs[0], NullOpt)); + } else if (var_def_lhs.size() > 1) { + frame->stmts.push_back(AssignDoc(TupleDoc(var_def_lhs), TupleDoc(var_def_rhs), NullOpt)); } return kwargs; } @@ -111,8 +183,8 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Arr args.push_back(doc.value()); } } - for (String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", "type", - "axis_separators"}) { + for (String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", + "buffer_type", "axis_separators"}) { if (Optional doc = attrs.Get(s)) { kwargs_keys.push_back(s); kwargs_values.push_back(doc.value()); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index f1435c4870447..9c4f62eb1c1b9 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -24,33 +24,47 @@ namespace tvm { namespace script { namespace printer { +ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) { + Type type = var->type_annotation; + ObjectPath type_p = var_p->Attr("type_annotation"); + ExprDoc rhs{nullptr}; + Array kwargs_keys; + Array kwargs_values; + + if (var->IsInstance()) { + kwargs_keys.push_back("is_size_var"); + kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt)); + } + + if (const auto* ptr_type = type.as()) { + const auto* prim_type = ptr_type->element_type.as(); + ICHECK(prim_type); + ExprDoc element_type = + LiteralDoc::DataType(prim_type->dtype, type_p->Attr("element_type")->Attr("dtype")); + rhs = TIR(d, "handle"); + rhs->source_paths.push_back(var_p->Attr("dtype")); + if (ptr_type->storage_scope == "") { + rhs = rhs->Call({element_type}, kwargs_keys, kwargs_values); + } else { + rhs = rhs->Call({element_type, + LiteralDoc::Str(ptr_type->storage_scope, // + type_p->Attr("storage_scope"))}, + kwargs_keys, kwargs_values); + } + } else { + rhs = TIR(d, DType2Str(var->dtype)); + rhs->source_paths.push_back(var_p->Attr("dtype")); + rhs = rhs->Call({}, kwargs_keys, kwargs_values); + } + rhs->source_paths.push_back(type_p); + return rhs; +} + Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { if (Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); - Type type = var->type_annotation; - ObjectPath type_p = var_p->Attr("type_annotation"); - ExprDoc rhs{nullptr}; - if (const auto* ptr_type = type.as()) { - const auto* prim_type = ptr_type->element_type.as(); - ICHECK(prim_type); - ExprDoc element_type = - LiteralDoc::DataType(prim_type->dtype, type_p->Attr("element_type")->Attr("dtype")); - rhs = TIR(d, "handle"); - rhs->source_paths.push_back(var_p->Attr("dtype")); - if (ptr_type->storage_scope == "") { - rhs = rhs->Call({element_type}); - } else { - rhs = rhs->Call({element_type, - LiteralDoc::Str(ptr_type->storage_scope, // - type_p->Attr("storage_scope"))}); - } - } else { - rhs = TIR(d, DType2Str(var->dtype)); - rhs->source_paths.push_back(var_p->Attr("dtype")); - rhs = rhs->Call({}); - } - rhs->source_paths.push_back(type_p); + ExprDoc rhs = PrintVarCreation(var, var_p, d); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } else { LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint; @@ -211,11 +225,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { - return TIR(d, "let")->Call({ - d->AsDoc(let->var, p->Attr("var")), - d->AsDoc(let->value, p->Attr("value")), - d->AsDoc(let->body, p->Attr("body")), - }); + DictDoc where({d->AsDoc(let->var, p->Attr("var"))}, + {d->AsDoc(let->value, p->Attr("value"))}); + return TIR(d, "Let")->Call({d->AsDoc(let->body, p->Attr("body"))}, // + {"where"}, {where}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -298,32 +311,47 @@ bool IsNumber(const ExprDoc& e) { return false; } -#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Div node, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); + PrimExpr ret = tvm::div(node->a, node->b); + if (!ret->IsInstance()) { + return TIR(d, "Div")->Call({a, b}); + } + if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) && + (node->b->dtype.is_int() || node->b->dtype.is_uint())) { + return TIR(d, "Div")->Call({a, b}); + } + return OperationDoc(OperationDocNode::Kind::kDiv, {a, b}); + }); + +#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch("", \ [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - if (IsNumber(a) && IsNumber(b)) { \ + PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ + if (!ret->IsInstance()) { \ return TIR(d, OpString)->Call({a, b}); \ } \ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ }); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div", kDiv); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd); -TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, AddNode, add, "Add", kAdd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, SubNode, sub, "Sub", kSub); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, MulNode, mul, "Mul", kMult); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, FloorDivNode, floordiv, "FloorDiv", kFloorDiv); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, FloorModNode, floormod, "FloorMod", kMod); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, LTNode, less, "LT", kLt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, LENode, less_equal, "LE", kLtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, EQNode, equal, "EQ", kEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, NENode, not_equal, "NE", kNotEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, GTNode, greater, "GT", kGt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, GENode, greater_equal, "GE", kGtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, AndNode, logical_and, "And", kAnd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, OrNode, logical_or, "Or", kOr); TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod"); TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min"); diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 6a4df34a3a7a0..f40d7818d7e1e 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -92,8 +92,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) tir::Buffer buffer = func->buffer_map[var]; if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); - args.push_back(AssignDoc(DefineBuffer(buffer, *f, d), NullOpt, - BufferAttn(buffer, buffer_p, *f, d))); + IdDoc lhs = DefineBuffer(buffer, *f, d); + ExprDoc annotation = BufferAttn(buffer, buffer_p, *f, d); + args.push_back(AssignDoc(lhs, NullOpt, annotation)); buffer_inlined.insert(buffer.get()); continue; } @@ -117,7 +118,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } ExprDoc param_doc = args[i]->lhs; ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); - ExprDoc lhs = DefineBuffer(buffer, *f, d); // TODO(@junrushao): switch `lhs` and `rhs` + ExprDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *f, d); (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index b730dd5606ba7..591d1e3bc1da3 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) { LOG(FATAL) << "NotImplementedError: fragment printing"; } +bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { + if (!d->common_prefix.count(var.get())) { + return false; + } + const std::vector& path = d->common_prefix.at(var.get()); + for (auto it = path.rbegin(); it != path.rend(); ++it) { + if (*it == node.get()) { + return true; + } + } + return false; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc value = d->AsDoc(eval->value, p->Attr("value")); @@ -57,30 +70,35 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); - if (concise && !d->IsVarDefined(stmt->var)) { - ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - With f(d, stmt); - ExprDoc lhs = DefineVar(stmt->var, *f, d); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - Array* stmts = &(*f)->stmts; - Type type = stmt->var->type_annotation; - Optional type_doc = - d->AsDoc(type, p->Attr("var")->Attr("type_annotation")); - if (const auto* tuple_type = type.as()) { - if (tuple_type->fields.empty()) { - type_doc = NullOpt; - } + // Step 1. Type annotation + Optional type_doc = d->AsDoc(stmt->var->type_annotation, // + p->Attr("var")->Attr("type_annotation")); + if (const auto* tuple_type = stmt->var->type_annotation.as()) { + if (tuple_type->fields.empty()) { + type_doc = NullOpt; } + } + // Step 2. RHS + ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); + // Step 3. LHS and body + With f(d, stmt); + Array* stmts = &(*f)->stmts; + bool var_defined = d->IsVarDefined(stmt->var); + if (!var_defined) { + DefineVar(stmt->var, *f, d); + } + ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + // Step 4. Dispatch + if (var_defined) { + return ScopeDoc(NullOpt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts); + } else if (concise) { stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); return StmtBlockDoc(*stmts); + } else if (type_doc.defined() && !stmt->var->type_annotation->IsInstance()) { + return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs, type_doc.value()}), *stmts); } else { - ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); - ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - With f(d, stmt); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - Array* stmts = &(*f)->stmts; - rhs = TIR(d, "let")->Call({lhs, rhs}); - return ScopeDoc(NullOpt, rhs, *stmts); + return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs}), *stmts); } }); @@ -317,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, OptionalCall(args, kwargs_keys, kwargs_values); } +void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p, + const IRDocsifier& d) { + Frame f = FindLowestVarDef(iter_var->var, d).value(); + DefineVar(iter_var->var, f, d); + ExprDoc rhs = TIR(d, "env_thread") + ->Call({LiteralDoc::Str(iter_var->thread_tag, // + iter_var_p->Attr("thread_tag"))}); + ExprDoc lhs = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); + f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); +} + +ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p, + Optional* define_var, const IRDocsifier& d) { + tir::IterVar iter_var = Downcast(attr_stmt->node); + ObjectPath iter_var_p = attr_stmt_p->Attr("node"); + + ExprDoc var_doc{nullptr}; + if (d->IsVarDefined(iter_var->var)) { + var_doc = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); + } else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) { + var_doc = LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag")); + *define_var = iter_var->var; + } else { + InsertEnvThread(iter_var, iter_var_p, d); + var_doc = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); + } + return TIR(d, "launch_thread") + ->Call({ + var_doc, + d->AsDoc(attr_stmt->value, attr_stmt_p->Attr("value")), + }); +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { @@ -331,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); + Optional lhs = NullOpt; Optional rhs = NullOpt; + Optional define_var = NullOpt; tir::Stmt body = stmt->body; ObjectPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { @@ -342,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) /*value=*/d->AsDoc(stmt->value, stmt_p->Attr("value")), /*p=*/stmt_p->Attr("body"), d); body = realize->body; - body_p = body_p->Attr("body"); + body_p = stmt_p->Attr("body")->Attr("body"); } } } if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") { - if (const auto* iter_var = stmt->node.as()) { - if (!d->IsVarDefined(iter_var->var)) { - // `DefineVar` is not used here because a more specific name is desirable - ObjectPath iter_var_p = stmt_p->Attr("node"); - Frame f = FindLowestVarDef(iter_var->var, d).value(); - DefineVar(iter_var->var, f, d); - f->stmts.push_back( - AssignDoc(d->AsDoc(iter_var->var, iter_var_p->Attr("var")), - TIR(d, "env_thread") - ->Call({LiteralDoc::Str(iter_var->thread_tag, - iter_var_p->Attr("thread_tag"))}), // - NullOpt)); - } - rhs = TIR(d, "launch_thread") - ->Call({ - d->AsDoc(iter_var->var, stmt_p->Attr("node")), - d->AsDoc(stmt->value, stmt_p->Attr("value")), - }); + if (stmt->node->IsInstance()) { + rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d); } } if (!rhs.defined()) { @@ -375,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); } With f(d, stmt); + if (define_var.defined()) { + lhs = DefineVar(define_var.value(), *f, d); + } AsDocBody(body, body_p, f->get(), d); - return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise); + return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 08eb12bfa785e..cee5fbd0f0216 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -201,6 +201,15 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, const IRDocsifier& d); +/*! + * \brief Print the creation of a Var + * \param var The Var to be printed + * \param var_p The object path of the Var + * \param d The IRDocsifier + * \return The ExprDoc corresponding to the Var creation + */ +ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d); + /*! \brief A Var occurrence counter visitor */ class OccurrenceCounter : public tir::StmtExprVisitor { public: diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 10c7aaf4f2bb7..ec0f0eaf72b06 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -61,6 +61,7 @@ inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f, const PrinterConfig& cfg) { Doc doc = d->AsDoc(obj, ObjectPath::Root()); + bool move_source_paths = false; if (const auto* expr_doc = doc.as()) { if (!cfg->verbose_expr) { f->stmts.clear(); @@ -72,6 +73,7 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra for (const StmtDoc& d : stmt_block->stmts) { f->stmts.push_back(d); } + move_source_paths = true; } else { LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); } @@ -87,7 +89,13 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra CommentDoc("Metadata omitted. Use show_meta=True in script() method to show it.")); } } - os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg); + if (move_source_paths) { + StmtBlockDoc new_doc(f->stmts); + new_doc->source_paths = std::move(doc->source_paths); + os << DocToPythonScript(new_doc, cfg); + } else { + os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg); + } return os.str(); } diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index e9bff1b6fdee5..ab328efaa6d14 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -76,8 +76,6 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; - /*! \brief The analyzer for simplifying*/ - arith::Analyzer analyzer_; /*! * \brief Update read/write buffers and regions with provided buffer and region @@ -330,7 +328,12 @@ Array BlockReadWriteDetector::CollectRegions( ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + if (range.IsSinglePoint()) { + PrimExpr min = range.min(); + region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); + } else { + region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + } } res.push_back(BufferRegion(buffers[i], region)); } diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc new file mode 100644 index 0000000000000..7ef8e532a3960 --- /dev/null +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file var_use_def_analysis.cc + * \brief Classes and functions to analyze var defition and usage. + */ +#include "var_use_def_analysis.h" +namespace tvm { +namespace tir { + +VarUseDefAnalyzer::VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent) + : visit_thread_extent_(visit_thread_extent) { + for (const Var v : defined_vars) { + use_count_[v.get()] = 0; + } +} + +void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!use_count_.count(iv->var.get())) { + this->HandleDef(iv->var.get()); + } + + if (visit_thread_extent_) { + this->VisitExpr(op->value); + } + + this->VisitStmt(op->body); + } else { + StmtExprVisitor::VisitStmt_(op); + } +} + +void VarUseDefAnalyzer::VisitStmt_(const LetStmtNode* op) { + this->HandleDef(op->var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) { + this->HandleDef(op->loop_var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) { + this->HandleDef(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) { + this->HandleDef(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); +} + +void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to construct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_binding_.find(op->var.get()); + this->VisitExpr(op->value); + if (it != let_binding_.end()) { + ICHECK(deep_equal_(it->second->value, op->value)) + << "Let cannot bind the same var to two different values"; + } else { + this->HandleDef(op->var.get()); + let_binding_[op->var.get()] = op; + } + this->VisitExpr(op->body); +} + +void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { + this->HandleUse(op); + StmtExprVisitor::VisitExpr_(op); +} + +void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) { + for (const auto& iv : op->axis) { + this->HandleDef(iv->var.get()); + } + StmtExprVisitor::VisitExpr_(op); +} + +void VarUseDefAnalyzer::VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; +} + +void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); +} + +void VarUseDefAnalyzer::VisitBuffer(Buffer buffer) { + this->HandleUse(buffer->data.get()); + auto visit_arr = [&](Array arr) { + for (const auto& element : arr) { + this->VisitExpr(element); + } + }; + + visit_arr(buffer->shape); + visit_arr(buffer->strides); +} + +void VarUseDefAnalyzer::HandleDef(const VarNode* v) { + ICHECK(!def_count_.count(v)) << "variable " << v->name_hint + << " has already been defined, the Stmt is not SSA"; + ICHECK(!use_count_.count(v)) << "variable " << v->name_hint + << " has been used before definition!"; + use_count_[v] = 0; + def_count_[v] = 1; +} + +void VarUseDefAnalyzer::HandleUse(const VarNode* v) { + auto it = use_count_.find(v); + if (it != use_count_.end()) { + if (it->second >= 0) { + ++it->second; + } + } else { + undefined_.push_back(GetRef(v)); + use_count_[v] = -1; + } +} + +Array UndefinedVars(const Stmt& stmt, const Array& args) { + VarUseDefAnalyzer m(args); + m(stmt); + return m.undefined_; +} + +Array UndefinedVars(const PrimExpr& expr) { + VarUseDefAnalyzer m({}); + m(expr); + return m.undefined_; +} + +Array UndefinedVars(const PrimExpr& expr, const Array& args) { + VarUseDefAnalyzer m(args); + m(expr); + return m.undefined_; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h new file mode 100644 index 0000000000000..ad275011d90c7 --- /dev/null +++ b/src/tir/analysis/var_use_def_analysis.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/src/tir/analysis/var_use_def_analyzer.h + * \brief Variable definition and usage analysis class. + */ +#ifndef TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ +#define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ + +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. + * \param defined_vars Variables that have been defined. + * \param visit_thread_extent Whether enters thread extent expressions or not. + * \sa UndefinedVars + */ +class VarUseDefAnalyzer : public StmtExprVisitor { + public: + explicit VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent = true); + // The fields are publically readible to + // be accessible to the users. + bool visit_thread_extent_{true}; + Array undefined_; + + std::unordered_map use_count_; + std::unordered_map def_count_; + + private: + ExprDeepEqual deep_equal_; + std::unordered_map let_binding_; + void VisitStmt_(const AttrStmtNode* op) final; + + void VisitStmt_(const LetStmtNode* op) final; + + void VisitStmt_(const ForNode* op) final; + + void VisitStmt_(const AllocateNode* op) final; + + void VisitStmt_(const AllocateConstNode* op) final; + + void VisitStmt_(const StoreNode* op) final; + + void VisitStmt_(const BufferStoreNode* op) final; + + void VisitExpr_(const LetNode* op) final; + + void VisitExpr_(const VarNode* op) final; + + void VisitExpr_(const ReduceNode* op) final; + + void VisitExpr_(const LoadNode* op) final; + + void VisitExpr_(const BufferLoadNode* op) final; + + void HandleDef(const VarNode* v); + + void HandleUse(const VarNode* v); + + void VisitBuffer(Buffer buffer); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index d5caeab53922b..db09ac17e6eb6 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -126,6 +126,15 @@ SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { data_ = std::move(n); } +SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { + auto n = make_object(); + n->name_hint = std::move(name_hint); + n->dtype = GetRuntimeDataType(type_annotation); + n->type_annotation = std::move(type_annotation); + n->span = std::move(span); + data_ = std::move(n); +} + TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { return SizeVar(s, t, span); }); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8af39b24fdb8b..5a9dab4854bd3 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -568,6 +568,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, + index_map); + TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, + storage_scope, index_map); + TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) { Array results; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 41168fb016f3e..82ac9f913374a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -116,6 +116,10 @@ class ConcreteScheduleNode : public ScheduleNode { const Array consumer_blocks = {}) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) override; + BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) override; + BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 9d89c641630bd..5353a051a60ac 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -43,7 +43,7 @@ class TensorIntrinMismatchError : public ScheduleError { std::ostringstream os; os << "The stmt {0} doesn't match the tensor intrin\nThe pattern attempting to be matched:\n" << lhs_stmt_ << "\nDoes not match the tensorize description:\n" - << rhs_stmt_; + << rhs_stmt_ << '\n'; for (const auto& msg : error_messages_) { os << msg << std::endl; } @@ -173,6 +173,9 @@ bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& oth bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { const auto* rhs = other.as(); + for (const IterVar& iter : op->iter_vars) { + lhs_analyzer_.Bind(iter->var, iter->dom); + } // Check block equality. // All iter vars and buffer regions including the order should match. // When checking iter vars, DefEqual is used to remap variables. @@ -465,7 +468,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { + if (!lhs_analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; os << "Buffer base index consistency check failed due to unequal index base: " @@ -487,7 +490,8 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); + PrimExpr normalized_lhs_min = + lhs_analyzer_.Simplify((lhs->region[i + offset]->min - indices_base[i + offset])); if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 394d828673931..debf0f946e283 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -102,8 +102,13 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool assert_mode_; /*! \brief Whether it is visiting the scope block (the outermost block). */ bool is_scope_block = true; - /*! \brief The arithmetic analyzer. */ + /*! \brief The arithmetic analyzer for comparing LHS and RHS */ arith::Analyzer analyzer_; + /*! + * \brief The arithmetic analyzer for simplifying expressions on LHS. + * This analyzer only contains the domains of the iterators on LHS. + */ + arith::Analyzer lhs_analyzer_; /*! \brief Additional error messages. Only used when assert_mode is true. */ std::vector error_messages_; // variable remap if any diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0b7a4f6280dbd..563864229a262 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -269,6 +269,39 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}); +/*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the indices to access allocated cache buffer is customized by user. + * \param self The state of the schedule + * \param block_sref The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope, + const IndexMap& index_map); +/*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block that writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the indices to access allocated cache buffer is customized by user. + * \param self The state of the schedule + * \param block_sref The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope, + const IndexMap& index_map); + /*! *! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index a2b45d407ddf6..39e915ba961ae 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -19,6 +19,7 @@ #include +#include "../../analysis/var_use_def_analysis.h" #include "../utils.h" namespace tvm { @@ -94,6 +95,132 @@ Optional GetBufferRegionFromBuffer(const Array& buff return res; } +struct ReindexCacheStageInfo : CacheStageInfo { + /* Indices used to access the allocated cache buffer. */ + Array indices; + /* Touched loop variable related information. */ + Array loop_vars; + Array loop_ranges; + /* Touched block variable related information. */ + Array block_iter_vars; + Array block_iter_values; +}; + +/* \brief The schedule error that accessed buffer region is not a single point for + * reindex_cache_read/write. */ +class NotSinglePointAccess : public ScheduleError { + public: + explicit NotSinglePointAccess(IRModule mod, Block block, BufferRegion cache_region, + bool is_cache_read) + : mod_(std::move(mod)), block_(std::move(block)), cache_region_(cache_region) { + primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; + } + + String FastErrorString() const final { + return "ScheduleError: The buffer region accessed inside the block is not a single point."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The buffer region " << cache_region_ + << " accessed inside block {0} is not a single point, which violates" + << " the prerequisite of " << primitive_name_ << " primitive."; + return String(os.str()); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion cache_region_; + String primitive_name_; +}; + +/*! + * \brief Create a loop nest that represents reindex cache copy (reindex_cache_read / + * reindex_cache_write) from read buffer to write buffer. + * \param cache_region The cached copy region. + * \param info The cache stage information, which will be updated in the function. + * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \returns A block indicating the body of the loop nesting. + */ +template +Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageInfo* info, + const String& storage_scope) { + // loop variables + std::vector loop_vars; + // block variables + Array block_vars; + // bindings in block realize + std::vector iter_values; + // Create loop vars and block vars' binding_value + Map var_map; + for (size_t i = 0; i < info->loop_vars.size(); ++i) { + Var original_var = info->loop_vars[i]; + Var loop_var(original_var->name_hint, original_var.dtype()); + var_map.Set(original_var, loop_var); + loop_vars.push_back(loop_var); + } + for (size_t i = 0; i < info->block_iter_vars.size(); ++i) { + IterVar original_block_var = info->block_iter_vars[i]; + PrimExpr original_iter_value = info->block_iter_values[i]; + IterVar block_var = IterVar( + /*dom=*/original_block_var->dom, + /*var=*/Var(original_block_var->var->name_hint, original_block_var->var.dtype()), + /*IterVarType=*/kDataPar); + var_map.Set(original_block_var->var, block_var->var); + block_vars.push_back(block_var); + iter_values.push_back(Substitute(original_iter_value, var_map)); + } + + // block access region for read/write buffers + Region read_access_region, write_access_region; + Array read_access_indices, write_access_indices; + // Compute read/write region and read/write access indices. + Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; + Region& old_region = (is_cache_read) ? read_access_region : write_access_region; + for (const Range& range : cache_region->region) { + old_indices.push_back(Substitute(range->min, var_map)); + old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); + } + Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; + Region& new_region = (is_cache_read) ? write_access_region : read_access_region; + for (const PrimExpr& idx : info->indices) { + new_indices.push_back(Substitute((idx), var_map)); + new_region.push_back(Range::FromMinExtent(new_indices.back(), Integer(1))); + } + + // Create New Block + Block block( + /*iter_vars*/ std::move(block_vars), + /*reads=*/{BufferRegion(info->read_buffer, read_access_region)}, + /*writes=*/{BufferRegion(info->write_buffer, write_access_region)}, + /*name_hint*/ cache_region->buffer->name + "_" + storage_scope, + /*body=*/ + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices), + write_access_indices), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*buf_doms=*/{}); + // Create Block Realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/block); + // Create surrounding loops + for (size_t i = loop_vars.size(); i >= 1; --i) { + body = For(/*loop_var=*/loop_vars[i - 1], + /*min=*/info->loop_ranges[i - 1]->min, + /*extent=*/info->loop_ranges[i - 1]->extent, + /*kind=*/ForKind::kSerial, + /*body=*/body); + } + info->cache_stage = std::move(body); + return block; +} + /*! * \brief Create a loop nest that represents cache copy (cache_read / cache_write) from read buffer * to write buffer. @@ -378,9 +505,12 @@ class CacheLocDetector : public StmtVisitor { public: /*! * \brief Detect the insertion position of the cache stage, and write the position into the - * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique - * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref - * of the scope block of the cached block \param info The cache stage info. + * CacheStageInfo + * \param self The state of the schedule + * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or + * cache_write + * \param scope_sref The sref of the scope block of the cached block + * \param info The cache stage info. */ template static void Detect(const ScheduleState& self, const StmtSRef& block_sref, @@ -433,8 +563,9 @@ class CacheLocDetector : public StmtVisitor { * \brief Constructor * \param self The state of the schedule * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or - * cache_write \param scope_sref The sref of the scope block of the cached block \param - * related_blocks Producer blocks for cache_write, or consumer blocks for cache_read + * cache_write + * \param scope_sref The sref of the scope block of the cached block + * \param related_blocks Producer blocks for cache_write, or consumer blocks for cache_read */ CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref, const std::vector& related_blocks) @@ -525,9 +656,11 @@ class CacheInplaceLocDetector : public StmtVisitor { public: /*! * \brief Detect the insertion position of the cache stage, and write the position into the - * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique - * block of the buffer being applied cache_inplace \param scope_sref The sref - * of the scope block of the cached block \param info The cache stage info. + * CacheStageInfo + * \param self The state of the schedule + * \param block_sref The sref of the unique block of the buffer being applied cache_inplace + * \param scope_sref The sref of the scope block of the cached block + * \param info The cache stage info. */ static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { @@ -606,6 +739,8 @@ class CacheInplaceLocDetector : public StmtVisitor { int loc_pos_{-1}; }; +class ReindexCacheReadRewriter; + /*! \brief Mutator for CacheRead. */ class CacheReadRewriter : public StmtExprMutator { public: @@ -622,7 +757,14 @@ class CacheReadRewriter : public StmtExprMutator { private: explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), info_(info) {} + : scope_sref_(scope_sref), info_(info) { + update_access_regions = [&](Array regions) { + return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); + }; + update_match_buffers = [&](Array match_buffers) { + return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); + }; + } Stmt VisitStmt_(const ForNode* loop) final { Stmt stmt = StmtMutator::VisitStmt_(loop); @@ -636,7 +778,7 @@ class CacheReadRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const BlockNode* block) override { Block old_stmt = GetRef(block); // Check if this block is one of the specified consumers. // If no consumer blocks are specified, all blocks should be considered consumers. @@ -678,10 +820,8 @@ class CacheReadRewriter : public StmtExprMutator { // Otherwise, update read regions and match_buffers // Only make this change if the block is one of the specified consumers. if (is_consumer) { - Array reads = - ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer); - Array match_buffers = - ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer); + Array reads = update_access_regions(block->reads); + Array match_buffers = update_match_buffers(block->match_buffers); if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { ObjectPtr n = make_object(*stmt.as()); n->reads = std::move(reads); @@ -694,7 +834,7 @@ class CacheReadRewriter : public StmtExprMutator { return std::move(stmt); } - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { ObjectPtr n = make_object(*load); n->buffer = info_->write_buffer; @@ -721,8 +861,82 @@ class CacheReadRewriter : public StmtExprMutator { CacheStageInfo* info_; /*! \brief Whether the most recently visited block is a specified consumer. */ bool current_block_consumes; + /*! \brief function to update read/write region of block being cache read.*/ + std::function(Array)> update_access_regions; + /*! \brief function to update match buffers of block being cache read.*/ + std::function(Array)> update_match_buffers; + + friend ReindexCacheReadRewriter; }; +/*! \brief Mutator for ReindexCacheRead. */ +class ReindexCacheReadRewriter : public CacheReadRewriter { + public: + /*! + * \brief Rewrite the AST and add a cache_read stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param info The cache stage information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) { + ReindexCacheReadRewriter rewriter(scope_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReindexCacheReadRewriter(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) + : CacheReadRewriter(scope_sref, info) { + new_indices_ = info->indices; + update_access_regions = [&](Array reads) { + Array new_reads; + for (const BufferRegion& buf_region : reads) { + if (buf_region->buffer.same_as(info_->read_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_reads.push_back(BufferRegion(info_->write_buffer, region)); + } else { + new_reads.push_back(buf_region); + } + } + return new_reads; + }; + update_match_buffers = [&](const Array match_buffers) { + Array new_match_buffers; + for (const MatchBufferRegion& match_buffer_region : match_buffers) { + BufferRegion source = match_buffer_region->source; + if (source->buffer.same_as(info_->read_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, + BufferRegion(info_->write_buffer, region))); + } else { + new_match_buffers.push_back(match_buffer_region); + } + } + return new_match_buffers; + }; + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { + ObjectPtr n = make_object(*load); + n->buffer = info_->write_buffer; + n->indices = new_indices_; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + /*! \brief The indices to use for new buffer. */ + Array new_indices_; +}; + +class ReindexCacheWriteRewriter; + /*! \brief Mutator for CacheWrite */ class CacheWriteRewriter : public StmtExprMutator { public: @@ -742,7 +956,14 @@ class CacheWriteRewriter : public StmtExprMutator { private: explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) {} + : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) { + update_access_regions = [&](Array regions) { + return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); + }; + update_match_buffers = [&](Array match_buffers) { + return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); + }; + } Stmt VisitStmt_(const ForNode* loop) final { Stmt stmt = StmtMutator::VisitStmt_(loop); @@ -756,7 +977,7 @@ class CacheWriteRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const BlockNode* block) override { Block old_stmt = GetRef(block); // Check if this block is one of the specified cache consumers. @@ -765,12 +986,13 @@ class CacheWriteRewriter : public StmtExprMutator { const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); Block consumer_block = GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { - Array reads = - ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); - Array match_buffers = - ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); - if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + Array writes = update_access_regions(block->writes); + Array reads = update_access_regions(block->reads); + Array match_buffers = update_match_buffers(block->match_buffers); + if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || + !match_buffers.same_as(block->match_buffers)) { auto n = CopyOnWrite(block); + n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); n->body = VisitStmt(block->body); @@ -809,10 +1031,9 @@ class CacheWriteRewriter : public StmtExprMutator { } } else { // Since cache_write changes the block, we need to update the buffer it writes - auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer); - auto reads = ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); - auto match_buffers = - ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); + auto writes = update_access_regions(block->writes); + auto reads = update_access_regions(block->reads); + auto match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { ObjectPtr n = make_object(*stmt.as()); @@ -826,7 +1047,7 @@ class CacheWriteRewriter : public StmtExprMutator { return std::move(stmt); } - Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt VisitStmt_(const BufferStoreNode* store) override { BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); if (stmt->buffer.same_as(info_->write_buffer)) { auto n = CopyOnWrite(stmt.get()); @@ -837,7 +1058,7 @@ class CacheWriteRewriter : public StmtExprMutator { } } - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->write_buffer)) { ObjectPtr n = make_object(*load); n->buffer = info_->read_buffer; @@ -870,6 +1091,93 @@ class CacheWriteRewriter : public StmtExprMutator { CacheStageInfo* info_; /*! \brief Whether the current node is under the given block. */ bool under_writer_block_{false}; + /*! \brief function to update read/write region of block being cache write.*/ + std::function(Array)> update_access_regions; + /*! \brief function to update match buffers of block being cache write.*/ + std::function(Array)> update_match_buffers; + + friend ReindexCacheWriteRewriter; +}; + +/*! \brief Mutator for ReindexCacheWrite. */ +class ReindexCacheWriteRewriter : public CacheWriteRewriter { + public: + /*! + * \brief Rewrite the AST and add a cache_write stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param writer_block_sref The only writer block in the scope. + * \param info The cache stage information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + ReindexCacheStageInfo* info) { + ReindexCacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReindexCacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + ReindexCacheStageInfo* info) + : CacheWriteRewriter(scope_sref, writer_block_sref, info) { + new_indices_ = info->indices; + update_access_regions = [&](Array reads) { + Array new_reads; + for (const BufferRegion& buf_region : reads) { + if (buf_region->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_reads.push_back(BufferRegion(info_->read_buffer, region)); + } else { + new_reads.push_back(buf_region); + } + } + return new_reads; + }; + update_match_buffers = [&](const Array match_buffers) { + Array new_match_buffers; + for (const MatchBufferRegion& match_buffer_region : match_buffers) { + BufferRegion source = match_buffer_region->source; + if (source->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, + BufferRegion(info_->read_buffer, region))); + } else { + new_match_buffers.push_back(match_buffer_region); + } + } + return new_match_buffers; + }; + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); + if (stmt->buffer.same_as(info_->write_buffer)) { + auto n = CopyOnWrite(stmt.get()); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return Stmt(n); + } else { + return std::move(stmt); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->write_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + /*! \brief The indices to use for new buffer. */ + Array new_indices_; }; /*! @@ -898,7 +1206,9 @@ Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_ite return Buffer(new_buffer); } -/*! \brief The schedule error that the target is not a leaf block. */ +/*! + * \brief The schedule error that the target is not a leaf block. + */ class NotLeafBlockError : public ScheduleError { public: NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} @@ -1297,6 +1607,293 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } +Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { + std::vector result; + for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); + parent = parent->parent) { + if (parent == top_sref.get()) break; + result.push_back(GetRef(parent)); + } + return {result.rbegin(), result.rend()}; +} + +/*! + * \brief The schedule error that block iter vars appears in old buffer and new + * allocated cache buffer does not match. + */ +class ReindexCacheReadWriteNotMatchError : public ScheduleError { + public: + ReindexCacheReadWriteNotMatchError(IRModule mod, Block block, Var var, + Array old_indices, Array new_indices, + bool is_cache_read, bool appears_in_old) + : mod_(std::move(mod)), block_(std::move(block)), var_(std::move(var)) { + primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; + if (appears_in_old) { + appears_indices_ = std::move(old_indices); + other_indices_ = std::move(new_indices); + } else { + appears_indices_ = std::move(new_indices); + other_indices_ = std::move(old_indices); + } + } + String FastErrorString() const final { + return "ScheduleError: the block itervars appeared in lhs and rhs of reindex cache stage do " + "not match."; + } + + String DetailRenderTemplate() const final { + std::stringstream s; + s << "Error when applying " << primitive_name_ << " on block {0}, the block itervar " << var_ + << " appears in " << appears_indices_ << ", but not in " << other_indices_ << "."; + return String(s.str()); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + String primitive_name_; + Block block_; + Var var_; + Array appears_indices_; + Array other_indices_; +}; + +/*! + * \brief Update ReindexCacheStageInfo and create new cache buffer, used in + * both ReindexCacheRead and ReindexCacheWrite. + * \param info Pointer to ReindexCacheStageInfo + * \param mod The IRModule. + * \param block_sref The StmtSRef to the block we are working on. + * \param storage_scope The storage scope of cache buffer (e.g. "shared"/"local"). + * \param index_map The user defined indices. + * \param blok The block we are working on. + * \param realize The BlockRealize this block belongs to. + * \param old_buffer The buffer whose buffer access need to be rewriten. + * \param cache_region The old buffer access region. + */ +template +void CollectReindexCacheStageInfoAndCreateBuffer( + ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& block_sref, + const String& storage_scope, const IndexMap& index_map, const Block& block, + const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { + Array block_iter_vars, block_shape; + for (const IterVar& iter_var : block->iter_vars) { + block_iter_vars.push_back(iter_var); + block_shape.push_back(iter_var->dom->extent); + } + Array new_indices = index_map->MapIndices(block_iter_vars); + Array new_shape = index_map->MapShape(block_shape); + info->indices = new_indices; + + // Step 5. Update CacheTouchedInfo + VarUseDefAnalyzer collector_old(/*defined_vars=*/{}); + Array old_indices; + for (const Range& range : cache_region->region) { + collector_old(range->min); + old_indices.push_back(range->min); + } + + arith::Analyzer analyzer; + + VarUseDefAnalyzer collector_new(/*defined_vars=*/{}); + for (const PrimExpr& idx : new_indices) { + collector_new(idx); + } + + VarUseDefAnalyzer collector_iter_values(/*defined_vars=*/{}); + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& block_iter_var = block->iter_vars[i]; + const PrimExpr& block_iter_value = realize->iter_values[i]; + bool appears_in_new = collector_new.use_count_.count(block_iter_var->var.get()); + bool appears_in_old = collector_old.use_count_.count(block_iter_var->var.get()); + if (appears_in_new != appears_in_old) { + throw ReindexCacheReadWriteNotMatchError(mod, block, block_iter_var->var, old_indices, + new_indices, is_cache_read, appears_in_old); + } + if (appears_in_new) { + info->block_iter_vars.push_back(block_iter_var); + info->block_iter_values.push_back(block_iter_value); + collector_iter_values(block_iter_value); + } + } + + for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info->loc_sref)) { + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (collector_iter_values.use_count_.count(loop->loop_var.get())) { + info->loop_vars.push_back(loop->loop_var); + info->loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); + } + } + + // Create new buffer + ObjectPtr new_buffer = make_object(*old_buffer.get()); + ObjectPtr new_var = make_object(*old_buffer->data.get()); + const auto* ptr_type = TVM_TYPE_AS(old_buffer->data->type_annotation, PointerTypeNode); + new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); + new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); + new_buffer->name = old_buffer->name + "_" + storage_scope; + new_buffer->shape = new_shape; + + if (is_cache_read) { + info->write_buffer = Buffer(new_buffer); + info->alloc = info->write_buffer; + } else { + info->read_buffer = Buffer(new_buffer); + info->alloc = info->read_buffer; + } +} + +/*! \brief Check whether given cache_region is a single point access. */ +template +void CheckSinglePoint(ScheduleState self, const Block& block, const BufferRegion& cache_region) { + bool single_point = true; + for (const Range& range : cache_region->region) { + const auto* ext_int = range->extent.as(); + if (!ext_int || ext_int->value != 1) { + single_point = false; + } + } + if (!single_point) { + throw NotSinglePointAccess(self->mod, block, cache_region, is_cache_read); + } +} + +StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is at most one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the consumers blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + + // Step 1. Check index, getting the target buffer and the parent scope + Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + BlockRealize realize = GetBlockRealize(self, block_sref); + Buffer read_buffer = GetNthAccessBuffer(self, block, read_buffer_index, BufferIndexType::kRead); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Create CacheStageInfo + ReindexCacheStageInfo info; + info.read_buffer = read_buffer; + info.consumer_blocks.insert(block_sref); + + // Step 3. Update cache stage info. + Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); + ICHECK(maybe_region.defined()) << read_buffer + << " should appear in the block's read region: " << block->reads; + BufferRegion cache_region = maybe_region.value(); + if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + // Case 1. The buffer is written inside the block. + StmtSRef write_block_sref = _write_block_sref.value(); + // Find the producing region + StmtSRef parent_sref = GetRef(write_block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); + } else { + // Case 2. The buffer is the input block for the scope. + info.loc_sref = scope_sref; + info.loc_pos = 0; + } + + // Step 4. Check whether cache region is a single point. + CheckSinglePoint(self, block, cache_region); + + // Step 5. Collect ReindexCacheStageInfo and create new buffer. + CollectReindexCacheStageInfoAndCreateBuffer( + &info, self->mod, block_sref, storage_scope, index_map, block, realize, read_buffer, + cache_region); + + // Step 6. Making new cache stage block and rewrite readers. + Block cache_read_stage = + MakeReindexCacheStage(/*cache_region=*/cache_region, + /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReindexCacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + + // Step 7. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + +StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is only one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the producer blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + + // Step 1. Checking index, getting the target buffer and the parent scope + Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + BlockRealize realize = GetBlockRealize(self, block_sref); + Buffer write_buffer = + GetNthAccessBuffer(self, block, write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Creating CacheStageInfo + ReindexCacheStageInfo info; + info.write_buffer = write_buffer; + LOG(INFO) << block->name_hint; + info.consumer_blocks.insert(block_sref); + + // Step 3. Check the only writer block. + ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); + + // Step 4. Find the producing region and insert position + Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); + ICHECK(maybe_region.defined()) << write_buffer << " should appear in the block's write region"; + StmtSRef parent_sref = GetRef(block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, block_sref, scope_sref, &info); + BufferRegion cache_region = maybe_region.value(); + + CollectReindexCacheStageInfoAndCreateBuffer( + &info, self->mod, block_sref, storage_scope, index_map, block, realize, write_buffer, + cache_region); + + // Step 5. Check whether cache region is a single point. + CheckSinglePoint(self, block, cache_region); + + // Step 6. Making new cache stage block and rewrite readers. + Block cache_write_stage = + MakeReindexCacheStage(/*cache_region=*/cache_region, + /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReindexCacheWriteRewriter::Rewrite( + /*scope_sref=*/scope_sref, + /*writer_block_sref=*/block_sref, /*info=*/&info); + + // Step 7. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + /*! \brief The schedule error that the target block doesn't both read&write target buffer. */ class NotReadWriteError : public ScheduleError { public: @@ -1606,9 +2203,70 @@ struct ReIndexTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct ReindexCacheReadTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReindexCacheRead"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, IndexMap index_map, + Integer read_buffer_index, String storage_scope) { + return sch->ReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map); + } + + static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("reindex_cache_read"); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.Input("index_map", index_map->ToPythonString()); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReindexCacheWriteTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReindexCacheWrite"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, IndexMap index_map, + Integer write_buffer_index, String storage_scope) { + return sch->ReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map); + } + + static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("reindex_cache_write"); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.Input("index_map", index_map->ToPythonString()); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheInplaceTraits); TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheReadTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheWriteTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 4177d916486b8..cb8b5a1d77879 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -179,6 +179,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") + .set_body_method(&ScheduleNode::ReindexCacheRead); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") + .set_body_method(&ScheduleNode::ReindexCacheWrite); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") .set_body_method(&ScheduleNode::CacheInplace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index dba34c2ca3f37..a5cb66a0cb44c 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -309,6 +309,38 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + BlockRV result = + ConcreteScheduleNode::ReindexCacheRead(block_rv, read_buffer_index, storage_scope, index_map); + + static const InstructionKind& kind = InstructionKind::Get("ReindexCacheRead"); + trace_->Append( + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, index_map}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + BlockRV result = ConcreteScheduleNode::ReindexCacheWrite(block_rv, write_buffer_index, + storage_scope, index_map); + + static const InstructionKind& kind = InstructionKind::Get("ReindexCacheWrite"); + trace_->Append( + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, index_map}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) { Array result = diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 7bd83855557df..1fcba98063800 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -76,6 +76,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { const Array consumer_blocks = {}) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) final; + BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) final; + BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) final; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 2de7d38d7d57a..4f411228d262e 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -35,64 +35,43 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "../analysis/var_use_def_analysis.h" #include "ir_utils.h" namespace tvm { namespace tir { -// use/def analysis, also delete unreferenced lets -class VarUseDefAnalysis : public StmtExprMutator { +/*! + * \brief Visitor class to collect device-side program information. + */ +class DeviceInfoCollector : public StmtVisitor { public: - Stmt VisitStmt_(const AttrStmtNode* op) final { + Array thread_axis_; + Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; + + private: + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times // use the first appearance as def. - if (!use_count_.count(iv->var.get())) { - this->HandleDef(iv->var.get()); + if (!defined_thread.count(iv.get())) { + defined_thread.insert(iv.get()); thread_axis_.push_back(iv); thread_extent_.push_back(op->value); } - PrimExpr value = op->value; - if (visit_thread_extent_) { - value = this->VisitExpr(value); - } - Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); - } - return AttrStmt(op->node, op->attr_key, value, body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - Stmt VisitStmt_(const LetStmtNode* op) final { - this->HandleDef(op->var.get()); - Stmt body = this->VisitStmt(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { - return body; + this->VisitExpr(op->value); + this->VisitStmt(op->body); } else { - PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return LetStmt(op->var, value, body); - } + StmtVisitor::VisitStmt_(op); } } - Stmt VisitStmt_(const ForNode* op) final { - this->HandleDef(op->loop_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateNode* op) final { - this->HandleDef(op->buffer_var.get()); + void VisitStmt_(const AllocateNode* op) final { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; @@ -104,44 +83,42 @@ class VarUseDefAnalysis : public StmtExprMutator { dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); use_dyn_shmem_ = true; } - return StmtExprMutator::VisitStmt_(op); + StmtVisitor::VisitStmt_(op); } - Stmt VisitStmt_(const AllocateConstNode* op) final { - this->HandleDef(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } + // recording what thread axis have been visited. + std::unordered_set defined_thread; +}; - Stmt VisitStmt_(const StoreNode* op) final { - LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; - } +/*! + * \brief Mutator class to remove unrefenced let stmt/expressions. + * \param use_count The pre-computed variable to use count map. + */ +class UnreferencedLetRemover : public StmtExprMutator { + public: + explicit UnreferencedLetRemover(const std::unordered_map& use_count) + : use_count_(use_count) {} - Stmt VisitStmt_(const BufferStoreNode* op) final { - VisitBuffer(op->buffer); - return StmtExprMutator::VisitStmt_(op); + private: + Stmt VisitStmt_(const LetStmtNode* op) final { + Stmt body = this->VisitStmt(op->body); + // eliminate unreferenced let + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { + return body; + } else { + PrimExpr value = this->VisitExpr(op->value); + if (body.same_as(op->body) && value.same_as(op->value)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } } PrimExpr VisitExpr_(const LetNode* op) final { - // Weaker SSA condition - // A single var can be binded in multiple lets - // but they have to bind to the same value. - // This is used to allow cases when we reuse a single let - // expression to construct a nested expr. - // (let x = 1 in x + 1) * (let x = 1 in x + 1) - auto it = let_binding_.find(op->var); - PrimExpr value = this->VisitExpr(op->value); - if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, value)) - << "Let cannot bind the same var to two different values"; - return GetRef(it->second); - } else { - this->HandleDef(op->var.get()); - let_binding_[op->var] = op; - } PrimExpr body = this->VisitExpr(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && - simplify_let_) { + PrimExpr value = this->VisitExpr(op->value); + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { return body; } else { if (body.same_as(op->body) && value.same_as(op->value)) { @@ -152,96 +129,10 @@ class VarUseDefAnalysis : public StmtExprMutator { } } - PrimExpr VisitExpr_(const VarNode* op) final { - this->HandleUse(GetRef(op)); - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const ReduceNode* op) final { - for (const auto& iv : op->axis) { - this->HandleDef(iv->var.get()); - } - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const LoadNode* op) final { - LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - VisitBuffer(op->buffer); - return StmtExprMutator::VisitExpr_(op); - } - - void VisitBuffer(Buffer buffer) { - this->HandleUse(buffer->data); - auto visit_arr = [&](Array arr) { - for (const auto& element : arr) { - this->VisitExpr(element); - } - }; - - visit_arr(buffer->shape); - visit_arr(buffer->strides); - } - - void HandleDef(const VarNode* v) { - ICHECK(!def_count_.count(v)) << "variable " << v->name_hint - << " has already been defined, the Stmt is not SSA"; - ICHECK(!use_count_.count(v)) << "variable " << v->name_hint - << " has been used before definition!"; - use_count_[v] = 0; - def_count_[v] = 1; - } - - void HandleUse(const PrimExpr& v) { - ICHECK(v.as()); - Var var = Downcast(v); - auto it = use_count_.find(var.get()); - if (it != use_count_.end()) { - if (it->second >= 0) { - ++it->second; - } - } else { - undefined_.push_back(var); - use_count_[var.get()] = -1; - } - } - - // The fields are publically readible to - // be accessible to the users. - bool visit_thread_extent_{true}; - bool simplify_let_{true}; - Array undefined_; - Array thread_axis_; - Array thread_extent_; - PrimExpr dyn_shmem_size_{0}; - bool use_dyn_shmem_{false}; - std::unordered_map use_count_; - std::unordered_map def_count_; - - private: - ExprDeepEqual deep_equal_; - std::unordered_map let_binding_; + // pre-computed variable to use count map. + const std::unordered_map& use_count_; }; -Array UndefinedVars(const Stmt& stmt, const Array& args) { - VarUseDefAnalysis m; - m.simplify_let_ = false; - for (Var arg : args) { - m.use_count_[arg.get()] = 0; - } - m(stmt); - return m.undefined_; -} - -Array UndefinedVars(const PrimExpr& expr) { - VarUseDefAnalysis m; - m.simplify_let_ = false; - m(expr); - return m.undefined_; -} - class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) @@ -266,16 +157,19 @@ class HostDeviceSplitter : public StmtMutator { os << name_prefix_ << "_kernel" << device_func_counter_++; std::string kernel_symbol = os.str(); // isolate the device function. - VarUseDefAnalysis m; - m.visit_thread_extent_ = false; - body = m(std::move(body)); + VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); + use_def(body); + DeviceInfoCollector dev_info; + dev_info(body); + UnreferencedLetRemover let_remover(use_def.use_count_); + body = let_remover(std::move(body)); Array params; Array arguments; Map remap_vars; // Strictly order the arguments: Var pointers, positional arguments. - for (Var var : m.undefined_) { + for (Var var : use_def.undefined_) { if (var.dtype().is_handle()) { // Create a new version of v. auto it = handle_data_type_.find(var.get()); @@ -295,7 +189,7 @@ class HostDeviceSplitter : public StmtMutator { } } // positional arguments - for (Var var : m.undefined_) { + for (Var var : use_def.undefined_) { if (!var.dtype().is_handle()) { params.push_back(var); arguments.push_back(var); @@ -305,7 +199,8 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); + device_func = + WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, dev_info.thread_axis_); device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, @@ -313,7 +208,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); - if (m.use_dyn_shmem_) { + if (dev_info.use_dyn_shmem_) { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); } @@ -325,11 +220,11 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr arg : arguments) { call_args.push_back(arg); } - for (PrimExpr ext : m.thread_extent_) { + for (PrimExpr ext : dev_info.thread_extent_) { call_args.push_back(ext); } - if (m.use_dyn_shmem_) { - call_args.push_back(m.dyn_shmem_size_); + if (dev_info.use_dyn_shmem_) { + call_args.push_back(dev_info.dyn_shmem_size_); } return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 1e55cb22ee263..dc14e4512f1e7 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -33,6 +33,7 @@ #include #include +#include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" namespace tvm { @@ -43,6 +44,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; + int unroll_local_access; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -57,6 +59,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); + TVM_ATTR_FIELD(unroll_local_access) + .describe("Whether to always unroll local access") + .set_default(false); } }; @@ -68,14 +73,30 @@ class UnrollLoopConfig : public Attrs { TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode); TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); +class VarLocalAccessMarker : public ExprVisitor { + public: + explicit VarLocalAccessMarker( + std::unordered_set* var_touched_local) + : var_touched_local_(var_touched_local) {} + + void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef(op)); } + + private: + std::unordered_set* var_touched_local_; +}; + +// The Visitor is used to check whether var is used as write index in a local memory +// If a loop var is used as indices to a local memory, it must be unrolled so +// the local memory access can be turned into register access. class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll) + bool explicit_unroll, bool unroll_local_access) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) {} + explicit_unroll_(explicit_unroll), + unroll_local_access_(unroll_local_access) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -96,6 +117,7 @@ class LoopUnroller : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) { + // Post order so we can collect more information Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); int value = GetExtent(op); @@ -111,6 +133,12 @@ class LoopUnroller : public StmtExprMutator { auto_unroll = true; } + // If a loop var is used as indices to a local memory, it must be unrolled so + // the local memory access can be turned into register access. + if (this->var_touched_local_.count(op->loop_var) && value > 0 && unroll_local_access_) { + auto_unroll = true; + } + if (auto_unroll) { step_count_ *= value; unroll_depth_ += 1; @@ -137,8 +165,32 @@ class LoopUnroller : public StmtExprMutator { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + if (unroll_local_access_) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } + return GetRef(op); + } + Stmt VisitStmt_(const BufferStoreNode* op) final { ++step_count_; + if (unroll_local_access_) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } return StmtExprMutator::VisitStmt_(op); } @@ -161,7 +213,7 @@ class LoopUnroller : public StmtExprMutator { unroll_depth_ = std::max(unroll_depth_, unroll_depth); return ret; }; - return StmtMutator::VisitSeqStmt_(op, false, fmutate); + return StmtExprMutator::VisitSeqStmt_(op, false, fmutate); } Stmt Unroll(const ForNode* op) { @@ -202,19 +254,23 @@ class LoopUnroller : public StmtExprMutator { // this not not count the total steps, only count the number of loops int auto_max_extent_; bool explicit_unroll_; + // Wether to unroll loops to local access. + bool unroll_local_access_{false}; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. int unroll_depth_{0}; // Number of total steps unrolled int step_count_{0}; + // set of indices touched during visit local memory + std::unordered_set var_touched_local_; // analyzer arith::Analyzer analyzer_; }; Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll)(stmt); + cfg->explicit_unroll, cfg->unroll_local_access)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 3a3f297729fda..61b1828aad992 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -40,11 +40,12 @@ def test_save_dumps(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("data") - dump_formats = {"relay": "fake relay", "ll": "fake llvm", "asm": "fake asm"} + dump_formats = {"relay": "fake relay", "tir": "fake tir", "ll": "fake llvm", "asm": "fake asm"} tvmc.compiler.save_dumps("fake_module", dump_formats, dump_root=tmpdir) assert path.exists("{}/{}".format(tmpdir, "fake_module.ll")) assert path.exists("{}/{}".format(tmpdir, "fake_module.asm")) + assert path.exists("{}/{}".format(tmpdir, "fake_module.tir")) assert path.exists("{}/{}".format(tmpdir, "fake_module.relay")) @@ -69,7 +70,11 @@ def verify_compile_tflite_module(model, shape_dict=None, use_vm=False): pytest.importorskip("tflite") tvmc_model = tvmc.load(model, shape_dict=shape_dict) tvmc_package = tvmc.compile( - tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW", use_vm=use_vm + tvmc_model, + target="llvm", + dump_code="ll", + desired_layout="NCHW", + use_vm=use_vm, ) dumps_path = tvmc_package.package_path + ".ll" verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) @@ -87,6 +92,28 @@ def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant): verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=use_vm) +def test_single_tir_dump(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="tir") + dumps_path = tvmc_package.package_path + ".tir" + assert os.path.exists(dumps_path) + with open(dumps_path) as f: + assert "tir" in f.read() + + +def test_code_dumps(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + dump_code = ["asm", "ll", "tir", "relay"] + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code=dump_code) + for ext in dump_code: + dumps_path = tvmc_package.package_path + "." + ext + assert os.path.exists(dumps_path) + with open(dumps_path) as f: + assert len(f.read()) > 0 + + # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @pytest.mark.skipif( not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed" diff --git a/tests/python/driver/tvmc/test_transform.py b/tests/python/driver/tvmc/test_transform.py index 98bd3b5f98a3c..72c7cda6ff1a7 100644 --- a/tests/python/driver/tvmc/test_transform.py +++ b/tests/python/driver/tvmc/test_transform.py @@ -19,8 +19,10 @@ import tvm from tvm import relay +from tvm.relay import testing +from tvm.relay.expr_functor import ExprMutator from tvm.ir.instrument import pass_instrument -from tvm.driver.tvmc.transform import convert_graph_layout +from tvm.driver.tvmc.transform import apply_graph_transforms def test_layout_transform_fold_constant(relay_conv2d): @@ -39,7 +41,7 @@ def run_after_pass(self, _, info): pass_names = CollectPassNames() with tvm.transform.PassContext(opt_level=3, instruments=[pass_names]): - convert_graph_layout(relay_conv2d, desired_layout) + apply_graph_transforms(relay_conv2d, {"desired_layout": desired_layout}) names = pass_names.names assert "ConvertLayout" in names @@ -59,7 +61,7 @@ def test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch): monkeypatch.setattr(relay.transform, "ConvertLayout", mock_convert_layout) with tvm.transform.PassContext(opt_level=3): - convert_graph_layout(relay_conv2d, desired_layout) + apply_graph_transforms(relay_conv2d, {"desired_layout": desired_layout}) mock_convert_layout.assert_called_once_with( { @@ -70,5 +72,98 @@ def test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch): ) +def test_layout_transform_to_mixed_precision_pass_args_mock(relay_conv2d, monkeypatch): + """ + Check the mixed precision arugments which are expected when + mixed precision arguments are provided. + """ + mock_mixed_precision = MagicMock() + mock_mixed_precision.return_value = tvm.driver.tvmc.transform.MixedPrecision([], "") + monkeypatch.setattr(tvm.driver.tvmc.transform, "MixedPrecision", mock_mixed_precision) + + with tvm.transform.PassContext(opt_level=3): + apply_graph_transforms( + relay_conv2d, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float16", + }, + ) + mock_mixed_precision.assert_called_with(["nn.conv2d"], "float16") + + apply_graph_transforms( + relay_conv2d, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float32", + }, + ) + mock_mixed_precision.assert_called_with(["nn.conv2d", "nn.dense"], "float32") + + +def test_layout_transform_to_mixed_precision_pass_args_graph(): + """ + Check the mixed precision arugments application with in a graph. + """ + + mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") + + class CheckOpMutator(ExprMutator): + """Inspect Ops According to expected types.""" + + def __init__(self, calculation_type, acc_type, op): + self.calculation_type = calculation_type + self.acc_type = acc_type + self.op = op + self.is_expected = True + super().__init__() + + def visit_call(self, call): + visit = super().visit(call.args[0]) + if call.op == relay.op.get(self.op): + if self.is_expected: + self.is_expected = ( + call.checked_type.dtype == self.acc_type + or call.args[0].checked_type.dtype == self.calculation_type + ) + return call + + def check(self, func): + self.visit(func) + return self.is_expected + + mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float16", + }, + ) + ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"]) + assert ret + ret = CheckOpMutator("float16", "float16", "nn.dense").check(mod["main"]) + assert ret + + mod = apply_graph_transforms( + mod, + { + "mixed_precision": True, + "mixed_precision_ops": ["nn.conv2d", "nn.dense"], + "mixed_precision_calculation_type": "float16", + "mixed_precision_acc_type": "float32", + }, + ) + ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"]) + assert ret + ret = CheckOpMutator("float16", "float32", "nn.dense").check(mod["main"]) + assert ret + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 70fbf6aee5548..d21323d7ba70b 100755 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -1062,8 +1062,7 @@ def __init__(self, api_name): @paddle.jit.to_static def forward(self, x, y): - out = paddle.to_tensor([True, True, True]) - z = self.func(x, y, out=out) + z = self.func(x, y) return paddle.cast(z, "int32") x_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] @@ -1268,7 +1267,7 @@ def __init__(self, perm): @paddle.jit.to_static def forward(self, inputs): - inputs = inputs + inputs.size() + inputs = inputs * 2 return paddle.transpose(inputs, perm=self.perm) input_data = paddle.rand([1, 3, 5, 4, 3], dtype="float32") @@ -1784,5 +1783,191 @@ def where_index_1(inputs): verify_model(where_index_1, input_data=input_data, use_vm=True) +@tvm.testing.uses_gpu +def test_forward_stack(): + class Stack1(nn.Layer): + @paddle.jit.to_static + def forward(self, input0, input1, input2): + return paddle.stack([input0, input1, input2], axis=-1) + + class Stack2(nn.Layer): + @paddle.jit.to_static + def forward(self, input0, input1, input2): + return paddle.stack([input0, input1, input2], axis=1) + + class Stack3(nn.Layer): + @paddle.jit.to_static + def forward(self, input0, input1, input2): + return paddle.stack([input0, input1, input2], axis=2) + + input_shapes = [[2, 3], [5, 10, 11], [3, 4, 5, 6]] + for input_shape in input_shapes: + input_data_0 = paddle.randn(shape=input_shape, dtype="float32") + input_data_1 = paddle.randn(shape=input_shape, dtype="float32") + input_data_2 = paddle.randn(shape=input_shape, dtype="float32") + verify_model(Stack1(), [input_data_0, input_data_1, input_data_2]) + verify_model(Stack2(), [input_data_0, input_data_1, input_data_2]) + verify_model(Stack3(), [input_data_0, input_data_1, input_data_2]) + + +@tvm.testing.uses_gpu +def test_forward_unstack(): + class UnStack1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.unstack(inputs, axis=-1) + + class UnStack2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.unstack(inputs, axis=1) + + class UnStack3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.unstack(inputs, axis=0) + + input_shapes = [[2, 3], [5, 10, 11], [3, 4, 5, 6], [1, 3, 4, 1, 1]] + for input_shape in input_shapes: + input_data = paddle.randn(shape=input_shape, dtype="float32") + verify_model(UnStack1(), input_data) + verify_model(UnStack2(), input_data) + verify_model(UnStack3(), input_data) + + +@tvm.testing.uses_gpu +def test_forward_silu(): + class Silu(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.silu(inputs) + + input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]] + for input_shape in input_shapes: + input_data = paddle.randn(shape=input_shape, dtype="float32") + verify_model(Silu(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_softshrink(): + @paddle.jit.to_static + def Softshrink1(input): + return nn.functional.softshrink(input, threshold=0.0) + + @paddle.jit.to_static + def Softshrink2(input): + return nn.functional.softshrink(input, threshold=0.5) + + @paddle.jit.to_static + def Softshrink3(input): + return nn.functional.softshrink(input, threshold=1.0) + + x = paddle.to_tensor([-0.9, -0.2, 0.1, 0.8]) + verify_model(Softshrink2, x) + + input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]] + for input_shape in input_shapes: + input_data = paddle.randn(shape=input_shape, dtype="float32") + verify_model(Softshrink1, input_data=input_data) + verify_model(Softshrink2, input_data=input_data) + verify_model(Softshrink3, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_where(): + @paddle.jit.to_static + def where1(x, y): + return paddle.where(x > 1, x, y) + + @paddle.jit.to_static + def where2(x, y): + return paddle.where(x > y, x, y) + + x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2]) + y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0]) + verify_model(where1, [x, y]) + + input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]] + for input_shape in input_shapes: + x = paddle.randn(shape=input_shape, dtype="float32") + y = paddle.randn(shape=input_shape, dtype="float32") + verify_model(where1, [x, y]) + verify_model(where2, [x, y]) + + +@tvm.testing.uses_gpu +def test_forward_tile(): + class Tile1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.tile(inputs, repeat_times=[10]) + + class Tile2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.tile(inputs, repeat_times=[2, 3]) + + class Tile3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.tile(inputs, repeat_times=[1, 2, 3]) + + class Tile4(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.tile(inputs, repeat_times=[2, 3, 4, 1, 5]) + + class Tile5(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + reps = paddle.to_tensor([3, 2]) + reps = paddle.cast(reps, "int32") + return paddle.tile(inputs, repeat_times=reps) + + class Tile6(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + rep_0 = paddle.to_tensor([3]) + rep_1 = paddle.to_tensor([2]) + rep_0 = paddle.cast(rep_0, "int32") + rep_1 = paddle.cast(rep_1, "int32") + return paddle.tile(inputs, repeat_times=[rep_0, rep_1]) + + input_shapes = [ + [10], + [2, 3], + [3, 4, 5], + [5, 3, 1, 4], + [1, 3, 1, 6, 7], + ] + for input_shape in input_shapes: + input_data = paddle.randn(shape=input_shape, dtype="float32") + verify_model(Tile1(), input_data=input_data) + verify_model(Tile2(), input_data=input_data) + verify_model(Tile3(), input_data=input_data) + verify_model(Tile4(), input_data=input_data) + verify_model(Tile5(), input_data=input_data) + verify_model(Tile6(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_mish(): + class Mish(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.mish(inputs) + + input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]] + if paddle.version.full_version >= "2.4.2": + for input_shape in input_shapes: + input_data = paddle.randn(shape=input_shape, dtype="float32") + verify_model(Mish(), input_data=input_data) + input_data += 20.0 + verify_model(Mish(), input_data=input_data) + + input_data = paddle.to_tensor([-5.0, 0.0, 5.0, 23.1, 20.0]) + verify_model(Mish(), input_data=input_data) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py b/tests/python/relay/aot/test_pass_aot_lower_main.py index bc58812cd67ca..9667d20937571 100644 --- a/tests/python/relay/aot/test_pass_aot_lower_main.py +++ b/tests/python/relay/aot/test_pass_aot_lower_main.py @@ -187,7 +187,7 @@ def func(a: T.handle, output: T.handle) -> None: tmp_write: T.handle("uint8") = output_buffer.data tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write) for i in T.serial(140): - tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i]) + tmp_write_1[i] = T.Let(tmp_read_1[i], where={tmp_read : a_buffer.data}) # fmt: on _assert_lowered_main(mod, func, CallType.CPacked) diff --git a/tests/python/unittest/test_inject_ptx_ldg32.py b/tests/python/unittest/test_inject_ptx_ldg32.py index 81c6e89ad9218..8e8547c572d0d 100644 --- a/tests/python/unittest/test_inject_ptx_ldg32.py +++ b/tests/python/unittest/test_inject_ptx_ldg32.py @@ -32,7 +32,7 @@ def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> No with T.block(): T.reads(A[0:16]) T.writes(A_local[0:32]) - A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx / 2], T.float32(0), dtype="float32") + A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx // 2], T.float32(0), dtype="float32") B[tx] = A_local[tx] + 1.0 diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py index 88947962d69d3..c62ac788d74b1 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -70,7 +70,7 @@ def main(placeholder: T.Buffer((1, 16, 7, 7, 32), "float32"), placeholder_1: T.B ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512) T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) - T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32") + T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32") # fmt: on diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 9b869b4436c05..97ee53f4e4092 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -83,39 +83,39 @@ def test_matmul_relu(shared_scope): @T.prim_func def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4096): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(4096): + for ax0_ax1_fused in range(4096): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(4): + for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): @@ -127,8 +127,8 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): @@ -141,44 +141,54 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(1024): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ @@ -223,44 +233,42 @@ def test_matmul_relu_with_fallback(): # fmt: off @T.prim_func def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): - for ax0_ax1_fused in T.serial(2048): + for ax2_0_0 in range(2): + for ax0_ax1_fused in range(2048): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64) v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":4}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(8192): + for ax0_ax1_fused in range(8192): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(1): + for ax2_0_1 in range(1): for ax0_0, ax1_0 in T.grid(2, 4): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -271,9 +279,9 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -285,44 +293,54 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 4 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 4): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(4096): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 128) - v1 = T.axis.spatial(128, ax0_ax1_fused % 128) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(2048): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ ("SamplePerfectTile", [2, 2, 1, 1, 2]), @@ -373,46 +391,46 @@ def test_conv2d(shared_scope): @T.prim_func def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16") - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope=shared_scope) - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope=shared_scope) - weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope=shared_scope) - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b") + PadInput = T.alloc_buffer((1, 18, 18, 32), "float16") + conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer((256, 288), "float16", scope=shared_scope) + weight_reindex_shared = T.alloc_buffer((288, 32), "float16", scope=shared_scope) + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 288), "float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((288, 32), "float16", scope="wmma.matrix_b") for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float16(0)) for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4608): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4608): with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) v1 = T.axis.spatial(288, ax0_ax1_fused % 288) T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) PadInput_reindex_shared[v0, v1] = PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] - for ax0_ax1_fused in T.serial(4608): + for ax0_ax1_fused in range(4608): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(288, ax0_ax1_fused // 16) v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) T.writes(weight_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] - for ax2_0_1 in T.serial(18): + for ax2_0_1 in range(18): for ax0_0, ax1_0 in T.grid(1, 1): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): @@ -424,8 +442,8 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, with T.block("weight_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) - T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(weight_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): @@ -438,44 +456,49 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, v0_o = T.axis.spatial(16, ax0_0_4 + ax0_0_1_ax1_0_1_fused + ax0_0_3) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(18, ax2_0_0 * 18 + ax2_0_1 + ax2_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) - v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(256): - with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 16) - v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) - T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(1): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3]) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(256): + with T.block("conv2d_nhwc_reindex_shared"): + v0, v1, v2 = T.axis.remap("SSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2]) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 16, 1, 1, 1]), @@ -551,40 +574,40 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - C = T.alloc_buffer([128, 128], dtype="float32") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C = T.alloc_buffer((128, 128)) + C_reindex_shared = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}): - for ax0_ax1_fused in T.serial(1024): + for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): + for ax0_ax1_fused in range(1024): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":4, "tir.manifest_shared_memory_local_stage":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tir.manifest_shared_memory_local_stage": 1}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(1024): + for ax0_ax1_fused in range(1024): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":2, "tir.manifest_shared_memory_local_stage":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tir.manifest_shared_memory_local_stage": 1}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): + for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0, ax1_0 in T.grid(2, 1): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): @@ -596,8 +619,8 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): @@ -610,50 +633,61 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.grid(1024): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(C[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - C[v0, v1] = C_reindex_shared[v0, v1] + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32] = C_reindex_shared[v0, v1, v2, v3, v4, v5] for i0, i1 in T.grid(128, 128): with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(C[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(C[v_i0, v_i1], T.float32(0)) + # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 4, 1, 1, 2]), @@ -693,141 +727,6 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, ) -def test_matmul_relu_global(): - # fmt: off - @T.prim_func - def matmul_relu_global_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - C = T.alloc_buffer([128, 128], dtype="float32") - C_reindex_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") - for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): - for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): - for ax0_ax1_fused in T.serial(8192): - with T.block("A_reindex_shared"): - v0 = T.axis.spatial(128, ax0_ax1_fused // 64) - v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) - T.reads(A[v0, v1]) - T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(8192): - with T.block("B_reindex_shared"): - v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) - v1 = T.axis.spatial(128, ax0_ax1_fused % 128) - T.reads(B[v0, v1]) - T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(2): - for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("A_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) - v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_shared_wmma.matrix_a"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_0, ax1_0 in T.grid(2, 4): - with T.block("B_reindex_shared_wmma.matrix_b_o"): - v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 4, 2, 1, 1): - with T.block("C_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4) - v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0_3) - v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) - with T.init(): - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): - v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads() - T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) - for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): - v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 4): - with T.block("C_reindex_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_global"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for i0, i1 in T.grid(128, 128): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(C[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) - # fmt: on - decision_0 = [ - ("SamplePerfectTile", [1, 1, 8, 1, 1]), - ("SamplePerfectTile", [1, 1, 2, 4, 1]), - ("SamplePerfectTile", [2, 2, 2]), - ("SampleCategorical", 0), - ("SampleCategorical", 0), - ] - mod = te.create_prim_func( - te_workload.matmul_relu( - n=128, - m=128, - k=128, - in_dtype="float16", - out_dtype="float32", - ) - ) - actual = generate_design_space( - kind="cuda", - mod=mod, - target=tvm.target.Target("cuda"), - types=None, - sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] - + get_rules("cuda", ms.schedule_rule.AutoInline), - ) - check_sketches( - mod, - sketches=actual, - expected_mods=[matmul_relu_global_0], - expected_decisions=[decision_0], - ) - - def test_matmul_relu_non_tensorizable(): # expected to do nothing on non-tensorizable workloads mod = te.create_prim_func( @@ -842,7 +741,7 @@ def test_matmul_relu_non_tensorizable(): mod=mod, target=tvm.target.Target("cuda"), types=None, - sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), ) tvm.ir.assert_structural_equal(mod, sch.mod["main"]) @@ -856,40 +755,40 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4096): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) - A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0), dtype="float16") - for ax0_ax1_fused in T.serial(4096): + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0)) + for ax0_ax1_fused in range(4096): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0), dtype="float16") - for ax2_0_1 in T.serial(4): + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0)) + for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -900,9 +799,9 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -914,45 +813,56 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(1024): - with T.block("C_reindex_shared"): - T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32 < 127) - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, T.Add(ax0_0_0_ax1_0_0_fused // 2, 0)) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + # fmt: on decision_0 = [ @@ -994,25 +904,25 @@ def test_conv_1x1(): @T.prim_func def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 64], dtype="float32", scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 64], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 64], dtype="float16", scope="shared") - weight_reindex_shared = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="shared") - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 64], dtype="float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="wmma.matrix_b") + conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") + weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared") + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b") for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"): for ax2_0_1_ax3_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1): - for ax0_ax1_fused in T.serial(1024): + for ax0_ax1_fused in range(1024): with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial(64, ax0_ax1_fused % 64) T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, v1]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) PadInput_reindex_shared[v0, v1] = inputs[v0 // 256, v0 // 16, v0 % 16, v1] - for ax0_ax1_ax2_ax3_fused in T.serial(2048): + for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) @@ -1020,16 +930,16 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 8]], "meta_schedule.cooperative_fetch":4}) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): for ax0_0_1, ax1_0_1 in T.grid(1, 4): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1) v1_o = T.axis.spatial(4, ax1_0_1) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) @@ -1040,9 +950,9 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( with T.block("weight_reindex_shared_wmma.matrix_b_o"): v0, v1, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0]) v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0) - T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(weight_reindex_shared[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) @@ -1056,44 +966,53 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 : v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init] = T.float32(0) for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0) - v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(512): - with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_fused % 32) - T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + for ax2 in range(1): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_fused) + v2, v3 = T.axis.remap("SS", [ax2_1, ax3]) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(1, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index bc674064d1d66..ef662ed5b1e78 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -315,7 +315,7 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 4 + i6_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16) - v2 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0 + 0) + v2 = T.axis.spatial(18, T.Add(i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0, 0)) v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8) v4 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4) v5 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4) @@ -493,9 +493,9 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, for ax0_ax1_ax2_ax3_fused in T.serial(217): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2 + 0) + v1 = T.axis.spatial(230, T.Add(i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2, 0)) v2 = T.axis.spatial(230, i5_0 * 2 + ax0_ax1_ax2_ax3_fused % 217) - v3 = T.axis.spatial(3, i6_0 + 0) + v3 = T.axis.spatial(3, T.Add(i6_0, 0)) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) T.block_attr({"meta_schedule.cooperative_fetch":2}) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index d8a853ff5dbf2..bef682435ebf0 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -525,7 +525,7 @@ def _show_info(): ]: for dstart, dend in [ (-11, -1), - (-11, 0), + (-11, 1), (-4, -4), (-2, -2), (1, 11), @@ -534,7 +534,7 @@ def _show_info(): (2, 2), (-11, 11), ]: - if end < start or dend < dstart or (dend == 0 and dstart == 0): + if end < start or dend < dstart or (dend == 0 and dstart == 0) or dend == 0: continue check(start, end, dstart, dend, "int32", floor_div=False) check(start, end, dstart, dend, "int32", floor_div=True) diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index be91505f3d154..cf75768ec0e3a 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -59,6 +59,58 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 +@T.prim_func +def elementwise_reindex_cache_read( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") +): + B = T.alloc_buffer((128, 128)) + B_shared = T.alloc_buffer((128, 64, 2), scope="shared") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(B_shared[vj, vi // 2, vi % 2]) + B_shared[vj, vi // 2, vi % 2] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B_shared[vj, vi // 2, vi % 2]) + T.writes(C[vi, vj]) + C[vi, vj] = B_shared[vj, vi // 2, vi % 2] + T.float32(1) + + +@T.prim_func +def elementwise_reindex_cache_write( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") +): + B = T.alloc_buffer((128, 128)) + B_shared = T.alloc_buffer((128, 128), scope="shared") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B_shared[vj, vi]) + B_shared[vj, vi] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B_shared[vj, vi]) + T.writes(B[vi, vj]) + B[vi, vj] = B_shared[vj, vi] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + @T.prim_func def func_nested_seq(b: T.handle, c: T.handle) -> None: A = T.alloc_buffer((128, 128)) @@ -216,6 +268,39 @@ def func_multi_consumer() -> None: C[vi] = A[vi] +@T.prim_func +def reindex_cache_read_multi_consumer() -> None: + A = T.alloc_buffer((128,)) + B = T.alloc_buffer((128,)) + C = T.alloc_buffer((128,)) + A_shared = T.alloc_buffer((4, 32), scope="shared") + for i in range(8): + for j in range(16): + with T.block("A"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads() + T.writes(A[vi]) + A[vi] = T.float32(1) + for j in range(16): + with T.block("A_shared"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads(A[vi]) + T.writes(A_shared[vi // 32, vi % 32]) + A_shared[vi // 32, vi % 32] = A[vi] + for j in range(16): + with T.block("B"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads(A_shared[vi // 32, vi % 32]) + T.writes(B[vi]) + B[vi] = A_shared[vi // 32, vi % 32] + T.float32(1) + for i in range(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + T.reads(A[vi]) + T.writes(C[vi]) + C[vi] = A[vi] + + @T.prim_func def func_multi_producer() -> None: A = T.alloc_buffer((128)) @@ -1336,5 +1421,60 @@ def test_cache_write_allocate_const(): verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const) +def test_reindex_cache_read(): + sch = tir.Schedule(elementwise, debug_mask="all") + sch.reindex_cache_read("C", 0, "shared", lambda i, j: (j, i // 2, i % 2)) + tvm.ir.assert_structural_equal(elementwise_reindex_cache_read, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reindex_cache_read_multi_consumer(): + sch = tir.Schedule(func_multi_consumer) + sch.reindex_cache_read("B", 0, "shared", lambda i: (i // 32, i % 32)) + tvm.ir.assert_structural_equal(reindex_cache_read_multi_consumer, sch.mod["main"]) + # NOTE(zihao): we do not verify trace roundtrip because of in set analysis issues. + + +def test_reindex_cache_read_fail_not_match(): + sch = tir.Schedule(elementwise, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_read( + "C", + 0, + "shared", + lambda i, j: j * 2, + ) + + +def test_reindex_cache_read_faile_not_single_point(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_read("scope", 0, "shared", lambda i, j: (i, j)) + + +def test_reindex_cache_write(): + sch = tir.Schedule(elementwise, debug_mask="all") + sch.reindex_cache_write("B", 0, "shared", lambda i, j: (j, i)) + tvm.ir.assert_structural_equal(elementwise_reindex_cache_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reindex_cache_write_fail_not_match(): + sch = tir.Schedule(elementwise, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_write( + "B", + 0, + "shared", + lambda i, j: i, + ) + + +def test_reindex_cache_write_fail_not_single_point(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_write("scope", 0, "shared", lambda i, j: (i, j)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 1755a66ec9fb2..5ba2824e74dde 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -22,6 +22,7 @@ from tvm.ir.module import IRModule from tvm.script import tir as T + # ----------------------------------------------------- # Basic test for the expected Behavior of the CSE pass # ----------------------------------------------------- @@ -359,8 +360,7 @@ def func_distributivity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: B = T.Buffer((50,), "int32") - cse_var_1 = T.int32() - with T.let(cse_var_1, x * y + x * z): + with T.LetStmt(x * y + x * z) as cse_var_1: B[i1] = cse_var_1 B[i2] = cse_var_1 @@ -377,8 +377,7 @@ def func_associativity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: B = T.Buffer((50,), "int32") - cse_var_1 = T.int32() - with T.let(cse_var_1, (x + y) + z): + with T.LetStmt((x + y) + z) as cse_var_1: B[i1] = cse_var_1 B[i2] = cse_var_1 diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py index ca37915597a5c..a0b624a15c31d 100644 --- a/tests/python/unittest/test_tir_transform_hoist_expression.py +++ b/tests/python/unittest/test_tir_transform_hoist_expression.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir import tvm.testing - +from tvm import tir from tvm.script import tir as T -from tvm.tir.transform import HoistExpression, HoistedConditionals, HoistedLetBindings +from tvm.tir.transform import HoistedConditionals, HoistedLetBindings, HoistExpression class BaseBeforeAfter: @@ -448,7 +447,7 @@ class TestHoistLetExpr(BaseBeforeAfter): def before(A: T.Buffer((4, 4), "float32")): for i, j in T.grid(4, 4): x = T.float32() - A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) + A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")}) @T.prim_func def expected(A: T.Buffer((4, 4), "float32")): @@ -467,7 +466,7 @@ class TestSuppressHoistLetExpr(BaseBeforeAfter): def before(A: T.Buffer((4, 4), "float32")): for i, j in T.grid(4, 4): x = T.float32() - A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) + A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")}) expected = before diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 1e5fd8843ba31..b9f35ed553e1e 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1124,7 +1124,7 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0]) - T.writes(B[0, tx, 0]) + T.writes(B[T.FloorMod(0, 2), tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * T.float32(2) @@ -1350,8 +1350,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N B[i % 2, tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.where(i == 1 and i - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1) % 2, tx, 0]) + T.writes(C[(i - 1) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1366,14 +1366,14 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N with T.block(): T.where(i + 2 < 16) T.reads(A[tx, i + 2]) - T.writes(B[i % 2, tx, 0]) + T.writes(B[(i + 2) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2) with T.block(): T.where(i + 2 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1 + 2) % 2, tx, 0]) + T.writes(C[(i - 1 + 2) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1394,8 +1394,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N for i in T.unroll(2): with T.block(): T.where(i + 16 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1 + 16) % 2, tx, 0]) + T.writes(C[(i - 1 + 16) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 0 - i): diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index d32714938424e..beb20fd43ba6d 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -182,10 +182,10 @@ def before_func(): def expected_func(): B_data = T.allocate([4], "int32x4", "shared") B = T.Buffer([4], "int32x4", data=B_data, scope="shared") - B[T.Mul(0, 4) / 4] = T.broadcast(0, 4) - B[T.Mul(1, 4) / 4] = T.broadcast(1, 4) - B[T.Mul(2, 4) / 4] = T.broadcast(2, 4) - B[T.Mul(3, 4) / 4] = T.broadcast(3, 4) + B[T.Div(T.Mul(0, 4), 4)] = T.broadcast(0, 4) + B[T.Div(T.Mul(1, 4), 4)] = T.broadcast(1, 4) + B[T.Div(T.Mul(2, 4), 4)] = T.broadcast(2, 4) + B[T.Div(T.Mul(3, 4), 4)] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod) diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index a76e6135b3c45..a05a085eeb641 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -134,7 +134,49 @@ def main(): tvm.ir.assert_structural_equal(after, expected) +def test_unroll_local_access(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((64,), "float32")): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(4, thread="threadIdx.x"): + A_local_data = T.allocate([4], dtype="float32", scope="local") + A_local = T.Buffer([4], dtype="float32", data=A_local_data) + for i in T.serial(4): + A_local[i] = T.float32(i) + + @tvm.script.ir_module + class Expected: + @T.prim_func + def main(B: T.Buffer((64,), "float32")): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(4, thread="threadIdx.x"): + A_local_data = T.allocate([4], dtype="float32", scope="local") + A_local = T.Buffer([4], dtype="float32", data=A_local_data) + A_local[0] = T.float32(0) + A_local[1] = T.float32(1) + A_local[2] = T.float32(2) + A_local[3] = T.float32(3) + + with tvm.transform.PassContext( + config={ + "tir.UnrollLoop": { + "auto_max_depth": 0, + "auto_max_extent": 1, + "explicit_unroll": True, + "unroll_local_access": True, + } + } + ): + after = tvm.tir.transform.UnrollLoop()(Before) + after = tvm.tir.transform.Simplify()(after) + + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": + test_unroll_local_access() test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 58f37f04967d9..d0403fcae9387 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -164,7 +164,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.Buffer([200704], dtype="uint8") - with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): + with T.LetStmt(T.address_of(fast_memory_6_buffer_var[0], dtype="handle"), var=tensor_2_let.data): for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): for ax3_init in T.serial(0, 64): tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init] = T.uint8(0) @@ -194,12 +194,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.Buffer([157323], "int16") - with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): + with T.LetStmt(T.address_of(slow_memory_5_buffer_var[802816], dtype="handle"), var=PaddedInput_7_let.data): for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): PaddedInput_7_let[i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7] = T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, placeholder_65[i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7_let = T.Buffer([64], "int32") - with T.let(Conv2dOutput_7_let.data, T.address_of(fast_memory_4_buffer_var[0], dtype="handle")): + with T.LetStmt(T.address_of(fast_memory_4_buffer_var[0], dtype="handle"), var=Conv2dOutput_7_let.data): for ff_3 in T.serial(0, 64): Conv2dOutput_7_let[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): @@ -399,12 +399,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.Buffer([360000], 'int16') - with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle"), var=PaddedInput_3_let.data): for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): PaddedInput_3_let[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3_let = T.Buffer([64], 'int32') - with T.let(Conv2dOutput_3_let.data, T.address_of(global_workspace_5_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_5_buffer_var[7200000], dtype="handle"), var=Conv2dOutput_3_let.data): for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): Conv2dOutput_3_let[ff_3] = 0 @@ -422,12 +422,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.Buffer([360000], "int16") - with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle"), var=PaddedInput_2_let.data): for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): PaddedInput_2_let[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2_let = T.Buffer([64], 'int32') - with T.let(Conv2dOutput_2_let.data, T.address_of(global_workspace_4_buffer_var[7920000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_4_buffer_var[7920000], dtype="handle"), var=Conv2dOutput_2_let.data): for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): Conv2dOutput_2_let[ff_2] = 0 @@ -445,12 +445,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.Buffer([360000], "int16") - with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle"), var=PaddedInput_let.data): for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): PaddedInput_let[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput_let = T.Buffer([64], "int32") - with T.let(Conv2dOutput_let.data, T.address_of(global_workspace_2_buffer_var[7920000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_2_buffer_var[7920000], dtype="handle"), var=Conv2dOutput_let.data): for ff in T.serial(0, 64): Conv2dOutput_let[ff] = 0 for rc in T.serial(0, 64): @@ -467,12 +467,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.Buffer([379456], "int16") - with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_3_buffer_var[0], dtype="handle"), var=PaddedInput_1_let.data): for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1_let[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1_let = T.Buffer([64], "int32") - with T.let(Conv2dOutput_1_let.data, T.address_of(global_workspace_3_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_3_buffer_var[7200000], dtype="handle"), var=Conv2dOutput_1_let.data): for ff_1 in T.serial(0, 64): Conv2dOutput_1_let[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): @@ -562,7 +562,9 @@ def tensor_intrin_primfunc(global_workspace_1_var: T.handle("uint8")) -> None: global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 ) dense_let = T.Buffer([10], "int32") - with T.let(dense_let.data, T.address_of(global_workspace_1_buffer_var[0], dtype="handle")): + with T.LetStmt( + T.address_of(global_workspace_1_buffer_var[0], dtype="handle"), var=dense_let.data + ): T.evaluate( T.call_extern( "intrin_function", diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 889f0c9eda33b..5599d2f7c69af 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -283,7 +283,7 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_let(): with IRBuilder() as ib: - with T.let(T.int32(), tir.IntImm("int32", 2)): + with T.LetStmt(tir.IntImm("int32", 2)) as v: T.evaluate(0) # the let binding generated by IRBuilder let_actual = ib.get() diff --git a/tests/python/unittest/test_tvmscript_printer_structural_equal.py b/tests/python/unittest/test_tvmscript_printer_structural_equal.py index 1b9e0fa9beabc..5c587354cc3f8 100644 --- a/tests/python/unittest/test_tvmscript_printer_structural_equal.py +++ b/tests/python/unittest/test_tvmscript_printer_structural_equal.py @@ -21,7 +21,7 @@ from tvm.ir import assert_structural_equal from tvm.relay.op.transform import split from tvm.runtime import ObjectPath -from tvm.script import tir as T +from tvm.script import ir as I, tir as T def _error_message(exception): @@ -68,21 +68,35 @@ def func2(a: T.handle, b: T.handle): def test_evaluate(): - @T.prim_func - def func1(): - T.evaluate(0) - - @T.prim_func - def func2(): - T.evaluate(1) + @I.ir_module + class module1: + @T.prim_func + def func(): + T.evaluate(0) + + @I.ir_module + class module2: + @T.prim_func + def func(): + T.evaluate(1) with pytest.raises(ValueError) as ve: - assert_structural_equal(func1, func2) + assert_structural_equal(module1, module2) assert _error_message(ve.value) == _expected_result( - func1, - func2, - ObjectPath.root().attr("body").attr("value").attr("value"), - ObjectPath.root().attr("body").attr("value").attr("value"), + module1, + module2, + ObjectPath.root() + .attr("functions") + .map_value(module1.get_global_var("func")) + .attr("body") + .attr("value") + .attr("value"), + ObjectPath.root() + .attr("functions") + .map_value(module2.get_global_var("func")) + .attr("body") + .attr("value") + .attr("value"), ) diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 13aaacb3b7584..171d49b6191b7 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -248,14 +248,14 @@ def test_for(): def test_let_stmt(): with IRBuilder() as ib: - with T.let(T.float32(), T.float32(10)): + with T.LetStmt(T.float32(10)) as v: + ib.name("v", v) T.evaluate(0) obj = ib.get() _assert_print( obj, """ -v = T.float32() -with T.let(v, T.float32(10)): +with T.LetStmt(T.float32(10)) as v: T.evaluate(0) """, ) @@ -468,7 +468,7 @@ def test_size_var(): _assert_print( a, """ -a = T.float32() +a = T.float32(is_size_var=True) a""", ) @@ -501,13 +501,12 @@ def test_cast(): def test_binary_arith(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") + a = tir.Var("a", "int32") + b = tir.Var("b", "int32") for op, sign in [ (tir.Add, "+"), (tir.Sub, "-"), (tir.Mul, "*"), - (tir.Div, "/"), (tir.Mod, "truncmod"), (tir.FloorDiv, "//"), (tir.FloorMod, "%"), @@ -521,21 +520,60 @@ def test_binary_arith(): obj = op(a, b) if sign.isalpha(): expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() T.{}(a, b)""".format( sign ) else: expected = """ -a = T.float32() -b = T.float32() +a = T.int32() +b = T.int32() a {} b""".format( sign ) _assert_print(obj, expected) +def test_binary_arith_const(): + a = tir.IntImm("int64", 3) + b = tir.IntImm("int64", 4) + for op, name in [ + (tir.Add, "Add"), + (tir.Sub, "Sub"), + (tir.Mul, "Mul"), + (tir.Div, "Div"), + (tir.Mod, "truncmod"), + (tir.FloorDiv, "FloorDiv"), + (tir.FloorMod, "FloorMod"), + (tir.LT, "LT"), + (tir.LE, "LE"), + (tir.EQ, "EQ"), + (tir.NE, "NE"), + (tir.GT, "GT"), + (tir.GE, "GE"), + ]: + obj = op(a, b) + expected = """ +T.{}({}, {})""".format( + name, str(a), str(b) + ) + _assert_print(obj, expected) + + +def test_int_div(): + a = tir.Var("a", "int32") + b = tir.Var("b", "int32") + _assert_print( + tir.Div(a, b), + """ +a = T.int32() +b = T.int32() +T.Div(a, b) +""", + ) + + def test_logical(): a = tir.Var("a", "bool") b = tir.Var("b", "bool") @@ -602,7 +640,7 @@ def test_let_expr(): obj, """ x = T.int32() -T.let(x, 1, x + 1) +T.Let(x + 1, where={x: 1}) """, ) diff --git a/tests/python/unittest/test_tvmscript_printer_underlining.py b/tests/python/unittest/test_tvmscript_printer_underlining.py index 7230d4546a9fc..4a4d17d0d89b7 100644 --- a/tests/python/unittest/test_tvmscript_printer_underlining.py +++ b/tests/python/unittest/test_tvmscript_printer_underlining.py @@ -27,7 +27,7 @@ StmtBlockDoc, ) from tvm.script.printer.doc_printer import to_python_script -from tvm.script import tir as T +from tvm.script import ir as I, tir as T def make_path(name: str) -> ObjectPath: @@ -470,3 +470,87 @@ def main(): ^^^^^^^^^^^^^ """ ) + + +def test_underline_func(): + @T.prim_func + def func(): + T.evaluate(0) + + result = func.script( + path_to_underline=[ + ObjectPath.root(), + ] + ) + assert result == format_script( + """ + # from tvm.script import tir as T + + @T.prim_func + ^^^^^^^^^^^^ + def main(): + ^^^^^^^^^^^ + T.evaluate(0) + ^^^^^^^^^^^^^ + """ + ) + + +def test_underline_func_in_irmodule(): + @I.ir_module + class irmodule: + @T.prim_func + def func(): + T.evaluate(0) + + result = irmodule.script( + path_to_underline=[ + ObjectPath.root().attr("functions").map_value(irmodule.get_global_var("func")), + ] + ) + assert result == format_script( + """ + # from tvm.script import ir as I + # from tvm.script import tir as T + + @I.ir_module + class Module: + @T.prim_func + ^^^^^^^^^^^^ + def func(): + ^^^^^^^^^^^ + T.evaluate(0) + ^^^^^^^^^^^^^ + """ + ) + + +def test_underline_irmodule(): + @I.ir_module + class irmodule: + @T.prim_func + def func(): + T.evaluate(0) + + result = irmodule.script( + path_to_underline=[ + ObjectPath.root(), + ] + ) + assert result == format_script( + """ + # from tvm.script import ir as I + # from tvm.script import tir as T + + @I.ir_module + ^^^^^^^^^^^^ + class Module: + ^^^^^^^^^^^^^ + @T.prim_func + ^^^^^^^^^^^^ + def func(): + ^^^^^^^^^^^ + T.evaluate(0) + ^^^^^^^^^^^^^ + """ + ) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 48a59994690b8..c956f3bb02b9e 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -171,6 +171,16 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: return Module +def launch_env_thread(): + @T.prim_func + def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None: + bx = T.launch_thread("blockIdx.x", 64) + for i, j in T.grid(2, 4): + T.evaluate(inputs[bx, i, j]) + + return main + + def opt_gemm_mod_host(): @tvm.script.ir_module class Module: @@ -335,9 +345,9 @@ def mmult( T.attr(0, "compute_scope", "mmult_compute_") T.attr(packedB.data, "storage_scope", "global") T.attr(packedB.data, "storage_alignment", 128) - with T.let( - packedB.data, + with T.LetStmt( T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), + var=packedB.data, ): if T.isnullptr(packedB.data, dtype="bool"): T.evaluate(T.tvm_throw_last_error(dtype="int32")) @@ -349,11 +359,11 @@ def mmult( for x_outer in T.parallel(0, 32): T.attr(C_global.data, "storage_scope", "global") T.attr(C_global.data, "storage_alignment", 128) - with T.let( - C_global.data, + with T.LetStmt( T.TVMBackendAllocWorkspace( 1, dev_id, T.uint64(4096), 2, 32, dtype="handle" ), + var=C_global.data, ): if T.isnullptr(C_global.data, dtype="bool"): T.evaluate(T.tvm_throw_last_error(dtype="int32")) @@ -3317,7 +3327,7 @@ def let_expression(): @T.prim_func def func(): x = T.int32() - T.evaluate(T.let(x, 1, x + 1)) + T.evaluate(T.Let(x + 1, where={x: 1})) return func @@ -3542,10 +3552,8 @@ def func(): def let_stmt_var(): @T.prim_func def func(): - x = T.int32() - y = T.int32() - with T.let(x, 0): - with T.let(y, 0): + with T.LetStmt(0) as x: + with T.LetStmt(0) as y: T.evaluate(0) T.evaluate(0) @@ -3555,17 +3563,68 @@ def func(): def let_stmt_value(): @T.prim_func def func(): - x = T.int32() y = T.int32() - with T.let(x, y): - with T.let(y, 0): + with T.LetStmt(y) as x: + with T.LetStmt(0, var=y): T.evaluate(0) T.evaluate(0) return func +def string_stride(): + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + n = T.int32() + A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto") + B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto") + blockIdx_x = T.launch_thread("blockIdx.x", (n + 63) // 64) + threadIdx_x = T.launch_thread("threadIdx.x", 64) + if T.likely(blockIdx_x * 64 + threadIdx_x < n): + B2 = T.Buffer((B.strides[0] * n,), data=B.data) + A2 = T.Buffer((A.strides[0] * n,), data=A.data) + B2[(blockIdx_x * 64 + threadIdx_x) * B.strides[0]] = A2[ + (blockIdx_x * 64 + threadIdx_x) * A.strides[0] + ] * T.float32(2) + + return main + + +def merge_shape_var_def(): + @T.prim_func + def main(A: T.handle, B: T.handle): + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + m, n = T.int32(), T.int32() + A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), buffer_type="auto") + B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"), buffer_type="auto") + for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 10): + if T.likely(i_outer * 10 + i_inner < m): + for j_inner in range(5): + if T.likely(j_outer * 5 + j_inner < n): + cse_var_2: T.int32 = j_outer * 5 + j_inner + cse_var_1: T.int32 = i_outer * 10 + i_inner + B_2 = T.Buffer( + (B_1.strides[0] * m,), + data=B_1.data, + strides=("B_2_s0",), + buffer_type="auto", + ) + A_2 = T.Buffer( + (A_1.strides[0] * m,), + data=A_1.data, + strides=("A_2_s0",), + buffer_type="auto", + ) + B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[ + cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1] + ] + + return main + + ir_generator = tvm.testing.parameter( + launch_env_thread, opt_gemm_normalize, opt_gemm_lower, opt_gemm_mod_host, @@ -3625,6 +3684,8 @@ def func(): intrinsic_pow, let_stmt_var, let_stmt_value, + string_stride, + merge_shape_var_def, ) diff --git a/tests/scripts/task_config_build_arm.sh b/tests/scripts/task_config_build_arm.sh index 516e6ac867919..bd14a91f251be 100755 --- a/tests/scripts/task_config_build_arm.sh +++ b/tests/scripts/task_config_build_arm.sh @@ -28,7 +28,7 @@ echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake echo set\(USE_PROFILER ON\) >> config.cmake -echo set\(USE_LLVM llvm-config-8\) >> config.cmake +echo -e 'find_program(LLVM_CONFIG "llvm-config")\nif (LLVM_CONFIG) \n\tset(USE_LLVM llvm-config) \nelse() \n\tset(USE_LLVM llvm-config-8)\nendif()' >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_minimal_cross_isa.sh b/tests/scripts/task_config_build_minimal_cross_isa.sh index 1c08cb285d211..1b251632182ae 100755 --- a/tests/scripts/task_config_build_minimal_cross_isa.sh +++ b/tests/scripts/task_config_build_minimal_cross_isa.sh @@ -46,5 +46,5 @@ if [ "$architecture_type" != "aarch64" ]; then echo set\(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY\) >> config.cmake else # This usually runs in the ci_arm docker image. - echo set\(USE_LLVM llvm-config-8\) >> config.cmake + echo -e 'find_program(LLVM_CONFIG "llvm-config")\nif (LLVM_CONFIG) \n\tset(USE_LLVM llvm-config) \nelse() \n\tset(USE_LLVM llvm-config-8)\nendif()' >> config.cmake fi