From 7ef8fb21d7472801756b23977186b46e1ec81e3d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 18 Jan 2020 22:44:50 -0800 Subject: [PATCH] [REFACTOR] Establish tir (#4740) TIR is the new namespace for low-level IR for tensor-level optimizations and loop transformations. This PR establishes the namespace and files. - lowered_func.h,buffer.h,data_layout.h -> tir/buffer.h,tir/data_layout.h,tir/lowered_func.h - ir.h -> tir/expr.h, tir/stmt.h - ir_functor_ext.h -> tir/expr_functor.h, tir/stmt_functor.h --- CMakeLists.txt | 5 +- apps/extension/src/tvm_ext.cc | 3 +- apps/lldb/tvm.py | 4 +- include/tvm/arith/analyzer.h | 2 + include/tvm/arith/bound.h | 8 +- include/tvm/arith/int_set.h | 8 +- include/tvm/arith/pattern.h | 6 +- include/tvm/build_module.h | 32 +- include/tvm/codegen.h | 9 +- include/tvm/expr.h | 489 -------- include/tvm/ir/expr.h | 228 ++-- include/tvm/node/printer.h | 7 + include/tvm/relay/base.h | 2 +- include/tvm/relay/op_attr_types.h | 5 + include/tvm/relay/type.h | 4 +- include/tvm/{ => tir}/buffer.h | 18 +- include/tvm/{ => tir}/data_layout.h | 25 +- include/tvm/{ir.h => tir/expr.h} | 1058 +++++------------ include/tvm/tir/expr_functor.h | 295 +++++ include/tvm/{ => tir}/ir_pass.h | 21 +- include/tvm/{ => tir}/lowered_func.h | 25 +- include/tvm/{expr_operator.h => tir/op.h} | 307 ++--- include/tvm/tir/stmt.h | 775 ++++++++++++ .../{ir_functor_ext.h => tir/stmt_functor.h} | 276 +---- include/tvm/top/operation.h | 33 +- include/tvm/top/schedule.h | 2 +- include/tvm/top/tensor.h | 21 +- include/tvm/top/tensor_intrin.h | 2 +- python/tvm/ir_pass.py | 2 +- src/README.md | 6 +- src/api/api_arith.cc | 4 +- src/api/api_base.cc | 2 +- src/api/api_codegen.cc | 8 +- src/api/api_ir.cc | 10 +- src/api/api_lang.cc | 21 +- src/api/api_pass.cc | 15 +- src/api/api_schedule.cc | 2 +- src/api/api_test.cc | 2 +- src/arith/analyzer.cc | 12 +- src/arith/bound_deducer.cc | 8 +- src/arith/canonical_simplify.cc | 6 +- src/arith/compute_expr.h | 18 +- src/arith/const_fold.h | 48 +- src/arith/const_int_bound.cc | 6 +- src/arith/detect_linear_equation.cc | 15 +- src/arith/domain_touched.cc | 8 +- src/arith/int_operator.h | 8 +- src/arith/int_set.cc | 55 +- src/arith/interval_set.h | 2 +- src/arith/ir_mutator_with_analyzer.cc | 12 +- src/arith/ir_mutator_with_analyzer.h | 22 +- src/arith/ir_visitor_with_analyzer.h | 8 +- src/arith/modular_set.cc | 6 +- src/arith/pattern_match.h | 88 +- src/arith/rewrite_simplify.cc | 6 +- src/arith/rewrite_simplify.h | 4 +- src/arith/stmt_simplify.cc | 14 +- src/autotvm/feature_visitor.cc | 2 +- src/autotvm/feature_visitor.h | 11 +- src/autotvm/touch_extractor.h | 4 +- src/codegen/build_common.h | 3 +- src/codegen/build_module.cc | 113 +- src/codegen/codegen.cc | 12 +- src/codegen/codegen_c.cc | 10 +- src/codegen/codegen_c.h | 9 +- src/codegen/codegen_c_host.h | 2 +- src/codegen/codegen_cuda.cc | 5 +- src/codegen/codegen_cuda.h | 4 +- src/codegen/codegen_source_base.cc | 4 +- src/codegen/codegen_source_base.h | 9 +- src/codegen/codegen_vhls.h | 2 +- src/codegen/intrin_rule.cc | 2 +- src/codegen/intrin_rule.h | 6 +- src/codegen/llvm/codegen_amdgpu.cc | 1 - src/codegen/llvm/codegen_arm.cc | 12 +- src/codegen/llvm/codegen_cpu.cc | 17 +- src/codegen/llvm/codegen_llvm.cc | 9 +- src/codegen/llvm/codegen_llvm.h | 10 +- src/codegen/llvm/codegen_nvptx.cc | 1 - src/codegen/llvm/codegen_x86_64.cc | 14 +- src/codegen/llvm/intrin_rule_llvm.cc | 15 +- src/codegen/llvm/intrin_rule_llvm.h | 14 +- src/codegen/llvm/intrin_rule_nvptx.cc | 6 +- src/codegen/llvm/intrin_rule_rocm.cc | 6 +- src/codegen/spirv/build_vulkan.cc | 2 +- src/codegen/spirv/codegen_spirv.cc | 12 +- src/codegen/spirv/codegen_spirv.h | 8 +- src/codegen/spirv/intrin_rule_spirv.cc | 8 +- src/codegen/spirv/ir_builder.h | 2 +- src/codegen/stackvm/codegen_stackvm.cc | 3 +- src/codegen/stackvm/codegen_stackvm.h | 8 +- src/contrib/hybrid/codegen_hybrid.cc | 6 +- src/contrib/hybrid/codegen_hybrid.h | 8 +- src/ir/attr_functor.h | 154 +-- src/ir/attrs.cc | 2 +- src/ir/env_func.cc | 2 +- src/ir/expr.cc | 112 +- src/ir/transform.cc | 8 +- src/lang/expr.cc | 133 --- src/node/printer.cc | 4 + src/relay/backend/build_module.cc | 6 +- src/relay/backend/compile_engine.cc | 26 +- src/relay/backend/compile_engine.h | 4 +- .../backend/contrib/codegen_c/codegen_c.h | 2 +- src/relay/backend/graph_plan_memory.cc | 4 +- src/relay/backend/graph_runtime_codegen.cc | 14 +- src/relay/backend/interpreter.cc | 2 +- src/relay/backend/param_dict.h | 2 +- src/relay/backend/utils.h | 2 +- src/relay/backend/vm/compiler.cc | 2 + src/relay/backend/vm/compiler.h | 2 +- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/backend/vm/removed_unused_funcs.cc | 2 +- src/relay/ir/alpha_equal.cc | 4 +- src/relay/ir/expr.cc | 4 +- src/relay/ir/hash.cc | 6 +- src/relay/ir/pretty_printer.cc | 8 +- src/relay/ir/transform.cc | 2 +- src/relay/ir/type.cc | 4 +- src/relay/op/annotation/annotation.cc | 2 +- src/relay/op/debug.cc | 2 +- src/relay/op/device_copy.cc | 2 +- src/relay/op/image/resize.cc | 2 +- src/relay/op/nn/bitserial.cc | 2 +- src/relay/op/nn/convolution.cc | 8 +- src/relay/op/nn/convolution.h | 18 +- src/relay/op/nn/nn.cc | 6 +- src/relay/op/nn/pad.cc | 20 +- src/relay/op/nn/pooling.cc | 10 +- src/relay/op/nn/sparse.cc | 2 +- src/relay/op/nn/upsampling.cc | 12 +- src/relay/op/tensor/reduce.cc | 8 +- src/relay/op/tensor/transform.cc | 32 +- src/relay/op/type_relations.cc | 10 +- src/relay/op/vision/multibox_op.cc | 2 +- src/relay/pass/alter_op_layout.cc | 2 +- src/relay/pass/canonicalize_cast.cc | 2 +- src/relay/pass/canonicalize_ops.cc | 2 +- src/relay/pass/combine_parallel_conv2d.cc | 4 +- src/relay/pass/combine_parallel_dense.cc | 2 +- src/relay/pass/combine_parallel_op_batch.cc | 2 +- src/relay/pass/convert_layout.cc | 4 +- src/relay/pass/device_annotation.cc | 4 +- src/relay/pass/eliminate_common_subexpr.cc | 2 +- src/relay/pass/fold_constant.cc | 2 +- src/relay/pass/fold_scale_axis.cc | 6 +- src/relay/pass/fuse_ops.cc | 4 +- src/relay/pass/gradient.cc | 2 +- src/relay/pass/infer_layout_util.h | 3 +- src/relay/pass/legalize.cc | 2 +- src/relay/pass/mac_count.cc | 2 +- src/relay/pass/partition_graph.cc | 4 +- src/relay/pass/pattern_util.h | 12 +- src/relay/pass/simplify_inference.cc | 2 +- src/relay/pass/transform_layout.h | 2 +- src/relay/pass/type_solver.cc | 10 +- src/relay/qnn/op/concatenate.cc | 2 +- src/relay/qnn/op/convolution.cc | 4 +- src/relay/qnn/util.h | 10 +- src/target/target.cc | 42 +- src/{lang => tir/ir}/buffer.cc | 56 +- src/{lang => tir/ir}/data_layout.cc | 30 +- src/{lang/ir.cc => tir/ir/expr.cc} | 642 +--------- src/tir/ir/expr_functor.cc | 290 +++++ src/tir/ir/functor_common.h | 56 + src/{lang => tir/ir}/lowered_func.cc | 6 +- src/{lang/expr_operator.cc => tir/ir/op.cc} | 244 ++-- src/tir/ir/stmt.cc | 532 +++++++++ .../ir_functor.cc => tir/ir/stmt_functor.cc} | 293 +---- src/{ => tir}/pass/arg_binder.cc | 18 +- src/{ => tir}/pass/arg_binder.h | 14 +- src/{ => tir}/pass/bound_checker.cc | 16 +- src/{ => tir}/pass/combine_context_call.cc | 11 +- src/{ => tir}/pass/coproc_sync.cc | 10 +- src/{ => tir}/pass/detect_device.cc | 10 +- src/{ => tir}/pass/hoist_if_then_else.cc | 12 +- src/{ => tir}/pass/infer_fragment.cc | 12 +- src/{ => tir}/pass/inject_copy_intrin.cc | 12 +- src/{ => tir}/pass/inject_double_buffer.cc | 12 +- src/{ => tir}/pass/inject_prefetch.cc | 10 +- src/{ => tir}/pass/inject_virtual_thread.cc | 12 +- src/{ => tir}/pass/inline.cc | 11 +- src/{ => tir}/pass/ir_deep_compare.cc | 9 +- src/{ => tir}/pass/ir_util.cc | 4 +- src/{ => tir}/pass/ir_util.h | 14 +- src/{ => tir}/pass/lift_attr_scope.cc | 8 +- src/{ => tir}/pass/loop_partition.cc | 18 +- src/{ => tir}/pass/lower_custom_datatypes.cc | 10 +- src/{ => tir}/pass/lower_intrin.cc | 22 +- src/{ => tir}/pass/lower_thread_allreduce.cc | 16 +- src/{ => tir}/pass/lower_tvm_builtin.cc | 12 +- src/{ => tir}/pass/lower_warp_memory.cc | 14 +- src/{ => tir}/pass/make_api.cc | 14 +- src/{ => tir}/pass/remap_thread_axis.cc | 10 +- src/{ => tir}/pass/remove_no_op.cc | 10 +- src/{ => tir}/pass/rewrite_unsafe_select.cc | 10 +- src/{ => tir}/pass/simple_passes.cc | 10 +- src/{ => tir}/pass/skip_assert.cc | 10 +- src/{ => tir}/pass/split_host_device.cc | 12 +- src/{ => tir}/pass/ssa.cc | 10 +- src/{ => tir}/pass/storage_access.cc | 12 +- src/{ => tir}/pass/storage_access.h | 18 +- src/{ => tir}/pass/storage_flatten.cc | 28 +- src/{ => tir}/pass/storage_rewrite.cc | 18 +- src/{ => tir}/pass/storage_sync.cc | 12 +- src/{ => tir}/pass/tensor_core.cc | 26 +- src/{ => tir}/pass/unroll_loop.cc | 14 +- src/{ => tir}/pass/vectorize_loop.cc | 12 +- src/{ => tir}/pass/verify_compact_buffer.cc | 12 +- src/{ => tir}/pass/verify_gpu_code.cc | 21 +- src/{ => tir}/pass/verify_memory.cc | 10 +- src/top/operation/compute_op.cc | 50 +- src/top/operation/compute_op.h | 4 +- src/top/operation/cross_thread_reduction.cc | 6 +- src/top/operation/extern_op.cc | 6 +- src/top/operation/hybrid_op.cc | 32 +- src/top/operation/hybrid_op.h | 6 +- src/top/operation/op_util.cc | 32 +- src/top/operation/op_util.h | 12 +- src/top/operation/scan_op.cc | 16 +- src/top/operation/tensor_compute_op.cc | 22 +- src/top/operation/tensorize.cc | 28 +- src/top/schedule/auto_inline_elem_wise.cc | 6 +- src/top/schedule/bound.cc | 2 +- src/top/schedule/graph.cc | 18 +- src/top/schedule/graph.h | 2 +- src/top/schedule/message_passing.cc | 6 +- src/top/schedule/message_passing.h | 2 +- src/top/schedule/schedule_dataflow_rewrite.cc | 68 +- src/top/schedule/schedule_lang.cc | 2 +- src/top/schedule/schedule_ops.cc | 26 +- src/top/tensor.cc | 17 +- tests/cpp/attrs_test.cc | 8 +- tests/cpp/container_test.cc | 3 +- tests/cpp/expr_test.cc | 4 +- tests/cpp/ir_functor_test.cc | 32 +- tests/cpp/ir_simplify_test.cc | 24 +- tests/cpp/ir_ssa_test.cc | 14 +- tests/cpp/packed_func_test.cc | 7 +- tests/cpp/pattern_match_test.cc | 43 +- tests/cpp/simple_passes_test.cc | 6 +- topi/include/topi/broadcast.h | 6 +- topi/include/topi/cuda/dense.h | 8 +- topi/include/topi/cuda/normalization.h | 8 +- topi/include/topi/cuda/pooling.h | 12 +- topi/include/topi/cuda/reduction.h | 8 +- topi/include/topi/cuda/softmax.h | 4 +- topi/include/topi/detail/broadcast.h | 24 +- topi/include/topi/detail/constant_utils.h | 14 +- topi/include/topi/detail/extern.h | 20 +- topi/include/topi/detail/pad_utils.h | 4 +- topi/include/topi/detail/ravel_unravel.h | 2 +- topi/include/topi/elemwise.h | 17 +- topi/include/topi/image/resize.h | 10 +- topi/include/topi/nn.h | 82 +- topi/include/topi/nn/batch_matmul.h | 2 +- topi/include/topi/nn/bias_add.h | 2 +- topi/include/topi/nn/bnn.h | 6 +- topi/include/topi/nn/dense.h | 2 +- topi/include/topi/nn/dilate.h | 4 +- topi/include/topi/nn/flatten.h | 2 +- topi/include/topi/nn/local_response_norm.h | 2 +- topi/include/topi/nn/pooling.h | 107 +- topi/include/topi/nn/softmax.h | 10 +- topi/include/topi/reduction.h | 22 +- topi/include/topi/transform.h | 18 +- topi/include/topi/vision/reorg.h | 2 +- 267 files changed, 4515 insertions(+), 4324 deletions(-) delete mode 100644 include/tvm/expr.h rename include/tvm/{ => tir}/buffer.h (96%) rename include/tvm/{ => tir}/data_layout.h (96%) rename include/tvm/{ir.h => tir/expr.h} (57%) create mode 100644 include/tvm/tir/expr_functor.h rename include/tvm/{ => tir}/ir_pass.h (98%) rename include/tvm/{ => tir}/lowered_func.h (91%) rename include/tvm/{expr_operator.h => tir/op.h} (88%) create mode 100644 include/tvm/tir/stmt.h rename include/tvm/{ir_functor_ext.h => tir/stmt_functor.h} (52%) delete mode 100644 src/lang/expr.cc rename src/{lang => tir/ir}/buffer.cc (92%) rename src/{lang => tir/ir}/data_layout.cc (94%) rename src/{lang/ir.cc => tir/ir/expr.cc} (52%) create mode 100644 src/tir/ir/expr_functor.cc create mode 100644 src/tir/ir/functor_common.h rename src/{lang => tir/ir}/lowered_func.cc (94%) rename src/{lang/expr_operator.cc => tir/ir/op.cc} (71%) create mode 100644 src/tir/ir/stmt.cc rename src/{pass/ir_functor.cc => tir/ir/stmt_functor.cc} (60%) rename src/{ => tir}/pass/arg_binder.cc (96%) rename src/{ => tir}/pass/arg_binder.h (96%) rename src/{ => tir}/pass/bound_checker.cc (96%) rename src/{ => tir}/pass/combine_context_call.cc (95%) rename src/{ => tir}/pass/coproc_sync.cc (99%) rename src/{ => tir}/pass/detect_device.cc (88%) rename src/{ => tir}/pass/hoist_if_then_else.cc (98%) rename src/{ => tir}/pass/infer_fragment.cc (97%) rename src/{ => tir}/pass/inject_copy_intrin.cc (97%) rename src/{ => tir}/pass/inject_double_buffer.cc (98%) rename src/{ => tir}/pass/inject_prefetch.cc (95%) rename src/{ => tir}/pass/inject_virtual_thread.cc (98%) rename src/{ => tir}/pass/inline.cc (94%) rename src/{ => tir}/pass/ir_deep_compare.cc (99%) rename src/{ => tir}/pass/ir_util.cc (98%) rename src/{ => tir}/pass/ir_util.h (96%) rename src/{ => tir}/pass/lift_attr_scope.cc (98%) rename src/{ => tir}/pass/loop_partition.cc (98%) rename src/{ => tir}/pass/lower_custom_datatypes.cc (97%) rename src/{ => tir}/pass/lower_intrin.cc (95%) rename src/{ => tir}/pass/lower_thread_allreduce.cc (97%) rename src/{ => tir}/pass/lower_tvm_builtin.cc (98%) rename src/{ => tir}/pass/lower_warp_memory.cc (98%) rename src/{ => tir}/pass/make_api.cc (98%) rename src/{ => tir}/pass/remap_thread_axis.cc (95%) rename src/{ => tir}/pass/remove_no_op.cc (97%) rename src/{ => tir}/pass/rewrite_unsafe_select.cc (97%) rename src/{ => tir}/pass/simple_passes.cc (96%) rename src/{ => tir}/pass/skip_assert.cc (91%) rename src/{ => tir}/pass/split_host_device.cc (97%) rename src/{ => tir}/pass/ssa.cc (98%) rename src/{ => tir}/pass/storage_access.cc (98%) rename src/{ => tir}/pass/storage_access.h (94%) rename src/{ => tir}/pass/storage_flatten.cc (97%) rename src/{ => tir}/pass/storage_rewrite.cc (99%) rename src/{ => tir}/pass/storage_sync.cc (98%) rename src/{ => tir}/pass/tensor_core.cc (98%) rename src/{ => tir}/pass/unroll_loop.cc (96%) rename src/{ => tir}/pass/vectorize_loop.cc (99%) rename src/{ => tir}/pass/verify_compact_buffer.cc (90%) rename src/{ => tir}/pass/verify_gpu_code.cc (93%) rename src/{ => tir}/pass/verify_memory.cc (98%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 63075f3207d4..ac98e3cfd28a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -133,12 +133,11 @@ file(GLOB_RECURSE COMPILER_SRCS src/top/*.cc src/api/*.cc src/autotvm/*.cc - src/lang/*.cc - src/pass/*.cc + src/tir/*.cc ) file(GLOB CODEGEN_SRCS - src/codegen/*.cc + src/codegen/*.cc ) list(APPEND COMPILER_SRCS ${CODEGEN_SRCS}) diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index b439deb75593..a92d55fc4acd 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -27,9 +27,10 @@ #include #include #include -#include +#include using namespace tvm; +using namespace tvm::tir; using namespace tvm::runtime; namespace tvm_ext { diff --git a/apps/lldb/tvm.py b/apps/lldb/tvm.py index d7779b011c0a..a2607b7baa15 100644 --- a/apps/lldb/tvm.py +++ b/apps/lldb/tvm.py @@ -46,7 +46,7 @@ def __lldb_init_module(debugger, _): "tvm::IterVarAttr", "tvm::IterVarRelation", "tvm::Layout", - "tvm::LoweredFunc", + "tir::LoweredFunc", "tvm::Map", "tvm::Map", "tvm::MemoryInfo", @@ -60,7 +60,7 @@ def __lldb_init_module(debugger, _): "tvm::TensorIntrin", "tvm::TensorIntrinCall", "tvm::TypedEnvFunc", - "tvm::Var", + "tvm::tir::Var", "tvm::ir::CommReducer", "tvm::ir::FunctionRef", "tvm::relay::BaseTensorType", diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 78f12855cb93..31f2216b7e2b 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -50,6 +50,8 @@ namespace arith { // Forward declare Analyzer class Analyzer; +using tir::Var; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index e06954816148..4d77e3a4f6dc 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -26,7 +26,8 @@ #include #include #include -#include +#include +#include #include @@ -37,6 +38,11 @@ class Tensor; } namespace arith { +using tir::Var; +using tir::VarNode; +using tir::Domain; +using tir::Stmt; + /*! * \brief Deduce the bound of the target variable in a expression, * give the domain of each variables. Return undefined IntSet to diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index ca06cfd01c26..8b73f871f1d2 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -25,12 +25,16 @@ #define TVM_ARITH_INT_SET_H_ #include -#include +#include #include namespace tvm { namespace arith { +using tir::Var; +using tir::VarNode; +using tir::IterVar; + //----------------------------------------------- // Integer set data structure. // @@ -165,7 +169,7 @@ IntSet EvalSet(PrimExpr e, * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index a531d309f5c9..d3ba3e980430 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -26,7 +26,7 @@ #include #include -#include +#include namespace tvm { namespace arith { @@ -39,7 +39,7 @@ namespace arith { * \return [coeff[i]] if it is possible, empty array if it is not. */ Array DetectLinearEquation(const PrimExpr& e, - const Array& vars); + const Array& vars); /*! * \brief Detect if expression corresponds to clip bound of the vars @@ -50,7 +50,7 @@ Array DetectLinearEquation(const PrimExpr& e, * return empty if the e does not match the pattern. */ Array DetectClipBound(const PrimExpr& e, - const Array& vars); + const Array& vars); } // namespace arith } // namespace tvm diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 4e6e51744b6c..2ffb7b03f40e 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -24,9 +24,11 @@ #ifndef TVM_BUILD_MODULE_H_ #define TVM_BUILD_MODULE_H_ +#include #include #include #include +#include #include #include @@ -34,10 +36,6 @@ #include #include -#include "runtime/packed_func.h" - -#include "lowered_func.h" - namespace tvm { /*! @@ -174,11 +172,12 @@ class BuildConfig : public ::tvm::ObjectRef { * \param config The build configuration. * \return The lowered function. */ -TVM_DLL Array lower(top::Schedule sch, - const Array& args, - const std::string& name, - const std::unordered_map& binds, - const BuildConfig& config); +TVM_DLL Array lower( + top::Schedule sch, + const Array& args, + const std::string& name, + const std::unordered_map& binds, + const BuildConfig& config); /*! * \brief Split host/device function and running necessary pass before build * \param funcs The functions to be built. @@ -188,10 +187,11 @@ TVM_DLL Array lower(top::Schedule sch, * \return The Array> with 2 elements. First is host function Array, second is device function array */ -TVM_DLL Array > split_dev_host_funcs(const Array& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config); +TVM_DLL Array > split_dev_host_funcs( + const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from an array of lowered functions. @@ -201,7 +201,7 @@ TVM_DLL Array > split_dev_host_funcs(const Array * \param config The build configuration. * \return The built module. */ -TVM_DLL runtime::Module build(const Array& funcs, +TVM_DLL runtime::Module build(const Array& funcs, const Target& target, const Target& target_host, const BuildConfig& config); @@ -216,7 +216,7 @@ TVM_DLL runtime::Module build(const Array& funcs, * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map>& input, +TVM_DLL runtime::Module build(const Map>& input, const Target& target_host, const BuildConfig& config); @@ -231,7 +231,7 @@ TVM_DLL runtime::Module build(const Map>& input, * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map>& input, +TVM_DLL runtime::Module build(const Map>& input, const Target& target_host, const BuildConfig& config); diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index 025bc8d6bc28..202a9ce4dd8e 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -24,10 +24,11 @@ #ifndef TVM_CODEGEN_H_ #define TVM_CODEGEN_H_ +#include +#include +#include #include -#include "expr.h" -#include "lowered_func.h" -#include "runtime/packed_func.h" + namespace tvm { /*! \brief namespace for lowlevel IR pass and codegen */ @@ -45,7 +46,7 @@ using runtime::TVMRetValue; * * \note Calls global API function "_codegen_build_" + target */ -runtime::Module Build(const Array& funcs, +runtime::Module Build(const Array& funcs, const std::string& target); /*! * \brief Pack imported device library to a C file. diff --git a/include/tvm/expr.h b/include/tvm/expr.h deleted file mode 100644 index 3f154da8c130..000000000000 --- a/include/tvm/expr.h +++ /dev/null @@ -1,489 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/expr.h - * \brief The Expr and related elements in DataFlow construction. - */ -#ifndef TVM_EXPR_H_ -#define TVM_EXPR_H_ - -#include -#include -#include -#include -#include -#include -#include "node/node.h" -#include "node/container.h" -#include "node/functor.h" -#include "runtime/c_runtime_api.h" -#include "runtime/data_type.h" - -namespace tvm { - - -/*! \brief Base node of all statements. */ -class StmtNode : public Object { - public: - static constexpr const char* _type_key = "Stmt"; - TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); -}; - -/*! \brief Container of all statements */ -class Stmt : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode); -}; - -class Var; -/*! - * \brief A variable node in the IR. - * - * A variable is uniquely identified by its address. - * - * Each variable is only binded once in the following nodes: - * - Allocate - * - For - * - Let - * - LetStmt - */ -class VarNode : public PrimExprNode { - public: - /*! \brief constructor */ - VarNode() {} - VarNode(DataType dtype, std::string name_hint); - - /*! - * \brief The hint to the variable name. - * \note Each variable is uniquely identified by its address. - */ - std::string name_hint; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("name", &name_hint); - } - - static constexpr const char* _type_key = "Variable"; - TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); -}; - -/*! \brief a named variable in TVM */ -class Var : public PrimExpr { - public: - explicit Var(ObjectPtr n) : PrimExpr(n) {} - /*! \brief constructor - * \param name_hint variable name - * \param t data type - */ - TVM_DLL explicit Var(std::string name_hint = "v", - DataType t = DataType::Int(32)); - /*! - * \brief Make a new copy of var with same type, append suffix - * \param suffix The suffix to be appended. - * \return the new Var copy - */ - Var copy_with_suffix(const std::string& suffix) const { - return Var((*this)->name_hint + suffix, (*this)->dtype); - } - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const VarNode* operator->() const { - return get(); - } - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const VarNode* get() const { - return static_cast(data_.get()); - } - /*! \brief type indicate the container type */ - using ContainerType = VarNode; -}; - -class SizeVar; -/*! - * \brief A variable node represent a tensor index size, - * whose value must be non-negative. - */ -class SizeVarNode : public VarNode { - public: - /*! \brief constructor */ - SizeVarNode() {} - /*! \brief constructor - * \param dtype data type - * \param name_hint variable name - */ - SizeVarNode(DataType dtype, std::string name_hint); - - static constexpr const char* _type_key = "SizeVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); -}; - -/*! \brief a named variable represents a tensor index size */ -class SizeVar : public Var { - public: - explicit SizeVar(ObjectPtr n) : Var(n) {} - /*! \brief constructor - * \param name_hint variable name - * \param t data type - */ - TVM_DLL explicit SizeVar(std::string name_hint = "s", - DataType t = DataType::Int(32)); - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const SizeVarNode* operator->() const { - return get(); - } - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const SizeVarNode* get() const { - return static_cast(data_.get()); - } - /*! \brief type indicate the container type */ - using ContainerType = SizeVarNode; -}; - -/*! - * \brief Container of constant int that adds more constructors. - * - * This is used to store and automate type check - * attributes that must be constant integer. - * - * \sa IntImm - */ -class Integer : public IntImm { - public: - Integer() {} - /*! - * \brief constructor from node. - */ - explicit Integer(ObjectPtr node) : IntImm(node) {} - /*! - * \brief Construct integer from int value. - */ - Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) - /*! - * \brief Construct integer from int imm. - * \param other The other value. - */ - Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) - /*! - * \brief Assign an expression to integer. - * \param other another expression. - */ - Integer& operator=(const IntImm& other) { - data_ = ObjectRef::GetDataPtr(other); - return *this; - } - /*! - * \brief convert to int64_t - */ - operator int64_t() const { - CHECK(data_ != nullptr) - << " Trying to reference a null Integer"; - return (*this)->value; - } -}; - -/*! \brief range over one dimension */ -class RangeNode : public Object { - public: - /*! \brief beginning of the node */ - PrimExpr min; - /*! \brief the extend of range */ - PrimExpr extent; - /*! \brief constructor */ - RangeNode() {} - RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {} - - void VisitAttrs(AttrVisitor* v) { - v->Visit("min", &min); - v->Visit("extent", &extent); - } - - static constexpr const char* _type_key = "Range"; - TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); -}; - -/*! \brief Range constainer */ -class Range : public ObjectRef { - public: - /*! - * \brief constructor by begin and end - * \param begin The begin of the range. - * \param end The end of the range. - */ - TVM_DLL Range(PrimExpr begin, PrimExpr end); - /*! - * \brief construct a new range with min and extent - * The corresponding constructor is removed, - * because that is counter convention of tradition meaning - * of range(begin, end) - * - * \param min The minimum range. - * \param extent The extent of the range. - */ - static Range make_by_min_extent(PrimExpr min, PrimExpr extent); - // declare range. - TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); -}; - -/*! \brief container class of iteration variable. */ -class IterVarNode; - -using Region = Array; - -/*! - * \brief Type of iteration variable. - * Each IterVar have a specific type. - * - * The type of iter var can be overriden via - * stage.iter_var_attrs given they are compatible. - */ -enum IterVarType : int { - /*! - * \brief Data parallel iteration. - * This normally corresponds to axis of Tensor. - * Allow all IterVar manipulations. - * - * \note This does not mean the loop - * have to be executed in parallel fashion. - */ - kDataPar = 0, - /*! - * \brief The IterVar itself is a thread-index - * of a fixed thread launching group. - * Note that this is already assumed to be paralellized. - * - * Disallow: split/fuse/vectorize/parallel - */ - kThreadIndex = 1, - /*! - * \brief Communicative reduction. - * Cannot be directly parallelized. - * - * Disallow: parallel/vectorize - */ - kCommReduce = 2, - /*! - * \brief Serial loops with loop carry dependency, - * the iteration must execute in order. - * Cannot be re-ordered. - * - * Disallow: reorder/parallel/vectorize - */ - kOrdered = 3, - /*! - * \brief IterVar is opaque, - * - * May not corresponds to any generated loop - * Disallow all IterVar manipulations and compute_at - * - * \note This is usually used to implement composite op - * or external op, where the - */ - kOpaque = 4, - // The following are possible additional - // types that are provided during schedule - /*! - * \brief The execution is unrolled. - */ - kUnrolled = 5, - /*! - * \brief The loop is vectorized. - */ - kVectorized = 6, - /*! - * \brief The loop is parallelized. - */ - kParallelized = 7, - /*! - * \brief Marks boundary of tensorization intrinsic. - */ - kTensorized = 8 -}; - -/*! - * \brief Iteration Variable, - * represents an iteration over an integer interval. - */ -class IterVar : public ObjectRef { - public: - // construct a new iter var without a domain - IterVar() {} - // construct from shared ptr. - explicit IterVar(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const IterVarNode* operator->() const; - /*! - * \return the corresponding var in the IterVar. - */ - inline operator PrimExpr() const; - /*! \brief specify container node */ - using ContainerType = IterVarNode; -}; - -/*! - * \brief Create a new IterVar that represents an axis in thread. - * - * \param dom Optional, domain of the thread axis. - * \param tag The thread tag of the axis. - */ -TVM_DLL IterVar thread_axis(Range dom, std::string tag); - -/*! - * \brief Create a new IterVar for reduction operations. - * - * \param dom The domain of the reduction axis. - * \param name The name of the reduction axis. - */ -TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); - -using Domain = Array; - -/*! - * \brief Dump the node to stderr, used for debug purposes. - * \param node The input node - */ -TVM_DLL void Dump(const ObjectRef& node); - -// definition of Node. -/*! - * \brief An iteration variable representing an iteration - * over a one dimensional interval. - */ -class IterVarNode : public Object { - public: - /*! - * \brief the domain of iteration, if known, can be None - * For the intermediate schedule node, before schedule. - */ - Range dom; - /*! \brief The looping variable */ - Var var; - /*! \brief The type of the IterVar */ - IterVarType iter_type; - /*! - * \brief additional tag on the iteration variable, - * set this if this is binded already to a known thread tag. - */ - std::string thread_tag; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dom", &dom); - v->Visit("var", &var); - v->Visit("iter_type", &iter_type); - v->Visit("thread_tag", &thread_tag); - } - - TVM_DLL static IterVar make(Range dom, Var var, - IterVarType iter_type, - std::string thread_tag = ""); - - static constexpr const char* _type_key = "IterVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); -}; - -// inline implementations -inline const IterVarNode* IterVar::operator->() const { - return static_cast(data_.get()); -} - -inline IterVar::operator PrimExpr() const { - return (*this)->var; -} - -inline const char* IterVarType2String(IterVarType t) { - switch (t) { - case kDataPar: return "DataPar"; - case kThreadIndex: return "ThreadIndex"; - case kCommReduce: return "CommReduce"; - case kOrdered: return "Ordered"; - case kOpaque: return "Opaque"; - case kUnrolled: return "Unrolled"; - case kVectorized: return "Vectorized"; - case kParallelized: return "Parallelized"; - case kTensorized: return "Tensorized"; - } - return "Unknown"; -} - -/*! - * \brief Construct a new Var expression - * \param name_hint The name hint for the expression - * \param t The type of the expression - */ -TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32)); - -/* - * \brief Template function to convert Map to unordered_map - * Sometimes useful for API gluing when internal uses unordered_map - * \param dmap The container map - * \return The corresponding unordered_map. - * \tparam K the key of the Map. - * \tparam V the value of the Map. - */ -template -inline std::unordered_map as_unordered_map(const Map& dmap) { - std::unordered_map ret; - for (auto kv : dmap) { - ret[kv.first] = kv.second; - } - return ret; -} -} // namespace tvm - -namespace tvm { -namespace runtime { -// Additional implementattion overloads for PackedFunc. -inline TVMPODValue_::operator tvm::Integer() const { - if (type_code_ == kTVMNullptr) return Integer(); - if (type_code_ == kDLInt) { - CHECK_LE(value_.v_int64, std::numeric_limits::max()); - CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return Integer(static_cast(value_.v_int64)); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); - return Integer(ObjectPtr(ptr)); -} -} // namespace runtime -} // namespace tvm - -namespace std { -template <> -struct hash<::tvm::IterVar> : public ::tvm::ObjectHash { -}; -} -#endif // TVM_EXPR_H_ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index e8e459744c5c..61b3e13c1630 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -30,6 +30,7 @@ #include #include #include +#include #include namespace tvm { @@ -122,76 +123,6 @@ class PrimExpr : public BaseExpr { TVM_DLL static PrimExpr FromObject_(ObjectPtr ptr); }; -/*! - * \brief Constant integer literals in the program. - * \sa IntImm - */ -class IntImmNode : public PrimExprNode { - public: - /*! \brief the Internal value. */ - int64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); -}; - -/*! - * \brief Managed reference class to IntImmNode. - * - * \sa IntImmNode - */ -class IntImm : public PrimExpr { - public: - /*! - * \brief Constructor. - * \param dtype The data type of the value. - * \param value The internal value. - */ - TVM_DLL IntImm(DataType dtype, int64_t value); - - TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); -}; - -/*! - * \brief Constant floating point literals in the program. - * \sa FloatImm - */ -class FloatImmNode : public PrimExprNode { - public: - /*! \brief The constant value content. */ - double value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - static constexpr const char* _type_key = "FloatImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); -}; - -/*! - * \brief Managed reference class to FloatImmNode. - * - * \sa FloatImmNode - */ -class FloatImm : public PrimExpr { - public: - /*! - * \brief Constructor. - * \param dtype The data type of the value. - * \param value The internal value. - */ - TVM_DLL FloatImm(DataType dtype, double value); - - TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); -}; - /*! * \brief Base node of all non-primitive expressions. * @@ -304,6 +235,163 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; +// PrimExprs that are useful as runtime containers. +// +/*! + * \brief Constant integer literals in the program. + * \sa IntImm + */ +class IntImmNode : public PrimExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "IntImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to IntImmNode. + * + * \sa IntImmNode + */ +class IntImm : public PrimExpr { + public: + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + TVM_DLL IntImm(DataType dtype, int64_t value); + + TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); +}; + +/*! + * \brief Constant floating point literals in the program. + * \sa FloatImm + */ +class FloatImmNode : public PrimExprNode { + public: + /*! \brief The constant value content. */ + double value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "FloatImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to FloatImmNode. + * + * \sa FloatImmNode + */ +class FloatImm : public PrimExpr { + public: + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + TVM_DLL FloatImm(DataType dtype, double value); + + TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); +}; + +/*! + * \brief Container of constant int that adds more constructors. + * + * This is used to store and automate type check + * attributes that must be constant integer. + * + * \sa IntImm + */ +class Integer : public IntImm { + public: + Integer() {} + /*! + * \brief constructor from node. + */ + explicit Integer(ObjectPtr node) : IntImm(node) {} + /*! + * \brief Construct integer from int value. + */ + Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) + /*! + * \brief Construct integer from int imm. + * \param other The other value. + */ + Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) + /*! + * \brief Assign an expression to integer. + * \param other another expression. + */ + Integer& operator=(const IntImm& other) { + data_ = ObjectRef::GetDataPtr(other); + return *this; + } + /*! + * \brief convert to int64_t + */ + operator int64_t() const { + CHECK(data_ != nullptr) + << " Trying to reference a null Integer"; + return (*this)->value; + } +}; + +/*! \brief range over one dimension */ +class RangeNode : public Object { + public: + /*! \brief beginning of the node */ + PrimExpr min; + /*! \brief the extend of range */ + PrimExpr extent; + /*! \brief constructor */ + RangeNode() {} + RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("min", &min); + v->Visit("extent", &extent); + } + + static constexpr const char* _type_key = "Range"; + TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); +}; + +/*! \brief Range constainer */ +class Range : public ObjectRef { + public: + /*! + * \brief constructor by begin and end + * \param begin The begin of the range. + * \param end The end of the range. + */ + TVM_DLL Range(PrimExpr begin, PrimExpr end); + /*! + * \brief construct a new range with min and extent + * The corresponding constructor is removed, + * because that is counter convention of tradition meaning + * of range(begin, end) + * + * \param min The minimum range. + * \param extent The extent of the range. + */ + static Range make_by_min_extent(PrimExpr min, PrimExpr extent); + // declare range. + TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); +}; + // implementataions inline const Type& RelayExprNode::checked_type() const { CHECK(checked_type_.defined()) diff --git a/include/tvm/node/printer.h b/include/tvm/node/printer.h index 1e6c3e5bb8fb..a4c6a696633c 100644 --- a/include/tvm/node/printer.h +++ b/include/tvm/node/printer.h @@ -46,6 +46,13 @@ class NodePrinter { using FType = NodeFunctor; TVM_DLL static FType& vtable(); }; + +/*! + * \brief Dump the node to stderr, used for debug purposes. + * \param node The input node + */ +TVM_DLL void Dump(const ObjectRef& node); + } // namespace tvm namespace tvm { diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 45d060a88458..e00329d4d3ed 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -26,7 +26,7 @@ #include -#include +#include #include #include #include diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 385d6453fae9..51e9111d5a31 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -29,11 +29,16 @@ #include #include #include +#include #include namespace tvm { namespace relay { +using tir::Layout; +using tir::LayoutAxis; +using tir::BijectiveLayoutNode; + /*! \brief operator pattern used in graph fusion */ enum OpPatternKind { // Elementwise operation diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 0f81a1badc7c..adf1380eecb9 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include "base.h" @@ -40,7 +40,7 @@ namespace relay { // namespace update for backward compact // will be removed later. -using Any = tvm::ir::AnyNode; +using Any = tvm::tir::AnyNode; using Kind = TypeKind; using Type = tvm::Type; using TypeNode = tvm::TypeNode; diff --git a/include/tvm/buffer.h b/include/tvm/tir/buffer.h similarity index 96% rename from include/tvm/buffer.h rename to include/tvm/tir/buffer.h index db5334e0649f..c1723168d40c 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/tir/buffer.h @@ -18,20 +18,21 @@ */ /*! - * \file tvm/buffer.h + * \file tvm/tir/buffer.h * \brief Symbolic n-dimensional array, to represent a memory buffer. */ -#ifndef TVM_BUFFER_H_ -#define TVM_BUFFER_H_ +#ifndef TVM_TIR_BUFFER_H_ +#define TVM_TIR_BUFFER_H_ + +#include +#include +#include #include -#include "expr.h" -#include "expr_operator.h" -#include "tvm/node/container.h" namespace tvm { - +namespace tir { // Internal node container Buffer class BufferNode; @@ -186,5 +187,6 @@ inline const BufferNode* Buffer::operator->() const { TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), std::string name = "buffer"); +} // namespace tir } // namespace tvm -#endif // TVM_BUFFER_H_ +#endif // TVM_TIR_BUFFER_H_ diff --git a/include/tvm/data_layout.h b/include/tvm/tir/data_layout.h similarity index 96% rename from include/tvm/data_layout.h rename to include/tvm/tir/data_layout.h index 0f79f9f61e98..52870c669e9d 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -18,15 +18,16 @@ */ /*! - * \file tvm/data_layout.h + * \file tvm/tir/data_layout.h * \brief Layout expression to describe the data organization of a tensor. * And BijectiveLayout to mapping two data layouts between each other. */ -#ifndef TVM_DATA_LAYOUT_H_ -#define TVM_DATA_LAYOUT_H_ +#ifndef TVM_TIR_DATA_LAYOUT_H_ +#define TVM_TIR_DATA_LAYOUT_H_ -#include +#include +#include #include #include @@ -34,16 +35,16 @@ #include #include -#include "expr_operator.h" namespace tvm { +namespace tir { class LayoutAxis { public: static const LayoutAxis& Get(const char name); // Get the singleton LayoutAxis using itvar->var->name_hint - static const LayoutAxis& Get(const IterVar& itvar); + static const LayoutAxis& Get(const tir::IterVar& itvar); // Get the singleton LayoutAxis using name[0] (size of name must be 1). static const LayoutAxis& make(const std::string& name); @@ -102,7 +103,7 @@ class LayoutNode : public Object { * it is a variable for a primal axis, but a constant for a subordinate axis. * Empty for scalar's layout. */ - Array axes; + Array axes; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -132,7 +133,7 @@ class Layout : public ObjectRef { /*! \brief default constructor */ Layout() = default; - explicit Layout(const Array& axes); + explicit Layout(const Array& axes); /*! \brief construct from a string */ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) @@ -264,7 +265,7 @@ class Layout : public ObjectRef { */ bool Contains(const LayoutAxis& axis) const { if (!defined()) return false; - for (const IterVar var : operator->()->axes) { + for (const tir::IterVar var : operator->()->axes) { if (var->var->name_hint == axis.name()) { return true; } @@ -276,7 +277,7 @@ class Layout : public ObjectRef { CHECK(defined()) << "Try to access axis from an undefined layout."; int32_t index = i < 0 ? static_cast(ndim() + i) : i; CHECK(index >= 0 && static_cast(index) < ndim()) << "Invalid index " << i; - const IterVar axis = operator->()->axes[index]; + const tir::IterVar axis = operator->()->axes[index]; return LayoutAxis::Get(axis); } @@ -371,7 +372,7 @@ class BijectiveLayout : public ObjectRef { inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { return static_cast(get()); } - +} // namespace tir } // namespace tvm -#endif // TVM_DATA_LAYOUT_H_ +#endif // TVM_TIR_DATA_LAYOUT_H_ diff --git a/include/tvm/ir.h b/include/tvm/tir/expr.h similarity index 57% rename from include/tvm/ir.h rename to include/tvm/tir/expr.h index ff4b47ffca12..b4787d6b0aed 100644 --- a/include/tvm/ir.h +++ b/include/tvm/tir/expr.h @@ -16,28 +16,309 @@ * specific language governing permissions and limitations * under the License. */ + /*! - * \file tvm/ir.h - * \brief Additional high level nodes in the IR + * \file tvm/tir/expr.h + * \brief TIR expressions. */ -// Acknowledgement: Most low-level IR nodes originate from Halide. +// Acknowledgement: Many low-level IR nodes originate from Halide. +#ifndef TVM_TIR_EXPR_H_ +#define TVM_TIR_EXPR_H_ -#ifndef TVM_IR_H_ -#define TVM_IR_H_ +#include +#include +#include +#include +#include +#include -#include #include -#include +#include +#include +#include +#include #include -#include "expr.h" namespace tvm { -namespace ir { +namespace tir { + +/*! + * \brief A variable node in the IR. + * + * A variable is uniquely identified by its address. + * + * Each variable is only binded once in the following nodes: + * - Allocate + * - For + * - Let + * - LetStmt + */ +class VarNode : public PrimExprNode { + public: + /*! \brief constructor */ + VarNode() {} + VarNode(DataType dtype, std::string name_hint); + + /*! + * \brief The hint to the variable name. + * \note Each variable is uniquely identified by its address. + */ + std::string name_hint; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("name", &name_hint); + } + + static constexpr const char* _type_key = "Variable"; + TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); +}; + +/*! \brief a named variable in TVM */ +class Var : public PrimExpr { + public: + explicit Var(ObjectPtr n) : PrimExpr(n) {} + /*! \brief constructor + * \param name_hint variable name + * \param t data type + */ + TVM_DLL explicit Var(std::string name_hint = "v", + DataType t = DataType::Int(32)); + /*! + * \brief Make a new copy of var with same type, append suffix + * \param suffix The suffix to be appended. + * \return the new Var copy + */ + Var copy_with_suffix(const std::string& suffix) const { + return Var((*this)->name_hint + suffix, (*this)->dtype); + } + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const VarNode* operator->() const { + return get(); + } + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const VarNode* get() const { + return static_cast(data_.get()); + } + /*! \brief type indicate the container type */ + using ContainerType = VarNode; +}; + +/*! + * \brief A variable node represent a tensor index size, + * whose value must be non-negative. + */ +class SizeVarNode : public VarNode { + public: + /*! \brief constructor */ + SizeVarNode() {} + /*! \brief constructor + * \param dtype data type + * \param name_hint variable name + */ + SizeVarNode(DataType dtype, std::string name_hint); + + static constexpr const char* _type_key = "SizeVar"; + TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); +}; + +/*! \brief a named variable represents a tensor index size */ +class SizeVar : public Var { + public: + explicit SizeVar(ObjectPtr n) : Var(n) {} + /*! \brief constructor + * \param name_hint variable name + * \param t data type + */ + TVM_DLL explicit SizeVar(std::string name_hint = "s", + DataType t = DataType::Int(32)); + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const SizeVarNode* operator->() const { + return get(); + } + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const SizeVarNode* get() const { + return static_cast(data_.get()); + } + /*! \brief type indicate the container type */ + using ContainerType = SizeVarNode; +}; + + +/*! \brief container class of iteration variable. */ +class IterVarNode; + +using Region = Array; + +/*! + * \brief Type of iteration variable. + * Each IterVar have a specific type. + * + * The type of iter var can be overriden via + * stage.iter_var_attrs given they are compatible. + */ +enum IterVarType : int { + /*! + * \brief Data parallel iteration. + * This normally corresponds to axis of Tensor. + * Allow all IterVar manipulations. + * + * \note This does not mean the loop + * have to be executed in parallel fashion. + */ + kDataPar = 0, + /*! + * \brief The IterVar itself is a thread-index + * of a fixed thread launching group. + * Note that this is already assumed to be paralellized. + * + * Disallow: split/fuse/vectorize/parallel + */ + kThreadIndex = 1, + /*! + * \brief Communicative reduction. + * Cannot be directly parallelized. + * + * Disallow: parallel/vectorize + */ + kCommReduce = 2, + /*! + * \brief Serial loops with loop carry dependency, + * the iteration must execute in order. + * Cannot be re-ordered. + * + * Disallow: reorder/parallel/vectorize + */ + kOrdered = 3, + /*! + * \brief IterVar is opaque, + * + * May not corresponds to any generated loop + * Disallow all IterVar manipulations and compute_at + * + * \note This is usually used to implement composite op + * or external op, where the + */ + kOpaque = 4, + // The following are possible additional + // types that are provided during schedule + /*! + * \brief The execution is unrolled. + */ + kUnrolled = 5, + /*! + * \brief The loop is vectorized. + */ + kVectorized = 6, + /*! + * \brief The loop is parallelized. + */ + kParallelized = 7, + /*! + * \brief Marks boundary of tensorization intrinsic. + */ + kTensorized = 8 +}; + +/*! + * \brief Iteration Variable, + * represents an iteration over an integer interval. + */ +class IterVar : public ObjectRef { + public: + // construct a new iter var without a domain + IterVar() {} + // construct from shared ptr. + explicit IterVar(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const IterVarNode* operator->() const; + /*! + * \return the corresponding var in the IterVar. + */ + inline operator PrimExpr() const; + /*! \brief specify container node */ + using ContainerType = IterVarNode; +}; + +using Domain = Array; + +/*! + * \brief An iteration variable representing an iteration + * over a one dimensional interval. + */ +class IterVarNode : public Object { + public: + /*! + * \brief the domain of iteration, if known, can be None + * For the intermediate schedule node, before schedule. + */ + Range dom; + /*! \brief The looping variable */ + Var var; + /*! \brief The type of the IterVar */ + IterVarType iter_type; + /*! + * \brief additional tag on the iteration variable, + * set this if this is binded already to a known thread tag. + */ + std::string thread_tag; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dom", &dom); + v->Visit("var", &var); + v->Visit("iter_type", &iter_type); + v->Visit("thread_tag", &thread_tag); + } + + TVM_DLL static IterVar make(Range dom, Var var, + IterVarType iter_type, + std::string thread_tag = ""); + + static constexpr const char* _type_key = "IterVar"; + TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); +}; + +// inline implementations +inline const IterVarNode* IterVar::operator->() const { + return static_cast(data_.get()); +} + +inline IterVar::operator PrimExpr() const { + return (*this)->var; +} + +inline const char* IterVarType2String(IterVarType t) { + switch (t) { + case kDataPar: return "DataPar"; + case kThreadIndex: return "ThreadIndex"; + case kCommReduce: return "CommReduce"; + case kOrdered: return "Ordered"; + case kOpaque: return "Opaque"; + case kUnrolled: return "Unrolled"; + case kVectorized: return "Vectorized"; + case kParallelized: return "Parallelized"; + case kTensorized: return "Tensorized"; + } + return "Unknown"; +} using IntImmNode = tvm::IntImmNode; using FloatImmNode = tvm::FloatImmNode; -using VarNode = tvm::VarNode; -using SizeVarNode = tvm::SizeVarNode; /*! \brief String constants, only used in asserts. */ class StringImmNode : public PrimExprNode { @@ -688,704 +969,24 @@ class AnyNode : public PrimExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; -// Statements -/*! - * \brief Let binding, bind var to value, then run body. - */ -class LetStmtNode : public StmtNode { - public: - /*! \brief The variable. */ - Var var; - /*! \brief The value to be binded. */ - PrimExpr value; - /*! \brief The body block. */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("var", &var); - v->Visit("value", &value); - v->Visit("body", &body); - } - - TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body); - - static constexpr const char* _type_key = "LetStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); -}; - -/*! - * \brief Define certain auxiliary attribute for the body to be a symbolic value. - * This provide auxiliary information for IR passes that transforms body. - * - * In terms of effect, this is equivalent to Block(Evaluate(value), body). - * - * Examples of possible usage: - * - Bound of function, variables. - * - Hint which block corresponds to a parallel region. - */ -class AttrStmtNode : public StmtNode { - public: - /*! \brief this is attribute about certain node */ - ObjectRef node; - /*! \brief the type key of the attribute */ - std::string attr_key; - /*! \brief The attribute value, value is well defined at current scope. */ - PrimExpr value; - /*! \brief The body statement to be executed */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("node", &node); - v->Visit("attr_key", &attr_key); - v->Visit("value", &value); - v->Visit("body", &body); - } - - TVM_DLL static Stmt make(ObjectRef node, - std::string type_key, - PrimExpr value, - Stmt body); - - static constexpr const char* _type_key = "AttrStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); -}; - -/*! - * \brief Assert condition, if an error occurs, return the error message. - */ -class AssertStmtNode : public StmtNode { - public: - /*! \brief Condition to be checked. */ - PrimExpr condition; - /*! \brief Error message when assertion failed. */ - PrimExpr message; - /*! - * \brief Body which this assertion holds true. - * Will be executed after the assertion. - */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("condition", &condition); - v->Visit("message", &message); - v->Visit("body", &body); - } - TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body); - - static constexpr const char* _type_key = "AssertStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); -}; - -// TODO(tvm-team): consider consolidate with AttrStmt. -/*! \brief annotation node of producer/consumer relation. */ -class ProducerConsumerNode : public StmtNode { - public: - /*! \brief The corresponding tensor. */ - FunctionRef func; - /*! \brief Whether the relation is producer. */ - bool is_producer; - /*! \brief Body to be executed. */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("is_producer", &is_producer); - v->Visit("body", &body); - } - - TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); - - static constexpr const char* _type_key = "ProducerConsumer"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode); -}; - -/*! - * \brief Store value to the buffer. - * - * Equivalent to ((DType*)buffer_var)[index] = value. - * where DType is the type specified by type().element_of(). - * - * For example, if type = float32x3, then the store will corresponds to - * - * \code - * - * auto buffer = static_cast(buffer_var); - * buffer[index.v0] = value.v0; - * buffer[index.v1] = value.v1; - * buffer[index.v2] = value.v2; - * - * \endcode - * \sa LoadNode - */ -class StoreNode : public StmtNode { - public: - /*! \brief The buffer variable. */ - Var buffer_var; - /*! \brief The value to be stored. */ - PrimExpr value; - /*! \brief The index locations to be stored. */ - PrimExpr index; - /*! \brief The predicate to mask which lanes would be stored. */ - PrimExpr predicate; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - v->Visit("value", &value); - v->Visit("index", &index); - v->Visit("predicate", &predicate); - } - - TVM_DLL static Stmt make(Var buffer_var, - PrimExpr value, - PrimExpr index, - PrimExpr predicate); - - static constexpr const char* _type_key = "Store"; - TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); -}; - -/*! - * \brief Store value into mult-dimensional array defined by func. - */ -class ProvideNode : public StmtNode { - public: - /*! \brief The function to be updated. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index{0}; - /*! \brief The value to be stored. */ - PrimExpr value; - /*! \brief The index arguments of the function. */ - Array args; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("value", &value); - v->Visit("args", &args); - } - - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - PrimExpr value, - Array args); - - static constexpr const char* _type_key = "Provide"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); -}; - -/*! - * \brief Allocate a buffer that can be used in body. - */ -class AllocateNode : public StmtNode { - public: - /*! \brief The buffer variable. */ - Var buffer_var; - /*! \brief The type of the buffer. */ - DataType dtype; - /*! \brief The extents of the buffer. */ - Array extents; - /*! \brief Only allocate buffer when condition is satisfied. */ - PrimExpr condition; - /*! \brief The body to be executed. */ - Stmt body; - // The following two fields are deprecated - // kept for backward compatibility and will be refactored later. - PrimExpr new_expr; - std::string free_function; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - v->Visit("dtype", &dtype); - v->Visit("extents", &extents); - v->Visit("condition", &condition); - v->Visit("body", &body); - } - - TVM_DLL static Stmt make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, - Stmt body, - PrimExpr new_expr = PrimExpr(), - std::string free_function = std::string()); - - /*! - * \brief If the buffer size is constant, return the size. - * Otherwise return 0. - * \return The result. - */ - int32_t constant_allocation_size() const { - return constant_allocation_size(extents); - } - /*! - * \brief If the buffer size is constant, return the size. - * Otherwise return 0. - * \param extents The extents of the buffer. - * \return The result. - */ - TVM_DLL static int32_t constant_allocation_size( - const Array& extents); - - static constexpr const char* _type_key = "Allocate"; - TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); -}; - -/*! \brief Free the resources in the buffer before the scope ends. */ -class FreeNode : public StmtNode { - public: - /*! \brief The buffer variable. */ - Var buffer_var; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - } - - TVM_DLL static Stmt make(Var buffer_var); - - static constexpr const char* _type_key = "Free"; - TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode); -}; - -/*! - * \brief Annotate the bounds where func need to be written and read in body. - * We will need to allocate space for the corresponding regions. - */ -class RealizeNode : public StmtNode { - public: - /*! \brief The function to be realized. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index; - /*! \brief The data type of the array. */ - DataType dtype; - /*! \brief Bounds to be realized. */ - Region bounds; - /*! \brief Only realize if condition holds. */ - PrimExpr condition; - /*! \brief The body of realization. */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("dtype", &dtype); - v->Visit("bounds", &bounds); - v->Visit("condition", &condition); - v->Visit("body", &body); - } - - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body); - - static constexpr const char* _type_key = "Realize"; - TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode); -}; - -/*! - * \brief The container of seq statement. - * Represent a sequence of statements. - */ -class SeqStmtNode : public StmtNode { - public: - /*! \brief internal sequence content. */ - Array seq; - - /*! \return get the size of the sequence */ - size_t size() const { - return seq.size(); - } - /*! - * \brief Get the index-th element in the sequence. - */ - Stmt operator[](size_t index) const { - return seq[index]; - } - - void VisitAttrs(AttrVisitor* v) { - v->Visit("seq", &seq); - } - - static constexpr const char* _type_key = "SeqStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); -}; - -/*! \brief Sequence statement. */ -class SeqStmt : public Stmt { - public: - /*! - * \brief Construct SeqStmt. - * \param seq The sequence. - */ - TVM_DLL explicit SeqStmt(Array seq); - - /*! \return get the size of the sequence */ - size_t size() const { - return operator->()->size(); - } - /*! - * \brief Get the index-th element in the sequence. - */ - Stmt operator[](size_t index) const { - return (*(operator->()))[index]; - } - /*! - * \brief Construct a sequence statement by flattening - * all the arrays and sequences in the arguments - * recursively. - * - * - When an argument is nullptr, it will be ignored. - * - When an argument is an array or a SeqStmt, it will be flattened recursively. - * - When an argument is a consumer block in ProducerConsumer, the consumer - * tag will be dropped as such information is not useful in lowering. - * - A normal Stmt will be appended to the end of the sequence. - * - * \note This function can directly return an element - * if it is the only element in the sequence. - * - * \param seq_args The list of arguments to be flattened. - * \tparam Args arguments - * \return The constructed statement - */ - template - static Stmt Flatten(Args&&... seq_args) { - Array seq; - runtime::detail::for_each( - Flattener(&seq), std::forward(seq_args)...); - if (seq.size() == 1) return seq[0]; - return SeqStmt(seq); - } - /*! \brief Helper class to flatten sequence of arguments into Array. */ - class Flattener { - public: - explicit Flattener(Array* seq) - : seq_(seq) {} - - void operator()(size_t i, const Stmt& stmt) const { - if (!stmt.defined()) return; - if (auto* op = stmt.as()) { - operator()(0, op->seq); - } else if (auto* op = stmt.as()) { - // NOTE: The consumer block annotation was not as useful and can be safely dropped. - if (!op->is_producer) { - operator()(0, op->body); - } else { - seq_->push_back(stmt); - } - } else { - seq_->push_back(stmt); - } - } - - template - void operator()(size_t i, const T& seq) const { - for (auto v : seq) { - this->operator()(0, v); - } - } - - private: - Array* seq_; - }; - - TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); -}; - -/*! - * \brief IfThenElse statment. - */ -class IfThenElseNode : public StmtNode { - public: - /*! \brief The condition. */ - PrimExpr condition; - /*! \brief The branch to be executed when condition is true. */ - Stmt then_case; - /*! \brief The branch to be executed when condition is false, can be null. */ - Stmt else_case; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("condition", &condition); - v->Visit("then_case", &then_case); - v->Visit("else_case", &else_case); - } - - TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); - - static constexpr const char* _type_key = "IfThenElse"; - TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); -}; - -/*! - * \brief Evaluates an expression. - * This is mostly used for putting a Call node into Stmt. - * - * If value do not have side-effect, this node can be safely removed. - */ -class EvaluateNode : public StmtNode { - public: - /*! \brief The expression to be evaluated. */ - PrimExpr value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - } - - TVM_DLL static Stmt make(PrimExpr v); - - static constexpr const char* _type_key = "Evaluate"; - TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); -}; - -/*! \brief Additional annotation of for loop. */ -enum class ForType : int { - /*! \brief serial execution. */ - Serial = 0, - /*! \brief parallel execution on CPU. */ - Parallel = 1, - /*! \brief Vector SIMD loop annotaion. */ - Vectorized = 2, - /*! \brief Unroll annotation. */ - Unrolled = 3 -}; - -// Kevice api of for loop -// kept for backward compatibility -// consider refactor and remove later. -enum class DeviceAPI: int { - None = 0 -}; - -/*! - * \brief A for loop, with poissible type annotations. - * - * \code - * - * for (loop_var = min; loop_var < min + extent; ++loop_var) { - * // body - * } - * \endcode - */ -class ForNode : public StmtNode { - public: - /*! \brief The loop variable. */ - Var loop_var; - /*! \brief The minimum value of iteration. */ - PrimExpr min; - /*! \brief The extent of the iteration. */ - PrimExpr extent; - /*! \brief The type of the for loop. */ - ForType for_type; - /*! - * \brief Deprecated, reserved for backward compatibility. - * Consider refactor and remove later. - */ - DeviceAPI device_api; - /*! \brief The body of the for loop. */ - Stmt body; - - TVM_DLL static Stmt make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body); - - void VisitAttrs(AttrVisitor* v) { - v->Visit("loop_var", &loop_var); - v->Visit("min", &min); - v->Visit("extent", &extent); - v->Visit("for_type", &for_type); - v->Visit("device_api", &device_api); - v->Visit("body", &body); - } - - static constexpr const char* _type_key = "For"; - TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); -}; - -/*! - * \brief A prefetch hint of func. - */ -class PrefetchNode : public StmtNode { - public: - /*! \brief The function to be prefetched. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index; - /*! \brief The data type of the array. */ - DataType dtype; - /*! \brief Bounds to be prefetched. */ - Region bounds; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("dtype", &dtype); - v->Visit("bounds", &bounds); - } - - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds); - - static constexpr const char* _type_key = "Prefetch"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); -}; - -/*! - * \brief Auxiliary data structure used in IR Pass to indicate a tensor. - */ -struct TensorKey { - FunctionRef f; - int value_index; - - inline bool operator==(const TensorKey& other) const { - return f == other.f && value_index == other.value_index; - } - inline std::string GetName() const { - if (f->num_outputs() == 1) return f->func_name(); - std::ostringstream os; - os << f->func_name() << ".v" << value_index; - return os.str(); +/* + * \brief Template function to convert Map to unordered_map + * Sometimes useful for API gluing when internal uses unordered_map + * \param dmap The container map + * \return The corresponding unordered_map. + * \tparam K the key of the Map. + * \tparam V the value of the Map. + */ +template +inline std::unordered_map as_unordered_map(const Map& dmap) { + std::unordered_map ret; + for (auto kv : dmap) { + ret[kv.first] = kv.second; } -}; - -/*! \brief namespace of possible attribute sin AttrStmt.attr_key */ -namespace attr { -// The above attr does not pass to ir stage. -/*! \brief Mark launching extent of thread, used by device API. */ -constexpr const char* thread_extent = "thread_extent"; -/*! \brief Mark launching of a virtual thread. */ -constexpr const char* virtual_thread = "virtual_thread"; -/*! \brief Mark region is processed by a co-proccesor */ -constexpr const char* coproc_scope = "coproc_scope"; -/*! - * \brief Mark region creates coprocessor micro ops, - * can be reused if corresponding variable is independent. - */ -constexpr const char* coproc_uop_scope = "coproc_uop_scope"; -/*! \brief Mark the scope as volatile access for certain handle. */ -constexpr const char* volatile_scope = "volatile_scope"; -/*! - * \brief Mark the scope as generated by extern primitive. - * such scope can contain arbitrary ir program and we need to be careful - * when make certain assumptions about the structure of the program. - */ -constexpr const char* extern_scope = "extern_scope"; -/*! - * \brief Mark the scope as when computation start to happen - * This can hint some code generator to create a new function for compute. - */ -constexpr const char* compute_scope = "compute_scope"; -/*! \brief Mark storage scope of buffers */ -constexpr const char* storage_scope = "storage_scope"; -/*! \brief Mark storage alignement requirement of buffers */ -constexpr const char* storage_alignment = "storage_alignment"; -/*! \brief Mark storage scope of realization */ -constexpr const char* realize_scope = "realize_scope"; -/*! \brief The allocation context for global malloc in host. */ -constexpr const char* device_context_id = "device_context_id"; -/*! \brief The device type. */ -constexpr const char* device_context_type = "device_context_type"; -/*! \brief Mark of loop scope */ -constexpr const char* loop_scope = "loop_scope"; -/*! \brief Mark of reduce scope */ -constexpr const char* reduce_scope = "reduce_scope"; -/*! \brief Mark region is guarded by the pragma extension */ -constexpr const char* pragma_scope_prefix = "pragma_"; -/*! \brief Import llvm source or file into the final code gen module */ -constexpr const char* pragma_import_llvm = "pragma_import_llvm"; -/*! \brief Try to modify the AST to support Tensor Core */ -constexpr const char* pragma_tensor_core = "pragma_tensor_core"; -/*! - * \brief Mark of prefetch scope, value=offset, - * run prefetch of Tensor on the current loop scope - */ -constexpr const char* prefetch_scope = "prefetch_scope"; -/*! - * \brief Marks production of double buffer data - */ -constexpr const char* double_buffer_scope = "double_buffer_scope"; -/*! - * \brief Marks region used by double buffer write - */ -constexpr const char* double_buffer_write = "double_buffer_write"; -/*! \brief Mark of scan update scope */ -constexpr const char* scan_update_scope = "scan_update_scope"; -/*! \brief Mark of scan init scope */ -constexpr const char* scan_init_scope = "scan_init_scope"; -/*! - * \brief Mark alignment of buffer dimension - * stmt.node is Tensor - * stmt.value is tvm_tuple(dim, align, offset) - * This gives hint to require stride of dim to be k * align + offset. - */ -constexpr const char* buffer_dim_align = "buffer_dim_align"; -/*! \brief Mark stores/loads with theirs bounds. */ -constexpr const char* buffer_bound = "buffer_bound"; -/*! - * \brief Bind the buffer specification to the region of the op - * When this scope occurs, the stmt.node is a Array = [buffer, tensor] - * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). - * The scope represents that we need to bind the storage region of tensor to buffer. - * This will affect replacement of some variables inside the scope that - * corresponds to field of buffer to be the actual expressions of tensor during - * storage flattening phase. - */ -constexpr const char* buffer_bind_scope = "buffer_bind_scope"; -// Pipeline related attributes -/*! \brief channel read scope */ -constexpr const char* channel_read_scope = "channel_read_scope"; -/*! \brief Advance step of channel after end of scope */ -constexpr const char* channel_read_advance = "channel_read_advance"; -/*! \brief channel write scope */ -constexpr const char* channel_write_scope = "channel_write_scope"; -/*! \brief Advance step of channel after end of scope */ -constexpr const char* channel_write_advance = "channel_write_advance"; -/*! \brief pipeline stage scope, implies always execution */ -constexpr const char* pipeline_stage_scope = "pipeline_stage_scope"; -/*! \brief pipeline execution scope, implies the scope can be pipelined. */ -constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; -/*! - * \brief Mark that this stage is an OpenGL shader. Since OpenGL shader only - * allows writing out to one element of the output texture, the Provide node - * gets translated to a special Call::glsl_texture_store statement instead of a - * Store statement. - */ -constexpr const char* opengl_stage_scope = "opengl_stage_scope"; - -/*! - * \brief Mark that it is in the device scope. - */ -constexpr const char* device_scope = "device_scope"; - -/*! - * \brief Mark that the shape of TensorCore fragment - */ -constexpr const char* fragment_shape = "fragment_shape"; - -/*! - * \brief Mark that the layout of TensorCore fragment - */ -constexpr const char* fragment_layout = "fragment_layout"; - -/*! - * \brief Check if attr_key is a pragma key extension - * \param attr_key The attr key to be compared - * \return true if it is a pragma key - */ -inline bool IsPragmaKey(const std::string& attr_key) { - return attr_key.compare(0, 7, "pragma_") == 0; + return ret; } -} // namespace attr - /*! \brief namespace of TVM Intrinsic functions */ namespace intrinsic { /*! @@ -1697,33 +1298,32 @@ enum TVMStructFieldKind : int { }; } // namespace intrinsic -/*! - * \brief Create a type annotation expression - * \param dtype The data type - * \return Expr a expression with dtype. - */ -inline PrimExpr TypeAnnotation(DataType dtype) { - return ir::CallNode::make(dtype, - "type_annotation", {}, - ir::CallNode::PureIntrinsic); -} - -// overload printing of for type. -TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); +} // namespace tir +} // namespace tvm -} // namespace ir +namespace tvm { +namespace runtime { +// Additional implementattion overloads for PackedFunc. +inline TVMPODValue_::operator tvm::Integer() const { + if (type_code_ == kTVMNullptr) return Integer(); + if (type_code_ == kDLInt) { + CHECK_LE(value_.v_int64, std::numeric_limits::max()); + CHECK_GE(value_.v_int64, std::numeric_limits::min()); + return Integer(static_cast(value_.v_int64)); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expect type " << ObjectTypeChecker::TypeName() + << " but get " << ptr->GetTypeKey(); + return Integer(ObjectPtr(ptr)); +} +} // namespace runtime } // namespace tvm namespace std { template <> -struct hash<::tvm::ir::TensorKey> { - std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = ::tvm::ObjectHash()(k.f); - size_t rhs = static_cast(k.value_index); - lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); - return lhs; - } +struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash { }; -} // namespace std - -#endif // TVM_IR_H_ +} +#endif // TVM_TIR_EXPR_H_ diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h new file mode 100644 index 000000000000..0de05a682703 --- /dev/null +++ b/include/tvm/tir/expr_functor.h @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/expr_functor.h + * + * \brief Functors for tir expressions. + */ +#ifndef TVM_TIR_EXPR_FUNCTOR_H_ +#define TVM_TIR_EXPR_FUNCTOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * This helps you to avoid to book-keep return value of Visitor via state, + * which can cause bugs easily when state is incorrectly maintained. + * + * \code + * // A functor that set variable to b. and calculate results. + * class MyExprFunctor + * : public tir::ExprFunctor { + * public: + * int VisitExpr_(const Variable* op, int b) final { + * return b; + * } + * int VisitExpr_(const IntImm* op, int b) final { + * return op->value; + * } + * int VisitExpr_(const Add* op, int b) final { + * return Visit(op->a, b) + Visit(op->b, b); + * } + * }; + * MyExprFunctor f; + * Var x("x"); + * CHECK_EQ(f(x + 1, 2), 3); + * \endcode + * + * \note Why do we need this more powerful Functor: + * + * We often need to implement a transformer tasks. + * Say we want to take Expr and transform it to some analysis result, + * This easily be done incorrectly using plain Visitor. See IRVisitor's + * document for possible error cases. + * + * \tparam FType function signiture + * This type if only defined for FType with function signiture R(const Expr&, Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT { \ + return VisitExprDefault_(op, std::forward(args)...); \ + } + +#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ + std::forward(args)...); \ + }); \ + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const PrimExpr& n, Args... args) { + return VisitExpr(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const PrimExpr& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SizeVarNode* op, Args... args) { + return VisitExpr_(static_cast(op), std::forward(args)...); + } + virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Object* op, Args ...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + return R(); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + IR_EXPR_FUNCTOR_DISPATCH(VarNode); + IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); + IR_EXPR_FUNCTOR_DISPATCH(LoadNode); + IR_EXPR_FUNCTOR_DISPATCH(LetNode); + IR_EXPR_FUNCTOR_DISPATCH(CallNode); + IR_EXPR_FUNCTOR_DISPATCH(AddNode); + IR_EXPR_FUNCTOR_DISPATCH(SubNode); + IR_EXPR_FUNCTOR_DISPATCH(MulNode); + IR_EXPR_FUNCTOR_DISPATCH(DivNode); + IR_EXPR_FUNCTOR_DISPATCH(ModNode); + IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode); + IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); + IR_EXPR_FUNCTOR_DISPATCH(MinNode); + IR_EXPR_FUNCTOR_DISPATCH(MaxNode); + IR_EXPR_FUNCTOR_DISPATCH(EQNode); + IR_EXPR_FUNCTOR_DISPATCH(NENode); + IR_EXPR_FUNCTOR_DISPATCH(LTNode); + IR_EXPR_FUNCTOR_DISPATCH(LENode); + IR_EXPR_FUNCTOR_DISPATCH(GTNode); + IR_EXPR_FUNCTOR_DISPATCH(GENode); + IR_EXPR_FUNCTOR_DISPATCH(AndNode); + IR_EXPR_FUNCTOR_DISPATCH(OrNode); + IR_EXPR_FUNCTOR_DISPATCH(ReduceNode); + IR_EXPR_FUNCTOR_DISPATCH(CastNode); + IR_EXPR_FUNCTOR_DISPATCH(NotNode); + IR_EXPR_FUNCTOR_DISPATCH(SelectNode); + IR_EXPR_FUNCTOR_DISPATCH(RampNode); + IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); + IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); + IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); + IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); + IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); + return vtable; + } +}; + +#undef IR_EXPR_FUNCTOR_DISPATCH +#undef EXPR_FUNCTOR_DEFAULT + +/*! + * \brief ExprVisitor + */ +class TVM_DLL ExprVisitor : + public ExprFunctor { + public: + using ExprFunctor::operator(); + + protected: + using ExprFunctor::VisitExpr; + // list of functions to override. + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const SizeVarNode* op) override; + void VisitExpr_(const LoadNode* op) override; + void VisitExpr_(const LetNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const AddNode* op) override; + void VisitExpr_(const SubNode* op) override; + void VisitExpr_(const MulNode* op) override; + void VisitExpr_(const DivNode* op) override; + void VisitExpr_(const ModNode* op) override; + void VisitExpr_(const FloorDivNode* op) override; + void VisitExpr_(const FloorModNode* op) override; + void VisitExpr_(const MinNode* op) override; + void VisitExpr_(const MaxNode* op) override; + void VisitExpr_(const EQNode* op) override; + void VisitExpr_(const NENode* op) override; + void VisitExpr_(const LTNode* op) override; + void VisitExpr_(const LENode* op) override; + void VisitExpr_(const GTNode* op) override; + void VisitExpr_(const GENode* op) override; + void VisitExpr_(const AndNode* op) override; + void VisitExpr_(const OrNode* op) override; + void VisitExpr_(const ReduceNode* op) override; + void VisitExpr_(const CastNode* op) override; + void VisitExpr_(const NotNode* op) override; + void VisitExpr_(const SelectNode* op) override; + void VisitExpr_(const RampNode* op) override; + void VisitExpr_(const BroadcastNode* op) override; + void VisitExpr_(const ShuffleNode* op) override; + void VisitExpr_(const IntImmNode* op) override; + void VisitExpr_(const FloatImmNode* op) override; + void VisitExpr_(const StringImmNode* op) override; +}; + +/*! + * \brief ExprMutator that mutates expressions. + */ +class TVM_DLL ExprMutator : + protected ExprFunctor { + public: + using ExprFunctor::operator(); + + protected: + using ExprFunctor::VisitExpr; + // list of functions to override. + PrimExpr VisitExpr_(const VarNode* op) override; + PrimExpr VisitExpr_(const SizeVarNode* op) override; + PrimExpr VisitExpr_(const LoadNode* op) override; + PrimExpr VisitExpr_(const LetNode* op) override; + PrimExpr VisitExpr_(const CallNode* op) override; + PrimExpr VisitExpr_(const AddNode* op) override; + PrimExpr VisitExpr_(const SubNode* op) override; + PrimExpr VisitExpr_(const MulNode* op) override; + PrimExpr VisitExpr_(const DivNode* op) override; + PrimExpr VisitExpr_(const ModNode* op) override; + PrimExpr VisitExpr_(const FloorDivNode* op) override; + PrimExpr VisitExpr_(const FloorModNode* op) override; + PrimExpr VisitExpr_(const MinNode* op) override; + PrimExpr VisitExpr_(const MaxNode* op) override; + PrimExpr VisitExpr_(const EQNode* op) override; + PrimExpr VisitExpr_(const NENode* op) override; + PrimExpr VisitExpr_(const LTNode* op) override; + PrimExpr VisitExpr_(const LENode* op) override; + PrimExpr VisitExpr_(const GTNode* op) override; + PrimExpr VisitExpr_(const GENode* op) override; + PrimExpr VisitExpr_(const AndNode* op) override; + PrimExpr VisitExpr_(const OrNode* op) override; + PrimExpr VisitExpr_(const ReduceNode* op) override; + PrimExpr VisitExpr_(const CastNode* op) override; + PrimExpr VisitExpr_(const NotNode* op) override; + PrimExpr VisitExpr_(const SelectNode* op) override; + PrimExpr VisitExpr_(const RampNode* op) override; + PrimExpr VisitExpr_(const BroadcastNode* op) override; + PrimExpr VisitExpr_(const ShuffleNode* op) override; + PrimExpr VisitExpr_(const IntImmNode* op) override; + PrimExpr VisitExpr_(const FloatImmNode* op) override; + PrimExpr VisitExpr_(const StringImmNode* op) override; +}; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_EXPR_FUNCTOR_H_ diff --git a/include/tvm/ir_pass.h b/include/tvm/tir/ir_pass.h similarity index 98% rename from include/tvm/ir_pass.h rename to include/tvm/tir/ir_pass.h index bf444265b078..ae1f35cdde49 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -18,27 +18,28 @@ */ /*! - * \file tvm/ir_pass.h + * \file tvm/tir/ir_pass.h * \brief Collection of IR pass functions * * When the pass functions in this file are for Stmt, * we can use PassFunction(Evaluate(expr)) to apply it to Expr */ -#ifndef TVM_IR_PASS_H_ -#define TVM_IR_PASS_H_ +#ifndef TVM_TIR_IR_PASS_H_ +#define TVM_TIR_IR_PASS_H_ #include +#include +#include +#include #include #include #include #include -#include "expr.h" -#include "buffer.h" -#include "lowered_func.h" + namespace tvm { -namespace ir { +namespace tir { /*! * \brief Simplify the expression. @@ -593,8 +594,6 @@ bool VerifyMemory(LoweredFunc func, int device_type); bool VerifyGPUCode(Stmt stmt, Map constraints); - -} // namespace ir +} // namespace tir } // namespace tvm - -#endif // TVM_IR_PASS_H_ +#endif // TVM_TIR_IR_PASS_H_ diff --git a/include/tvm/lowered_func.h b/include/tvm/tir/lowered_func.h similarity index 91% rename from include/tvm/lowered_func.h rename to include/tvm/tir/lowered_func.h index b0350ae8b4dc..2d01c8958aef 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/tir/lowered_func.h @@ -18,21 +18,20 @@ */ /*! - * \file tvm/lowered_func.h + * \file tvm/tir/lowered_func.h * \brief Information about a lowered TVM function. * This data structure is final step toward codegen. */ -#ifndef TVM_LOWERED_FUNC_H_ -#define TVM_LOWERED_FUNC_H_ - -#include +#ifndef TVM_TIR_LOWERED_FUNC_H_ +#define TVM_TIR_LOWERED_FUNC_H_ +#include +#include +#include #include -#include "expr.h" -#include "tvm/node/container.h" - namespace tvm { +namespace tir { // Internal node container of lowered function. class LoweredFuncNode; @@ -41,7 +40,7 @@ class LoweredFuncNode; * \brief LoweredFunc represents function after lowering. * This is the final IR representation before codegen. */ -class LoweredFunc : public ir::FunctionRef { +class LoweredFunc : public FunctionRef { public: LoweredFunc() {} explicit LoweredFunc(ObjectPtr n) : FunctionRef(n) {} @@ -65,7 +64,7 @@ enum LoweredFuncType : int { }; /*! \brief Node container of LoweredFunc */ -class LoweredFuncNode : public ir::FunctionBaseNode { +class LoweredFuncNode : public tir::FunctionBaseNode { public: /*! \brief The name of the function */ std::string name; @@ -138,13 +137,13 @@ class LoweredFuncNode : public ir::FunctionBaseNode { inline const LoweredFuncNode* LoweredFunc::operator->() const { return static_cast(get()); } - +} // namespace tir } // namespace tvm namespace std { template <> -struct hash<::tvm::LoweredFunc> : public tvm::ObjectHash { +struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash { }; } -#endif // TVM_LOWERED_FUNC_H_ +#endif // TVM_TIR_LOWERED_FUNC_H_ diff --git a/include/tvm/expr_operator.h b/include/tvm/tir/op.h similarity index 88% rename from include/tvm/expr_operator.h rename to include/tvm/tir/op.h index 4fa8da1c6b72..5172b1496ad2 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/tir/op.h @@ -18,112 +18,31 @@ */ /*! - * \file tvm/expr_operator.h + * \file tvm/tir/op.h * \brief Common operators defined for Expr. * * \note Most of the operator defined here perform simple constant folding * when the type is int32 or int64 for simplifying the index expressions. */ // Acknowledgement: Most operator APIs originate from Halide. -#ifndef TVM_EXPR_OPERATOR_H_ -#define TVM_EXPR_OPERATOR_H_ +#ifndef TVM_TIR_OP_H_ +#define TVM_TIR_OP_H_ + +#include +#include #include #include #include -#include "expr.h" -#include "ir.h" -namespace tvm { -/*! - * \brief Make a const value with certain data type. - * \param t The target type. - * \param value The input value - * \return the result expression. - * \tparam ValueType The constant value type - */ -template::value>::type> -inline PrimExpr make_const(DataType t, ValueType value); -/*! - * \brief Make a const zero expr. - * \param t The target type. - * \return the result expression. - */ -inline PrimExpr make_zero(DataType t); -/*! - * \brief Make a constant true expression. - * \param lanes The number of lanes in the bool - * \return The result expression. - */ -inline PrimExpr const_true(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 1); -} -/*! - * \brief Make a constant false expression. - * \param lanes The number of lanes in the bool - * \return The result expression. - */ -inline PrimExpr const_false(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 0); -} -/*! - * \brief Get x as constant int expression. - * \param x The expression - * \return the address to the int expression, - * return nullptr, if x is not IntImm. - */ -inline const int64_t* as_const_int(const PrimExpr& x) { - if (!x.defined()) return nullptr; - if (const ir::IntImmNode* op = x.as()) { - return &(op->value); - } else { - return nullptr; - } -} - -/*! - * \brief Check whether x is a constant integer expression. - * \param x The input argument - * \param value the value to be compared against. - * \return whether x is constant expression. - */ -inline bool is_const_int(const PrimExpr& x, int64_t value); - -/*! - * \brief Check whether stmt is nop. - * \param stmt The input statement - * \return whether stmt is nop - */ -inline bool is_no_op(const Stmt& stmt); - -/*! - * \brief Check whether x is a constant integer 1 - * \param x The input argument. - * \note This only return true for integer types. - * \return whether x is constant 1 - */ -inline bool is_one(const PrimExpr& x) { - return is_const_int(x, 1); -} - -/*! - * \brief Check whether x is a constant integer 0 - * \param x The input argument - * \return whether x is constant 0 - * \note This only return true for integer types. - */ -inline bool is_zero(const PrimExpr& x) { - return is_const_int(x, 0); -} - -/*! - * \brief Check whether x is a constant. - * \note This only return true for integer types. - * \return whether x is constant - */ -inline bool is_const(const PrimExpr& x); +namespace tvm { +// Most common operators can be overloaded by argument type(PrimExpr). +// So we put them under the root namespace. +// It is also necessary to overload operators for PrimExpr. +// +// We put more developer oriented APIs -- make_const and is_const under tir +// as they are more specific to the tir namespace. /*! * Query the maximum possible value of dtype. @@ -139,16 +58,6 @@ TVM_DLL PrimExpr max_value(const DataType& dtype); */ TVM_DLL PrimExpr min_value(const DataType& dtype); -/*! - * \brief Check whether x is a constant power of two - * If x is power of two, write the power to the shift. - * - * \param x The input expression. - * \param shift The output shift if x is power of two. - * \return whether x is constant power of two - */ -TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift); - /*! * \brief cast value to type. * @@ -510,42 +419,42 @@ TVM_DLL PrimExpr isnan(PrimExpr x); * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL PrimExpr sum(PrimExpr source, Array axis); +TVM_DLL PrimExpr sum(PrimExpr source, Array axis); /*! * \brief logical And of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL PrimExpr all(PrimExpr source, Array axis); +TVM_DLL PrimExpr all(PrimExpr source, Array axis); /*! * \brief logical Or of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL PrimExpr any(PrimExpr source, Array axis); +TVM_DLL PrimExpr any(PrimExpr source, Array axis); /*! * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL PrimExpr max(PrimExpr source, Array axis); +TVM_DLL PrimExpr max(PrimExpr source, Array axis); /*! * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL PrimExpr min(PrimExpr source, Array axis); +TVM_DLL PrimExpr min(PrimExpr source, Array axis); /*! * \brief product of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL PrimExpr prod(PrimExpr source, Array axis); +TVM_DLL PrimExpr prod(PrimExpr source, Array axis); /*! * \brief Calculate floor(x) @@ -593,10 +502,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x) { \ - return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ + return tir::CallNode::make(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ + } \ TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(erf); @@ -610,13 +519,113 @@ TVM_DECLARE_INTRIN_UNARY(cos); TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(atan); +namespace tir { +/*! + * \brief Make a const value with certain data type. + * \param t The target type. + * \param value The input value + * \return the result expression. + * \tparam ValueType The constant value type + */ +template::value>::type> +inline PrimExpr make_const(DataType t, ValueType value); +/*! + * \brief Make a const zero expr. + * \param t The target type. + * \return the result expression. + */ +inline PrimExpr make_zero(DataType t); +/*! + * \brief Make a constant true expression. + * \param lanes The number of lanes in the bool + * \return The result expression. + */ +inline PrimExpr const_true(int lanes = 1) { + return make_const(DataType::UInt(1, lanes), 1); +} +/*! + * \brief Make a constant false expression. + * \param lanes The number of lanes in the bool + * \return The result expression. + */ +inline PrimExpr const_false(int lanes = 1) { + return make_const(DataType::UInt(1, lanes), 0); +} +/*! + * \brief Get x as constant int expression. + * \param x The expression + * \return the address to the int expression, + * return nullptr, if x is not IntImm. + */ +inline const int64_t* as_const_int(const PrimExpr& x) { + if (!x.defined()) return nullptr; + if (const tir::IntImmNode* op = x.as()) { + return &(op->value); + } else { + return nullptr; + } +} + +/*! + * \brief Check whether x is a constant integer expression. + * \param x The input argument + * \param value the value to be compared against. + * \return whether x is constant expression. + */ +inline bool is_const_int(const PrimExpr& x, int64_t value); + +/*! + * \brief Check whether stmt is nop. + * \param stmt The input statement + * \return whether stmt is nop + */ +inline bool is_no_op(const tir::Stmt& stmt); + +/*! + * \brief Check whether x is a constant integer 1 + * \param x The input argument. + * \note This only return true for integer types. + * \return whether x is constant 1 + */ +inline bool is_one(const PrimExpr& x) { + return is_const_int(x, 1); +} + +/*! + * \brief Check whether x is a constant integer 0 + * \param x The input argument + * \return whether x is constant 0 + * \note This only return true for integer types. + */ +inline bool is_zero(const PrimExpr& x) { + return is_const_int(x, 0); +} + +/*! + * \brief Check whether x is a constant. + * \note This only return true for integer types. + * \return whether x is constant + */ +inline bool is_const(const PrimExpr& x); + +/*! + * \brief Check whether x is a constant power of two + * If x is power of two, write the power to the shift. + * + * \param x The input expression. + * \param shift The output shift if x is power of two. + * \return whether x is constant power of two + */ +TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift); + // Implementation details after this inline bool is_const(const PrimExpr& x) { - if (x.as()) { + if (x.as()) { return true; - } else if (const auto* op = x.as()) { + } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; - if (val.as()) { + if (val.as()) { return true; } } @@ -624,7 +633,7 @@ inline bool is_const(const PrimExpr& x) { } inline bool is_positive_const(const PrimExpr& a) { - if (const ir::IntImmNode* op = a.as()) { + if (const tir::IntImmNode* op = a.as()) { return op->value > 0; } else { return false; @@ -632,7 +641,7 @@ inline bool is_positive_const(const PrimExpr& a) { } inline bool is_negative_const(const PrimExpr& a) { - if (const ir::IntImmNode* op = a.as()) { + if (const tir::IntImmNode* op = a.as()) { return op->value < 0; } else { return false; @@ -640,23 +649,23 @@ inline bool is_negative_const(const PrimExpr& a) { } inline bool is_const_int(const PrimExpr& x, int64_t value) { - if (const auto* op = x.as()) { + if (const auto* op = x.as()) { return op->value == value; - } else if (const auto* op = x.as()) { + } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; - if (const auto* opv = val.as()) { + if (const auto* opv = val.as()) { return opv->value == value; } } return false; } -inline bool is_no_op(const Stmt& stmt) { +inline bool is_no_op(const tir::Stmt& stmt) { if (!stmt.defined()) return true; - if (const auto* op = stmt.as()) { + if (const auto* op = stmt.as()) { return is_const(op->value); } - if (const auto* op = stmt.as()) { + if (const auto* op = stmt.as()) { return op->seq.size() == 0; } return false; @@ -694,7 +703,7 @@ inline PrimExpr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { - return ir::BroadcastNode::make( + return tir::BroadcastNode::make( MakeConstScalar(t.element_of(), value), t.lanes()); } } @@ -705,6 +714,7 @@ inline PrimExpr make_zero(DataType t) { } return make_const(t, 0); } +} // namespace tir // additional const expression overloading #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ @@ -714,39 +724,38 @@ inline PrimExpr make_zero(DataType t) { } #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, float b) { \ - return Name(a, PrimExpr(b)); \ + inline PrimExpr Name(const PrimExpr& a, float b) { \ + return Name(a, PrimExpr(b)); \ } \ - inline PrimExpr Name(float a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ + inline PrimExpr Name(float a, const PrimExpr& b) { \ + return Name(PrimExpr(a), b); \ } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(make_const(b.dtype(), a), b); \ + inline PrimExpr Name(int a, const PrimExpr& b) { \ + return Name(tir::make_const(b.dtype(), a), b); \ } \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, make_const(a.dtype(), b)); \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ } \ - inline PrimExpr Name(const PrimExpr& a, double b) {\ - return Name(a, make_const(DataType::Float(64), b)); \ + inline PrimExpr Name(const PrimExpr& a, double b) { \ + return Name(a, tir::make_const(DataType::Float(64), b)); \ } -#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, bool b) { \ - return Name(a, PrimExpr(b)); \ - } \ - inline PrimExpr Name(bool a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ +#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, bool b) { \ + return Name(a, PrimExpr(b)); \ + } \ + inline PrimExpr Name(bool a, const PrimExpr& b) { \ + return Name(PrimExpr(a), b); \ } #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, make_const(a.dtype(), b)); \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(make_const(b.dtype(), a), b); \ + inline PrimExpr Name(int a, const PrimExpr& b) { \ + return Name(tir::make_const(b.dtype(), a), b); \ } - TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); @@ -776,7 +785,6 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); - /*! * \brief Helper function to raise a compiler error about division ambiguity. * \note The call to this function will always results in a compiler error. @@ -815,6 +823,5 @@ inline PrimExpr operator%(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } - } // namespace tvm -#endif // TVM_EXPR_OPERATOR_H_ +#endif // TVM_TIR_OP_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h new file mode 100644 index 000000000000..a543737f4065 --- /dev/null +++ b/include/tvm/tir/stmt.h @@ -0,0 +1,775 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/tir/stmt.h + * \brief TIR statements. + */ +// Acknowledgement: Mnay low-level stmts originate from Halide. +#ifndef TVM_TIR_STMT_H_ +#define TVM_TIR_STMT_H_ + +#include + +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! \brief Base node of all statements. */ +class StmtNode : public Object { + public: + static constexpr const char* _type_key = "Stmt"; + TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); +}; + +/*! \brief Container of all statements */ +class Stmt : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode); +}; + +/*! + * \brief Let binding, bind var to value, then run body. + */ +class LetStmtNode : public StmtNode { + public: + /*! \brief The variable. */ + Var var; + /*! \brief The value to be binded. */ + PrimExpr value; + /*! \brief The body block. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body); + + static constexpr const char* _type_key = "LetStmt"; + TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); +}; + +/*! + * \brief Define certain auxiliary attribute for the body to be a symbolic value. + * This provide auxiliary information for IR passes that transforms body. + * + * In terms of effect, this is equivalent to Block(Evaluate(value), body). + * + * Examples of possible usage: + * - Bound of function, variables. + * - Hint which block corresponds to a parallel region. + */ +class AttrStmtNode : public StmtNode { + public: + /*! \brief this is attribute about certain node */ + ObjectRef node; + /*! \brief the type key of the attribute */ + std::string attr_key; + /*! \brief The attribute value, value is well defined at current scope. */ + PrimExpr value; + /*! \brief The body statement to be executed */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("node", &node); + v->Visit("attr_key", &attr_key); + v->Visit("value", &value); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(ObjectRef node, + std::string type_key, + PrimExpr value, + Stmt body); + + static constexpr const char* _type_key = "AttrStmt"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); +}; + +/*! + * \brief Assert condition, if an error occurs, return the error message. + */ +class AssertStmtNode : public StmtNode { + public: + /*! \brief Condition to be checked. */ + PrimExpr condition; + /*! \brief Error message when assertion failed. */ + PrimExpr message; + /*! + * \brief Body which this assertion holds true. + * Will be executed after the assertion. + */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("condition", &condition); + v->Visit("message", &message); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body); + + static constexpr const char* _type_key = "AssertStmt"; + TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); +}; + +// TODO(tvm-team): consider consolidate with AttrStmt. +/*! \brief annotation node of producer/consumer relation. */ +class ProducerConsumerNode : public StmtNode { + public: + /*! \brief The corresponding tensor. */ + FunctionRef func; + /*! \brief Whether the relation is producer. */ + bool is_producer; + /*! \brief Body to be executed. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("func", &func); + v->Visit("is_producer", &is_producer); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); + + static constexpr const char* _type_key = "ProducerConsumer"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode); +}; + +/*! + * \brief Store value to the buffer. + * + * Equivalent to ((DType*)buffer_var)[index] = value. + * where DType is the type specified by type().element_of(). + * + * For example, if type = float32x3, then the store will corresponds to + * + * \code + * + * auto buffer = static_cast(buffer_var); + * buffer[index.v0] = value.v0; + * buffer[index.v1] = value.v1; + * buffer[index.v2] = value.v2; + * + * \endcode + * \sa LoadNode + */ +class StoreNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + /*! \brief The value to be stored. */ + PrimExpr value; + /*! \brief The index locations to be stored. */ + PrimExpr index; + /*! \brief The predicate to mask which lanes would be stored. */ + PrimExpr predicate; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer_var", &buffer_var); + v->Visit("value", &value); + v->Visit("index", &index); + v->Visit("predicate", &predicate); + } + + TVM_DLL static Stmt make(Var buffer_var, + PrimExpr value, + PrimExpr index, + PrimExpr predicate); + + static constexpr const char* _type_key = "Store"; + TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); +}; + +/*! + * \brief Store value into mult-dimensional array defined by func. + */ +class ProvideNode : public StmtNode { + public: + /*! \brief The function to be updated. */ + FunctionRef func; + /*! \brief The output value index if func's value is a tuple. */ + int value_index{0}; + /*! \brief The value to be stored. */ + PrimExpr value; + /*! \brief The index arguments of the function. */ + Array args; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("func", &func); + v->Visit("value_index", &value_index); + v->Visit("value", &value); + v->Visit("args", &args); + } + + TVM_DLL static Stmt make(FunctionRef func, + int value_index, + PrimExpr value, + Array args); + + static constexpr const char* _type_key = "Provide"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); +}; + +/*! + * \brief Allocate a buffer that can be used in body. + */ +class AllocateNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + /*! \brief The type of the buffer. */ + DataType dtype; + /*! \brief The extents of the buffer. */ + Array extents; + /*! \brief Only allocate buffer when condition is satisfied. */ + PrimExpr condition; + /*! \brief The body to be executed. */ + Stmt body; + // The following two fields are deprecated + // kept for backward compatibility and will be refactored later. + PrimExpr new_expr; + std::string free_function; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer_var", &buffer_var); + v->Visit("dtype", &dtype); + v->Visit("extents", &extents); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(Var buffer_var, + DataType dtype, + Array extents, + PrimExpr condition, + Stmt body, + PrimExpr new_expr = PrimExpr(), + std::string free_function = std::string()); + + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \return The result. + */ + int32_t constant_allocation_size() const { + return constant_allocation_size(extents); + } + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \param extents The extents of the buffer. + * \return The result. + */ + TVM_DLL static int32_t constant_allocation_size( + const Array& extents); + + static constexpr const char* _type_key = "Allocate"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); +}; + +/*! \brief Free the resources in the buffer before the scope ends. */ +class FreeNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer_var", &buffer_var); + } + + TVM_DLL static Stmt make(Var buffer_var); + + static constexpr const char* _type_key = "Free"; + TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode); +}; + +/*! + * \brief Annotate the bounds where func need to be written and read in body. + * We will need to allocate space for the corresponding regions. + */ +class RealizeNode : public StmtNode { + public: + /*! \brief The function to be realized. */ + FunctionRef func; + /*! \brief The output value index if func's value is a tuple. */ + int value_index; + /*! \brief The data type of the array. */ + DataType dtype; + /*! \brief Bounds to be realized. */ + Region bounds; + /*! \brief Only realize if condition holds. */ + PrimExpr condition; + /*! \brief The body of realization. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("func", &func); + v->Visit("value_index", &value_index); + v->Visit("dtype", &dtype); + v->Visit("bounds", &bounds); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(FunctionRef func, + int value_index, + DataType dtype, + Region bounds, + PrimExpr condition, + Stmt body); + + static constexpr const char* _type_key = "Realize"; + TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode); +}; + +/*! + * \brief The container of seq statement. + * Represent a sequence of statements. + */ +class SeqStmtNode : public StmtNode { + public: + /*! \brief internal sequence content. */ + Array seq; + + /*! \return get the size of the sequence */ + size_t size() const { + return seq.size(); + } + /*! + * \brief Get the index-th element in the sequence. + */ + Stmt operator[](size_t index) const { + return seq[index]; + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("seq", &seq); + } + + static constexpr const char* _type_key = "SeqStmt"; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); +}; + +/*! \brief Sequence statement. */ +class SeqStmt : public Stmt { + public: + /*! + * \brief Construct SeqStmt. + * \param seq The sequence. + */ + TVM_DLL explicit SeqStmt(Array seq); + + /*! \return get the size of the sequence */ + size_t size() const { + return operator->()->size(); + } + /*! + * \brief Get the index-th element in the sequence. + */ + Stmt operator[](size_t index) const { + return (*(operator->()))[index]; + } + /*! + * \brief Construct a sequence statement by flattening + * all the arrays and sequences in the arguments + * recursively. + * + * - When an argument is nullptr, it will be ignored. + * - When an argument is an array or a SeqStmt, it will be flattened recursively. + * - When an argument is a consumer block in ProducerConsumer, the consumer + * tag will be dropped as such information is not useful in lowering. + * - A normal Stmt will be appended to the end of the sequence. + * + * \note This function can directly return an element + * if it is the only element in the sequence. + * + * \param seq_args The list of arguments to be flattened. + * \tparam Args arguments + * \return The constructed statement + */ + template + static Stmt Flatten(Args&&... seq_args) { + Array seq; + runtime::detail::for_each( + Flattener(&seq), std::forward(seq_args)...); + if (seq.size() == 1) return seq[0]; + return SeqStmt(seq); + } + /*! \brief Helper class to flatten sequence of arguments into Array. */ + class Flattener { + public: + explicit Flattener(Array* seq) + : seq_(seq) {} + + void operator()(size_t i, const Stmt& stmt) const { + if (!stmt.defined()) return; + if (auto* op = stmt.as()) { + operator()(0, op->seq); + } else if (auto* op = stmt.as()) { + // NOTE: The consumer block annotation was not as useful and can be safely dropped. + if (!op->is_producer) { + operator()(0, op->body); + } else { + seq_->push_back(stmt); + } + } else { + seq_->push_back(stmt); + } + } + + template + void operator()(size_t i, const T& seq) const { + for (auto v : seq) { + this->operator()(0, v); + } + } + + private: + Array* seq_; + }; + + TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); +}; + +/*! + * \brief IfThenElse statment. + */ +class IfThenElseNode : public StmtNode { + public: + /*! \brief The condition. */ + PrimExpr condition; + /*! \brief The branch to be executed when condition is true. */ + Stmt then_case; + /*! \brief The branch to be executed when condition is false, can be null. */ + Stmt else_case; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("condition", &condition); + v->Visit("then_case", &then_case); + v->Visit("else_case", &else_case); + } + + TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); + + static constexpr const char* _type_key = "IfThenElse"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); +}; + +/*! + * \brief Evaluates an expression. + * This is mostly used for putting a Call node into Stmt. + * + * If value do not have side-effect, this node can be safely removed. + */ +class EvaluateNode : public StmtNode { + public: + /*! \brief The expression to be evaluated. */ + PrimExpr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("value", &value); + } + + TVM_DLL static Stmt make(PrimExpr v); + + static constexpr const char* _type_key = "Evaluate"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); +}; + +/*! \brief Additional annotation of for loop. */ +enum class ForType : int { + /*! \brief serial execution. */ + Serial = 0, + /*! \brief parallel execution on CPU. */ + Parallel = 1, + /*! \brief Vector SIMD loop annotaion. */ + Vectorized = 2, + /*! \brief Unroll annotation. */ + Unrolled = 3 +}; + +// Kevice api of for loop +// kept for backward compatibility +// consider refactor and remove later. +enum class DeviceAPI: int { + None = 0 +}; + +/*! + * \brief A for loop, with poissible type annotations. + * + * \code + * + * for (loop_var = min; loop_var < min + extent; ++loop_var) { + * // body + * } + * \endcode + */ +class ForNode : public StmtNode { + public: + /*! \brief The loop variable. */ + Var loop_var; + /*! \brief The minimum value of iteration. */ + PrimExpr min; + /*! \brief The extent of the iteration. */ + PrimExpr extent; + /*! \brief The type of the for loop. */ + ForType for_type; + /*! + * \brief Deprecated, reserved for backward compatibility. + * Consider refactor and remove later. + */ + DeviceAPI device_api; + /*! \brief The body of the for loop. */ + Stmt body; + + TVM_DLL static Stmt make(Var loop_var, + PrimExpr min, + PrimExpr extent, + ForType for_type, + DeviceAPI device_api, + Stmt body); + + void VisitAttrs(AttrVisitor* v) { + v->Visit("loop_var", &loop_var); + v->Visit("min", &min); + v->Visit("extent", &extent); + v->Visit("for_type", &for_type); + v->Visit("device_api", &device_api); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "For"; + TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); +}; + +/*! + * \brief A prefetch hint of func. + */ +class PrefetchNode : public StmtNode { + public: + /*! \brief The function to be prefetched. */ + FunctionRef func; + /*! \brief The output value index if func's value is a tuple. */ + int value_index; + /*! \brief The data type of the array. */ + DataType dtype; + /*! \brief Bounds to be prefetched. */ + Region bounds; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("func", &func); + v->Visit("value_index", &value_index); + v->Visit("dtype", &dtype); + v->Visit("bounds", &bounds); + } + + TVM_DLL static Stmt make(FunctionRef func, + int value_index, + DataType dtype, + Region bounds); + + static constexpr const char* _type_key = "Prefetch"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); +}; + +/*! + * \brief Auxiliary data structure used in IR Pass to indicate a tensor. + */ +struct TensorKey { + FunctionRef f; + int value_index; + + inline bool operator==(const TensorKey& other) const { + return f == other.f && value_index == other.value_index; + } + inline std::string GetName() const { + if (f->num_outputs() == 1) return f->func_name(); + std::ostringstream os; + os << f->func_name() << ".v" << value_index; + return os.str(); + } +}; + +/*! \brief namespace of possible attribute sin AttrStmt.attr_key */ +namespace attr { +// The above attr does not pass to ir stage. +/*! \brief Mark launching extent of thread, used by device API. */ +constexpr const char* thread_extent = "thread_extent"; +/*! \brief Mark launching of a virtual thread. */ +constexpr const char* virtual_thread = "virtual_thread"; +/*! \brief Mark region is processed by a co-proccesor */ +constexpr const char* coproc_scope = "coproc_scope"; +/*! + * \brief Mark region creates coprocessor micro ops, + * can be reused if corresponding variable is independent. + */ +constexpr const char* coproc_uop_scope = "coproc_uop_scope"; +/*! \brief Mark the scope as volatile access for certain handle. */ +constexpr const char* volatile_scope = "volatile_scope"; +/*! + * \brief Mark the scope as generated by extern primitive. + * such scope can contain arbitrary ir program and we need to be careful + * when make certain assumptions about the structure of the program. + */ +constexpr const char* extern_scope = "extern_scope"; +/*! + * \brief Mark the scope as when computation start to happen + * This can hint some code generator to create a new function for compute. + */ +constexpr const char* compute_scope = "compute_scope"; +/*! \brief Mark storage scope of buffers */ +constexpr const char* storage_scope = "storage_scope"; +/*! \brief Mark storage alignement requirement of buffers */ +constexpr const char* storage_alignment = "storage_alignment"; +/*! \brief Mark storage scope of realization */ +constexpr const char* realize_scope = "realize_scope"; +/*! \brief The allocation context for global malloc in host. */ +constexpr const char* device_context_id = "device_context_id"; +/*! \brief The device type. */ +constexpr const char* device_context_type = "device_context_type"; +/*! \brief Mark of loop scope */ +constexpr const char* loop_scope = "loop_scope"; +/*! \brief Mark of reduce scope */ +constexpr const char* reduce_scope = "reduce_scope"; +/*! \brief Mark region is guarded by the pragma extension */ +constexpr const char* pragma_scope_prefix = "pragma_"; +/*! \brief Import llvm source or file into the final code gen module */ +constexpr const char* pragma_import_llvm = "pragma_import_llvm"; +/*! \brief Try to modify the AST to support Tensor Core */ +constexpr const char* pragma_tensor_core = "pragma_tensor_core"; +/*! + * \brief Mark of prefetch scope, value=offset, + * run prefetch of Tensor on the current loop scope + */ +constexpr const char* prefetch_scope = "prefetch_scope"; +/*! + * \brief Marks production of double buffer data + */ +constexpr const char* double_buffer_scope = "double_buffer_scope"; +/*! + * \brief Marks region used by double buffer write + */ +constexpr const char* double_buffer_write = "double_buffer_write"; +/*! \brief Mark of scan update scope */ +constexpr const char* scan_update_scope = "scan_update_scope"; +/*! \brief Mark of scan init scope */ +constexpr const char* scan_init_scope = "scan_init_scope"; +/*! + * \brief Mark alignment of buffer dimension + * stmt.node is Tensor + * stmt.value is tvm_tuple(dim, align, offset) + * This gives hint to require stride of dim to be k * align + offset. + */ +constexpr const char* buffer_dim_align = "buffer_dim_align"; +/*! \brief Mark stores/loads with theirs bounds. */ +constexpr const char* buffer_bound = "buffer_bound"; +/*! + * \brief Bind the buffer specification to the region of the op + * When this scope occurs, the stmt.node is a Array = [buffer, tensor] + * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). + * The scope represents that we need to bind the storage region of tensor to buffer. + * This will affect replacement of some variables inside the scope that + * corresponds to field of buffer to be the actual expressions of tensor during + * storage flattening phase. + */ +constexpr const char* buffer_bind_scope = "buffer_bind_scope"; +// Pipeline related attributes +/*! \brief channel read scope */ +constexpr const char* channel_read_scope = "channel_read_scope"; +/*! \brief Advance step of channel after end of scope */ +constexpr const char* channel_read_advance = "channel_read_advance"; +/*! \brief channel write scope */ +constexpr const char* channel_write_scope = "channel_write_scope"; +/*! \brief Advance step of channel after end of scope */ +constexpr const char* channel_write_advance = "channel_write_advance"; +/*! \brief pipeline stage scope, implies always execution */ +constexpr const char* pipeline_stage_scope = "pipeline_stage_scope"; +/*! \brief pipeline execution scope, implies the scope can be pipelined. */ +constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; +/*! + * \brief Mark that this stage is an OpenGL shader. Since OpenGL shader only + * allows writing out to one element of the output texture, the Provide node + * gets translated to a special Call::glsl_texture_store statement instead of a + * Store statement. + */ +constexpr const char* opengl_stage_scope = "opengl_stage_scope"; + +/*! + * \brief Mark that it is in the device scope. + */ +constexpr const char* device_scope = "device_scope"; + +/*! + * \brief Mark that the shape of TensorCore fragment + */ +constexpr const char* fragment_shape = "fragment_shape"; + +/*! + * \brief Mark that the layout of TensorCore fragment + */ +constexpr const char* fragment_layout = "fragment_layout"; + +/*! + * \brief Check if attr_key is a pragma key extension + * \param attr_key The attr key to be compared + * \return true if it is a pragma key + */ +inline bool IsPragmaKey(const std::string& attr_key) { + return attr_key.compare(0, 7, "pragma_") == 0; +} + +} // namespace attr +/*! + * \brief Create a type annotation expression + * \param dtype The data type + * \return Expr a expression with dtype. + */ +inline PrimExpr TypeAnnotation(DataType dtype) { + return tir::CallNode::make(dtype, + "type_annotation", {}, + tir::CallNode::PureIntrinsic); +} + +// overload printing of for type. +TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); + +} // namespace tir +} // namespace tvm + +namespace std { +template <> +struct hash<::tvm::tir::TensorKey> { + std::size_t operator()(const ::tvm::tir::TensorKey& k) const { + size_t lhs = ::tvm::ObjectHash()(k.f); + size_t rhs = static_cast(k.value_index); + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; +} // namespace std + +#endif // TVM_TIR_STMT_H_ diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/tir/stmt_functor.h similarity index 52% rename from include/tvm/ir_functor_ext.h rename to include/tvm/tir/stmt_functor.h index 11e8e3836053..c880a4847356 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/tir/stmt_functor.h @@ -18,82 +18,34 @@ */ /*! - * \file tvm/ir_functor_ext.h - * \brief More powerful Visitor that allows define function signatures. + * \file tvm/tir/stmt_functor.h + * + * \brief Functors for tir stmts. */ -#ifndef TVM_IR_FUNCTOR_EXT_H_ -#define TVM_IR_FUNCTOR_EXT_H_ +#ifndef TVM_TIR_STMT_FUNCTOR_H_ +#define TVM_TIR_STMT_FUNCTOR_H_ #include -#include +#include +#include +#include #include namespace tvm { -namespace ir { - -/*! - * \brief A dynamical functor that dispatches on in the first Expr argument. - * You can use this as a more powerful Visitor, since it allows you to - * define function signatures of Visit Function. - * - * This helps you to avoid to book-keep return value of Visitor via state, - * which can cause bugs easily when state is incorrectly maintained. - * - * \code - * // A functor that set variable to b. and calculate results. - * class MyExprFunctor - * : public ir::ExprFunctor { - * public: - * int VisitExpr_(const Variable* op, int b) final { - * return b; - * } - * int VisitExpr_(const IntImm* op, int b) final { - * return op->value; - * } - * int VisitExpr_(const Add* op, int b) final { - * return Visit(op->a, b) + Visit(op->b, b); - * } - * }; - * MyExprFunctor f; - * Var x("x"); - * CHECK_EQ(f(x + 1, 2), 3); - * \endcode - * - * \note Why do we need this more powerful Functor: - * - * We often need to implement a transformer tasks. - * Say we want to take Expr and transform it to some analysis result, - * This easily be done incorrectly using plain Visitor. See IRVisitor's - * document for possible error cases. - * - * \tparam FType function signiture - * This type if only defined for FType with function signiture R(const Expr&, Args...) - */ -template -class ExprFunctor; +namespace tir { /*! * \brief Same as ExprFunctor except it is applied on statements * \tparam FType The function signature. + * \sa ExprFunctor */ template class StmtFunctor; -// functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT { \ - return VisitExprDefault_(op, std::forward(args)...); \ - } #define STMT_FUNCTOR_DEFAULT { \ return VisitStmtDefault_(op, std::forward(args)...); \ } -#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ - #define IR_STMT_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ [](const ObjectRef& n, TSelf* self, Args... args) { \ @@ -101,116 +53,6 @@ class StmtFunctor; std::forward(args)...); \ }); \ -template -class ExprFunctor { - private: - using TSelf = ExprFunctor; - using FType = NodeFunctor; - - public: - /*! \brief the result type of this functor */ - using result_type = R; - /*! \brief virtual destructor */ - virtual ~ExprFunctor() {} - /*! - * \brief Same as call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - R operator()(const PrimExpr& n, Args... args) { - return VisitExpr(n, std::forward(args)...); - } - /*! - * \brief The functor call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - virtual R VisitExpr(const PrimExpr& n, Args... args) { - static FType vtable = InitVTable(); - return vtable(n, this, std::forward(args)...); - } - // Functions that can be overriden by subclass - virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const SizeVarNode* op, Args... args) { - return VisitExpr_(static_cast(op), std::forward(args)...); - } - virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Object* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); - return R(); - } - - private: - // initialize the vtable. - static FType InitVTable() { - FType vtable; - // Set dispatch - IR_EXPR_FUNCTOR_DISPATCH(VarNode); - IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); - IR_EXPR_FUNCTOR_DISPATCH(LoadNode); - IR_EXPR_FUNCTOR_DISPATCH(LetNode); - IR_EXPR_FUNCTOR_DISPATCH(CallNode); - IR_EXPR_FUNCTOR_DISPATCH(AddNode); - IR_EXPR_FUNCTOR_DISPATCH(SubNode); - IR_EXPR_FUNCTOR_DISPATCH(MulNode); - IR_EXPR_FUNCTOR_DISPATCH(DivNode); - IR_EXPR_FUNCTOR_DISPATCH(ModNode); - IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode); - IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); - IR_EXPR_FUNCTOR_DISPATCH(MinNode); - IR_EXPR_FUNCTOR_DISPATCH(MaxNode); - IR_EXPR_FUNCTOR_DISPATCH(EQNode); - IR_EXPR_FUNCTOR_DISPATCH(NENode); - IR_EXPR_FUNCTOR_DISPATCH(LTNode); - IR_EXPR_FUNCTOR_DISPATCH(LENode); - IR_EXPR_FUNCTOR_DISPATCH(GTNode); - IR_EXPR_FUNCTOR_DISPATCH(GENode); - IR_EXPR_FUNCTOR_DISPATCH(AndNode); - IR_EXPR_FUNCTOR_DISPATCH(OrNode); - IR_EXPR_FUNCTOR_DISPATCH(ReduceNode); - IR_EXPR_FUNCTOR_DISPATCH(CastNode); - IR_EXPR_FUNCTOR_DISPATCH(NotNode); - IR_EXPR_FUNCTOR_DISPATCH(SelectNode); - IR_EXPR_FUNCTOR_DISPATCH(RampNode); - IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); - IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); - IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); - IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); - IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); - return vtable; - } -}; template class StmtFunctor { @@ -285,100 +127,8 @@ class StmtFunctor { }; #undef IR_STMT_FUNCTOR_DISPATCH -#undef IR_EXPR_FUNCTOR_DISPATCH -#undef EXPR_FUNCTOR_DEFAULT #undef STMT_FUNCTOR_DEFAULT -/*! - * \brief ExprVisitor - */ -class TVM_DLL ExprVisitor : - public ExprFunctor { - public: - using ExprFunctor::operator(); - - protected: - using ExprFunctor::VisitExpr; - // list of functions to override. - void VisitExpr_(const VarNode* op) override; - void VisitExpr_(const SizeVarNode* op) override; - void VisitExpr_(const LoadNode* op) override; - void VisitExpr_(const LetNode* op) override; - void VisitExpr_(const CallNode* op) override; - void VisitExpr_(const AddNode* op) override; - void VisitExpr_(const SubNode* op) override; - void VisitExpr_(const MulNode* op) override; - void VisitExpr_(const DivNode* op) override; - void VisitExpr_(const ModNode* op) override; - void VisitExpr_(const FloorDivNode* op) override; - void VisitExpr_(const FloorModNode* op) override; - void VisitExpr_(const MinNode* op) override; - void VisitExpr_(const MaxNode* op) override; - void VisitExpr_(const EQNode* op) override; - void VisitExpr_(const NENode* op) override; - void VisitExpr_(const LTNode* op) override; - void VisitExpr_(const LENode* op) override; - void VisitExpr_(const GTNode* op) override; - void VisitExpr_(const GENode* op) override; - void VisitExpr_(const AndNode* op) override; - void VisitExpr_(const OrNode* op) override; - void VisitExpr_(const ReduceNode* op) override; - void VisitExpr_(const CastNode* op) override; - void VisitExpr_(const NotNode* op) override; - void VisitExpr_(const SelectNode* op) override; - void VisitExpr_(const RampNode* op) override; - void VisitExpr_(const BroadcastNode* op) override; - void VisitExpr_(const ShuffleNode* op) override; - void VisitExpr_(const IntImmNode* op) override; - void VisitExpr_(const FloatImmNode* op) override; - void VisitExpr_(const StringImmNode* op) override; -}; - -/*! - * \brief ExprMutator that mutates expressions. - */ -class TVM_DLL ExprMutator : - protected ExprFunctor { - public: - using ExprFunctor::operator(); - - protected: - using ExprFunctor::VisitExpr; - // list of functions to override. - PrimExpr VisitExpr_(const VarNode* op) override; - PrimExpr VisitExpr_(const SizeVarNode* op) override; - PrimExpr VisitExpr_(const LoadNode* op) override; - PrimExpr VisitExpr_(const LetNode* op) override; - PrimExpr VisitExpr_(const CallNode* op) override; - PrimExpr VisitExpr_(const AddNode* op) override; - PrimExpr VisitExpr_(const SubNode* op) override; - PrimExpr VisitExpr_(const MulNode* op) override; - PrimExpr VisitExpr_(const DivNode* op) override; - PrimExpr VisitExpr_(const ModNode* op) override; - PrimExpr VisitExpr_(const FloorDivNode* op) override; - PrimExpr VisitExpr_(const FloorModNode* op) override; - PrimExpr VisitExpr_(const MinNode* op) override; - PrimExpr VisitExpr_(const MaxNode* op) override; - PrimExpr VisitExpr_(const EQNode* op) override; - PrimExpr VisitExpr_(const NENode* op) override; - PrimExpr VisitExpr_(const LTNode* op) override; - PrimExpr VisitExpr_(const LENode* op) override; - PrimExpr VisitExpr_(const GTNode* op) override; - PrimExpr VisitExpr_(const GENode* op) override; - PrimExpr VisitExpr_(const AndNode* op) override; - PrimExpr VisitExpr_(const OrNode* op) override; - PrimExpr VisitExpr_(const ReduceNode* op) override; - PrimExpr VisitExpr_(const CastNode* op) override; - PrimExpr VisitExpr_(const NotNode* op) override; - PrimExpr VisitExpr_(const SelectNode* op) override; - PrimExpr VisitExpr_(const RampNode* op) override; - PrimExpr VisitExpr_(const BroadcastNode* op) override; - PrimExpr VisitExpr_(const ShuffleNode* op) override; - PrimExpr VisitExpr_(const IntImmNode* op) override; - PrimExpr VisitExpr_(const FloatImmNode* op) override; - PrimExpr VisitExpr_(const StringImmNode* op) override; -}; - /*! * \brief StmtVisitor. */ @@ -591,7 +341,7 @@ TVM_DLL Stmt IRTransform(Stmt node, */ TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); - -} // namespace ir +} // namespace tir } // namespace tvm -#endif // TVM_IR_FUNCTOR_EXT_H_ + +#endif // TVM_TIR_STMT_FUNCTOR_H_ diff --git a/include/tvm/top/operation.h b/include/tvm/top/operation.h index 2cee21873687..1b138d00b07b 100644 --- a/include/tvm/top/operation.h +++ b/include/tvm/top/operation.h @@ -28,9 +28,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -41,8 +41,6 @@ namespace tvm { namespace top { -using arith::IntSet; - /*! * \brief Temporary data structure to store union * of bounds of each axis of Tensor. @@ -58,7 +56,7 @@ struct TensorDom { /*! * \brief Base class of all operation nodes */ -class OperationNode : public ir::FunctionBaseNode { +class OperationNode : public tir::FunctionBaseNode { public: /*! \brief optional name of the operation */ std::string name; @@ -554,6 +552,29 @@ class HybridOpNode : public OperationNode { TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); }; +/*! + * \brief Construct a new Var expression + * \param name_hint The name hint for the expression + * \param t The type of the expression + */ +TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32)); + +/*! + * \brief Create a new IterVar that represents an axis in thread. + * + * \param dom Optional, domain of the thread axis. + * \param tag The thread tag of the axis. + */ +TVM_DLL IterVar thread_axis(Range dom, std::string tag); + +/*! + * \brief Create a new IterVar for reduction operations. + * + * \param dom The domain of the reduction axis. + * \param name The name of the reduction axis. + */ +TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); + /*! \brief The compute function to specify the input source of a Tensor */ using FCompute = std::function& i)>; diff --git a/include/tvm/top/schedule.h b/include/tvm/top/schedule.h index 2adaa1337812..5eaa02db390c 100644 --- a/include/tvm/top/schedule.h +++ b/include/tvm/top/schedule.h @@ -25,7 +25,7 @@ #ifndef TVM_TOP_SCHEDULE_H_ #define TVM_TOP_SCHEDULE_H_ -#include +#include #include #include diff --git a/include/tvm/top/tensor.h b/include/tvm/top/tensor.h index bdfbbebabd6b..722ed50cdf5a 100644 --- a/include/tvm/top/tensor.h +++ b/include/tvm/top/tensor.h @@ -26,8 +26,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -39,6 +39,9 @@ namespace tvm { namespace top { +using arith::IntSet; +using namespace tvm::tir; + // Internal node container of Tensor class TensorNode; // internal node container for Operation @@ -139,7 +142,7 @@ class Tensor : public ObjectRef { }; /*! \brief Operation that produces tensors */ -class Operation : public ir::FunctionRef { +class Operation : public tir::FunctionRef { public: /*! \brief default constructor */ Operation() {} @@ -215,18 +218,18 @@ inline bool Tensor::operator!=(const Tensor& other) const { // macro to turn every operation of slice to expression #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ - inline PrimExpr operator Op (const Tensor::Slice& a) { \ - return Op a.operator PrimExpr() ; \ + inline PrimExpr operator Op (const Tensor::Slice& a) { \ + return Op a.operator PrimExpr() ; \ } \ #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ template \ - inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \ - return a.operator PrimExpr() Op b; \ + inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \ + return a.operator PrimExpr() Op b; \ } \ template \ - inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \ - return a Op b.operator PrimExpr(); \ + inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \ + return a Op b.operator PrimExpr(); \ } \ inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ return a.operator PrimExpr() Op b.operator PrimExpr(); \ diff --git a/include/tvm/top/tensor_intrin.h b/include/tvm/top/tensor_intrin.h index 99eb8852f0a2..d216eccbffb3 100644 --- a/include/tvm/top/tensor_intrin.h +++ b/include/tvm/top/tensor_intrin.h @@ -25,7 +25,7 @@ #define TVM_TOP_TENSOR_INTRIN_H_ #include -#include +#include #include diff --git a/python/tvm/ir_pass.py b/python/tvm/ir_pass.py index 5d5ddf0a8668..59354e2eb890 100644 --- a/python/tvm/ir_pass.py +++ b/python/tvm/ir_pass.py @@ -20,7 +20,7 @@ The functions are automatically exported from C++ side via PackedFunc. Each api is a PackedFunc that can be called in a positional argument manner. -You can read "include/tvm/ir_pass.h" for the function signature and +You can read "include/tvm/tir/ir_pass.h" for the function signature and "src/api/api_pass.cc" for the PackedFunc's body of these functions. """ from ._ffi.function import _init_api diff --git a/src/README.md b/src/README.md index 2de81416cdb0..4cd3a32ed749 100644 --- a/src/README.md +++ b/src/README.md @@ -24,12 +24,12 @@ There can be internal header files within each module that sit in src. - support: Internal support utilities. - runtime: Minimum runtime related codes. - node: base infra for IR/AST nodes that is dialect independent. +- ir: Common IR infrastructure. +- tir: Tensor-level IR. - arith: Arithmetic expression and set simplification. - top: tensor operation DSL for compute and schedule. -- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks. -- pass: The optimization pass on the IR structure. +- relay: Relay IR, high-level optimization. - codegen: The code generator. - autotvm: The auto-tuning module. - contrib: Contrib extension libraries. - api: API function registration. -- lang: The definition of DSL related data structure. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 0062379fd32a..f5232d850c36 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -26,8 +26,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 9078507e5453..d1d3fb0af366 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -22,7 +22,7 @@ * \file api_base.cc */ #include -#include +#include #include #include #include diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 6c1d193d0a3f..03ff2ead70b6 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -21,10 +21,10 @@ * Implementation of API functions related to Codegen * \file c_api_codegen.cc */ -#include -#include +#include +#include #include -#include +#include #include namespace tvm { @@ -32,7 +32,7 @@ namespace codegen { TVM_REGISTER_GLOBAL("codegen._Build") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { + if (args[0].IsObjectRef()) { *ret = Build({args[0]}, args[1]); } else { *ret = Build(args[0], args[1]); diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 45f7790d63d8..35810cbc23cf 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -21,14 +21,14 @@ * Implementation of API functions related to IR build * \file api_ir.cc */ -#include -#include +#include +#include #include -#include +#include namespace tvm { -namespace ir { +namespace tir { TVM_REGISTER_GLOBAL("_Var") .set_body_typed([](std::string s, DataType t) { @@ -233,5 +233,5 @@ TVM_REGISTER_GLOBAL("make._OpIfThenElse") return if_then_else(cond, true_value, false_value); }); -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index cf8e2c37d266..5ba99e33b28c 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -21,16 +21,16 @@ * Implementation of API functions related to Higher DSL build. * \file api_lang.cc */ -#include -#include +#include +#include #include #include -#include +#include #include #include #include -#include +#include namespace tvm { @@ -44,9 +44,9 @@ TVM_REGISTER_GLOBAL("_max_value") TVM_REGISTER_GLOBAL("_const") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { - *ret = make_const(args[1], args[0].operator int64_t()); + *ret = tir::make_const(args[1], args[0].operator int64_t()); } else if (args[0].type_code() == kDLFloat) { - *ret = make_const(args[1], args[0].operator double()); + *ret = tir::make_const(args[1], args[0].operator double()); } else { LOG(FATAL) << "only accept int or float"; } @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("_LargeUIntImm") .set_body_typed(LargeUIntImm); TVM_REGISTER_GLOBAL("_str") -.set_body_typed(ir::StringImmNode::make); +.set_body_typed(tir::StringImmNode::make); TVM_REGISTER_GLOBAL("_Array") @@ -200,7 +200,7 @@ TVM_REGISTER_GLOBAL("_MapItems") auto* n = static_cast(ptr); auto rkvs = make_object(); for (const auto& kv : n->data) { - rkvs->data.push_back(ir::StringImmNode::make(kv.first)); + rkvs->data.push_back(tir::StringImmNode::make(kv.first)); rkvs->data.push_back(kv.second); } *ret = Array(rkvs); @@ -216,6 +216,8 @@ TVM_REGISTER_GLOBAL("Range") } }); +namespace tir { + TVM_REGISTER_GLOBAL("_Buffer") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size(), 10); @@ -272,6 +274,7 @@ TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape") TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape") .set_body_method(&BijectiveLayout::BackwardShape); +} // namespace tir namespace top { TVM_REGISTER_GLOBAL("_Tensor") @@ -444,6 +447,6 @@ TVM_REGISTER_GLOBAL("_ScheduleRFactor") } // namespace top TVM_REGISTER_GLOBAL("_CommReducerCombine") -.set_body_method(&ir::CommReducerNode::operator()); +.set_body_method(&tir::CommReducerNode::operator()); } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 2154ec5aa11a..2fca435395cc 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -21,15 +21,16 @@ * Exposure of pass functions. * \file api_pass.cc */ -#include -#include +#include +#include #include -#include -#include +#include +#include +#include #include namespace tvm { -namespace ir { +namespace tir { TVM_REGISTER_GLOBAL("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { @@ -120,7 +121,7 @@ TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc f = args[1]; - ir::PostOrderVisit(args[0], [f](const ObjectRef& n) { + tir::PostOrderVisit(args[0], [f](const ObjectRef& n) { f(n); }); }); @@ -176,5 +177,5 @@ REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(InferFragment) -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 976f5a4263bb..19a8414f327f 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -21,7 +21,7 @@ * Implementation of API functions related to schedule pass. * \file api_schedule.cc */ -#include +#include #include #include #include diff --git a/src/api/api_test.cc b/src/api/api_test.cc index f63adb1f19ae..24934dbe2a0d 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -21,7 +21,7 @@ * Code mainly used for test purposes. * \file api_test.cc */ -#include +#include #include #include #include diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index a36c4093684c..b12e5f51f4fb 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -20,9 +20,9 @@ /*! * \file tvm/arith/analyzer.cc */ -#include +#include #include -#include +#include namespace tvm { namespace arith { @@ -48,7 +48,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr) { void Analyzer::Bind(const Var& var, const Range& range) { CHECK(range.defined()); - if (is_one(range->extent)) { + if (tir::is_one(range->extent)) { this->Bind(var, range->min); } else { this->const_int_bound.Bind(var, range); @@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() { } bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value >= lower_bound; } auto bd = this->const_int_bound(this->rewrite_simplify(expr)); @@ -102,9 +102,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) { } PrimExpr Analyzer::Simplify(const PrimExpr& expr) { - if (is_const(expr)) return expr; + if (tir::is_const(expr)) return expr; auto res = this->rewrite_simplify(expr); - if (is_const(res)) return res; + if (tir::is_const(res)) return res; res = this->canonical_simplify(res); return res; } diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index d6cd47b4ea90..df8f40230e04 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -21,9 +21,9 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ -#include -#include -#include +#include +#include +#include #include #include @@ -34,7 +34,7 @@ namespace tvm { namespace arith { -using namespace ir; +using namespace tir; // a visitor to find the path to the target variable // from a expression. diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index b3cfc725aa41..3580cddf8d2e 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -22,7 +22,7 @@ * \brief Canonical form based simplification. */ #include -#include +#include #include "const_fold.h" #include "pattern_match.h" #include "rewrite_simplify.h" @@ -30,7 +30,7 @@ namespace tvm { namespace arith { -using namespace ir; +using namespace tir; class SumExpr; class SplitExpr; @@ -157,7 +157,7 @@ class SplitExpr : public PrimExpr { inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { if (index.same_as(other->index)) return true; - return ir::Equal(index, other->index); + return tir::Equal(index, other->index); } inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const { diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h index b53543caae61..adb4f3000a29 100644 --- a/src/arith/compute_expr.h +++ b/src/arith/compute_expr.h @@ -24,7 +24,7 @@ #ifndef TVM_ARITH_COMPUTE_EXPR_H_ #define TVM_ARITH_COMPUTE_EXPR_H_ -#include +#include #include #include @@ -57,7 +57,7 @@ inline PrimExpr ComputeReduce( inline bool GetConst(PrimExpr e, int64_t* out) { if (e.dtype().is_vector()) return false; - const int64_t* v = as_const_int(e); + const int64_t* v = tir::as_const_int(e); if (v) { *out = *v; return true; } else { @@ -77,37 +77,37 @@ inline bool GetConstInt(PrimExpr e, int* out) { } template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a + b; } template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a - b; } template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a * b; } template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncdiv(a, b); } template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncmod(a, b); } template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return max(a, b); } template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return min(a, b); } diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 0d8e2abe2102..bae34bdd6b05 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -24,8 +24,8 @@ #ifndef TVM_ARITH_CONST_FOLD_H_ #define TVM_ARITH_CONST_FOLD_H_ -#include -#include +#include +#include #include #include #include "int_operator.h" @@ -76,7 +76,7 @@ inline bool IsIndexType(const DataType& type) { #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::FloatImmNode; \ + using tir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ const FloatImmNode* fa = a.as(); \ @@ -96,7 +96,7 @@ inline bool IsIndexType(const DataType& type) { // specialization of constant folders. template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value + pb->value); @@ -110,7 +110,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); @@ -122,7 +122,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value * pb->value); @@ -148,7 +148,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -177,7 +177,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -187,7 +187,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pa->value == 0) return a; } if (pb) { - if (pb->value == 1) return make_zero(rtype); + if (pb->value == 1) return tir::make_zero(rtype); CHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -195,7 +195,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -222,17 +222,17 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImm(rtype, arith::floormod(pa->value, pb->value)); + return IntImm(rtype, floormod(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; } if (pb) { - if (pb->value == 1) return make_zero(rtype); + if (pb->value == 1) return tir::make_zero(rtype); CHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -240,7 +240,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); @@ -251,7 +251,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); @@ -262,7 +262,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); @@ -271,7 +271,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); @@ -280,7 +280,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); @@ -289,7 +289,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); @@ -298,7 +298,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); @@ -307,7 +307,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); @@ -316,7 +316,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return b; @@ -327,7 +327,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return a; @@ -338,7 +338,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template<> -inline PrimExpr TryConstFold(PrimExpr a) { +inline PrimExpr TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { return IntImm(DataType::UInt(1), !(pa->value)); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 460b9d9a54e6..a75e86a32660 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -21,7 +21,7 @@ * \file tvm/arith/const_int_bound.cc */ #include -#include +#include #include #include "int_operator.h" #include "pattern_match.h" @@ -29,7 +29,7 @@ namespace tvm { namespace arith { -using namespace ir; +using namespace tir; TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); @@ -133,7 +133,7 @@ class ConstIntBoundAnalyzer::Impl : // a linear search over additional info // assume we won't have a lot of conditions for (const BoundInfo& info : additional_info_) { - if (ir::Equal(expr, info.expr)) { + if (tir::Equal(expr, info.expr)) { res = Intersect(res, info.bound); } } diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 81740a0be426..53adf35eb6ee 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -21,15 +21,16 @@ * \file detect_linear_equation.cc * \brief Utility to detect patterns in the expression. */ -#include -#include -#include +#include +#include +#include +#include #include namespace tvm { namespace arith { -using namespace ir; +using namespace tir; // Linear equation, the components can be undefined. struct LinearEqEntry { @@ -211,7 +212,7 @@ bool DetectClipBound( if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift if (p.min_value.defined()) { - p.min_value = ir::MaxNode::make(p.min_value, -ret.base); + p.min_value = tir::MaxNode::make(p.min_value, -ret.base); } else { p.min_value = -ret.base; } @@ -220,7 +221,7 @@ bool DetectClipBound( if (is_const_int(ret.coeff, -1)) { // -var + shift >=0 -> var <= shift if (p.max_value.defined()) { - p.max_value = ir::MinNode::make(p.max_value, ret.base); + p.max_value = tir::MinNode::make(p.max_value, ret.base); } else { p.max_value = ret.base; } @@ -244,7 +245,7 @@ void SplitCommExpr(const PrimExpr& e, std::vector* ret) { // e must be connected by and. Array DetectClipBound(const PrimExpr& e, const Array& vars) { std::vector splits; - SplitCommExpr(e, &splits); + SplitCommExpr(e, &splits); std::unordered_map rmap; for (Var v : vars) { rmap[v.get()] = IntervalEntry(); diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 7db03c22b748..71a92d828232 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -21,9 +21,9 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ -#include -#include -#include +#include +#include +#include #include #include @@ -33,7 +33,7 @@ namespace tvm { namespace arith { -using namespace ir; +using namespace tir; // Find Read region of the tensor in the stmt. class FuncTouchedDomain final : public StmtExprVisitor { diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index 6e7f3959c5c9..3be34b638777 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -47,7 +47,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { @@ -57,7 +57,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { @@ -67,7 +67,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { @@ -84,7 +84,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 2d56596860c7..27cdffee02b1 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -22,8 +22,8 @@ * \brief The integer set functions */ #include -#include -#include +#include +#include #include #include @@ -35,6 +35,11 @@ namespace tvm { namespace arith { +using tir::make_const; +using tir::make_zero; +using tir::is_zero; +using tir::is_one; + PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); @@ -79,7 +84,7 @@ struct is_logical_op { #define TVM_DECLARE_LOGICAL_OP(OP) \ template<> \ - struct is_logical_op { \ + struct is_logical_op { \ static const bool value = true; \ }; @@ -118,7 +123,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyer, +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -136,7 +141,7 @@ inline IntervalSet Combine(Analyzer* analyer, } template<> -inline IntervalSet Combine(Analyzer* analyer, +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -155,7 +160,7 @@ inline IntervalSet Combine(Analyzer* analyer, template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -178,7 +183,7 @@ inline IntervalSet Combine(Analyzer* analyzer, PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::SelectNode; + using tir::SelectNode; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value; @@ -190,7 +195,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -213,7 +218,7 @@ inline IntervalSet Combine(Analyzer* analyzer, PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::SelectNode; + using tir::SelectNode; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value; @@ -225,7 +230,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -256,7 +261,7 @@ inline IntervalSet Combine(Analyzer* analyzer, template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -279,7 +284,7 @@ inline IntervalSet Combine(Analyzer* analyzer, PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::SelectNode; + using tir::SelectNode; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value); @@ -291,7 +296,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -317,7 +322,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analzyer, +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -330,7 +335,7 @@ inline IntervalSet Combine(Analyzer* analzyer, } template<> -inline IntervalSet Combine(Analyzer* analzyer, +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -351,7 +356,7 @@ IntervalSet ToIntervalSet(IntSet set) { return IntervalSet::Everything(); } -using namespace ir; +using namespace tir; // Simplified version of int set evaluator that operates on IntervalSet // We might use better set analysis in the future to replace the intervalset. @@ -603,17 +608,17 @@ bool IntSet::is_single_point() const { bool IntSet::can_prove_positive() const { const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_positive_const(ir::Simplify(s_int->min_value))); + return (s_int && is_positive_const(tir::Simplify(s_int->min_value))); } bool IntSet::can_prove_negative() const { const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_negative_const(ir::Simplify(s_int->max_value))); + return (s_int && is_negative_const(tir::Simplify(s_int->max_value))); } bool IntSet::can_prove_non_positive() const { if (const auto* s_int = (*this).as()) { - auto max = ir::Simplify(s_int->max_value); + auto max = tir::Simplify(s_int->max_value); return is_zero(max) || is_negative_const(max); } return false; @@ -621,7 +626,7 @@ bool IntSet::can_prove_non_positive() const { bool IntSet::can_prove_non_negative() const { if (const IntervalSetNode* s_int = (*this).as()) { - auto min = ir::Simplify(s_int->min_value); + auto min = tir::Simplify(s_int->min_value); return is_zero(min) || is_positive_const(min); } return false; @@ -665,7 +670,7 @@ IntSet IntSet::interval(PrimExpr min, PrimExpr max) { // Range related code inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) { - return is_zero(ir::Simplify(lhs - rhs)); + return is_zero(tir::Simplify(lhs - rhs)); } IntSet IntSet::range(Range r) { @@ -692,8 +697,8 @@ IntSet Union(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Union(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(ir::Simplify(x->min_value), - ir::Simplify(x->max_value)); + return IntervalSet(tir::Simplify(x->min_value), + tir::Simplify(x->max_value)); } IntSet Intersect(const Array& sets) { @@ -704,8 +709,8 @@ IntSet Intersect(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Intersect(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(ir::Simplify(x->min_value), - ir::Simplify(x->max_value)); + return IntervalSet(tir::Simplify(x->min_value), + tir::Simplify(x->max_value)); } Map ConvertDomMap(const Map& dom_map) { diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index e931a7c04284..51b500adb412 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -25,7 +25,7 @@ #define TVM_ARITH_INTERVAL_SET_H_ #include -#include +#include #include #include "const_fold.h" diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 7c1c30a51761..32c732c21740 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -20,14 +20,14 @@ /*! * \file tvm/arith/ir_mutator_with_analyzer.cc */ -#include -#include +#include +#include #include "ir_mutator_with_analyzer.h" namespace tvm { namespace arith { -using namespace ir; +using namespace tir; Stmt IRMutatorWithAnalyzer:: VisitStmt_(const ForNode* op) { @@ -39,7 +39,7 @@ VisitStmt_(const ForNode* op) { Stmt IRMutatorWithAnalyzer:: VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!ir::HasSideEffect(value)) { + if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here @@ -128,7 +128,7 @@ VisitStmt_(const AssertStmtNode* op) { PrimExpr IRMutatorWithAnalyzer:: VisitExpr_(const CallNode* op) { // add condition context to if_then_else - if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { + if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr true_value, false_value; { @@ -162,7 +162,7 @@ VisitExpr_(const CallNode* op) { PrimExpr IRMutatorWithAnalyzer:: VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!ir::HasSideEffect(value)) { + if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 10dc427116c6..394e5db9c93e 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -24,7 +24,7 @@ #ifndef TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ -#include +#include #include #include @@ -40,7 +40,7 @@ namespace arith { * * \sa src/arithmetic/ir_mutator_with_analyzer.cc */ -class IRMutatorWithAnalyzer : public ir::StmtExprMutator { +class IRMutatorWithAnalyzer : public tir::StmtExprMutator { public: explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} @@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator { using StmtExprMutator::VisitExpr_; // override functions that need to populate the context information. - Stmt VisitStmt_(const ir::ForNode* op) override; - Stmt VisitStmt_(const ir::LetStmtNode* op) override; - Stmt VisitStmt_(const ir::IfThenElseNode* op) override; - Stmt VisitStmt_(const ir::AttrStmtNode* op) override; - Stmt VisitStmt_(const ir::AssertStmtNode* op) override; - PrimExpr VisitExpr_(const ir::LetNode* op) override; - PrimExpr VisitExpr_(const ir::SelectNode* op) override; - PrimExpr VisitExpr_(const ir::CallNode* op) override; - PrimExpr VisitExpr_(const ir::ReduceNode* op) override; + Stmt VisitStmt_(const tir::ForNode* op) override; + Stmt VisitStmt_(const tir::LetStmtNode* op) override; + Stmt VisitStmt_(const tir::IfThenElseNode* op) override; + Stmt VisitStmt_(const tir::AttrStmtNode* op) override; + Stmt VisitStmt_(const tir::AssertStmtNode* op) override; + PrimExpr VisitExpr_(const tir::LetNode* op) override; + PrimExpr VisitExpr_(const tir::SelectNode* op) override; + PrimExpr VisitExpr_(const tir::CallNode* op) override; + PrimExpr VisitExpr_(const tir::ReduceNode* op) override; protected: /*! \brief internal analyzer field. */ diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 8e5c0b19bd2f..b2dbe9d10c08 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -26,11 +26,11 @@ #define TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_ #include -#include -#include +#include +#include namespace tvm { -namespace ir { +namespace tir { class IRVisitorWithAnalyzer final : public StmtExprVisitor { public: @@ -71,6 +71,6 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor { arith::Analyzer analyzer_; }; -} // namespace ir +} // namespace tir } // namespace tvm #endif // TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_ diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index c79e94f967eb..8b5309272efe 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -22,8 +22,8 @@ * \brief Modular set analysis */ #include -#include -#include +#include +#include #include #include #include @@ -32,7 +32,7 @@ namespace tvm { namespace arith { -using namespace ir; +using namespace tir; TVM_REGISTER_NODE_TYPE(ModularSetNode); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index f6e6508366db..8a2df5043ffb 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -44,7 +44,7 @@ * return (max(x, y) + z).Eval(); * } * - * tvm::Var tx, ty; + * tvm::tir::Var tx, ty; * arith::PVar c; * arith::PVar v; * // We can match integer and Var, both of which are @@ -65,7 +65,7 @@ #ifndef TVM_ARITH_PATTERN_MATCH_H_ #define TVM_ARITH_PATTERN_MATCH_H_ -#include +#include #include #include "const_fold.h" @@ -135,7 +135,7 @@ class PEqualChecker { public: bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { if (lhs.same_as(rhs)) return true; - return ir::Equal(lhs, rhs); + return tir::Equal(lhs, rhs); } }; @@ -283,7 +283,7 @@ class PConstWithTypeLike : void InitMatch_() const {} bool Match_(const ObjectRef& node) const { - if (const ir::IntImmNode* ptr = node.as()) { + if (const tir::IntImmNode* ptr = node.as()) { return ptr->value == value_; } else { return false; @@ -291,7 +291,7 @@ class PConstWithTypeLike : } PrimExpr Eval() const { - return make_const(ref_.Eval().dtype(), value_); + return tir::make_const(ref_.Eval().dtype(), value_); } private: @@ -325,30 +325,30 @@ class PConstWithTypeLike : // raise ambiguity error for operator overload of / and % -TVM_PATTERN_BINARY_OP_EX(operator/, ir::DivNode, DivAmbiguityError(a)); -TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator%, tir::ModNode, DivAmbiguityError(a)); // arithmetic expressions -TVM_PATTERN_BINARY_OP(operator+, ir::AddNode); -TVM_PATTERN_BINARY_OP(operator-, ir::SubNode); -TVM_PATTERN_BINARY_OP(operator*, ir::MulNode); -TVM_PATTERN_BINARY_OP(min, ir::MinNode); -TVM_PATTERN_BINARY_OP(max, ir::MaxNode); -TVM_PATTERN_BINARY_OP(div, ir::DivNode); -TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode); -TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode); -TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode); -TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode); +TVM_PATTERN_BINARY_OP(operator+, tir::AddNode); +TVM_PATTERN_BINARY_OP(operator-, tir::SubNode); +TVM_PATTERN_BINARY_OP(operator*, tir::MulNode); +TVM_PATTERN_BINARY_OP(min, tir::MinNode); +TVM_PATTERN_BINARY_OP(max, tir::MaxNode); +TVM_PATTERN_BINARY_OP(div, tir::DivNode); +TVM_PATTERN_BINARY_OP(truncdiv, tir::DivNode); +TVM_PATTERN_BINARY_OP(truncmod, tir::ModNode); +TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDivNode); +TVM_PATTERN_BINARY_OP(floormod, tir::FloorModNode); // logical expressions -TVM_PATTERN_BINARY_OP(operator>, ir::GTNode); -TVM_PATTERN_BINARY_OP(operator>=, ir::GENode); -TVM_PATTERN_BINARY_OP(operator<, ir::LTNode); -TVM_PATTERN_BINARY_OP(operator<=, ir::LENode); -TVM_PATTERN_BINARY_OP(operator==, ir::EQNode); -TVM_PATTERN_BINARY_OP(operator!=, ir::NENode); -TVM_PATTERN_BINARY_OP(operator&&, ir::AndNode); -TVM_PATTERN_BINARY_OP(operator||, ir::OrNode); +TVM_PATTERN_BINARY_OP(operator>, tir::GTNode); +TVM_PATTERN_BINARY_OP(operator>=, tir::GENode); +TVM_PATTERN_BINARY_OP(operator<, tir::LTNode); +TVM_PATTERN_BINARY_OP(operator<=, tir::LENode); +TVM_PATTERN_BINARY_OP(operator==, tir::EQNode); +TVM_PATTERN_BINARY_OP(operator!=, tir::NENode); +TVM_PATTERN_BINARY_OP(operator&&, tir::AndNode); +TVM_PATTERN_BINARY_OP(operator||, tir::OrNode); /*! * \brief Pattern not expression. @@ -365,7 +365,7 @@ class PNotExpr : public Pattern > { } bool Match_(const ObjectRef& node) const { - if (const ir::NotNode* ptr = node.as()) { + if (const tir::NotNode* ptr = node.as()) { if (!value_.Match_(ptr->a)) return false; return true; } else { @@ -374,7 +374,7 @@ class PNotExpr : public Pattern > { } PrimExpr Eval() const { - return ir::NotNode::make(value_.Eval()); + return tir::NotNode::make(value_.Eval()); } private: @@ -411,7 +411,7 @@ class PSelectExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::SelectNode* ptr = node.as()) { + if (const tir::SelectNode* ptr = node.as()) { if (!condition_.Match_(ptr->condition)) return false; if (!true_value_.Match_(ptr->true_value)) return false; if (!false_value_.Match_(ptr->false_value)) return false; @@ -422,7 +422,7 @@ class PSelectExpr : } PrimExpr Eval() const { - return ir::SelectNode::make( + return tir::SelectNode::make( condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } @@ -473,7 +473,7 @@ class PCastExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::CastNode* ptr = node.as()) { + if (const tir::CastNode* ptr = node.as()) { if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; return true; @@ -483,7 +483,7 @@ class PCastExpr : } PrimExpr Eval() const { - return ir::CastNode::make(dtype_.Eval(), value_.Eval()); + return tir::CastNode::make(dtype_.Eval(), value_.Eval()); } private: @@ -531,7 +531,7 @@ class PRampExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::RampNode* ptr = node.as()) { + if (const tir::RampNode* ptr = node.as()) { if (!base_.Match_(ptr->base)) return false; if (!stride_.Match_(ptr->stride)) return false; if (!lanes_.Match_(ptr->lanes)) return false; @@ -542,7 +542,7 @@ class PRampExpr : } PrimExpr Eval() const { - return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); + return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: @@ -593,7 +593,7 @@ class PBroadcastExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::BroadcastNode* ptr = node.as()) { + if (const tir::BroadcastNode* ptr = node.as()) { if (!value_.Match_(ptr->value)) return false; if (!lanes_.Match_(ptr->lanes)) return false; return true; @@ -603,7 +603,7 @@ class PBroadcastExpr : } PrimExpr Eval() const { - return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); + return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); } private: @@ -662,10 +662,10 @@ struct PCallExprInitMatchFunctor { }; struct PCallExprMatchFunctor { - const ir::CallNode* call_; + const tir::CallNode* call_; bool matched_{true}; - explicit PCallExprMatchFunctor(const ir::CallNode* call) + explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {} template @@ -705,7 +705,7 @@ class PCallExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::CallNode* ptr = node.as()) { + if (const tir::CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; if (ptr->name != Op::kName) return false; detail::PCallExprMatchFunctor fmatch(ptr); @@ -730,8 +730,8 @@ class PCallExpr : #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ struct OpName { \ static PrimExpr Eval(Array args) { \ - return ir::CallNode::make(args[0].dtype(), kName, args, \ - ir::CallNode::PureIntrinsic); \ + return tir::CallNode::make(args[0].dtype(), kName, args, \ + tir::CallNode::PureIntrinsic); \ } \ static constexpr const char* kName = IntrinStr; \ }; \ @@ -751,8 +751,8 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ struct OpName { \ static PrimExpr Eval(Array args) { \ - return ir::CallNode::make(args[0].dtype(), kName, args, \ - ir::CallNode::PureIntrinsic); \ + return tir::CallNode::make(args[0].dtype(), kName, args, \ + tir::CallNode::PureIntrinsic); \ } \ static constexpr const char* kName = IntrinStr; \ }; \ @@ -767,9 +767,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return ir::CallNode::make( + return tir::CallNode::make( args[1].dtype(), kName, args, - ir::CallNode::PureIntrinsic); + tir::CallNode::PureIntrinsic); } static constexpr const char* kName = "tvm_if_then_else"; }; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 5f486cf88deb..39b87ef1b056 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -23,7 +23,7 @@ */ // Acknowledgement: Most rewrite-rules are from Halide. #include -#include +#include #include #include "const_fold.h" #include "pattern_match.h" @@ -32,7 +32,7 @@ namespace tvm { namespace arith { -using namespace ir; +using namespace tir; // macro for doing simple rewrite #define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ @@ -1747,7 +1747,7 @@ VisitExpr_(const CastNode* op) { PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!ir::HasSideEffect(value)) { + if (!tir::HasSideEffect(value)) { // it is fine to discard the let binding // because the value will always be inlined in the simplifier. analyzer_->Bind(op->var, value); diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 3255376ebd2c..8798df92777d 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -25,7 +25,7 @@ #define TVM_ARITH_REWRITE_SIMPLIFY_H_ #include -#include +#include #include #include #include "const_fold.h" @@ -35,7 +35,7 @@ namespace tvm { namespace arith { -using namespace ir; +using namespace tir; /*! * \brief Rewrite-based simplifier. diff --git a/src/arith/stmt_simplify.cc b/src/arith/stmt_simplify.cc index 0b7a4b7b416d..c0bc0c4787f1 100644 --- a/src/arith/stmt_simplify.cc +++ b/src/arith/stmt_simplify.cc @@ -21,17 +21,17 @@ * \file stmt_simplify.cc * \brief Statement simplifier based on analyzer */ -#include -#include +#include +#include #include -#include +#include #include #include "ir_mutator_with_analyzer.h" namespace tvm { namespace arith { -using namespace ir; +using namespace tir; class StmtSimplifier : public IRMutatorWithAnalyzer { public: @@ -59,7 +59,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { Stmt VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!ir::HasSideEffect(value)) { + if (!tir::HasSideEffect(value)) { // it is fine to discard the let binding // because the call to simplify will always inline the var. analyzer_->Bind(op->var, value); @@ -93,7 +93,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith -namespace ir { +namespace tir { Stmt CanonicalSimplify(Stmt stmt, Map vrange) { arith::Analyzer analyzer; @@ -123,5 +123,5 @@ PrimExpr Simplify(PrimExpr expr, Map vrange) { Stmt Simplify(Stmt stmt, Map vrange) { return CanonicalSimplify(std::move(stmt), vrange); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index a83d248dc0df..da044babdd43 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -60,7 +60,7 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - Var var = op->node.as()->var; + Var var = op->node.as()->var; const auto *extent = op->value.as(); CHECK(extent); diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index b2ea80f0c29f..5391bddfa2f6 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -26,14 +26,15 @@ #ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_ #define TVM_AUTOTVM_FEATURE_VISITOR_H_ -#include -#include +#include +#include +#include #include namespace tvm { namespace autotvm { -using namespace tvm::ir; +using namespace tvm::tir; /*! * \brief Type of for loop, used as one-hot encoding in features @@ -69,7 +70,7 @@ class FeatureVisitor : public StmtExprVisitor { * \param ann_type The type for the for loop * \return skip Whether skip this node */ - virtual bool EnterItervar_(tvm::Var var, int64_t length, AnnotationType ann_type) = 0; + virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0; /*! \brief Exit a for loop subtree */ virtual void ExitItervar_() = 0; /*! @@ -77,7 +78,7 @@ class FeatureVisitor : public StmtExprVisitor { * \param buffer_var The buffer to access. * \param index Index expression */ - virtual void EnterMem_(tvm::Var buffer_var, tvm::PrimExpr index) = 0; + virtual void EnterMem_(tir::Var buffer_var, tvm::PrimExpr index) = 0; /*! \brief Exit a memory access node */ virtual void ExitMem_() = 0; }; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 360d761ba233..23fbc54d843e 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -25,8 +25,8 @@ #ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ -#include -#include +#include +#include #include #include diff --git a/src/codegen/build_common.h b/src/codegen/build_common.h index 47f70d9833b8..8e55bc1d1b27 100644 --- a/src/codegen/build_common.h +++ b/src/codegen/build_common.h @@ -26,7 +26,8 @@ #include #include -#include +#include +#include #include #include #include "../runtime/meta_data.h" diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index cfb75c4e68ef..e5f10c0f3e77 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -37,6 +37,7 @@ namespace tvm { using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; +using tir::LoweredFunc; TVM_REGISTER_NODE_TYPE(GenericFuncNode); @@ -58,39 +59,39 @@ Target DefaultTargetHost(Target target) { } } -Buffer BufferWithOffsetAlignment(Array shape, - DataType dtype, - std::string name, - int data_alignment, - int offset_factor, - bool compact) { - auto data = Var(name, DataType::Handle()); +tir::Buffer BufferWithOffsetAlignment(Array shape, + DataType dtype, + std::string name, + int data_alignment, + int offset_factor, + bool compact) { + auto data = tir::Var(name, DataType::Handle()); bool has_any = false; if (!compact) { for (const auto& it : shape) { - if (it.as()) { + if (it.as()) { has_any = true; break; } } } - BufferType buffer_type = has_any ? kAutoBroadcast : kDefault; + tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; PrimExpr elem_offset; if (offset_factor != 0) { - elem_offset = Var(name + "_elem_offset", shape[0].dtype()); + elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype()); } else { elem_offset = PrimExpr(); } - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", + return tir::BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, offset_factor, buffer_type); } void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list, const BuildConfig& config) { *out_binds = binds; @@ -117,50 +118,50 @@ void GetBinds(const Array& args, * \param config The build configuration. * \return The built Stmt. */ -Stmt BuildStmt(top::Schedule sch, - const Array& args, - const std::unordered_map& binds, - bool loop_partition, - Array *out_arg_list, - const BuildConfig& config) { +tir::Stmt BuildStmt(top::Schedule sch, + const Array& args, + const std::unordered_map& binds, + bool loop_partition, + Array *out_arg_list, + const BuildConfig& config) { sch = sch.normalize(); // Phase 0 auto bounds = top::InferBound(sch); auto stmt = top::ScheduleOps(sch, bounds, false); - stmt = ir::InjectPrefetch(stmt); + stmt = tir::InjectPrefetch(stmt); - bool compact = ir::VerifyCompactBuffer(stmt); - Map out_binds; + bool compact = tir::VerifyCompactBuffer(stmt); + Map out_binds; GetBinds(args, compact, binds, &out_binds, out_arg_list, config); // Phase 1 - stmt = ir::StorageFlatten(stmt, out_binds, 64, + stmt = tir::StorageFlatten(stmt, out_binds, 64, config->instrument_bound_checkers); - stmt = ir::CanonicalSimplify(stmt); + stmt = tir::CanonicalSimplify(stmt); if (loop_partition) { - stmt = ir::LoopPartition(stmt, config->partition_const_loop); + stmt = tir::LoopPartition(stmt, config->partition_const_loop); } if (config->disable_vectorize) { - stmt = ir::SkipVectorize(stmt); + stmt = tir::SkipVectorize(stmt); } else { - stmt = ir::VectorizeLoop(stmt); + stmt = tir::VectorizeLoop(stmt); } - stmt = ir::InjectVirtualThread(stmt); - stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); - stmt = ir::StorageRewrite(stmt); - stmt = ir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth, + stmt = tir::InjectVirtualThread(stmt); + stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); + stmt = tir::StorageRewrite(stmt); + stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth, config->auto_unroll_max_extent, config->unroll_explicit); // Phase 2 - stmt = ir::Simplify(stmt); - stmt = ir::RemoveNoOp(stmt); + stmt = tir::Simplify(stmt); + stmt = tir::RemoveNoOp(stmt); if (!(config->disable_select_rewriting)) - stmt = ir::RewriteUnsafeSelect(stmt); + stmt = tir::RewriteUnsafeSelect(stmt); if (config->instrument_bound_checkers) - stmt = ir::InstrumentBoundCheckers(stmt); + stmt = tir::InstrumentBoundCheckers(stmt); return stmt; } @@ -168,11 +169,11 @@ Stmt BuildStmt(top::Schedule sch, Array lower(top::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, + const std::unordered_map& binds, const BuildConfig& config) { Array out_arg_list; auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); - return Array({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); + return Array({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); } Array > split_dev_host_funcs(const Array& funcs, @@ -190,27 +191,27 @@ Array > split_dev_host_funcs(const Array& funcs, Array fdevice; for (const auto& x : funcs) { - CHECK(ir::VerifyMemory(x, target->device_type)) + CHECK(tir::VerifyMemory(x, target->device_type)) << "Direct host side access to device memory is detected in " << x->func_name() << ". Did you forget to bind?"; - if (x->func_type == kMixedFunc) { + if (x->func_type == tir::kMixedFunc) { auto func = x; if (config->detect_global_barrier) { - func = ir::ThreadSync(func, "global"); + func = tir::ThreadSync(func, "global"); } - func = ir::ThreadSync(func, "shared"); - func = ir::ThreadSync(func, "warp"); - func = ir::LowerThreadAllreduce(func, target->thread_warp_size); - auto fsplits = ir::SplitHostDevice(func); + func = tir::ThreadSync(func, "shared"); + func = tir::ThreadSync(func, "warp"); + func = tir::LowerThreadAllreduce(func, target->thread_warp_size); + auto fsplits = tir::SplitHostDevice(func); fhost.push_back(fsplits[0]); for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) { fdevice.push_back(*f); } - } else if (x->func_type == kHostFunc) { + } else if (x->func_type == tir::kHostFunc) { fhost.push_back(x); - } else if (x->func_type == kDeviceFunc) { + } else if (x->func_type == tir::kDeviceFunc) { fdevice.push_back(x); } else { LOG(FATAL) << "unknown function type " << x->func_type; @@ -220,7 +221,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fdevice.size(); i++) { auto warp_size = target->thread_warp_size; auto func = fdevice[i]; - func = ir::LowerWarpMemory(fdevice[i], warp_size); + func = tir::LowerWarpMemory(fdevice[i], warp_size); fdevice.Set(i, func); } @@ -234,7 +235,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fdevice.size(); ++i) { auto func = fdevice[i]; - func = ir::LowerIntrin(func, target->target_name); + func = tir::LowerIntrin(func, target->target_name); fdevice.Set(i, func); } @@ -247,17 +248,17 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; - func = ir::BindDeviceType(func, target->device_type); - func = ir::LowerDeviceStorageAccessInfo(func); - func = ir::LowerTVMBuiltin(func); + func = tir::BindDeviceType(func, target->device_type); + func = tir::LowerDeviceStorageAccessInfo(func); + func = tir::LowerTVMBuiltin(func); fhost.Set(i, func); } for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; - func = ir::LowerIntrin(func, target_host->target_name); - func = ir::LowerDeviceStorageAccessInfo(func); - func = ir::CombineContextCall(func); + func = tir::LowerIntrin(func, target_host->target_name); + func = tir::LowerDeviceStorageAccessInfo(func); + func = tir::CombineContextCall(func); fhost.Set(i, func); } return {fhost, fdevice}; @@ -580,7 +581,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") std::vector tags_vector; for (auto& tag : tags) { - tags_vector.push_back(tag.as()->value); + tags_vector.push_back(tag.as()->value); } generic_func diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index a038d4c56bb6..c14ea0ca3557 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,7 +22,7 @@ * \brief Common utilities to generated C style code. */ #include -#include +#include #include #include #include @@ -37,17 +37,17 @@ namespace tvm { namespace codegen { -runtime::Module Build(const Array& funcs, +runtime::Module Build(const Array& funcs, const std::string& target) { std::string mode = target; size_t pos = mode.find(' '); if (pos != std::string::npos) { mode = mode.substr(0, pos); } - Array transformed_funcs; + Array transformed_funcs; if (BuildConfig::Current()->disable_assert) { for (const auto& x : funcs) { - auto func = ir::SkipAssert(x); + auto func = tir::SkipAssert(x); transformed_funcs.push_back(func); } } diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 3d41d0828ace..4530a00208ab 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -23,13 +23,13 @@ #include #include #include "codegen_c.h" -#include "../pass/ir_util.h" #include "../arith/compute_expr.h" +#include "../tir/pass/ir_util.h" namespace tvm { namespace codegen { -using namespace ir; +using namespace tir; void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; @@ -809,18 +809,18 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { } void CodeGenC::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == ir::attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { BindThreadIndex(iv); } } - } else if (op->attr_key == ir::attr::storage_scope) { + } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); alloc_storage_scope_[v] = op->value.as()->value; - } else if (op->attr_key == ir::attr::volatile_scope) { + } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); volatile_buf_.insert(v); diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 7e5dd4269c94..04d08d85105e 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -24,10 +24,11 @@ #ifndef TVM_CODEGEN_CODEGEN_C_H_ #define TVM_CODEGEN_CODEGEN_C_H_ -#include -#include +#include +#include +#include #include -#include +#include #include #include #include @@ -37,7 +38,7 @@ namespace tvm { namespace codegen { -using namespace ir; +using namespace tir; /*! * \brief A base class to generate C code. * diff --git a/src/codegen/codegen_c_host.h b/src/codegen/codegen_c_host.h index 94544a8c7367..f069038e09eb 100644 --- a/src/codegen/codegen_c_host.h +++ b/src/codegen/codegen_c_host.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_CODEGEN_C_HOST_H_ #include -#include +#include #include #include "codegen_c.h" diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index b6ba17fc381a..5a3130f31995 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -22,6 +22,7 @@ */ #include + #include #include #include @@ -93,9 +94,9 @@ std::string CodeGenCUDA::Finish() { return CodeGenC::Finish(); } -void CodeGenCUDA::VisitStmt_(const ir::ForNode* op) { +void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { CHECK(is_const_int(op->min, 0)); - if (op->for_type == ir::ForType::Unrolled) { + if (op->for_type == tir::ForType::Unrolled) { PrintIndent(); stream << "#pragma unroll\n"; } diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index b7b7f849974e..b0bb19412a5b 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_CODEGEN_CUDA_H_ #include -#include +#include #include #include #include "codegen_c.h" @@ -43,7 +43,7 @@ class CodeGenCUDA final : public CodeGenC { return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior - void VisitStmt_(const ir::ForNode* op) final; + void VisitStmt_(const tir::ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp( diff --git a/src/codegen/codegen_source_base.cc b/src/codegen/codegen_source_base.cc index aa3b6ef68fd5..0859428aa58b 100644 --- a/src/codegen/codegen_source_base.cc +++ b/src/codegen/codegen_source_base.cc @@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { return e.vid; } -std::string CodeGenSourceBase::AllocVarID(const VarNode* v) { +std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { CHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; @@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const VarNode* v) { return vid; } -std::string CodeGenSourceBase::GetVarID(const VarNode* v) const { +std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const { auto it = var_idmap_.find(v); CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; diff --git a/src/codegen/codegen_source_base.h b/src/codegen/codegen_source_base.h index b39ee46b0a17..24584f2c0844 100644 --- a/src/codegen/codegen_source_base.h +++ b/src/codegen/codegen_source_base.h @@ -24,7 +24,8 @@ #ifndef TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ #define TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ -#include +#include +#include #include #include #include @@ -66,13 +67,13 @@ class CodeGenSourceBase { * \param v The variable. * \return the variable name. */ - std::string AllocVarID(const VarNode* v); + std::string AllocVarID(const tir::VarNode* v); /*! * \brief Get a variable name. * \param v The variable. * \return the variable name. */ - std::string GetVarID(const VarNode* v) const; + std::string GetVarID(const tir::VarNode* v) const; /*! * \brief Get the SSA ID corresponds to src * If necessary, generate new assignment @@ -110,7 +111,7 @@ class CodeGenSourceBase { /*! \brief the stream to be printed */ std::ostringstream stream; /*! \brief name of each variable */ - std::unordered_map var_idmap_; + std::unordered_map var_idmap_; private: /*! \brief assignment map of ssa */ diff --git a/src/codegen/codegen_vhls.h b/src/codegen/codegen_vhls.h index 06510890980c..c08ceb665402 100644 --- a/src/codegen/codegen_vhls.h +++ b/src/codegen/codegen_vhls.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_CODEGEN_VHLS_H_ #include -#include +#include #include #include "codegen_c.h" diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc index 699abd8db622..7e9ac71bb753 100644 --- a/src/codegen/intrin_rule.cc +++ b/src/codegen/intrin_rule.cc @@ -21,7 +21,7 @@ * \file intrin_rule_default.cc * \brief Default intrinsic rules. */ -#include +#include #include "intrin_rule.h" namespace tvm { diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h index b6332f1bbff3..babe42e7a64b 100644 --- a/src/codegen/intrin_rule.h +++ b/src/codegen/intrin_rule.h @@ -24,15 +24,15 @@ #ifndef TVM_CODEGEN_INTRIN_RULE_H_ #define TVM_CODEGEN_INTRIN_RULE_H_ -#include -#include +#include +#include #include #include namespace tvm { namespace codegen { namespace intrin { -using namespace ir; +using namespace tir; // Add float suffix to the intrinsics struct FloatSuffix { diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index fb7abc394bb8..69354ef4429d 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -29,7 +29,6 @@ #include "codegen_llvm.h" #include "../build_common.h" #include "../codegen_source_base.h" -#include "../../pass/ir_util.h" #include "../../runtime/rocm/rocm_module.h" namespace tvm { diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 44862cf7a97c..73d849a7b3d1 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -58,7 +58,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { } PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { - using namespace ir; + using namespace tir; const PrimExpr& e = call->args[2]; ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop; ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu; @@ -71,7 +71,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); + return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -96,7 +96,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = ir::CallNode::make( + PrimExpr vcnt8 = tir::CallNode::make( uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit @@ -104,7 +104,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = ir::CallNode::make( + PrimExpr vcnt16 = tir::CallNode::make( uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; @@ -115,7 +115,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = ir::CallNode::make( + PrimExpr vcnt32 = tir::CallNode::make( uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; @@ -126,7 +126,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return ir::CallNode::make( + return tir::CallNode::make( call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); } diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc index 6c0b6846bedf..88ca6b6da499 100644 --- a/src/codegen/llvm/codegen_cpu.cc +++ b/src/codegen/llvm/codegen_cpu.cc @@ -23,11 +23,10 @@ #ifdef TVM_LLVM_VERSION #include -#include +#include #include #include #include "codegen_cpu.h" -#include "../../pass/ir_util.h" namespace tvm { namespace codegen { @@ -423,7 +422,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs. // This is easier than set the alias scope manually. using llvm::BasicBlock; - Array vargs = ir::UndefinedVars(op->body, {}); + Array vargs = tir::UndefinedVars(op->body, {}); std::vector arg_values; std::vector arg_types; for (Var v : vargs) { @@ -513,7 +512,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { llvm::Function::PrivateLinkage, "__tvm_parallel_lambda", module_.get()); // allocate and setup the closure, call the closure. - Array vfields = ir::UndefinedVars(body, {}); + Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; llvm::Value* cdata = PackClosureData(vfields, &nbytes); BasicBlock* par_launch_end = CheckCallSuccess( @@ -582,7 +581,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod } // allocate and setup the closure, call the closure. uint64_t nbytes; - Array vfields = ir::UndefinedVars(body, {}); + Array vfields = tir::UndefinedVars(body, {}); llvm::Value* cdata = PackClosureData(vfields, &nbytes); BasicBlock* init_end = CheckCallSuccess( builder_->CreateCall( @@ -692,7 +691,7 @@ CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall( RuntimeTVMFuncCall(), {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); - DataType r_api_type = ir::APIType(r_type); + DataType r_api_type = tir::APIType(r_type); *rvalue = builder_->CreateAlignedLoad( builder_->CreatePointerCast(ret_value, LLVMType(r_api_type)->getPointerTo()), @@ -870,9 +869,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { } void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == ir::attr::coproc_uop_scope) { + if (op->attr_key == tir::attr::coproc_uop_scope) { this->CreateStaticInit(op->value.as()->value, op->body); - } else if (op->attr_key == ir::attr::compute_scope) { + } else if (op->attr_key == tir::attr::compute_scope) { this->CreateComputeScope(op); } else if (attr::IsPragmaKey(op->attr_key)) { if (op->attr_key == "pragma_parallel_stride_pattern") { @@ -892,7 +891,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { builder_->CreateCall( RuntimeTVMParallelBarrier(), {MakeValue(parallel_env_.task_id), parallel_env_.penv}); - } else if (op->attr_key == ir::attr::pragma_import_llvm) { + } else if (op->attr_key == tir::attr::pragma_import_llvm) { const StringImmNode* value = op->value.as(); CHECK(value != nullptr); this->HandleImport(value->value); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index d214add06ac1..434b6501d37a 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -30,9 +30,6 @@ #include "codegen_llvm.h" #include "codegen_cpu.h" #include "../build_common.h" -#include "../../pass/ir_util.h" -#include "../../arith/compute_expr.h" - namespace tvm { namespace codegen { @@ -1179,17 +1176,17 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); } } - } else if (op->attr_key == ir::attr::storage_scope) { + } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); alloc_storage_info_[v].scope = runtime::StorageScope::make(op->value.as()->value); - } else if (op->attr_key == ir::attr::storage_alignment) { + } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); CHECK(v); alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); - } else if (op->attr_key == ir::attr::volatile_scope) { + } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); volatile_buf_.insert(v); diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 34e1fb8509a1..6875a1d9cd16 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -26,8 +26,10 @@ #ifdef TVM_LLVM_VERSION #include -#include -#include +#include +#include +#include +#include #include #include #include @@ -37,11 +39,13 @@ #include #include "llvm_common.h" #include "../../runtime/thread_storage_scope.h" +#include "../../arith/compute_expr.h" +#include "../../tir/pass/ir_util.h" namespace tvm { namespace codegen { -using namespace ir; +using namespace tir; /*! * \brief A base class to generate a LLVM. diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc index 877bbba6ba61..555adc9d26ed 100644 --- a/src/codegen/llvm/codegen_nvptx.cc +++ b/src/codegen/llvm/codegen_nvptx.cc @@ -26,7 +26,6 @@ #include #include "codegen_llvm.h" #include "../build_common.h" -#include "../../pass/ir_util.h" #include "../../runtime/cuda/cuda_module.h" namespace tvm { diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 2e419316e957..05467633ce5c 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -90,11 +90,11 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, LLVMType(DataType::Float(32, from.lanes())), { - MakeValue(ir::CallNode::make( - DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value}, - ir::CallNode::PureIntrinsic)), + MakeValue(tir::CallNode::make( + DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, + tir::CallNode::PureIntrinsic)), MakeValue( - ir::BroadcastNode::make( + tir::BroadcastNode::make( FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), @@ -104,9 +104,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin( ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())), - {MakeValue(ir::CallNode::make( - DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value}, - ir::CallNode::PureIntrinsic))}); + {MakeValue(tir::CallNode::make( + DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, + tir::CallNode::PureIntrinsic))}); } } diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index b05185bafd9c..758b0afc22ff 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -22,6 +22,7 @@ */ #ifdef TVM_LLVM_VERSION +#include #include "intrin_rule_llvm.h" namespace tvm { @@ -63,22 +64,24 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; PrimExpr e = targs[0]; - const ir::CallNode* call = e.as(); + const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1); PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_two = make_const(x.dtype(), -2); - PrimExpr exp_neg2x = ir::CallNode::make( - x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic); - PrimExpr exp_pos2x = ir::CallNode::make( - x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic); + PrimExpr exp_neg2x = tir::CallNode::make( + x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_pos2x = tir::CallNode::make( + x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - *rv = ir::SelectNode::make( + *rv = tir::SelectNode::make( x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index d81c33b70185..0001d4a35537 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_ #ifdef TVM_LLVM_VERSION -#include +#include #include #include @@ -38,7 +38,7 @@ namespace codegen { template inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; - const ir::CallNode* call = e.as(); + const tir::CallNode* call = e.as(); CHECK(call != nullptr); Array cargs; // intrin id. @@ -48,14 +48,14 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = ir::CallNode::make( - call->dtype, "llvm_intrin", cargs, ir::CallNode::PureIntrinsic); + *rv = tir::CallNode::make( + call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); } template inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; - const ir::CallNode* call = e.as(); + const tir::CallNode* call = e.as(); CHECK(call != nullptr); Array cargs; // intrin id. @@ -64,8 +64,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = ir::CallNode::make( - call->dtype, "llvm_intrin", cargs, ir::CallNode::Intrinsic); + *rv = tir::CallNode::make( + call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/codegen/llvm/intrin_rule_nvptx.cc b/src/codegen/llvm/intrin_rule_nvptx.cc index 68475f01c39b..6f7a89cebd97 100644 --- a/src/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/codegen/llvm/intrin_rule_nvptx.cc @@ -22,8 +22,8 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include +#include +#include #include #include @@ -32,7 +32,7 @@ namespace codegen { inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; - using namespace ir; + using namespace tir; const CallNode* call = e.as(); CHECK(call != nullptr); CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index 477bcb4f97cc..31b7bf19419b 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -22,8 +22,8 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include +#include +#include #include #include @@ -33,7 +33,7 @@ namespace codegen { inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; - using namespace ir; + using namespace tir; const CallNode* call = e.as(); CHECK(call != nullptr); std::ostringstream intrinsic_name; diff --git a/src/codegen/spirv/build_vulkan.cc b/src/codegen/spirv/build_vulkan.cc index 6c90e1dfa3bf..c90b4c7eeb48 100644 --- a/src/codegen/spirv/build_vulkan.cc +++ b/src/codegen/spirv/build_vulkan.cc @@ -24,7 +24,7 @@ // Use libspirv for parsing and validating code. #include #include -#include +#include #include "codegen_spirv.h" #include "../build_common.h" diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 91eee8c345ed..4021b17d7243 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -21,8 +21,8 @@ * \file codegen_spirv.cc * \brief Generate SPIRV block */ -#include -#include +#include +#include #include #include "codegen_spirv.h" #include "../../arith/compute_expr.h" @@ -406,7 +406,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = ir::Simplify( + PrimExpr vec_index = tir::Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); @@ -484,7 +484,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = ir::Simplify( + PrimExpr vec_index = tir::Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); @@ -615,12 +615,12 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); } } - } else if (op->attr_key == ir::attr::storage_scope) { + } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); storage_info_[v].scope = runtime::StorageScope::make(op->value.as()->value); - } else if (op->attr_key == ir::attr::volatile_scope) { + } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); storage_info_[v].is_volatile = true; diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 877bc712085b..2e6b519c070a 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -25,9 +25,9 @@ #define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_ #include -#include -#include -#include +#include +#include +#include #include #include @@ -39,7 +39,7 @@ namespace tvm { namespace codegen { -using namespace ir; +using namespace tir; /*! * \brief Code generator into SPIRV diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index aa69ebf0ebc7..ead6952b434e 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -21,7 +21,7 @@ * \file intrin_rule_spirv.cc */ #include -#include +#include #include namespace tvm { @@ -34,7 +34,7 @@ using namespace runtime; template inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; - const ir::CallNode* call = e.as(); + const tir::CallNode* call = e.as(); CHECK(call != nullptr); Array cargs; // intrin id. @@ -43,8 +43,8 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = ir::CallNode::make( - call->dtype, "spirv_glsl450", cargs, ir::CallNode::PureIntrinsic); + *rv = tir::CallNode::make( + call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") diff --git a/src/codegen/spirv/ir_builder.h b/src/codegen/spirv/ir_builder.h index 5d25e8634e84..55b15a78bfc5 100644 --- a/src/codegen/spirv/ir_builder.h +++ b/src/codegen/spirv/ir_builder.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_SPIRV_IR_BUILDER_H_ #include -#include +#include #include #include diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index c12f66f37e5e..bce878a58abe 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -21,6 +21,7 @@ * \file codegen_stackvm.cc */ #include +#include #include #include #include "codegen_stackvm.h" @@ -29,7 +30,7 @@ namespace tvm { namespace codegen { -using namespace ir; +using namespace tir; // map struct field kind to runtime variants // We keep two separate enums to ensure runtime/compiler isolation. diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 1360cc2d70f1..ea7cc4e65c08 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -24,9 +24,9 @@ #ifndef TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_ #define TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_ -#include -#include -#include +#include +#include +#include #include #include #include @@ -37,7 +37,7 @@ namespace tvm { namespace codegen { -using namespace ir; +using namespace tir; using runtime::StackVM; /*! diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 346ec3808919..a17ae8786877 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -31,7 +31,7 @@ namespace contrib { using runtime::TVMArgs; using runtime::TVMRetValue; -using namespace ir; +using namespace tir; std::string dot_to_underscore(std::string s) { for (auto &ch : s) @@ -288,7 +288,7 @@ void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) { } void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == ir::attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { auto iter_var = op->node.as(); CHECK(iter_var); binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); @@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; - } else if (op->attr_key == ir::attr::realize_scope) { + } else if (op->attr_key == tir::attr::realize_scope) { auto v = Downcast(op->node); alloc_storage_scope_[v] = op->value.as()->value; PrintStmt(op->body); diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index f5ba9abf1244..19c2cbfacc8a 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -24,10 +24,10 @@ #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ -#include -#include +#include +#include #include -#include +#include #include #include #include @@ -39,7 +39,7 @@ namespace tvm { namespace contrib { using namespace top; -using namespace ir; +using namespace tir; /*! * \brief A base class to generate Hybrid Script. * diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 378e8a8fd1f0..49431407a6ee 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -31,7 +31,7 @@ #define TVM_IR_ATTR_FUNCTOR_H_ #include -#include +#include #include namespace tvm { @@ -77,40 +77,40 @@ class AttrFunctor { virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. - virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const SizeVarNode* op, Args... args) { - return VisitAttr_(static_cast(op), std::forward(args)...); + virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) { + return VisitAttr_(static_cast(op), std::forward(args)...); } - virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. static FType InitVTable() { - using namespace ir; + using namespace tir; FType vtable; // Set dispatch ATTR_FUNCTOR_DISPATCH(StrMapNode); @@ -159,30 +159,30 @@ class AttrsEqualHandler : bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final; bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::SubNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::MulNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::DivNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::ModNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::FloorDivNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::FloorModNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::MinNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::MaxNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::GENode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::GTNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::LTNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::LENode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::EQNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::NENode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::AndNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::OrNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::NotNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::CastNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::CallNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::SelectNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final; }; class AttrsHashHandler : @@ -199,32 +199,32 @@ class AttrsHashHandler : protected: size_t VisitAttrDefault_(const Object* lhs) final; - size_t VisitAttr_(const ir::IntImmNode* lhs) final; - size_t VisitAttr_(const ir::FloatImmNode* lhs) final; - size_t VisitAttr_(const ir::StringImmNode* lhs) final; + size_t VisitAttr_(const tir::IntImmNode* lhs) final; + size_t VisitAttr_(const tir::FloatImmNode* lhs) final; + size_t VisitAttr_(const tir::StringImmNode* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; size_t VisitAttr_(const StrMapNode* lhs) final; - size_t VisitAttr_(const ir::AddNode* op) final; - size_t VisitAttr_(const ir::SubNode* op) final; - size_t VisitAttr_(const ir::MulNode* op) final; - size_t VisitAttr_(const ir::DivNode* op) final; - size_t VisitAttr_(const ir::ModNode* op) final; - size_t VisitAttr_(const ir::FloorDivNode* op) final; - size_t VisitAttr_(const ir::FloorModNode* op) final; - size_t VisitAttr_(const ir::MinNode* op) final; - size_t VisitAttr_(const ir::MaxNode* op) final; - size_t VisitAttr_(const ir::GENode* op) final; - size_t VisitAttr_(const ir::GTNode* op) final; - size_t VisitAttr_(const ir::LENode* op) final; - size_t VisitAttr_(const ir::LTNode* op) final; - size_t VisitAttr_(const ir::EQNode* op) final; - size_t VisitAttr_(const ir::NENode* op) final; - size_t VisitAttr_(const ir::AndNode* op) final; - size_t VisitAttr_(const ir::OrNode* op) final; - size_t VisitAttr_(const ir::NotNode* op) final; - size_t VisitAttr_(const ir::CastNode* op) final; - size_t VisitAttr_(const ir::CallNode* op) final; - size_t VisitAttr_(const ir::SelectNode* op) final; + size_t VisitAttr_(const tir::AddNode* op) final; + size_t VisitAttr_(const tir::SubNode* op) final; + size_t VisitAttr_(const tir::MulNode* op) final; + size_t VisitAttr_(const tir::DivNode* op) final; + size_t VisitAttr_(const tir::ModNode* op) final; + size_t VisitAttr_(const tir::FloorDivNode* op) final; + size_t VisitAttr_(const tir::FloorModNode* op) final; + size_t VisitAttr_(const tir::MinNode* op) final; + size_t VisitAttr_(const tir::MaxNode* op) final; + size_t VisitAttr_(const tir::GENode* op) final; + size_t VisitAttr_(const tir::GTNode* op) final; + size_t VisitAttr_(const tir::LENode* op) final; + size_t VisitAttr_(const tir::LTNode* op) final; + size_t VisitAttr_(const tir::EQNode* op) final; + size_t VisitAttr_(const tir::NENode* op) final; + size_t VisitAttr_(const tir::AndNode* op) final; + size_t VisitAttr_(const tir::OrNode* op) final; + size_t VisitAttr_(const tir::NotNode* op) final; + size_t VisitAttr_(const tir::CastNode* op) final; + size_t VisitAttr_(const tir::CallNode* op) final; + size_t VisitAttr_(const tir::SelectNode* op) final; /*! * \brief alias of dmlc::HashCombine * \param lhs The first hash value. diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 54f5ee2ac0c4..8c6e191ce287 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); -using namespace ir; +using namespace tir; // Equal handler. bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 8118d036bcd6..b041d73949fd 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include namespace tvm { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index b173f4f396ba..f81eb33ba0df 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -29,14 +29,23 @@ // // Rationale: convert from IterVar and top::Tensor #include -#include +#include namespace tvm { +PrimExpr::PrimExpr(int32_t value) + : PrimExpr(IntImm(DataType::Int(32), value)) {} + +PrimExpr::PrimExpr(float value) + : PrimExpr(FloatImm(DataType::Float(32), value)) {} + +PrimExpr::PrimExpr(std::string str) + : PrimExpr(tir::StringImmNode::make(str)) {} + PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { using runtime::ObjectTypeChecker; - if (ptr->IsInstance()) { - return IterVar(ptr)->var; + if (ptr->IsInstance()) { + return tir::IterVar(ptr)->var; } if (ptr->IsInstance()) { return top::Tensor(ptr)(); @@ -47,6 +56,7 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { return PrimExpr(ptr); } + IntImm::IntImm(DataType dtype, int64_t value) { CHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar."; @@ -66,6 +76,17 @@ TVM_REGISTER_GLOBAL("make.IntImm") return IntImm(dtype, value); }); +TVM_REGISTER_NODE_TYPE(IntImmNode); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + if (op->dtype == DataType::Int(32)) { + p->stream << op->value; + } else { + p->stream << "(" << op->dtype << ")" << op->value; + } + }); FloatImm::FloatImm(DataType dtype, double value) { CHECK_EQ(dtype.lanes(), 1) @@ -81,6 +102,49 @@ TVM_REGISTER_GLOBAL("make.FloatImm") return FloatImm(dtype, value); }); +TVM_REGISTER_NODE_TYPE(FloatImmNode); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + auto& stream = p->stream; + switch (op->dtype.bits()) { + case 64: + stream << op->value; + break; + case 32: + stream << op->value << 'f'; + break; + case 16: + stream << op->value << 'h'; + break; + default: + LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); + } + }); + + +Range::Range(PrimExpr begin, PrimExpr end) + : Range(make_object( + begin, + tir::is_zero(begin) ? end : (end - begin))) { +} + +Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { + return Range(make_object(min, extent)); +} + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; + }); + +TVM_REGISTER_NODE_TYPE(ArrayNode); +TVM_REGISTER_NODE_TYPE(MapNode); +TVM_REGISTER_NODE_TYPE(StrMapNode); +TVM_REGISTER_NODE_TYPE(RangeNode); + GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); @@ -101,4 +165,46 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "GlobalVar(" << node->name_hint << ")"; }); +// Container printer +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0 ; i < op->data.size(); ++i) { + if (i != 0) { + p->stream << ", "; + } + p->Print(op->data[i]); + } + p->stream << ']'; +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->data.begin(); it != op->data.end(); ++it) { + if (it != op->data.begin()) { + p->stream << ", "; + } + p->Print(it->first); + p->stream << ": "; + p->Print(it->second); + } + p->stream << '}'; + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->data.begin(); it != op->data.end(); ++it) { + if (it != op->data.begin()) { + p->stream << ", "; + } + p->stream << '\"' << it->first << "\": "; + p->Print(it->second); + } + p->stream << '}'; + }); } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index bf180cfadb90..5d0f5d845833 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -28,7 +28,7 @@ #include // TODO(tqchen): Update to use String container after it is merged. -#include +#include #include #include @@ -268,7 +268,7 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { inline bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { - auto* str_name = x.as(); + auto* str_name = x.as(); CHECK(str_name) << "pass name must be str"; if (str_name->value == pass_name) return true; } @@ -310,7 +310,7 @@ IRModule SequentialNode::operator()(const IRModule& module, if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { - const auto* name = it.as(); + const auto* name = it.as(); CHECK(name); mod = GetPass(name->value)(mod, pass_ctx); } @@ -349,7 +349,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "opt_level: " << node->opt_level; p->stream << "required passes: [" << "\n"; for (const auto& it : node->required) { - const auto* str = it.as(); + const auto* str = it.as(); p->stream << str->value << ", "; } p->stream << "]\n"; diff --git a/src/lang/expr.cc b/src/lang/expr.cc deleted file mode 100644 index 1dd88b5d0bbb..000000000000 --- a/src/lang/expr.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file expr.cc - */ - -#include -#include -#include -#include -#include - -namespace tvm { - -PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImm(DataType::Int(32), value)) {} - -PrimExpr::PrimExpr(float value) - : PrimExpr(FloatImm(DataType::Float(32), value)) {} - -PrimExpr::PrimExpr(std::string str) - : PrimExpr(ir::StringImmNode::make(str)) {} - -Var::Var(std::string name_hint, DataType t) - : Var(make_object(t, name_hint)) {} - -VarNode::VarNode(DataType t, std::string name_hint) { - this->dtype = t; - this->name_hint = std::move(name_hint); -} - -SizeVar::SizeVar(std::string name_hint, DataType t) - : SizeVar(make_object(t, name_hint)) {} - -SizeVarNode::SizeVarNode(DataType t, std::string name_hint) - : VarNode(t, std::move(name_hint)) {} - -Range::Range(PrimExpr begin, PrimExpr end) - : Range(make_object( - begin, - is_zero(begin) ? end : (end - begin))) { -} - -Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { - return Range(make_object(min, extent)); -} - -IterVar IterVarNode::make(Range dom, - Var var, - IterVarType t, - std::string thread_tag) { - ObjectPtr n = make_object(); - n->dom = dom; - n->var = var; - n->iter_type = t; - n->thread_tag = thread_tag; - return IterVar(n); -} - -IterVar thread_axis(Range dom, std::string tag) { - return IterVarNode::make( - dom, Var(tag), kThreadIndex, tag); -} - -IterVar reduce_axis(Range dom, std::string name) { - return IterVarNode::make( - dom, Var(name), kCommReduce); -} - -void Dump(const ObjectRef& n) { - std::cerr << n << "\n"; -} - -Var var(std::string name_hint, DataType t) { - return Var(name_hint, t); -} - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - if (op->dtype == DataType::Int(32)) { - p->stream << op->value; - } else { - p->stream << "(" << op->dtype << ")" << op->value; - } - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "iter_var("; - if (op->var->name_hint.length() != 0) { - p->stream << op->var->name_hint << ", "; - } - if (op->dom.defined()) { - p->stream << op->dom; - } - if (op->thread_tag.length() != 0) { - p->stream << ", " << op->thread_tag; - } - p->stream << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; - }); - -TVM_REGISTER_NODE_TYPE(ArrayNode); -TVM_REGISTER_NODE_TYPE(MapNode); -TVM_REGISTER_NODE_TYPE(StrMapNode); -TVM_REGISTER_NODE_TYPE(RangeNode); -TVM_REGISTER_NODE_TYPE(IterVarNode); - -} // namespace tvm diff --git a/src/node/printer.cc b/src/node/printer.cc index 15171dfefc4c..e0176d245d80 100644 --- a/src/node/printer.cc +++ b/src/node/printer.cc @@ -49,4 +49,8 @@ NodePrinter::FType& NodePrinter::vtable() { static FType inst; return inst; } + +void Dump(const ObjectRef& n) { + std::cerr << n << "\n"; +} } // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 0458dfd55b17..64c7ac990213 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -36,6 +36,8 @@ namespace tvm { namespace relay { namespace backend { +using tir::LoweredFunc; + using TargetsMap = Map; using namespace tvm::relay::transform; @@ -85,7 +87,7 @@ struct GraphCodegen { std::unordered_map ret; auto names = CallFunc >("list_params_name", nullptr); for (auto expr : names) { - auto key = expr.as()->value; + auto key = expr.as()->value; ret[key] = CallFunc("get_param_by_name", key); } return ret; @@ -193,7 +195,7 @@ class RelayBuildModule : public runtime::ModuleNode { Array ListParamNames() { Array ret; for (const auto& kv : params_) { - ret.push_back(ir::StringImmNode::make(kv.first)); + ret.push_back(tir::StringImmNode::make(kv.first)); } return ret; } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d32a6e2e6ff4..6720e225cb07 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -83,13 +83,13 @@ Array GetShape(const Array& shape) { // even if the result of shape inference becomes int64. Array res; for (IndexExpr val : shape) { - const int64_t* pval = as_const_int(val); + const int64_t* pval = tir::as_const_int(val); if (pval != nullptr) { CHECK_LE(pval[0], std::numeric_limits::max()); CHECK_GE(pval[0], std::numeric_limits::min()); res.push_back(IntImm(DataType::Int(32), *pval)); - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); + } else if (val->IsInstance()) { + res.push_back(val.as()->ToVar()); } else { res.push_back(val); } @@ -186,10 +186,11 @@ class ScheduleGetter : } Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; CHECK(op->is_scalar()); void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = top::compute({}, [&](const Array&) { + auto value = top::compute({}, [&](const Array&) { if (dtype == DataType::Int(32)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Int(64)) { @@ -459,13 +460,14 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { } Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; CHECK(data_dependants_.size()); CHECK(op->is_scalar()); bool data_dependant = data_dependants_.back(); if (data_dependant) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = tvm::top::compute({}, [&](const Array&) { + auto value = tvm::top::compute({}, [&](const Array&) { if (dtype == DataType::Int(32)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Int(64)) { @@ -484,8 +486,8 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { scalars_.push_back(value); return {value}; } else { - auto value = tvm::top::compute({}, [&](const Array&) { - return make_const(DataType::Int(64), 0); + auto value = tvm::top::compute({}, [&](const Array&) { + return tir::make_const(DataType::Int(64), 0); }, "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; @@ -620,13 +622,13 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(src_func.defined()); if (!src_func->UseDefaultCompiler()) { auto compiler = FunctionGetAttr(src_func, attr::kCompiler); - const tvm::ir::StringImmNode* code_gen = compiler.as(); + const tvm::tir::StringImmNode* code_gen = compiler.as(); CHECK(code_gen) << "No external codegen is set"; if (ext_mods.find(code_gen->value) == ext_mods.end()) { ext_mods[code_gen->value] = IRModule({}, {}); } auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); - const tvm::ir::StringImmNode* symbol_name = ext_symbol.as(); + const tvm::tir::StringImmNode* symbol_name = ext_symbol.as(); CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); auto gv = GlobalVar(symbol_name->value); ext_mods[code_gen->value]->Add(gv, src_func); @@ -697,7 +699,7 @@ class CompileEngineImpl : public CompileEngineNode { if (!key->source_func->UseDefaultCompiler()) { auto cache_node = make_object(); const auto name_node = - FunctionGetAttr(key->source_func, attr::kExternalSymbol).as(); + FunctionGetAttr(key->source_func, attr::kExternalSymbol).as(); CHECK(name_node != nullptr) << "External function has not been attached a name yet."; cache_node->func_name = name_node->value; cache_node->target = tvm::target::ext_dev(); @@ -733,7 +735,7 @@ class CompileEngineImpl : public CompileEngineNode { spair.first, all_args, cache_node->func_name, key->source_func); } else { tvm::BuildConfig bcfg = BuildConfig::Create(); - std::unordered_map binds; + std::unordered_map binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); } value->cached_func = CachedFunc(cache_node); @@ -768,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode { all_args.push_back(arg); } tvm::BuildConfig bcfg = BuildConfig::Create(); - std::unordered_map binds; + std::unordered_map binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); value->cached_func = CachedFunc(cache_node); return value; diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 386eba7f9fd8..4ea49766f8f1 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -25,7 +25,7 @@ #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ -#include +#include #include #include #include @@ -55,7 +55,7 @@ struct CachedFuncNode : public Object { /* \brief The outputs to the function */ tvm::Array outputs; /*! \brief The lowered functions to support the function. */ - tvm::Array funcs; + tvm::Array funcs; /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index c465139a1797..23e07d55dac9 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -60,7 +60,7 @@ class CSourceModuleCodegenBase { */ std::string GetExtSymbol(const Function& func) const { const auto name_node = - FunctionGetAttr(func, attr::kExternalSymbol).as(); + FunctionGetAttr(func, attr::kExternalSymbol).as(); CHECK(name_node != nullptr) << "Fail to retrieve external symbol."; std::string ext_symbol = name_node->value; return ext_symbol; diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index fd4165567e8d..736509d2d97f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -22,7 +22,7 @@ * \brief Memory index assignment pass for executing * the program in the graph runtime. */ -#include +#include #include #include #include @@ -295,7 +295,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { CHECK(ttype != nullptr); size_t size = 1; for (IndexExpr dim : ttype->shape) { - const int64_t* pval = as_const_int(dim); + const int64_t* pval = tir::as_const_int(dim); CHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 2e18e46e4ad3..f28d5415449f 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -55,7 +55,7 @@ using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ struct LoweredOutput { std::string graph_json; - Map > lowered_funcs; + Map > lowered_funcs; Array external_mods; std::unordered_map params; }; @@ -216,10 +216,10 @@ class GraphRuntimeCodegen ret.params = params_; for (auto& kv : lowered_funcs_) { if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, Array()); + ret.lowered_funcs.Set(kv.first, Array()); } auto& vec = ret.lowered_funcs[kv.first]; - Array tmp; + Array tmp; for (auto f : kv.second) { tmp.push_back(f); } @@ -242,7 +242,7 @@ class GraphRuntimeCodegen std::vector _ShapeToJSON(tvm::Array shape) { std::vector ret; for (IndexExpr dim : shape) { - const int64_t* pval = as_const_int(dim); + const int64_t* pval = tir::as_const_int(dim); ret.push_back(*pval); } return ret; @@ -601,7 +601,7 @@ class GraphRuntimeCodegen /*! \brief plan memory of device result */ Map> storage_device_map_; /*! \brief lowered funcs */ - std::unordered_map> + std::unordered_map> lowered_funcs_; /*! \brief name map */ std::unordered_map name_map_; @@ -623,7 +623,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { Map tmp = args[1]; TargetsMap targets; for (const auto& it : tmp) { - auto dev_type = it.first.as(); + auto dev_type = it.first.as(); CHECK(dev_type); targets[dev_type->value] = it.second; } @@ -643,7 +643,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Array ret; for (const auto &kv : this->output_.params) { - tvm::PrimExpr name = ir::StringImmNode::make(kv.first); + tvm::PrimExpr name = tir::StringImmNode::make(kv.first); ret.push_back(name); } *rv = ret; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index eb0a7b75a44d..b9330ae2f439 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -494,7 +494,7 @@ class Interpreter : // Allocate output tensor. std::vector shape; for (auto dim : rtype->shape) { - const auto* ivalue = as_const_int(dim); + const auto* ivalue = tir::as_const_int(dim); CHECK(ivalue) << "expected concrete dimensions"; shape.push_back(ivalue[0]); } diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index dd42eb3abfc5..c829e546b90b 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -25,7 +25,7 @@ #define TVM_RELAY_BACKEND_PARAM_DICT_H_ #include -#include +#include #include #include diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 79583685a6a2..df33bcbb5aea 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 00e47bcc9c5e..d359d171e441 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -968,6 +968,8 @@ void VMCompiler::PopulateGlobalMap() { } void VMCompiler::Codegen() { + using tir::LoweredFunc; + if (!context_.module.defined()) { LOG(WARNING) << "Did you forget to call VMCompiler::Lower?"; return; diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index df08c7e39cb6..602e6cceb3dd 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -76,7 +76,7 @@ struct VMCompilerContext { // List of cached functions std::vector cached_funcs; // The functions that have been lowered. - std::unordered_map seen_funcs; + std::unordered_map seen_funcs; }; diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 955453d187dd..1cf671a1999b 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -44,7 +44,7 @@ inline std::string GenerateName(const Function& func) { bool IsClosure(const Function& func) { ObjectRef res = FunctionGetAttr(func, attr::kClosure); - const ir::IntImmNode* pval = res.as(); + const tir::IntImmNode* pval = res.as(); return pval && pval->value != 0; } diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index d535c8d1de83..dd11fce5cc42 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -90,7 +90,7 @@ IRModule RemoveUnusedFunctions(const IRModule& module, Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { - auto* str_name = entry.as(); + auto* str_name = entry.as(); auto funcs = CallTracer(module).Trace(str_name->value); called_funcs.insert(funcs.cbegin(), funcs.cend()); } diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index ae4b83faac8f..b55a4afd22e2 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -21,7 +21,7 @@ * \file src/tvm/relay/ir/alpha_equal.cc * \brief Alpha equality check by deep comparing two nodes. */ -#include +#include #include #include #include @@ -196,7 +196,7 @@ class AlphaEqualHandler: } } using AttrsEqualHandler::VisitAttr_; - bool VisitAttr_(const tvm::VarNode* lhs, const ObjectRef& other) final { + bool VisitAttr_(const tvm::tir::VarNode* lhs, const ObjectRef& other) final { return LeafObjectEqual(GetRef(lhs), other); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 9966d9cc55ef..7e19d5173ad4 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -140,7 +140,7 @@ FuncType FunctionNode::func_type_annotation() const { bool FunctionNode::IsPrimitive() const { ObjectRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); - const ir::IntImmNode* pval = res.as(); + const tir::IntImmNode* pval = res.as(); return pval && pval->value != 0; } @@ -166,7 +166,7 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams") bool FunctionNode::UseDefaultCompiler() const { ObjectRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); - const ir::StringImmNode* pval = res.as(); + const tir::StringImmNode* pval = res.as(); return pval == nullptr || pval->value == "default"; } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 0ee9ac5f457a..b1906d3e0feb 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -21,7 +21,7 @@ * \file src/tvm/relay/ir/hash.cc * \brief Hash functions for Relay types and expressions. */ -#include +#include #include #include #include @@ -125,9 +125,9 @@ class RelayHashHandler: } using AttrsHashHandler::VisitAttr_; - size_t VisitAttr_(const tvm::VarNode* var) final { + size_t VisitAttr_(const tvm::tir::VarNode* var) final { size_t hash = std::hash()(VarNode::_type_key); - auto it = hash_map_.find(GetRef(var)); + auto it = hash_map_.find(GetRef(var)); if (it != hash_map_.end()) { return it->second; } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index a22a4a24bc2f..ae2089d5b765 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -824,7 +824,7 @@ class PrettyPrinter : Doc PrintAttr(const ObjectRef& value, bool meta = false) { if (value.defined()) { Doc printed_attr; - if (value.as()) { + if (value.as()) { printed_attr << "?"; } else if (meta) { printed_attr = meta_.GetMetaNode(Downcast(value)); @@ -853,15 +853,15 @@ class PrettyPrinter : return doc; } - Doc VisitAttr_(const ir::IntImmNode* op) final { + Doc VisitAttr_(const tir::IntImmNode* op) final { return PrintConstScalar(op->dtype, &(op->value)); } - Doc VisitAttr_(const ir::FloatImmNode* op) final { + Doc VisitAttr_(const tir::FloatImmNode* op) final { return PrintConstScalar(op->dtype, &(op->value)); } - Doc VisitAttr_(const ir::StringImmNode* op) final { + Doc VisitAttr_(const tir::StringImmNode* op) final { return PrintString(op->value); } diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 02ae273350bc..1f2e8ed52f8d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -139,7 +139,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, bool FunctionPassNode::SkipFunction(const Function& func) const { ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); - const ir::IntImmNode* pval = skip_opt.as(); + const tir::IntImmNode* pval = skip_opt.as(); return (pval && pval->value != 0) || (!func->UseDefaultCompiler()); } diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index a3b4668ad9e5..f1e59a4b97c0 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -22,7 +22,7 @@ * \brief The type system AST nodes of Relay. */ #include -#include +#include namespace tvm { namespace relay { @@ -43,7 +43,7 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { IndexExpr TensorTypeNode::Size() const { if (shape.size() == 0) { - return make_const(DataType::Int(64), 1); + return tir::make_const(DataType::Int(64), 1); } IndexExpr size = shape[0]; diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 2aefbd7ea883..50a55f56b45b 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -23,7 +23,7 @@ * \brief Registration of annotation operators. */ -#include +#include #include #include #include diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index 87e579740892..9b4647dea151 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -22,7 +22,7 @@ * \brief Property def of nn operators. */ -#include +#include #include #include #include diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 463b76f7046d..d15c85cabcad 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -26,7 +26,7 @@ * used as "barrier" to avoid fusing operators belonging to differen devices. */ -#include +#include #include #include #include diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index e387a712435f..e796a044f388 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -21,7 +21,7 @@ * \file resize.cc * \brief Image operators */ -#include +#include #include #include #include "../op_common.h" diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index 68aa77bd18bb..eccffc8f3f0f 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -22,7 +22,7 @@ * \brief Property def of bitserial operators. */ -#include +#include #include #include diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 6c3fb6187b43..82f4ba50467d 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -21,8 +21,8 @@ * \file convolution.cc * \brief Convolution operators */ -#include -#include +#include +#include #include #include #include @@ -597,13 +597,13 @@ bool Conv2DWinogradRel(const Array& types, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - if (!dshape_nchw[2].as()) { + if (!dshape_nchw[2].as()) { oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1); } else { oshape.Set(2, dshape_nchw[2]); } - if (!dshape_nchw[3].as()) { + if (!dshape_nchw[3].as()) { oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1); } else { diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 9e8f4b55d26e..f858efca62bd 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ -#include +#include #include #include @@ -104,7 +104,7 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, // dilation Array oshape({dshape_ncw[0], channels, 0}); - if (!dshape_ncw[2].as()) { + if (!dshape_ncw[2].as()) { oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, param->strides[0]) + 1); } else { @@ -161,7 +161,7 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(param->dilation.size(), 2); Array wshape; - if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) { + if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) { // infer weight's shape for depthwise convolution wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0], param->kernel_size[1]}}; @@ -207,14 +207,14 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - if (!dshape_nchw[2].as()) { + if (!dshape_nchw[2].as()) { oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } - if (!dshape_nchw[3].as()) { + if (!dshape_nchw[3].as()) { oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { @@ -270,7 +270,7 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(param->dilation.size(), 3); Array wshape; - if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) { + if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) { // infer weight's shape for depthwise convolution wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}}; @@ -320,21 +320,21 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); - if (!dshape_ncdhw[2].as()) { + if (!dshape_ncdhw[2].as()) { oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } - if (!dshape_ncdhw[3].as()) { + if (!dshape_ncdhw[3].as()) { oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } - if (!dshape_ncdhw[4].as()) { + if (!dshape_ncdhw[4].as()) { oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); } else { diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 9fca22d777d9..2ff439a527ba 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -22,7 +22,7 @@ * \brief Property def of nn operators. */ -#include +#include #include #include #include @@ -405,10 +405,10 @@ bool BatchFlattenRel(const Array& types, if (data == nullptr) return false; if (data->shape.size() == 0) return false; - auto target_dim = make_const(DataType::Int(32), 1); + auto target_dim = tir::make_const(DataType::Int(32), 1); for (uint32_t i = 1; i < data->shape.size(); ++i) { - if (!data->shape[i].as()) { + if (!data->shape[i].as()) { target_dim = target_dim * data->shape[i]; } else { target_dim = data->shape[i]; diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index b67f93928a88..e33f751fb638 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -21,8 +21,8 @@ * \file pad.cc * \brief Implementation of operator pad */ -#include -#include +#include +#include #include #include #include @@ -135,8 +135,8 @@ bool PadRel(const Array& types, << "Each pad width element should be a pair but at index " << i << " there are " << param->pad_width[i].size() << " elements."; - auto width1 = as_const_int(param->pad_width[i][0]); - auto width2 = as_const_int(param->pad_width[i][1]); + auto width1 = tir::as_const_int(param->pad_width[i][0]); + auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); @@ -147,8 +147,8 @@ bool PadRel(const Array& types, << "Param width elements should be positive but first pad width at " << "index " << i << " is " << *width2 << "."; - if (!data->shape[i].as()) { - auto padding = make_const(data->shape[i].dtype(), *width1 + *width2); + if (!data->shape[i].as()) { + auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); oshape.push_back(data->shape[i] + padding); } else { oshape.push_back(data->shape[i]); @@ -181,7 +181,7 @@ Array PadCompute(const Attrs& attrs, } const auto* out_ttype = out_type.as(); return Array{ topi::pad(inputs[0], pad_before, pad_after, - tvm::make_const(out_ttype->dtype, param->pad_value), + tvm::tir::make_const(out_ttype->dtype, param->pad_value), "T_pad", topi::kElementWise, param->pad_mode) }; @@ -244,8 +244,8 @@ bool MirrorPadRel(const Array& types, << "Each pad width element should be a pair but at index " << i << " there are " << param->pad_width[i].size() << " elements."; - auto width1 = as_const_int(param->pad_width[i][0]); - auto width2 = as_const_int(param->pad_width[i][1]); + auto width1 = tir::as_const_int(param->pad_width[i][0]); + auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); @@ -256,7 +256,7 @@ bool MirrorPadRel(const Array& types, << "Param width elements should be positive but first pad width at " << "index " << i << " is " << *width2 << "."; - auto padding = make_const(data->shape[i].dtype(), *width1 + *width2); + auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); oshape.push_back(data->shape[i] + padding); } diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 65fd09d93d4d..0c74b2711f01 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,7 +21,7 @@ * \file pooling.cc * \brief Pooling operators */ -#include +#include #include #include #include @@ -139,7 +139,7 @@ bool Pool2DRel(const Array& types, oshape.push_back(e); } - if (dshape[hidx].as()) { + if (dshape[hidx].as()) { oshape[hidx] = dshape[hidx]; } else { if (param->ceil_mode) { @@ -149,7 +149,7 @@ bool Pool2DRel(const Array& types, oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; } } - if (dshape[widx].as()) { + if (dshape[widx].as()) { oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { @@ -780,7 +780,7 @@ bool Pool1DRel(const Array& types, oshape.push_back(e); } - if (dshape[widx].as()) { + if (dshape[widx].as()) { oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { @@ -974,7 +974,7 @@ bool Pool3DRel(const Array& types, std::vector idxes = {didx, hidx, widx}; for (int i = 0; i < 3; i++) { int ii = idxes[i]; - if (dshape[ii].as()) { + if (dshape[ii].as()) { oshape[ii] = dshape[ii]; } else { if (param->ceil_mode) { diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index f2be18202410..caad01b9e66b 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -22,7 +22,7 @@ * \brief Property def of nn.sparse_dense operator. */ -#include +#include #include #include #include diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 73cb5a1c4e11..f3a31c09f865 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -21,7 +21,7 @@ * \file upsampling.cc * \brief upsampling operator */ -#include +#include #include #include #include @@ -83,8 +83,8 @@ bool UpSamplingRel(const Array& types, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); - oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); + oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); + oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); // assign output type reporter->Assign(types[1], @@ -162,9 +162,9 @@ bool UpSampling3DRel(const Array& types, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); - oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); - oshape.Set(4, ir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); + oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); + oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); + oshape.Set(4, tir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); // assign output type reporter->Assign(types[1], diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index e20b7cfc52c1..5156330d7601 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -209,7 +209,7 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s return in_shape; } - auto max_shape = make_const(DataType::Int(64), 1); + auto max_shape = tir::make_const(DataType::Int(64), 1); bool is_dynamic_input = false; for (int64_t axis : r_axes) { if (in_shape[axis].as()) { @@ -221,7 +221,7 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } if (is_dynamic_input) { - CHECK(reporter->Assert(max_shape < make_const( + CHECK(reporter->Assert(max_shape < tir::make_const( DataType::Int(64), std::numeric_limits::max()))) << "The maximum possible index of reduced shape cannot be more than int32 max."; } @@ -537,7 +537,7 @@ Array MeanCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, const Target& target) { - IndexExpr count = make_const(inputs[0]->dtype, 1); + IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); auto axes = param->axis; @@ -602,7 +602,7 @@ Array VarianceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, const Target& target) { - IndexExpr count = make_const(inputs[0]->dtype, 1); + IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); auto axes = param->axis; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index c2af56d68d46..538c92ed42df 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -24,9 +24,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include #include @@ -42,7 +42,7 @@ namespace tvm { namespace relay { -using ir::IntImmNode; +using tir::IntImmNode; // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs); @@ -695,8 +695,8 @@ Array ReshapeCompute(const Attrs& attrs, CHECK(out_ttype != nullptr); Array newshape; for (auto val : out_ttype->shape) { - if (val->IsInstance()) { - newshape.push_back(val.as()->ToVar()); + if (val->IsInstance()) { + newshape.push_back(val.as()->ToVar()); } else { newshape.push_back(val); } @@ -1223,8 +1223,8 @@ inline top::Tensor DynamicArange(const top::Tensor& start, tvm::DataType dtype, std::string name = "tensor", std::string tag = topi::kInjective) { - tvm::PrimExpr num_elem = tvm::Var("num_elem"); - return top::compute({num_elem}, [&](const Array& indices) { + tvm::PrimExpr num_elem = tvm::tir::Var("num_elem"); + return top::compute({num_elem}, [&](const Array& indices) { return tvm::cast(dtype, start[0] + step[0] * indices[0]); }, name, tag); } @@ -1384,7 +1384,7 @@ bool TileRel(const Array& types, << "repetition array is not defined. data.ndim = " << ndim; const size_t rndim = reps.size(); for (size_t i = 0; i < rndim; ++i) { - if (const tvm::ir::IntImmNode* val = reps[i].as()) { + if (const tvm::tir::IntImmNode* val = reps[i].as()) { CHECK_GT(val->value, 0) << "Tile reps value should always be larger than 0, but get: " << val->value; } @@ -1652,7 +1652,7 @@ bool SqueezeRel(const Array& types, if (!e.as()) { LOG(FATAL) << "axis needs to be defined for dynamic input."; } - const int64_t* axis_ptr = as_const_int(e); + const int64_t* axis_ptr = tir::as_const_int(e); CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; if (*axis_ptr != 1) { result_shape.push_back(e); @@ -1677,7 +1677,7 @@ bool SqueezeRel(const Array& types, if (p.second) { result_shape.push_back(p.first); } else { - const int64_t* axis_ptr = as_const_int(p.first); + const int64_t* axis_ptr = tir::as_const_int(p.first); CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor"; CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; } @@ -1916,7 +1916,7 @@ bool StridedSliceRel(const Array& types, // Normal path, require the shape to be concrete integer. // Require concrete integer as symbolic inference of min/max // can get complicated and not very helpful. - const int64_t* p_dim_size = as_const_int(dshape[i]); + const int64_t* p_dim_size = tir::as_const_int(dshape[i]); CHECK(p_dim_size) << "strided_slice requires sliced dimension to be concrete int"; int64_t dim_size = p_dim_size[0]; @@ -1940,7 +1940,7 @@ bool StridedSliceRel(const Array& types, slice_range = end_v - begin_v; step = stride_v; } - oshape[i] = make_const(dshape[i].dtype(), (slice_range + step - 1) / step); + oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); return true; @@ -2141,7 +2141,7 @@ bool SplitRel(const Array& types, if (const IntImmNode* sections = param->indices_or_sections.as()) { CHECK(reporter->Assert(indexmod(data->shape[axis], - sections->value) == make_zero(DataType::Int(64)))) + sections->value) == tir::make_zero(DataType::Int(64)))) << "indices_or_sections need to be able to divide input.shape[axis]"; std::vector fields; for (int i = 0; i < sections->value; ++i) { @@ -2153,7 +2153,7 @@ bool SplitRel(const Array& types, reporter->Assign(types[1], TupleType(Array(fields))); } else { auto indices = param->indices_or_sections.as()->data; - auto begin = IndexExpr(make_zero(DataType::Int(32))); + auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; for (unsigned int i = 0; i < indices.size(); ++i) { CHECK(reporter->Assert(Downcast(indices[i]) > begin)) @@ -2205,7 +2205,7 @@ Expr MakeSplit(Expr data, TVM_REGISTER_GLOBAL("relay.op._make.split") .set_body([](const TVMArgs& args, TVMRetValue* rv) { if (args.type_codes[1] == kDLInt) { - *rv = MakeSplit(args[0], make_const(DataType::Int(64), int64_t(args[1])), args[2]); + *rv = MakeSplit(args[0], tir::make_const(DataType::Int(64), int64_t(args[1])), args[2]); } else { *rv = MakeSplit(args[0], args[1], args[2]); } diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index a1e2a6630f05..cd476fd49b87 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -24,7 +24,7 @@ */ #include #include -#include +#include #include #include "./type_relations.h" @@ -44,19 +44,19 @@ bool IdentityRel(const Array& types, bool EqualCheck(const IndexExpr& lhs, const IndexExpr& rhs) { IndexExpr diff = lhs - rhs; - if (const int64_t* pdiff = as_const_int(diff)) { + if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; } // symbolic - diff = tvm::ir::CanonicalSimplify(diff); - if (const int64_t* pdiff = as_const_int(diff)) { + diff = tvm::tir::CanonicalSimplify(diff); + if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; } return false; } bool EqualConstInt(const IndexExpr& lhs, int64_t value) { - if (const int64_t* pvalue = as_const_int(lhs)) { + if (const int64_t* pvalue = tir::as_const_int(lhs)) { return pvalue[0] == value; } return false; diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index d837e99b176e..b801186fb5b5 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -21,7 +21,7 @@ * \file multibox_op.cc * \brief Multibox related operators */ -#include +#include #include #include diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 974618209931..8e819bc44c28 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -123,7 +123,7 @@ Pass AlterOpLayout() { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index 85ca66ddfe0a..dd899dfbaf9d 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -134,7 +134,7 @@ Pass CanonicalizeCast() { return Downcast(CanonicalizeCast(f)); }; return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 5def35bf2e6c..bcb7f9dbbc9b 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -74,7 +74,7 @@ Pass CanonicalizeOps() { return Downcast(CanonicalizeOps(f)); }; return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 530b19948aa1..5c9263722788 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -204,7 +204,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); CHECK_NE(index, std::string::npos); return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), - MakeConstScalar(DataType::Int(32), num_filters)); + tir::make_const(DataType::Int(32), num_filters)); } }; @@ -221,7 +221,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 7cf161860301..fdd42be2c486 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -81,7 +81,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) { return Downcast(CombineParallelDense(f, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelDense", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc index f1514b5d7b0f..a9eefaf50268 100644 --- a/src/relay/pass/combine_parallel_op_batch.cc +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -194,7 +194,7 @@ Pass CombineParallelOpBatch(const std::string& op_name, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") diff --git a/src/relay/pass/convert_layout.cc b/src/relay/pass/convert_layout.cc index d435efd249d3..5a90651b47e5 100644 --- a/src/relay/pass/convert_layout.cc +++ b/src/relay/pass/convert_layout.cc @@ -134,8 +134,8 @@ Pass ConvertLayout(const std::string& desired_layout) { }; return CreateFunctionPass( pass_func, 3, "ConvertLayout", - {ir::StringImmNode::make("InferType"), - ir::StringImmNode::make("CanonicalizeOps")}); + {tir::StringImmNode::make("InferType"), + tir::StringImmNode::make("CanonicalizeOps")}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 286305157b6c..c87acd021a0b 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -28,7 +28,7 @@ * 3. Collect the device allocation of each expression. */ -#include +#include #include #include #include @@ -577,7 +577,7 @@ Pass RewriteAnnotatedOps(int fallback_device) { return Downcast(RewriteAnnotatedOps(f, fallback_device)); }; return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index bf08b071553a..c041b7e6eb41 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -92,7 +92,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 1e1d626a02ca..f79f3ee560d2 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -253,7 +253,7 @@ class ConstantFolder : public ExprMutator { std::vector cshape = { static_cast(ishape.size()) }; value = runtime::NDArray::Empty(cshape, cdtype, ctx); int32_t* dims = static_cast(value->data); - using ::tvm::ir::IntImmNode; + using ::tvm::tir::IntImmNode; for (size_t i = 0; i < ishape.size(); ++i) { if (const IntImmNode* dim = ishape[i].as()) { dims[i] = dim->value; diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 9d8d5452288c..4bfe270cc044 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -23,7 +23,7 @@ * \brief Fold axis scaling into weights of * conv/dense operators. */ -#include +#include #include #include #include @@ -955,7 +955,7 @@ Pass ForwardFoldScaleAxis() { relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") @@ -968,7 +968,7 @@ Pass BackwardFoldScaleAxis() { relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index e774549ece68..3d37d61448f0 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -24,7 +24,7 @@ * \brief This is a backend-aware optimization pass. * Fuse necessary ops into a single one. */ -#include +#include #include #include #include @@ -982,7 +982,7 @@ Pass FuseOps(int fuse_opt_level) { return Downcast(FuseOps(f, opt_level, m)); }; return CreateFunctionPass(pass_func, 1, "FuseOps", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.FuseOps") diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 8d89c0a53d3a..20958ab598da 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -22,7 +22,7 @@ * \brief API for Automatic Differentiation for the Relay IR. */ -#include +#include #include #include #include diff --git a/src/relay/pass/infer_layout_util.h b/src/relay/pass/infer_layout_util.h index b2cef6c12f70..9ecd0bf1f45b 100644 --- a/src/relay/pass/infer_layout_util.h +++ b/src/relay/pass/infer_layout_util.h @@ -27,8 +27,9 @@ #ifndef TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ #define TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ -#include +#include #include +#include #include #include #include "pattern_util.h" diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc index 12e72cf1fad2..4480861ec267 100644 --- a/src/relay/pass/legalize.cc +++ b/src/relay/pass/legalize.cc @@ -102,7 +102,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); }; - return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize); diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index 9e3e95e14e5a..a03924d1c965 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include "pattern_util.h" namespace tvm { diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc index 634affebdebd..e58bf61aa496 100644 --- a/src/relay/pass/partition_graph.cc +++ b/src/relay/pass/partition_graph.cc @@ -212,10 +212,10 @@ class Partitioner : public ExprMutator { Expr arg0 = call->args[0]; std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); subgraph_func = - FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImmNode::make(name)); + FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, - tvm::ir::StringImmNode::make(compiler_attrs->compiler)); + tvm::tir::StringImmNode::make(compiler_attrs->compiler)); return CallNode::make(subgraph_func, args); } } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 801fc1757337..b5431cbab579 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -27,13 +27,15 @@ #define TVM_RELAY_PASS_PATTERN_UTIL_H_ #include -#include +#include #include #include #include #include #include #include +#include + #include #include #include @@ -117,7 +119,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, } ++j; } else if (i >= base) { - if (!is_const_int(trhs->shape[i - base], 1)) { + if (!tir::is_const_int(trhs->shape[i - base], 1)) { return false; } if (rhs_value != nullptr) { @@ -182,8 +184,8 @@ inline bool IsDepthwiseConv2D(const Call& call, static const Layout kOIHW("OIHW"); const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW); auto wshape = bilayout.ForwardShape(call->args[1]->type_as()->shape); - return is_const_int(wshape[0], param->groups) && - is_const_int(wshape[1], 1); + return tir::is_const_int(wshape[0], param->groups) && + tir::is_const_int(wshape[1], 1); } /*! @@ -196,7 +198,7 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { auto tweight = call->args[1]->type_as(); auto index = param->kernel_layout.find('O'); CHECK_NE(index, std::string::npos); - auto channels = as_const_int(tweight->shape[index]); + auto channels = tir::as_const_int(tweight->shape[index]); return *channels; } diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 32fc06f93c16..bc8f9e144e0a 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -188,7 +188,7 @@ Pass SimplifyInference() { return Downcast(SimplifyInference(f)); }; return CreateFunctionPass(pass_func, 0, "SimplifyInference", - {ir::StringImmNode::make("InferType")}); + {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") diff --git a/src/relay/pass/transform_layout.h b/src/relay/pass/transform_layout.h index d283a239f2f6..f7845242d21b 100644 --- a/src/relay/pass/transform_layout.h +++ b/src/relay/pass/transform_layout.h @@ -26,7 +26,7 @@ #ifndef TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ #define TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ -#include +#include #include #include #include diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index be5ac516d46a..ec6d721cf7b0 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -21,7 +21,7 @@ * \file type_solver.cc * \brief Type solver implementations. */ -#include +#include #include #include #include @@ -42,7 +42,7 @@ class TypeSolver::Reporter : public TypeReporterNode { } bool Assert(const IndexExpr& cond) final { - if (const int64_t* pdiff = as_const_int(cond)) { + if (const int64_t* pdiff = tir::as_const_int(cond)) { return pdiff[0]; } return true; @@ -51,7 +51,7 @@ class TypeSolver::Reporter : public TypeReporterNode { bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) final { // early warning constant case. IndexExpr diff = lhs - rhs; - if (const int64_t* pdiff = as_const_int(diff)) { + if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; } return true; @@ -184,7 +184,7 @@ class TypeSolver::Unifier : public TypeFunctor { return Any::make(); } - auto left_index0 = ulhs.as(); + auto left_index0 = ulhs.as(); auto right_index0 = urhs.as(); if (left_index0 && right_index0) { solver_->shape_uf_.Set(ulhs, urhs); @@ -192,7 +192,7 @@ class TypeSolver::Unifier : public TypeFunctor { } auto left_index1 = ulhs.as(); - auto right_index1 = urhs.as(); + auto right_index1 = urhs.as(); if (left_index1 && right_index1) { solver_->shape_uf_.Set(urhs, ulhs); return ulhs; diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 26093f2d43ce..360543b219ab 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -22,7 +22,7 @@ * \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis. */ -#include +#include #include #include #include diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 2335c598fef9..5ebd9b951855 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -21,7 +21,7 @@ * \file src/relay/qnn/op/convolution.cc * \brief Property def of qnn convolution operator. */ -#include +#include #include #include #include @@ -69,7 +69,7 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, } bool is_depthwise(const Conv2DAttrs* param) { - return param->channels.defined() && tvm::ir::Equal(param->channels, param->groups) && + return param->channels.defined() && tvm::tir::Equal(param->channels, param->groups) && param->groups != 1; } diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 6c99ae13a85c..2d4bcb4fd6ab 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -25,8 +25,8 @@ #ifndef TVM_RELAY_QNN_UTIL_H_ #define TVM_RELAY_QNN_UTIL_H_ -#include -#include +#include +#include #include #include #include @@ -49,7 +49,7 @@ static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { - auto* min_value = as_const_int(tvm::min_value(dtype)); + auto* min_value = tir::as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); return static_cast(min_value[0]); } else { @@ -62,7 +62,7 @@ static inline const int32_t GetQmax(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { - auto* max_value = as_const_int(tvm::max_value(dtype)); + auto* max_value = tir::as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); return static_cast(max_value[0]); } else { @@ -88,7 +88,7 @@ static inline Expr Requantize(const Expr& data, const Array& input_sh } static inline int64_t get_const_int(const tvm::PrimExpr& x) { - auto* value_ptr = as_const_int(x); + auto* value_ptr = tir::as_const_int(x); CHECK(value_ptr) << "Expr is not a constant int"; return value_ptr[0]; } diff --git a/src/target/target.cc b/src/target/target.cc index 014d3f9ff09a..53d07d3cc3bd 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include @@ -62,39 +62,39 @@ Target CreateTarget(const std::string& target_name, std::string device_flag = "-device="; std::string keys_flag = "-keys="; for (auto& item : options) { - t->options_array.push_back(ir::StringImmNode::make(item)); + t->options_array.push_back(tir::StringImmNode::make(item)); if (item.find(libs_flag) == 0) { std::stringstream ss(item.substr(libs_flag.length())); std::string lib_item; while (std::getline(ss, lib_item, ',')) { - t->libs_array.push_back(ir::StringImmNode::make(lib_item)); + t->libs_array.push_back(tir::StringImmNode::make(lib_item)); } } else if (item.find(device_flag) == 0) { t->device_name = item.substr(device_flag.length()); - t->keys_array.push_back(ir::StringImmNode::make(t->device_name)); + t->keys_array.push_back(tir::StringImmNode::make(t->device_name)); } else if (item.find(keys_flag) == 0) { std::stringstream ss(item.substr(keys_flag.length())); std::string key_item; while (std::getline(ss, key_item, ',')) { - t->keys_array.push_back(ir::StringImmNode::make(key_item)); + t->keys_array.push_back(tir::StringImmNode::make(key_item)); } } } if (t->device_name.length() > 0) { - t->keys_array.push_back(ir::StringImmNode::make(t->device_name)); + t->keys_array.push_back(tir::StringImmNode::make(t->device_name)); } t->device_type = kDLCPU; t->thread_warp_size = 1; if (target_name == "c" && t->device_name == "micro_dev") { t->device_type = kDLMicroDev; } else if (target_name == "c" || target_name == "llvm") { - t->keys_array.push_back(ir::StringImmNode::make("cpu")); + t->keys_array.push_back(tir::StringImmNode::make("cpu")); } else if (target_name == "cuda" || target_name == "nvptx") { t->device_type = kDLGPU; - t->keys_array.push_back(ir::StringImmNode::make("cuda")); - t->keys_array.push_back(ir::StringImmNode::make("gpu")); + t->keys_array.push_back(tir::StringImmNode::make("cuda")); + t->keys_array.push_back(tir::StringImmNode::make("gpu")); t->max_num_threads = 1024; t->thread_warp_size = 32; } else if (target_name == "rocm" || target_name == "opencl") { @@ -104,8 +104,8 @@ Target CreateTarget(const std::string& target_name, } else { t->device_type = kDLROCM; } - t->keys_array.push_back(ir::StringImmNode::make(target_name)); - t->keys_array.push_back(ir::StringImmNode::make("gpu")); + t->keys_array.push_back(tir::StringImmNode::make(target_name)); + t->keys_array.push_back(tir::StringImmNode::make("gpu")); t->max_num_threads = 256; if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; @@ -116,20 +116,20 @@ Target CreateTarget(const std::string& target_name, } else { t->device_type = kDLVulkan; } - t->keys_array.push_back(ir::StringImmNode::make(target_name)); - t->keys_array.push_back(ir::StringImmNode::make("gpu")); + t->keys_array.push_back(tir::StringImmNode::make(target_name)); + t->keys_array.push_back(tir::StringImmNode::make("gpu")); t->max_num_threads = 256; } else if (target_name == "sdaccel") { t->device_type = kDLOpenCL; - t->keys_array.push_back(ir::StringImmNode::make("sdaccel")); - t->keys_array.push_back(ir::StringImmNode::make("hls")); + t->keys_array.push_back(tir::StringImmNode::make("sdaccel")); + t->keys_array.push_back(tir::StringImmNode::make("hls")); } else if (target_name == "aocl" || target_name == "aocl_sw_emu") { t->device_type = kDLAOCL; - t->keys_array.push_back(ir::StringImmNode::make("aocl")); - t->keys_array.push_back(ir::StringImmNode::make("hls")); + t->keys_array.push_back(tir::StringImmNode::make("aocl")); + t->keys_array.push_back(tir::StringImmNode::make("hls")); } else if (target_name == "opengl") { t->device_type = kOpenGL; - t->keys_array.push_back(ir::StringImmNode::make("opengl")); + t->keys_array.push_back(tir::StringImmNode::make("opengl")); } else if (target_name == "stackvm") { t->device_type = kDLCPU; } else if (target_name == "ext_dev") { @@ -165,7 +165,7 @@ TVM_REGISTER_GLOBAL("_TargetFromString") std::vector TargetNode::keys() const { std::vector result; for (auto& expr : keys_array) { - result.push_back(expr.as()->value); + result.push_back(expr.as()->value); } return result; } @@ -173,7 +173,7 @@ std::vector TargetNode::keys() const { std::vector TargetNode::options() const { std::vector result; for (auto& expr : options_array) { - result.push_back(expr.as()->value); + result.push_back(expr.as()->value); } return result; } @@ -181,7 +181,7 @@ std::vector TargetNode::options() const { std::unordered_set TargetNode::libs() const { std::unordered_set result; for (auto& expr : libs_array) { - result.insert(expr.as()->value); + result.insert(expr.as()->value); } return result; } diff --git a/src/lang/buffer.cc b/src/tir/ir/buffer.cc similarity index 92% rename from src/lang/buffer.cc rename to src/tir/ir/buffer.cc index 8c7108ae6819..c2fc581a3904 100644 --- a/src/lang/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -20,23 +20,23 @@ /*! * \file buffer.cc */ -#include +#include #include -#include -#include +#include +#include #include #include -#include "../arith/compute_expr.h" +#include "../../arith/compute_expr.h" namespace tvm { - +namespace tir { // TODO(tqchen): change to floormod/div -using IndexMod = ir::FloorModNode; -using IndexDiv = ir::FloorDivNode; +using IndexMod = tir::FloorModNode; +using IndexDiv = tir::FloorDivNode; Array SimplifyArray(Array array) { for (size_t i = 0; i < array.size(); ++i) { - array.Set(i, ir::Simplify(array[i])); + array.Set(i, tir::Simplify(array[i])); } return array; } @@ -58,7 +58,7 @@ Buffer decl_buffer(Array shape, // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr &expr) { - using namespace ir; + using namespace tir; std::vector ret; std::stack split_buffer; split_buffer.push(&expr); @@ -87,7 +87,7 @@ inline std::vector ExprSplitAddition(const PrimExpr &expr) { inline std::pair MergeMulModInner(const PrimExpr &mult_expr, const PrimExpr &mod_l_expr, const PrimExpr &mod_r_expr) { - using namespace ir; + using namespace tir; const MulNode* mult_ptr = mult_expr.as(); if (!mult_ptr) return std::make_pair(false, PrimExpr()); PrimExpr mult_outer = mult_ptr->b; @@ -155,7 +155,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { - using namespace ir; + using namespace tir; *has_mult = false; *has_mod = false; for (const PrimExpr* ele : eles) { @@ -181,7 +181,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized inline PrimExpr MergeMulMod(const PrimExpr &base) { - using namespace ir; + using namespace tir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and // a list that contain all the elements that match Mod. @@ -285,7 +285,7 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp offset = offset * make_const(offset.dtype(), dtype.lanes()); } if (dtype.lanes() != 1) { - return ir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); + return tir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); } else { return offset; } @@ -299,13 +299,13 @@ PrimExpr Buffer::vload(Array begin, DataType dtype) const { << "Cannot load " << dtype << " from buffer of " << n->dtype; if (dtype == DataType::Bool()) { - return ir::CastNode::make( + return tir::CastNode::make( DataType::Bool(), - ir::LoadNode::make( + tir::LoadNode::make( DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), const_true())); } else { - return ir::LoadNode::make( + return tir::LoadNode::make( dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } @@ -320,12 +320,12 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { << "Cannot load " << dtype << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { - return ir::StoreNode::make(n->data, - ir::CastNode::make(DataType::Int(8), value), + return tir::StoreNode::make(n->data, + tir::CastNode::make(DataType::Int(8), value), BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { - return ir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), + return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } } @@ -349,7 +349,7 @@ Buffer Buffer::MakeStrideView() const { Buffer Buffer::MakeSlice(Array begins, Array extents) const { const BufferNode* n = operator->(); begins = SimplifyArray(begins); - PrimExpr elem_offset = ir::Simplify(ElemOffset(n, begins)); + PrimExpr elem_offset = tir::Simplify(ElemOffset(n, begins)); Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; @@ -358,7 +358,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { if (!is_zero(begins[i]) || - !is_zero(ir::Simplify(extents[i] - n->shape[i]))) { + !is_zero(tir::Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } @@ -394,22 +394,22 @@ PrimExpr Buffer::access_ptr(int access_mask, int highest_dim = 0; extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; } else { - extent = arith::ComputeReduce(self->shape, PrimExpr()) - offset; + extent = arith::ComputeReduce(self->shape, PrimExpr()) - offset; } PrimExpr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { - e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); + e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.dtype(), content_lanes); elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); } else { - e_dtype = ir::TypeAnnotation(self->dtype); + e_dtype = tir::TypeAnnotation(self->dtype); } Array acc_args{ e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return ir::CallNode::make( - ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::CallNode::Intrinsic); + return tir::CallNode::make( + ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); } Buffer BufferNode::make(Var data, @@ -447,7 +447,7 @@ Buffer BufferNode::make(Var data, n->buffer_type = buffer_type; if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { for (size_t i = 0; i < n->shape.size(); ++i) { - n->strides.push_back(tvm::var("stride")); + n->strides.push_back(Var("stride")); } } return Buffer(n); @@ -460,5 +460,5 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_REGISTER_NODE_TYPE(BufferNode); - +} // namespace tir } // namespace tvm diff --git a/src/lang/data_layout.cc b/src/tir/ir/data_layout.cc similarity index 94% rename from src/lang/data_layout.cc rename to src/tir/ir/data_layout.cc index ba5e4adeb66f..59fa2af41631 100644 --- a/src/lang/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -21,11 +21,15 @@ * \file src/lang/data_layout.cc * \brief Data Layout expression. */ -#include -#include +#include +#include #include namespace tvm { +namespace tir { +using tir::Var; +using tir::IterVar; +using tir::IterVarNode; TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode); @@ -104,13 +108,13 @@ Layout::Layout(const std::string& name) { // NOLINT(*) std::string shape_name("_shape"); shape_name.insert(0, 1, c); IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), - Var(std::string(1, c)), kDataPar); + Var(std::string(1, c)), tir::kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " for dimension " << c; IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), - Var(std::string(1, c)), kDataPar); + Var(std::string(1, c)), tir::kDataPar); node->axes.push_back(axis); factor = 0; } else if (c >= '0' && c <= '9') { @@ -172,7 +176,7 @@ Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) for (size_t i = 0; i <= this->ndim(); ++i) { if (i == target_pos) { new_layout.push_back(IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), - Var(axis.ToSubordinate().name()), kDataPar)); + Var(axis.ToSubordinate().name()), tir::kDataPar)); } if (i == this->ndim()) break; new_layout.push_back(axes[i]); @@ -228,7 +232,7 @@ inline bool GetStoreRule(Array* rule, } } } - if (is_zero(store)) { + if (tir::is_zero(store)) { // Not convertible return false; } @@ -251,12 +255,12 @@ inline Array TransformIndex(const Array& src_index, const Array& src_axis, const Array& transform_rule) { Array result; - std::unordered_map bind_map; + std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } for (PrimExpr rule : transform_rule) { - result.push_back(ir::Simplify(ir::Substitute(rule, bind_map))); + result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); } return result; } @@ -287,12 +291,12 @@ inline Array TransformShape(const Array& src_shape, // for major-axis, bind the corresponding size // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, // e.g., (C * 16 + c) / 32 - std::unordered_map bind_map; + std::unordered_map bind_map; std::unordered_set symbolic_var_set; for (size_t i = 0; i < src_shape.size(); ++i) { PrimExpr orig_shape = src_shape[i]; IterVar orig_axis = src_axis[i]; - if (orig_shape.as()) { + if (orig_shape.as()) { symbolic_var_set.insert(i); } if (!LayoutAxis::Get(orig_axis).IsPrimal()) { @@ -322,9 +326,9 @@ inline Array TransformShape(const Array& src_shape, result.push_back(axis->dom->extent); } else { if (symbolic_var_set.count(i)) { - result.push_back(ir::AnyNode::make()); + result.push_back(tir::AnyNode::make()); } else { - result.push_back(ir::Simplify(ir::Substitute(rule, bind_map))); + result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); } } } @@ -367,5 +371,5 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() << ")"; }); - +} // namespace tir } // namespace tvm diff --git a/src/lang/ir.cc b/src/tir/ir/expr.cc similarity index 52% rename from src/lang/ir.cc rename to src/tir/ir/expr.cc index d51307537038..0cdbfdc71c97 100644 --- a/src/lang/ir.cc +++ b/src/tir/ir/expr.cc @@ -18,29 +18,64 @@ */ /*! - * \file ir.cc + * \file expr.cc */ - -#include -#include -#include +#include +#include +#include +#include #include +#include #include "../pass/ir_util.h" namespace tvm { -namespace ir { +namespace tir { -// constructors +Var::Var(std::string name_hint, DataType t) + : Var(make_object(t, name_hint)) {} -PrimExpr FloatImm(DataType t, double value) { - CHECK_EQ(t.lanes(), 1) - << "ValueError: FloatImm can only take scalar"; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return PrimExpr(node); +VarNode::VarNode(DataType t, std::string name_hint) { + this->dtype = t; + this->name_hint = std::move(name_hint); +} + +SizeVar::SizeVar(std::string name_hint, DataType t) + : SizeVar(make_object(t, name_hint)) {} + +SizeVarNode::SizeVarNode(DataType t, std::string name_hint) + : VarNode(t, std::move(name_hint)) {} + +IterVar IterVarNode::make(Range dom, + Var var, + IterVarType t, + std::string thread_tag) { + ObjectPtr n = make_object(); + n->dom = dom; + n->var = var; + n->iter_type = t; + n->thread_tag = thread_tag; + return IterVar(n); } +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "iter_var("; + if (op->var->name_hint.length() != 0) { + p->stream << op->var->name_hint << ", "; + } + if (op->dom.defined()) { + p->stream << op->dom; + } + if (op->thread_tag.length() != 0) { + p->stream << ", " << op->thread_tag; + } + p->stream << ")"; + }); + + +TVM_REGISTER_NODE_TYPE(IterVarNode); + PrimExpr StringImmNode::make(std::string value) { ObjectPtr node = make_object(); node->dtype = DataType::Handle(); @@ -170,8 +205,8 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { const char* CallNode::vectorizable_intrinsics[] = { "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", - "log", "sin", "cos", "pow", ir::CallNode::shift_left, ir::CallNode::shift_right, - ir::CallNode::likely, ir::CallNode::popcount + "log", "sin", "cos", "pow", tir::CallNode::shift_left, tir::CallNode::shift_right, + tir::CallNode::likely, tir::CallNode::popcount }; bool CallNode::is_vectorizable() const { @@ -304,245 +339,6 @@ PrimExpr AnyNode::make() { return PrimExpr(n); } -Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { - CHECK(value.defined()); - CHECK(body.defined()); - CHECK_EQ(value.dtype(), var.dtype()); - - ObjectPtr node = make_object(); - node->var = std::move(var); - node->value = std::move(value); - node->body = std::move(body); - return Stmt(node); -} - -Stmt AttrStmtNode::make(ObjectRef node, - std::string attr_key, - PrimExpr value, - Stmt body) { - auto n = make_object(); - n->node = node; - n->attr_key = std::move(attr_key); - n->value = std::move(value); - n->body = std::move(body); - return Stmt(n); -} - -Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { - CHECK(condition.defined()); - CHECK(message.dtype() == DataType::Int(32) || - message.as()) - << "TypeError: AssertStmt message must be an int or string:" - << message << "\n"; - - ObjectPtr node = make_object(); - node->condition = std::move(condition); - node->message = std::move(message); - node->body = std::move(body); - return Stmt(node); -} - -Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { - CHECK(body.defined()); - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->is_producer = is_producer; - node->body = std::move(body); - return Stmt(node); -} - -Stmt ForNode::make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body) { - CHECK(min.defined()); - CHECK(extent.defined()); - CHECK(min.dtype().is_scalar()); - CHECK(extent.dtype().is_scalar()); - CHECK(loop_var.dtype().is_scalar()); - CHECK(body.defined()); - - ObjectPtr node = make_object(); - node->loop_var = std::move(loop_var); - node->min = std::move(min); - node->extent = std::move(extent); - node->for_type = for_type; - node->device_api = device_api; - node->body = std::move(body); - return Stmt(node); -} - -Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { - CHECK(value.defined()); - CHECK(index.defined()); - CHECK(predicate.defined()); - CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); - CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); - - ObjectPtr node = make_object(); - node->buffer_var = std::move(buffer_var); - node->value = std::move(value); - node->index = std::move(index); - node->predicate = std::move(predicate); - return Stmt(node); -} - -Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { - CHECK(value_index >=0 && value_index < func->num_outputs()) - << "value index output function return value bound"; - CHECK(value.defined()) << "Provide of undefined value\n"; - - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].defined()) << "Provide to undefined location\n"; - } - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->value = std::move(value); - node->args = std::move(args); - return Stmt(node); -} - -Stmt AllocateNode::make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, - Stmt body, - PrimExpr new_expr, - std::string free_function) { - for (size_t i = 0; i < extents.size(); ++i) { - CHECK(extents[i].defined()); - CHECK(extents[i].dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->buffer_var = std::move(buffer_var); - node->dtype = dtype; - node->extents = std::move(extents); - node->condition = std::move(condition); - node->body = std::move(body); - node->new_expr = std::move(new_expr); - node->free_function = std::move(free_function); - return Stmt(node); -} - -int32_t AllocateNode::constant_allocation_size(const Array& extents) { - int64_t result = 1; - for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode *int_size = extents[i].as()) { - result *= int_size->value; - if (result > std::numeric_limits::max()) { - return 0; - } - } else { - return 0; - } - } - return static_cast(result); -} - -Stmt FreeNode::make(Var buffer_var) { - ObjectPtr node = make_object(); - node->buffer_var = buffer_var; - return Stmt(node); -} - -Stmt RealizeNode::make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body) { - for (size_t i = 0; i < bounds.size(); ++i) { - CHECK(bounds[i]->min.defined()); - CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.dtype().is_scalar()); - CHECK(bounds[i]->extent.dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->dtype = dtype; - node->bounds = std::move(bounds); - node->condition = std::move(condition); - node->body = std::move(body); - return Stmt(node); -} - -Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { - for (size_t i = 0; i < bounds.size(); ++i) { - CHECK(bounds[i]->min.defined()); - CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.dtype().is_scalar()); - CHECK(bounds[i]->extent.dtype().is_scalar()); - } - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->dtype = dtype; - node->bounds = std::move(bounds); - return Stmt(node); -} - -SeqStmt::SeqStmt(Array seq) { - auto node = make_object(); - node->seq = std::move(seq); - data_ = std::move(node); -} - -Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { - CHECK(condition.defined()); - CHECK(then_case.defined()); - // else_case may be null. - - ObjectPtr node = make_object(); - node->condition = std::move(condition); - node->then_case = std::move(then_case); - node->else_case = std::move(else_case); - return Stmt(node); -} - -Stmt EvaluateNode::make(PrimExpr value) { - CHECK(value.defined()); - - ObjectPtr node = make_object(); - node->value = std::move(value); - return Stmt(node); -} - -// Printers - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - auto& stream = p->stream; - switch (op->dtype.bits()) { - case 64: - stream << op->value; - break; - case 32: - stream << op->value << 'f'; - break; - case 16: - stream << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - }); - TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); @@ -806,311 +602,10 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "let " << op->var << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "// attr ["; - p->Print(op->node); - p->stream << "] " - << op->attr_key << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "assert("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->message); - p->stream << ")\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - if (op->is_producer) { - p->PrintIndent(); - p->stream << "produce " << op->func->func_name() << " {\n"; - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - p->PrintIndent(); - p->stream << "}\n"; - } else { - p->Print(op->body); - } - }); - -std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) - switch (type) { - case ForType::Serial: - out << "for"; - break; - case ForType::Parallel: - out << "parallel"; - break; - case ForType::Unrolled: - out << "unrolled"; - break; - case ForType::Vectorized: - out << "vectorized"; - break; - } - return out; -} - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->for_type << " (" << op->loop_var << ", "; - p->Print(op->min); - p->stream << ", "; - p->Print(op->extent); - p->stream << ") {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; -}); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "] = "; - p->Print(op->value); - if (!is_one(op->predicate)) { - p->stream << " if "; - p->Print(op->predicate); - } - p->stream << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->func->func_name() << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - p->stream << " ="; - p->Print(op->value); - p->stream << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - p->stream << " * "; - p->Print(op->extents[i]); - } - p->stream << "]"; - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << "\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "free " << op->buffer_var; - p->stream << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "realize " << op->func->func_name() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "prefetch " << op->func->func_name() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - for (Stmt stmt : op->seq) { - p->Print(stmt); - } - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - while (true) { - p->stream << "if (" << op->condition << ") {\n"; - p->indent += 2; - p->Print(op->then_case); - p->indent -= 2; - - if (!op->else_case.defined()) { - break; - } - - if (const IfThenElseNode *nested_if = op->else_case.as()) { - p->PrintIndent(); - p->stream << "} else "; - op = nested_if; - } else { - p->PrintIndent(); - p->stream << "} else {\n"; - p->indent += 2; - p->Print(op->else_case); - p->indent -= 2; - break; - } - } - p->PrintIndent(); - p->stream << "}\n"; -}); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->Print(op->value); - p->stream << "\n"; - }); - -template -void PrintList(const Array &exprs, NodePrinter* p) { - for (size_t i = 0; i < exprs.size(); ++i) { - p->Print(exprs[i]); - if (i < exprs.size() - 1) { - p->stream << ", "; - } - } -} - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "shuffle("; - PrintList(op->vectors, p); - p->stream << ", "; - PrintList(op->indices, p); - p->stream << ")"; - }); - -// Container printer -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '['; - for (size_t i = 0 ; i < op->data.size(); ++i) { - if (i != 0) { - p->stream << ", "; - } - p->Print(op->data[i]); - } - p->stream << ']'; +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + p->stream << "?"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; - } - p->Print(it->first); - p->stream << ": "; - p->Print(it->second); - } - p->stream << '}'; - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; - } - p->stream << '\"' << it->first << "\": "; - p->Print(it->second); - } - p->stream << '}'; - }); - TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); @@ -1133,17 +628,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) << ")"; }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - p->stream << "?"; -}); - -TVM_REGISTER_NODE_TYPE(CommReducerNode); -TVM_REGISTER_NODE_TYPE(ReduceNode); -TVM_REGISTER_NODE_TYPE(AnyNode); -TVM_REGISTER_NODE_TYPE(AttrStmtNode); -TVM_REGISTER_NODE_TYPE(FloatImmNode); -TVM_REGISTER_NODE_TYPE(IntImmNode); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(VarNode); @@ -1171,21 +655,9 @@ TVM_REGISTER_NODE_TYPE(LoadNode); TVM_REGISTER_NODE_TYPE(RampNode); TVM_REGISTER_NODE_TYPE(BroadcastNode); TVM_REGISTER_NODE_TYPE(ShuffleNode); -TVM_REGISTER_NODE_TYPE(PrefetchNode); -TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_NODE_TYPE(LetStmtNode); -TVM_REGISTER_NODE_TYPE(AssertStmtNode); -TVM_REGISTER_NODE_TYPE(ProducerConsumerNode); -TVM_REGISTER_NODE_TYPE(ForNode); -TVM_REGISTER_NODE_TYPE(StoreNode); -TVM_REGISTER_NODE_TYPE(ProvideNode); -TVM_REGISTER_NODE_TYPE(AllocateNode); -TVM_REGISTER_NODE_TYPE(FreeNode); -TVM_REGISTER_NODE_TYPE(RealizeNode); -TVM_REGISTER_NODE_TYPE(SeqStmtNode); -TVM_REGISTER_NODE_TYPE(IfThenElseNode); -TVM_REGISTER_NODE_TYPE(EvaluateNode); +TVM_REGISTER_NODE_TYPE(CommReducerNode); +TVM_REGISTER_NODE_TYPE(ReduceNode); +TVM_REGISTER_NODE_TYPE(AnyNode); -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc new file mode 100644 index 000000000000..f8371f3765a4 --- /dev/null +++ b/src/tir/ir/expr_functor.cc @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file expr_functor.cc + */ +#include +#include "functor_common.h" + +namespace tvm { +namespace tir { + +void ExprVisitor::VisitExpr_(const VarNode* op) {} + +void ExprVisitor::VisitExpr_(const SizeVarNode* op) { + this->VisitExpr_(static_cast(op)); +} + +void ExprVisitor::VisitExpr_(const LoadNode* op) { + this->VisitExpr(op->index); + this->VisitExpr(op->predicate); +} + +void ExprVisitor::VisitExpr_(const LetNode* op) { + this->VisitExpr(op->value); + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + +#define DEFINE_BINOP_VISIT_(OP) \ + void ExprVisitor::VisitExpr_(const OP* op) { \ + this->VisitExpr(op->a); \ + this->VisitExpr(op->b); \ + } + +DEFINE_BINOP_VISIT_(AddNode); +DEFINE_BINOP_VISIT_(SubNode); +DEFINE_BINOP_VISIT_(MulNode); +DEFINE_BINOP_VISIT_(DivNode); +DEFINE_BINOP_VISIT_(ModNode); +DEFINE_BINOP_VISIT_(FloorDivNode); +DEFINE_BINOP_VISIT_(FloorModNode); +DEFINE_BINOP_VISIT_(MinNode); +DEFINE_BINOP_VISIT_(MaxNode); +DEFINE_BINOP_VISIT_(EQNode); +DEFINE_BINOP_VISIT_(NENode); +DEFINE_BINOP_VISIT_(LTNode); +DEFINE_BINOP_VISIT_(LENode); +DEFINE_BINOP_VISIT_(GTNode); +DEFINE_BINOP_VISIT_(GENode); +DEFINE_BINOP_VISIT_(AndNode); +DEFINE_BINOP_VISIT_(OrNode); + +void ExprVisitor::VisitExpr_(const IntImmNode* op) {} +void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} +void ExprVisitor::VisitExpr_(const StringImmNode* op) {} + +void ExprVisitor::VisitExpr_(const ReduceNode* op) { + VisitArray(op->axis, [this](const IterVar& r) { + this->VisitExpr(r->dom->min); + this->VisitExpr(r->dom->extent); + }); + VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); + this->VisitExpr(op->condition); +} + +void ExprVisitor::VisitExpr_(const CastNode* op) { + this->VisitExpr(op->value); +} + +void ExprVisitor::VisitExpr_(const NotNode* op) { + this->VisitExpr(op->a); +} + +void ExprVisitor::VisitExpr_(const SelectNode* op) { + this->VisitExpr(op->condition); + this->VisitExpr(op->true_value); + this->VisitExpr(op->false_value); +} + +void ExprVisitor::VisitExpr_(const RampNode* op) { + this->VisitExpr(op->base); + this->VisitExpr(op->stride); +} + +void ExprVisitor::VisitExpr_(const ShuffleNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); + VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + +void ExprVisitor::VisitExpr_(const BroadcastNode* op) { + this->VisitExpr(op->value); +} + +PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { + return GetRef(op); +} + +PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { + return this->VisitExpr_(static_cast(op)); +} + +PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { + PrimExpr index = this->VisitExpr(op->index); + PrimExpr predicate = this->VisitExpr(op->predicate); + if (index.same_as(op->index) && predicate.same_as(op->predicate)) { + return GetRef(op); + } else { + return LoadNode::make(op->dtype, op->buffer_var, index, predicate); + } +} + +PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { + PrimExpr value = this->VisitExpr(op->value); + PrimExpr body = this->VisitExpr(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); + } else { + return LetNode::make(op->var, value, body); + } +} + +PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array args = MutateArray(op->args, fmutate); + + if (args.same_as(op->args)) { + return GetRef(op); + } else { + return CallNode::make(op->dtype, + op->name, + args, + op->call_type, + op->func, + op->value_index); + } +} + +#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP *op) { \ + return GetRef(op); \ + } + +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) + +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return OP::make(a, b); \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_(AddNode); +DEFINE_BIOP_EXPR_MUTATE_(SubNode); +DEFINE_BIOP_EXPR_MUTATE_(MulNode); +DEFINE_BIOP_EXPR_MUTATE_(DivNode); +DEFINE_BIOP_EXPR_MUTATE_(ModNode); +DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode); +DEFINE_BIOP_EXPR_MUTATE_(FloorModNode); +DEFINE_BIOP_EXPR_MUTATE_(MinNode); +DEFINE_BIOP_EXPR_MUTATE_(MaxNode); +DEFINE_BIOP_EXPR_MUTATE_(EQNode); +DEFINE_BIOP_EXPR_MUTATE_(NENode); +DEFINE_BIOP_EXPR_MUTATE_(LTNode); +DEFINE_BIOP_EXPR_MUTATE_(LENode); +DEFINE_BIOP_EXPR_MUTATE_(GTNode); +DEFINE_BIOP_EXPR_MUTATE_(GENode); +DEFINE_BIOP_EXPR_MUTATE_(AndNode); +DEFINE_BIOP_EXPR_MUTATE_(OrNode); + +PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { + auto fitervar = [this](const IterVar& v) { + Range r = v->dom; + PrimExpr min = this->VisitExpr(r->min); + PrimExpr extent = this->VisitExpr(r->extent); + if (min.same_as(r->min) && + extent.same_as(r->extent)) { + return v; + } else { + return IterVarNode::make( + Range::make_by_min_extent(min, extent), + v->var, v->iter_type, v->thread_tag); + } + }; + Array axis = MutateArray(op->axis, fitervar); + + auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array source = MutateArray(op->source, fexpr); + + PrimExpr condition = this->VisitExpr(op->condition); + + if (axis.same_as(op->axis) && + source.same_as(op->source) && + condition.same_as(op->condition)) { + return GetRef(op); + } else { + return ReduceNode::make( + op->combiner, source, axis, condition, op->value_index); + } +} + +PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { + PrimExpr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return CastNode::make(op->dtype, value); + } +} + +PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { + PrimExpr a = this->VisitExpr(op->a); + if (a.same_as(op->a)) { + return GetRef(op); + } else { + return NotNode::make(a); + } +} + +PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr true_value = this->VisitExpr(op->true_value); + PrimExpr false_value = this->VisitExpr(op->false_value); + if (condition.same_as(op->condition) && + true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return GetRef(op); + } else { + return SelectNode::make(condition, true_value, false_value); + } +} + +PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { + PrimExpr base = this->VisitExpr(op->base); + PrimExpr stride = this->VisitExpr(op->stride); + if (base.same_as(op->base) && + stride.same_as(op->stride)) { + return GetRef(op); + } else { + return RampNode::make(base, stride, op->lanes); + } +} + +PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { + PrimExpr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return BroadcastNode::make(value, op->lanes); + } +} + +PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { + auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + auto vectors = MutateArray(op->vectors, fexpr); + if (vectors.same_as(op->vectors)) { + return GetRef(op); + } else { + return ShuffleNode::make(vectors, op->indices); + } +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h new file mode 100644 index 000000000000..76a91ea42d42 --- /dev/null +++ b/src/tir/ir/functor_common.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tir/ir/functor_common.h + * \brief Common utils for implementing functors + */ +#ifndef TVM_TIR_IR_FUNCTOR_COMMON_H_ +#define TVM_TIR_IR_FUNCTOR_COMMON_H_ + +namespace tvm { +namespace tir { + +// Implementation of Visitors +template +inline void VisitArray(const Array& arr, F fvisit) { + for (size_t i = 0; i < arr.size(); i++) { + fvisit(arr[i]); + } +} + +// Implementation of mutators +template +inline Array MutateArray(const Array& arr, + F fmutate, + bool allow_copy_on_write = false) { + if (allow_copy_on_write) { + // if we allow copy on write, we can directly + // call the inplace mutate function. + const_cast&>(arr).MutateByApply(fmutate); + return arr; + } else { + Array copy = arr; + copy.MutateByApply(fmutate); + return copy; + } +} + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_FUNCTOR_COMMON_H_ diff --git a/src/lang/lowered_func.cc b/src/tir/ir/lowered_func.cc similarity index 94% rename from src/lang/lowered_func.cc rename to src/tir/ir/lowered_func.cc index a6b6908d95f9..a2755343fc43 100644 --- a/src/lang/lowered_func.cc +++ b/src/tir/ir/lowered_func.cc @@ -20,10 +20,10 @@ /*! * \file lowered_func.cc */ -#include +#include namespace tvm { - +namespace tir { TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); @@ -31,5 +31,5 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_REGISTER_NODE_TYPE(LoweredFuncNode); - +} // namespace tir } // namespace tvm diff --git a/src/lang/expr_operator.cc b/src/tir/ir/op.cc similarity index 71% rename from src/lang/expr_operator.cc rename to src/tir/ir/op.cc index 966087c5025d..a264915e5fb5 100644 --- a/src/lang/expr_operator.cc +++ b/src/tir/ir/op.cc @@ -21,26 +21,28 @@ * \file expr_operator.cc */ -#include -#include +#include +#include #include // Centralized header for constant folders. -#include "../arith/const_fold.h" +#include "../../arith/const_fold.h" namespace tvm { +using namespace tir; + // simple cast that only checks if type matches and cast inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return ir::CastNode::make(t, value); + return tir::CastNode::make(t, value); } PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { - return ir::CallNode::make( - t, ir::intrinsic::tvm_large_uint_imm, + return tir::CallNode::make( + t, tir::intrinsic::tvm_large_uint_imm, {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, - ir::CallNode::PureIntrinsic); + tir::CallNode::PureIntrinsic); } // The public function with a quick checking path. @@ -49,9 +51,9 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); if (ltype.lanes() == 1 && rtype.lanes() != 1) { - lhs = ir::BroadcastNode::make(lhs, rtype.lanes()); + lhs = tir::BroadcastNode::make(lhs, rtype.lanes()); } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { - rhs = ir::BroadcastNode::make(rhs, ltype.lanes()); + rhs = tir::BroadcastNode::make(rhs, ltype.lanes()); } else { CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; @@ -88,7 +90,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) // maximum and min limits PrimExpr max_value(const DataType& dtype) { - using namespace ir; + using namespace tir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { @@ -120,7 +122,7 @@ PrimExpr max_value(const DataType& dtype) { } PrimExpr min_value(const DataType& dtype) { - using namespace ir; + using namespace tir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { @@ -145,6 +147,7 @@ PrimExpr min_value(const DataType& dtype) { return PrimExpr(); } +namespace tir { template inline bool ConstPowerHelper(ValueType val, int *shift) { if (val <= 0) return false; @@ -160,15 +163,16 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { } bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { - if (const auto* op = x.as()) { + if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); } else { return false; } } +} // namespace tir PrimExpr cast(const DataType& t, PrimExpr value) { - using ir::FloatImmNode; + using tir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { @@ -177,7 +181,7 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } - return ir::CastNode::make(t, value); + return tir::CastNode::make(t, value); } else { if (value.dtype().lanes() == 1) { // manually unroll cast @@ -188,34 +192,34 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { - value = ir::CastNode::make(vtype, value); + value = tir::CastNode::make(vtype, value); } } - return ir::BroadcastNode::make(value, t.lanes()); + return tir::BroadcastNode::make(value, t.lanes()); } else { CHECK(value.dtype().lanes() == t.lanes()); - return ir::CastNode::make(t, value); + return tir::CastNode::make(t, value); } } } PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return ir::CallNode::make( - t, ir::CallNode::reinterpret, { value }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + t, tir::CallNode::reinterpret, { value }, tir::CallNode::PureIntrinsic); } PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::AddNode::make(a, b); + return tir::AddNode::make(a, b); } // negation PrimExpr operator-(PrimExpr a) { - using ir::IntImmNode; - using ir::FloatImmNode; + using tir::IntImmNode; + using tir::FloatImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); if (pa) return IntImm(a.dtype(), -pa->value); @@ -225,23 +229,23 @@ PrimExpr operator-(PrimExpr a) { PrimExpr operator-(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::SubNode::make(a, b); + return tir::SubNode::make(a, b); } PrimExpr operator*(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::MulNode::make(a, b); + return tir::MulNode::make(a, b); } PrimExpr div(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::DivNode::make(a, b); + return tir::DivNode::make(a, b); } PrimExpr truncdiv(PrimExpr a, PrimExpr b) { @@ -252,9 +256,9 @@ PrimExpr truncdiv(PrimExpr a, PrimExpr b) { PrimExpr truncmod(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::ModNode::make(a, b); + return tir::ModNode::make(a, b); } PrimExpr operator/(PrimExpr a, PrimExpr b) { @@ -278,18 +282,18 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::FloorDivNode::make(a, b); + return tir::FloorDivNode::make(a, b); } PrimExpr floormod(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::FloorModNode::make(a, b); + return tir::FloorModNode::make(a, b); } PrimExpr min(PrimExpr a, PrimExpr b) { @@ -301,9 +305,9 @@ PrimExpr min(PrimExpr a, PrimExpr b) { if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::MinNode::make(a, b); + return tir::MinNode::make(a, b); } PrimExpr max(PrimExpr a, PrimExpr b) { @@ -315,9 +319,9 @@ PrimExpr max(PrimExpr a, PrimExpr b) { if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::MaxNode::make(a, b); + return tir::MaxNode::make(a, b); } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { @@ -331,84 +335,84 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - return ir::CallNode::make( + return tir::CallNode::make( true_value.dtype(), - ir::intrinsic::tvm_if_then_else, + tir::intrinsic::tvm_if_then_else, {cond, true_value, false_value}, - ir::CallNode::PureIntrinsic); + tir::CallNode::PureIntrinsic); } PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return ir::CallNode::make(cond.dtype(), - ir::CallNode::likely, + return tir::CallNode::make(cond.dtype(), + tir::CallNode::likely, { cond }, - ir::CallNode::PureIntrinsic); + tir::CallNode::PureIntrinsic); } PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::GTNode::make(a, b); + return tir::GTNode::make(a, b); } PrimExpr operator>=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::GENode::make(a, b); + return tir::GENode::make(a, b); } PrimExpr operator<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::LTNode::make(a, b); + return tir::LTNode::make(a, b); } PrimExpr operator<=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::LENode::make(a, b); + return tir::LENode::make(a, b); } PrimExpr operator==(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::EQNode::make(a, b); + return tir::EQNode::make(a, b); } PrimExpr operator!=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::NENode::make(a, b); + return tir::NENode::make(a, b); } PrimExpr operator&&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::AndNode::make(a, b); + return tir::AndNode::make(a, b); } PrimExpr operator||(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::OrNode::make(a, b); + return tir::OrNode::make(a, b); } PrimExpr operator!(PrimExpr a) { CHECK(a.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a); + PrimExpr ret = arith::TryConstFold(a); if (ret.defined()) return ret; - return ir::NotNode::make(a); + return tir::NotNode::make(a); } PrimExpr operator>>(PrimExpr a, PrimExpr b) { @@ -420,8 +424,8 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return ir::CallNode::make( - a.dtype(), ir::CallNode::shift_right, { a, b }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + a.dtype(), tir::CallNode::shift_right, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator<<(PrimExpr a, PrimExpr b) { @@ -433,8 +437,8 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return ir::CallNode::make( - a.dtype(), ir::CallNode::shift_left, { a, b }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + a.dtype(), tir::CallNode::shift_left, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator&(PrimExpr a, PrimExpr b) { @@ -443,8 +447,8 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); - return ir::CallNode::make( - a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + a.dtype(), tir::CallNode::bitwise_and, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator|(PrimExpr a, PrimExpr b) { @@ -453,8 +457,8 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); - return ir::CallNode::make( - a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + a.dtype(), tir::CallNode::bitwise_or, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator^(PrimExpr a, PrimExpr b) { @@ -463,38 +467,38 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); - return ir::CallNode::make( - a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + a.dtype(), tir::CallNode::bitwise_xor, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return ir::CallNode::make( - a.dtype(), ir::CallNode::bitwise_not, { a }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + a.dtype(), tir::CallNode::bitwise_not, { a }, tir::CallNode::PureIntrinsic); } PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return ir::CallNode::make( - x.dtype(), "pow", { x, y }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make( + x.dtype(), "pow", { x, y }, tir::CallNode::PureIntrinsic); } PrimExpr abs(PrimExpr x) { if (x.dtype().is_int()) { - using ir::IntImmNode; + using tir::IntImmNode; const IntImmNode* px = x.as(); if (px) { return IntImm(x.dtype(), std::abs(px->value)); } - return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); + return tir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { - using ir::FloatImmNode; + using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value)); } - return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { @@ -509,17 +513,17 @@ PrimExpr isnan(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); } else if (x.dtype().is_float()) { - using ir::FloatImmNode; + using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return make_const(t, std::isnan(fx->value)); } if (x.dtype().bits() == 16) { - return ir::CallNode::make(t, ir::CallNode::isnan, + return tir::CallNode::make(t, tir::CallNode::isnan, {cast(DataType::Float(32, t.lanes()), std::move(x))}, - ir::CallNode::PureIntrinsic); + tir::CallNode::PureIntrinsic); } else { - return ir::CallNode::make(t, ir::CallNode::isnan, {x}, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); } } else { LOG(FATAL) << "Data type " << x.dtype() @@ -530,102 +534,102 @@ PrimExpr isnan(PrimExpr x) { PrimExpr sum(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = ir::AddNode::make(x, y); + PrimExpr result = tir::AddNode::make(x, y); PrimExpr identity_element = make_zero(source.dtype()); - ir::CommReducer combiner = - ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = + tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr all(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = ir::AndNode::make(x, y); + PrimExpr result = tir::AndNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), true); - ir::CommReducer combiner = - ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = + tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr any(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = ir::OrNode::make(x, y); + PrimExpr result = tir::OrNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), false); - ir::CommReducer combiner = - ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = + tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr max(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = ir::MaxNode::make(x, y); + PrimExpr result = tir::MaxNode::make(x, y); PrimExpr identity_element = min_value(source.dtype()); - ir::CommReducer combiner = - ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = + tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr min(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = ir::MinNode::make(x, y); + PrimExpr result = tir::MinNode::make(x, y); PrimExpr identity_element = max_value(source.dtype()); - ir::CommReducer combiner = - ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = + tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr prod(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = ir::MulNode::make(x, y); + PrimExpr result = tir::MulNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), 1); - ir::CommReducer combiner = - ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = + tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return ir::CallNode::make(x.dtype(), "fmod", { x, y }, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "fmod", { x, y }, tir::CallNode::PureIntrinsic); } PrimExpr floor(PrimExpr x) { - using ir::FloatImmNode; + using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); - return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); } PrimExpr ceil(PrimExpr x) { - using ir::FloatImmNode; + using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); - return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); } PrimExpr round(PrimExpr x) { - using ir::FloatImmNode; + using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); } PrimExpr nearbyint(PrimExpr x) { - using ir::FloatImmNode; + using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); } PrimExpr trunc(PrimExpr x) { - using ir::FloatImmNode; + using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); } } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc new file mode 100644 index 000000000000..6d39d99f6939 --- /dev/null +++ b/src/tir/ir/stmt.cc @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/stmt.cc + */ + +#include +#include +#include "../pass/ir_util.h" + +namespace tvm { +namespace tir { + +Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { + CHECK(value.defined()); + CHECK(body.defined()); + CHECK_EQ(value.dtype(), var.dtype()); + + ObjectPtr node = make_object(); + node->var = std::move(var); + node->value = std::move(value); + node->body = std::move(body); + return Stmt(node); +} + +Stmt AttrStmtNode::make(ObjectRef node, + std::string attr_key, + PrimExpr value, + Stmt body) { + auto n = make_object(); + n->node = node; + n->attr_key = std::move(attr_key); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); +} + +Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { + CHECK(condition.defined()); + CHECK(message.dtype() == DataType::Int(32) || + message.as()) + << "TypeError: AssertStmt message must be an int or string:" + << message << "\n"; + + ObjectPtr node = make_object(); + node->condition = std::move(condition); + node->message = std::move(message); + node->body = std::move(body); + return Stmt(node); +} + +Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { + CHECK(body.defined()); + + ObjectPtr node = make_object(); + node->func = std::move(func); + node->is_producer = is_producer; + node->body = std::move(body); + return Stmt(node); +} + +Stmt ForNode::make(Var loop_var, + PrimExpr min, + PrimExpr extent, + ForType for_type, + DeviceAPI device_api, + Stmt body) { + CHECK(min.defined()); + CHECK(extent.defined()); + CHECK(min.dtype().is_scalar()); + CHECK(extent.dtype().is_scalar()); + CHECK(loop_var.dtype().is_scalar()); + CHECK(body.defined()); + + ObjectPtr node = make_object(); + node->loop_var = std::move(loop_var); + node->min = std::move(min); + node->extent = std::move(extent); + node->for_type = for_type; + node->device_api = device_api; + node->body = std::move(body); + return Stmt(node); +} + +Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { + CHECK(value.defined()); + CHECK(index.defined()); + CHECK(predicate.defined()); + CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); + CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->value = std::move(value); + node->index = std::move(index); + node->predicate = std::move(predicate); + return Stmt(node); +} + +Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { + CHECK(value_index >=0 && value_index < func->num_outputs()) + << "value index output function return value bound"; + CHECK(value.defined()) << "Provide of undefined value\n"; + + for (size_t i = 0; i < args.size(); ++i) { + CHECK(args[i].defined()) << "Provide to undefined location\n"; + } + + ObjectPtr node = make_object(); + node->func = std::move(func); + node->value_index = value_index; + node->value = std::move(value); + node->args = std::move(args); + return Stmt(node); +} + +Stmt AllocateNode::make(Var buffer_var, + DataType dtype, + Array extents, + PrimExpr condition, + Stmt body, + PrimExpr new_expr, + std::string free_function) { + for (size_t i = 0; i < extents.size(); ++i) { + CHECK(extents[i].defined()); + CHECK(extents[i].dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->dtype = dtype; + node->extents = std::move(extents); + node->condition = std::move(condition); + node->body = std::move(body); + node->new_expr = std::move(new_expr); + node->free_function = std::move(free_function); + return Stmt(node); +} + +int32_t AllocateNode::constant_allocation_size(const Array& extents) { + int64_t result = 1; + for (size_t i = 0; i < extents.size(); ++i) { + if (const IntImmNode *int_size = extents[i].as()) { + result *= int_size->value; + if (result > std::numeric_limits::max()) { + return 0; + } + } else { + return 0; + } + } + return static_cast(result); +} + +Stmt FreeNode::make(Var buffer_var) { + ObjectPtr node = make_object(); + node->buffer_var = buffer_var; + return Stmt(node); +} + +Stmt RealizeNode::make(FunctionRef func, + int value_index, + DataType dtype, + Region bounds, + PrimExpr condition, + Stmt body) { + for (size_t i = 0; i < bounds.size(); ++i) { + CHECK(bounds[i]->min.defined()); + CHECK(bounds[i]->extent.defined()); + CHECK(bounds[i]->min.dtype().is_scalar()); + CHECK(bounds[i]->extent.dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->func = std::move(func); + node->value_index = value_index; + node->dtype = dtype; + node->bounds = std::move(bounds); + node->condition = std::move(condition); + node->body = std::move(body); + return Stmt(node); +} + +Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { + for (size_t i = 0; i < bounds.size(); ++i) { + CHECK(bounds[i]->min.defined()); + CHECK(bounds[i]->extent.defined()); + CHECK(bounds[i]->min.dtype().is_scalar()); + CHECK(bounds[i]->extent.dtype().is_scalar()); + } + + ObjectPtr node = make_object(); + node->func = std::move(func); + node->value_index = value_index; + node->dtype = dtype; + node->bounds = std::move(bounds); + return Stmt(node); +} + +SeqStmt::SeqStmt(Array seq) { + auto node = make_object(); + node->seq = std::move(seq); + data_ = std::move(node); +} + +Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { + CHECK(condition.defined()); + CHECK(then_case.defined()); + // else_case may be null. + + ObjectPtr node = make_object(); + node->condition = std::move(condition); + node->then_case = std::move(then_case); + node->else_case = std::move(else_case); + return Stmt(node); +} + +Stmt EvaluateNode::make(PrimExpr value) { + CHECK(value.defined()); + + ObjectPtr node = make_object(); + node->value = std::move(value); + return Stmt(node); +} + +// Printers + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "let " << op->var << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "// attr ["; + p->Print(op->node); + p->stream << "] " + << op->attr_key << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "assert("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->message); + p->stream << ")\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + if (op->is_producer) { + p->PrintIndent(); + p->stream << "produce " << op->func->func_name() << " {\n"; + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + } else { + p->Print(op->body); + } + }); + +std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) + switch (type) { + case ForType::Serial: + out << "for"; + break; + case ForType::Parallel: + out << "parallel"; + break; + case ForType::Unrolled: + out << "unrolled"; + break; + case ForType::Vectorized: + out << "vectorized"; + break; + } + return out; +} + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->for_type << " (" << op->loop_var << ", "; + p->Print(op->min); + p->stream << ", "; + p->Print(op->extent); + p->stream << ") {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + p->stream << "}\n"; +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "] = "; + p->Print(op->value); + if (!is_one(op->predicate)) { + p->stream << " if "; + p->Print(op->predicate); + } + p->stream << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->func->func_name() << "("; + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (op->func->num_outputs() != 1) { + p->stream << ".value[" << op->value_index << "]"; + } + p->stream << " ="; + p->Print(op->value); + p->stream << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "allocate " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + p->stream << " * "; + p->Print(op->extents[i]); + } + p->stream << "]"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << "\n"; + p->Print(op->body); + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "free " << op->buffer_var; + p->stream << '\n'; + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "realize " << op->func->func_name() << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (op->func->num_outputs() != 1) { + p->stream << ".value[" << op->value_index << "]"; + } + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + p->stream << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "prefetch " << op->func->func_name() << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (op->func->num_outputs() != 1) { + p->stream << ".value[" << op->value_index << "]"; + } + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + while (true) { + p->stream << "if (" << op->condition << ") {\n"; + p->indent += 2; + p->Print(op->then_case); + p->indent -= 2; + + if (!op->else_case.defined()) { + break; + } + + if (const IfThenElseNode *nested_if = op->else_case.as()) { + p->PrintIndent(); + p->stream << "} else "; + op = nested_if; + } else { + p->PrintIndent(); + p->stream << "} else {\n"; + p->indent += 2; + p->Print(op->else_case); + p->indent -= 2; + break; + } + } + p->PrintIndent(); + p->stream << "}\n"; +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->Print(op->value); + p->stream << "\n"; + }); + +template +void PrintList(const Array &exprs, NodePrinter* p) { + for (size_t i = 0; i < exprs.size(); ++i) { + p->Print(exprs[i]); + if (i < exprs.size() - 1) { + p->stream << ", "; + } + } +} + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "shuffle("; + PrintList(op->vectors, p); + p->stream << ", "; + PrintList(op->indices, p); + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(AttrStmtNode); +TVM_REGISTER_NODE_TYPE(PrefetchNode); +TVM_REGISTER_NODE_TYPE(CallNode); +TVM_REGISTER_NODE_TYPE(LetNode); +TVM_REGISTER_NODE_TYPE(LetStmtNode); +TVM_REGISTER_NODE_TYPE(AssertStmtNode); +TVM_REGISTER_NODE_TYPE(ProducerConsumerNode); +TVM_REGISTER_NODE_TYPE(ForNode); +TVM_REGISTER_NODE_TYPE(StoreNode); +TVM_REGISTER_NODE_TYPE(ProvideNode); +TVM_REGISTER_NODE_TYPE(AllocateNode); +TVM_REGISTER_NODE_TYPE(FreeNode); +TVM_REGISTER_NODE_TYPE(RealizeNode); +TVM_REGISTER_NODE_TYPE(SeqStmtNode); +TVM_REGISTER_NODE_TYPE(IfThenElseNode); +TVM_REGISTER_NODE_TYPE(EvaluateNode); + +} // namespace tir +} // namespace tvm diff --git a/src/pass/ir_functor.cc b/src/tir/ir/stmt_functor.cc similarity index 60% rename from src/pass/ir_functor.cc rename to src/tir/ir/stmt_functor.cc index 7292df6cda23..b4b27b9abef9 100644 --- a/src/pass/ir_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -17,12 +17,13 @@ * under the License. */ /*! - * \file ir_functor.cc + * \file stmt_functor.cc */ -#include +#include +#include "functor_common.h" namespace tvm { -namespace ir { +namespace tir { // visitor to implement apply class IRApplyVisit : @@ -128,14 +129,6 @@ Stmt IRTransform(Stmt ir_node, return transform(std::move(ir_node)); } -// Implementation of Visitors -template -inline void VisitArray(const Array& arr, F fvisit) { - for (size_t i = 0; i < arr.size(); i++) { - fvisit(arr[i]); - } -} - void StmtVisitor::VisitStmt_(const LetStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); @@ -218,107 +211,6 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const VarNode* op) {} - -void ExprVisitor::VisitExpr_(const SizeVarNode* op) { - this->VisitExpr_(static_cast(op)); -} - -void ExprVisitor::VisitExpr_(const LoadNode* op) { - this->VisitExpr(op->index); - this->VisitExpr(op->predicate); -} - -void ExprVisitor::VisitExpr_(const LetNode* op) { - this->VisitExpr(op->value); - this->VisitExpr(op->body); -} - -void ExprVisitor::VisitExpr_(const CallNode* op) { - VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); -} - -#define DEFINE_BINOP_VISIT_(OP) \ - void ExprVisitor::VisitExpr_(const OP* op) { \ - this->VisitExpr(op->a); \ - this->VisitExpr(op->b); \ - } - -DEFINE_BINOP_VISIT_(AddNode); -DEFINE_BINOP_VISIT_(SubNode); -DEFINE_BINOP_VISIT_(MulNode); -DEFINE_BINOP_VISIT_(DivNode); -DEFINE_BINOP_VISIT_(ModNode); -DEFINE_BINOP_VISIT_(FloorDivNode); -DEFINE_BINOP_VISIT_(FloorModNode); -DEFINE_BINOP_VISIT_(MinNode); -DEFINE_BINOP_VISIT_(MaxNode); -DEFINE_BINOP_VISIT_(EQNode); -DEFINE_BINOP_VISIT_(NENode); -DEFINE_BINOP_VISIT_(LTNode); -DEFINE_BINOP_VISIT_(LENode); -DEFINE_BINOP_VISIT_(GTNode); -DEFINE_BINOP_VISIT_(GENode); -DEFINE_BINOP_VISIT_(AndNode); -DEFINE_BINOP_VISIT_(OrNode); - -void ExprVisitor::VisitExpr_(const IntImmNode* op) {} -void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} -void ExprVisitor::VisitExpr_(const StringImmNode* op) {} - -void ExprVisitor::VisitExpr_(const ReduceNode* op) { - VisitArray(op->axis, [this](const IterVar& r) { - this->VisitExpr(r->dom->min); - this->VisitExpr(r->dom->extent); - }); - VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); - this->VisitExpr(op->condition); -} - -void ExprVisitor::VisitExpr_(const CastNode* op) { - this->VisitExpr(op->value); -} - -void ExprVisitor::VisitExpr_(const NotNode* op) { - this->VisitExpr(op->a); -} - -void ExprVisitor::VisitExpr_(const SelectNode* op) { - this->VisitExpr(op->condition); - this->VisitExpr(op->true_value); - this->VisitExpr(op->false_value); -} - -void ExprVisitor::VisitExpr_(const RampNode* op) { - this->VisitExpr(op->base); - this->VisitExpr(op->stride); -} - -void ExprVisitor::VisitExpr_(const ShuffleNode* op) { - VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); - VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); -} - -void ExprVisitor::VisitExpr_(const BroadcastNode* op) { - this->VisitExpr(op->value); -} - -// Implementation of mutators -template -inline Array MutateArray(const Array& arr, - F fmutate, - bool allow_copy_on_write = false) { - if (allow_copy_on_write) { - // if we allow copy on write, we can directly - // call the inplace mutate function. - const_cast&>(arr).MutateByApply(fmutate); - return arr; - } else { - Array copy = arr; - copy.MutateByApply(fmutate); - return copy; - } -} class StmtMutator::Internal { public: @@ -595,181 +487,6 @@ Stmt StmtMutator::VisitStmt_(const FreeNode* op) { } -PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { - return GetRef(op); -} - -PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { - return this->VisitExpr_(static_cast(op)); -} - -PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && predicate.same_as(op->predicate)) { - return GetRef(op); - } else { - return LoadNode::make(op->dtype, op->buffer_var, index, predicate); - } -} - -PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { - PrimExpr value = this->VisitExpr(op->value); - PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return GetRef(op); - } else { - return LetNode::make(op->var, value, body); - } -} - -PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { - auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array args = MutateArray(op->args, fmutate); - - if (args.same_as(op->args)) { - return GetRef(op); - } else { - return CallNode::make(op->dtype, - op->name, - args, - op->call_type, - op->func, - op->value_index); - } -} - -#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP *op) { \ - return GetRef(op); \ - } - -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) - -#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return OP::make(a, b); \ - } \ - } - -DEFINE_BIOP_EXPR_MUTATE_(AddNode); -DEFINE_BIOP_EXPR_MUTATE_(SubNode); -DEFINE_BIOP_EXPR_MUTATE_(MulNode); -DEFINE_BIOP_EXPR_MUTATE_(DivNode); -DEFINE_BIOP_EXPR_MUTATE_(ModNode); -DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode); -DEFINE_BIOP_EXPR_MUTATE_(FloorModNode); -DEFINE_BIOP_EXPR_MUTATE_(MinNode); -DEFINE_BIOP_EXPR_MUTATE_(MaxNode); -DEFINE_BIOP_EXPR_MUTATE_(EQNode); -DEFINE_BIOP_EXPR_MUTATE_(NENode); -DEFINE_BIOP_EXPR_MUTATE_(LTNode); -DEFINE_BIOP_EXPR_MUTATE_(LENode); -DEFINE_BIOP_EXPR_MUTATE_(GTNode); -DEFINE_BIOP_EXPR_MUTATE_(GENode); -DEFINE_BIOP_EXPR_MUTATE_(AndNode); -DEFINE_BIOP_EXPR_MUTATE_(OrNode); - -PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { - auto fitervar = [this](const IterVar& v) { - Range r = v->dom; - PrimExpr min = this->VisitExpr(r->min); - PrimExpr extent = this->VisitExpr(r->extent); - if (min.same_as(r->min) && - extent.same_as(r->extent)) { - return v; - } else { - return IterVarNode::make( - Range::make_by_min_extent(min, extent), - v->var, v->iter_type, v->thread_tag); - } - }; - Array axis = MutateArray(op->axis, fitervar); - - auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array source = MutateArray(op->source, fexpr); - - PrimExpr condition = this->VisitExpr(op->condition); - - if (axis.same_as(op->axis) && - source.same_as(op->source) && - condition.same_as(op->condition)) { - return GetRef(op); - } else { - return ReduceNode::make( - op->combiner, source, axis, condition, op->value_index); - } -} - -PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { - PrimExpr value = this->VisitExpr(op->value); - if (value.same_as(op->value)) { - return GetRef(op); - } else { - return CastNode::make(op->dtype, value); - } -} - -PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { - PrimExpr a = this->VisitExpr(op->a); - if (a.same_as(op->a)) { - return GetRef(op); - } else { - return NotNode::make(a); - } -} - -PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { - PrimExpr condition = this->VisitExpr(op->condition); - PrimExpr true_value = this->VisitExpr(op->true_value); - PrimExpr false_value = this->VisitExpr(op->false_value); - if (condition.same_as(op->condition) && - true_value.same_as(op->true_value) && - false_value.same_as(op->false_value)) { - return GetRef(op); - } else { - return SelectNode::make(condition, true_value, false_value); - } -} - -PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { - PrimExpr base = this->VisitExpr(op->base); - PrimExpr stride = this->VisitExpr(op->stride); - if (base.same_as(op->base) && - stride.same_as(op->stride)) { - return GetRef(op); - } else { - return RampNode::make(base, stride, op->lanes); - } -} - -PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { - PrimExpr value = this->VisitExpr(op->value); - if (value.same_as(op->value)) { - return GetRef(op); - } else { - return BroadcastNode::make(value, op->lanes); - } -} - -PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { - auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - auto vectors = MutateArray(op->vectors, fexpr); - if (vectors.same_as(op->vectors)) { - return GetRef(op); - } else { - return ShuffleNode::make(vectors, op->indices); - } -} -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc similarity index 96% rename from src/pass/arg_binder.cc rename to src/tir/pass/arg_binder.cc index 3b3bf4b1853e..bd35c768e6ac 100644 --- a/src/pass/arg_binder.cc +++ b/src/tir/pass/arg_binder.cc @@ -21,15 +21,15 @@ * \file arg_binder.cc * \brief Helper utility to match and bind arguments. */ -#include -#include +#include +#include #include #include "ir_util.h" #include "arg_binder.h" -#include "../arith/compute_expr.h" +#include "../../arith/compute_expr.h" namespace tvm { -namespace ir { +namespace tir { void BinderAddAssert(PrimExpr cond, const std::string& arg_name, @@ -189,10 +189,10 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), arg_name + ".data", true)) { Var vptr(buffer->data); - def_handle_dtype_.Set(vptr, ir::TypeAnnotation(buffer->dtype)); + def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); // mark alignment of external bufs init_nest_.emplace_back(AttrStmtNode::make( - vptr, ir::attr::storage_alignment, + vptr, tir::attr::storage_alignment, IntImm(DataType::Int(32), buffer->data_alignment), nop)); } @@ -211,7 +211,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, } // strides field Var v_strides(arg_name + ".strides", DataType::Handle()); - def_handle_dtype_.Set(v_strides, ir::TypeAnnotation(tvm_shape_type)); + def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmtNode::make( v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); @@ -237,7 +237,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, << " expected to be compact array"; if (conds.size() != 0) { Stmt check = - AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), + AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), stride_err_msg.str(), EvaluateNode::make(0)); check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); @@ -304,5 +304,5 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, arg_name + ".device_id", true); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/arg_binder.h b/src/tir/pass/arg_binder.h similarity index 96% rename from src/pass/arg_binder.h rename to src/tir/pass/arg_binder.h index 75006d68dfba..dfeb82853529 100644 --- a/src/pass/arg_binder.h +++ b/src/tir/pass/arg_binder.h @@ -21,17 +21,17 @@ * \file arg_binder.h * \brief Helper utility to match and bind arguments. */ -#ifndef TVM_PASS_ARG_BINDER_H_ -#define TVM_PASS_ARG_BINDER_H_ +#ifndef TVM_TIR_PASS_ARG_BINDER_H_ +#define TVM_TIR_PASS_ARG_BINDER_H_ -#include -#include +#include +#include #include #include #include namespace tvm { -namespace ir { +namespace tir { /*! * \brief Helper utility to generate match and bind of arguments. @@ -154,6 +154,6 @@ class ArgBinder { /*! \brief asserts generated */ std::vector asserts_; }; -} // namespace ir +} // namespace tir } // namespace tvm -#endif // TVM_PASS_ARG_BINDER_H_ +#endif // TVM_TIR_PASS_ARG_BINDER_H_ diff --git a/src/pass/bound_checker.cc b/src/tir/pass/bound_checker.cc similarity index 96% rename from src/pass/bound_checker.cc rename to src/tir/pass/bound_checker.cc index 439c8862c9c6..ee24d0f77673 100644 --- a/src/pass/bound_checker.cc +++ b/src/tir/pass/bound_checker.cc @@ -22,22 +22,22 @@ */ // Instrument checkers for out of the bounds access. -#include -#include -#include +#include +#include +#include #include #include #include namespace tvm { -namespace ir { +namespace tir { class BoundCollector : public StmtVisitor { public: BoundCollector() {} void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == ir::attr::buffer_bound) { + if (op->attr_key == tir::attr::buffer_bound) { if (const VarNode *key = op->node.as()) { mem_to_shape[key] = op->value; } @@ -173,8 +173,8 @@ class BoundChecker : public StmtExprMutator { } // Try to simplify index and bound. - index = ir::Simplify(index); - upper_bound = ir::Simplify(upper_bound); + index = tir::Simplify(index); + upper_bound = tir::Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. index = CastNode::make(DataType::Int(64), index); @@ -209,5 +209,5 @@ Stmt InstrumentBoundCheckers(Stmt stmt) { bound_collector(stmt); return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/combine_context_call.cc b/src/tir/pass/combine_context_call.cc similarity index 95% rename from src/pass/combine_context_call.cc rename to src/tir/pass/combine_context_call.cc index 4561dba5469c..5f043bc8ac73 100644 --- a/src/pass/combine_context_call.cc +++ b/src/tir/pass/combine_context_call.cc @@ -22,13 +22,14 @@ * * \file combine_context_call.cc */ -#include -#include -#include +#include +#include +#include +#include #include namespace tvm { -namespace ir { +namespace tir { // Calculate the statistics of packed function. // These information are needed during codegen. @@ -113,5 +114,5 @@ LoweredFunc CombineContextCall(LoweredFunc f) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/coproc_sync.cc b/src/tir/pass/coproc_sync.cc similarity index 99% rename from src/pass/coproc_sync.cc rename to src/tir/pass/coproc_sync.cc index 4e68793cc875..38b7798eae11 100644 --- a/src/pass/coproc_sync.cc +++ b/src/tir/pass/coproc_sync.cc @@ -20,16 +20,16 @@ /*! * \file coproc_sync.cc */ -#include -#include -#include +#include +#include +#include #include #include #include "ir_util.h" #include "storage_access.h" namespace tvm { -namespace ir { +namespace tir { // Visitor to find touched set by co-processor scope. class CoProcTouchedBuffer : public StmtExprVisitor { @@ -677,5 +677,5 @@ Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/detect_device.cc b/src/tir/pass/detect_device.cc similarity index 88% rename from src/pass/detect_device.cc rename to src/tir/pass/detect_device.cc index 3578ce52803d..ee3a2e23b487 100644 --- a/src/pass/detect_device.cc +++ b/src/tir/pass/detect_device.cc @@ -21,18 +21,18 @@ * \file detect_device.cc */ -#include -#include "../pass/ir_util.h" +#include +#include "ir_util.h" namespace tvm { -namespace ir { +namespace tir { Stmt DecorateDeviceScope(Stmt stmt) { Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), - ir::attr::device_scope, + tir::attr::device_scope, 0, stmt); return body; } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc similarity index 98% rename from src/pass/hoist_if_then_else.cc rename to src/tir/pass/hoist_if_then_else.cc index 302abea6363f..1fd43ff72ffe 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -20,19 +20,19 @@ /*! * \file hoist_if_then_else.cc */ -#include -#include +#include +#include #include #include #include #include #include -#include "../arith/interval_set.h" -#include "../runtime/thread_storage_scope.h" +#include "../../arith/interval_set.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { using HoistMap = std::unordered_map>; using VarMap = std::unordered_map>; @@ -418,5 +418,5 @@ Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/infer_fragment.cc b/src/tir/pass/infer_fragment.cc similarity index 97% rename from src/pass/infer_fragment.cc rename to src/tir/pass/infer_fragment.cc index 6dfa509345be..0cb1b9686cbd 100644 --- a/src/pass/infer_fragment.cc +++ b/src/tir/pass/infer_fragment.cc @@ -21,17 +21,17 @@ * \brief Infer TensorCore metadata from tensor intrinsic. * \file tensorcore_fragment.cc */ -#include -#include -#include +#include +#include +#include #include #include #include "ir_util.h" #include "storage_access.h" -#include "../runtime/thread_storage_scope.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { // Get fragment information from tensor intrinsics class FragmentGetter : public StmtExprVisitor { @@ -220,5 +220,5 @@ LoweredFunc InferFragment(LoweredFunc f) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/inject_copy_intrin.cc b/src/tir/pass/inject_copy_intrin.cc similarity index 97% rename from src/pass/inject_copy_intrin.cc rename to src/tir/pass/inject_copy_intrin.cc index 29bb5b484774..4805caf5ac55 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/tir/pass/inject_copy_intrin.cc @@ -22,13 +22,13 @@ * \file copy_intrin_rewrite.cc */ #include -#include -#include -#include -#include "../arith/pattern_match.h" +#include +#include +#include +#include "../../arith/pattern_match.h" namespace tvm { -namespace ir { +namespace tir { using runtime::PackedFunc; @@ -196,5 +196,5 @@ Stmt InjectCopyIntrin(Stmt stmt, return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/inject_double_buffer.cc b/src/tir/pass/inject_double_buffer.cc similarity index 98% rename from src/pass/inject_double_buffer.cc rename to src/tir/pass/inject_double_buffer.cc index 691f6a79e667..b9aa5a9e697e 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/tir/pass/inject_double_buffer.cc @@ -21,14 +21,14 @@ * \brief Inject double buffering optimization for data fetch. * \file inject_double_buffer.cc */ -#include -#include -#include +#include +#include +#include #include "ir_util.h" -#include "../arith/compute_expr.h" +#include "../../arith/compute_expr.h" namespace tvm { -namespace ir { +namespace tir { // Detect double buffer variables. class DoubleBufferDetector : public StmtExprVisitor { @@ -273,5 +273,5 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { return DoubleBufferInjector(split_loop).Inject(stmt); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/inject_prefetch.cc b/src/tir/pass/inject_prefetch.cc similarity index 95% rename from src/pass/inject_prefetch.cc rename to src/tir/pass/inject_prefetch.cc index a2895d55a70d..f04d5d46fe71 100644 --- a/src/pass/inject_prefetch.cc +++ b/src/tir/pass/inject_prefetch.cc @@ -21,14 +21,14 @@ * \file inject_prefetch.cc */ // Inject prefetch op in HalideIR -#include -#include -#include +#include +#include +#include #include #include namespace tvm { -namespace ir { +namespace tir { using arith::IntSet; using arith::DomainTouched; @@ -90,5 +90,5 @@ Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/inject_virtual_thread.cc b/src/tir/pass/inject_virtual_thread.cc similarity index 98% rename from src/pass/inject_virtual_thread.cc rename to src/tir/pass/inject_virtual_thread.cc index a0a67a785a5a..99e11491c4b1 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/tir/pass/inject_virtual_thread.cc @@ -20,14 +20,14 @@ /*! * \file inject_virtual_thread.cc */ -#include -#include -#include +#include +#include +#include #include -#include "../arith/compute_expr.h" +#include "../../arith/compute_expr.h" namespace tvm { -namespace ir { +namespace tir { // If expression is touched by var. class ExprTouched final : public StmtExprVisitor { @@ -507,5 +507,5 @@ Stmt InjectVirtualThread(Stmt stmt) { return ConvertSSA(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/inline.cc b/src/tir/pass/inline.cc similarity index 94% rename from src/pass/inline.cc rename to src/tir/pass/inline.cc index fad3f1766872..1b322964b873 100644 --- a/src/pass/inline.cc +++ b/src/tir/pass/inline.cc @@ -20,12 +20,13 @@ /*! * \file inline.cc */ -#include -#include -#include +#include +#include +#include +#include namespace tvm { -namespace ir { +namespace tir { // inliner to inline a function // the result may not be SSA, @@ -82,5 +83,5 @@ Stmt Inline(Stmt stmt, if (ret.same_as(stmt)) return ret; return ConvertSSA(ret); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/ir_deep_compare.cc b/src/tir/pass/ir_deep_compare.cc similarity index 99% rename from src/pass/ir_deep_compare.cc rename to src/tir/pass/ir_deep_compare.cc index 8c441510c51d..e45251fe8a4a 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/tir/pass/ir_deep_compare.cc @@ -20,11 +20,12 @@ /*! * \file ir_deep_compare.cc */ -#include -#include +#include +#include +#include namespace tvm { -namespace ir { +namespace tir { using ExprComparator = ExprFunctor; using StmtComparator = StmtFunctor; @@ -455,5 +456,5 @@ int Compare(const PrimExpr& lhs, const PrimExpr& rhs) { return IRDeepCompare().Compare(lhs, rhs); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/ir_util.cc b/src/tir/pass/ir_util.cc similarity index 98% rename from src/pass/ir_util.cc rename to src/tir/pass/ir_util.cc index 8ecfbfff5b47..7223c5b1c9e6 100644 --- a/src/pass/ir_util.cc +++ b/src/tir/pass/ir_util.cc @@ -24,7 +24,7 @@ #include "ir_util.h" namespace tvm { -namespace ir { +namespace tir { Stmt MergeNest(const std::vector& nest, Stmt body) { // use reverse iteration @@ -80,5 +80,5 @@ Stmt MergeNest(const std::vector >& nest, Stmt body) { return body; } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/ir_util.h b/src/tir/pass/ir_util.h similarity index 96% rename from src/pass/ir_util.h rename to src/tir/pass/ir_util.h index f1a01953d5bd..d8da61fdd961 100644 --- a/src/pass/ir_util.h +++ b/src/tir/pass/ir_util.h @@ -21,16 +21,16 @@ * \file ir_util.h * \brief Helper functions to construct and compose IR nodes. */ -#ifndef TVM_PASS_IR_UTIL_H_ -#define TVM_PASS_IR_UTIL_H_ +#ifndef TVM_TIR_PASS_IR_UTIL_H_ +#define TVM_TIR_PASS_IR_UTIL_H_ -#include -#include +#include +#include #include #include namespace tvm { -namespace ir { +namespace tir { /*! * \brief combine the nest stmt, whose body is not defined. * \param nest A list of For and LetStmt, whose body is not defined. @@ -190,6 +190,6 @@ inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) { *base = r->base; return true; } -} // namespace ir +} // namespace tir } // namespace tvm -#endif // TVM_PASS_IR_UTIL_H_ +#endif // TVM_TIR_PASS_IR_UTIL_H_ diff --git a/src/pass/lift_attr_scope.cc b/src/tir/pass/lift_attr_scope.cc similarity index 98% rename from src/pass/lift_attr_scope.cc rename to src/tir/pass/lift_attr_scope.cc index 5aba355b7003..2874ac2b19de 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/tir/pass/lift_attr_scope.cc @@ -23,12 +23,12 @@ * the body contains the same scope. * \file lift_attr_scope.cc */ -#include -#include +#include +#include #include "ir_util.h" namespace tvm { -namespace ir { +namespace tir { // NOTE: this optimization can only be applied // to a few specified attr keys @@ -192,5 +192,5 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/loop_partition.cc b/src/tir/pass/loop_partition.cc similarity index 98% rename from src/pass/loop_partition.cc rename to src/tir/pass/loop_partition.cc index 7af2240294dd..d1fa46e38860 100644 --- a/src/pass/loop_partition.cc +++ b/src/tir/pass/loop_partition.cc @@ -20,17 +20,17 @@ /*! * \file loop_partition.cc */ -#include -#include -#include +#include +#include +#include #include #include #include -#include "../arith/interval_set.h" -#include "../runtime/thread_storage_scope.h" +#include "../../arith/interval_set.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { using arith::IntSet; using arith::DeduceBound; @@ -500,7 +500,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Stmt pre_stmt; bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { - body_begin = ir::Simplify(middle_interval.min()); + body_begin = tir::Simplify(middle_interval.min()); if (!analyzer_.CanProve(body_begin == min)) { PrimExpr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { @@ -525,7 +525,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { - post_doubt_begin = ir::Simplify(middle_interval.max() + 1); + post_doubt_begin = tir::Simplify(middle_interval.max() + 1); if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative PrimExpr cond = (max - post_doubt_begin + 1 >= 0); @@ -610,5 +610,5 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) { return stmt; } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/lower_custom_datatypes.cc b/src/tir/pass/lower_custom_datatypes.cc similarity index 97% rename from src/pass/lower_custom_datatypes.cc rename to src/tir/pass/lower_custom_datatypes.cc index b494328f6366..77c7e8c35dd9 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/tir/pass/lower_custom_datatypes.cc @@ -21,12 +21,12 @@ * \brief Pass for lowering custom datatypes */ -#include -#include -#include "../codegen/datatype/registry.h" +#include +#include +#include "../../codegen/datatype/registry.h" namespace tvm { -namespace ir { +namespace tir { /*! * \brief Helper mutator to implement lowering of custom datatypes. @@ -135,5 +135,5 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/lower_intrin.cc b/src/tir/pass/lower_intrin.cc similarity index 95% rename from src/pass/lower_intrin.cc rename to src/tir/pass/lower_intrin.cc index 4e1ea8d0b2dc..d39624de5488 100644 --- a/src/pass/lower_intrin.cc +++ b/src/tir/pass/lower_intrin.cc @@ -21,18 +21,18 @@ * Lower intrinsic calls and ops to device specific ir when possible. * \file lower_intrin.cc */ -#include -#include +#include +#include #include -#include +#include #include #include "ir_util.h" -#include "../arith/pattern_match.h" -#include "../arith/ir_mutator_with_analyzer.h" +#include "../../arith/pattern_match.h" +#include "../../arith/ir_mutator_with_analyzer.h" namespace tvm { -namespace ir { +namespace tir { class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: @@ -103,7 +103,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return ir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); + return tir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); } } } else { @@ -113,7 +113,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) PrimExpr rdiv = truncdiv(op->a, op->b); PrimExpr rmod = truncmod(op->a, op->b); - return ir::SelectNode::make( + return tir::SelectNode::make( (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, rdiv - make_const(dtype, 1)); } @@ -152,7 +152,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // -> rmod >= 0 ? 0 : b return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); } else { - return ir::SelectNode::make(rmod >= 0, rmod, rmod + op->b); + return tir::SelectNode::make(rmod >= 0, rmod, rmod + op->b); } } } else { @@ -163,7 +163,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod // b < 0 && rmod > 0 -> rmod + b - return ir::SelectNode::make( + return tir::SelectNode::make( (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b); } @@ -296,5 +296,5 @@ LowerIntrin(LoweredFunc f, const std::string& target) { TVM_REGISTER_GLOBAL("ir_pass._LowerIntrinStmt") .set_body_typed(LowerIntrinStmt); -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/lower_thread_allreduce.cc b/src/tir/pass/lower_thread_allreduce.cc similarity index 97% rename from src/pass/lower_thread_allreduce.cc rename to src/tir/pass/lower_thread_allreduce.cc index 7b1378f8686c..259a3a62d24b 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/tir/pass/lower_thread_allreduce.cc @@ -21,16 +21,16 @@ * Lower allreduce to device implementable ir. * \file lower_thread_allreduce.cc */ -#include -#include -#include +#include +#include +#include #include #include "ir_util.h" -#include "../arith/compute_expr.h" -#include "../runtime/thread_storage_scope.h" +#include "../../arith/compute_expr.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { class ThreadAllreduceBuilder final : public StmtExprMutator { public: @@ -318,7 +318,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The local buffer index. static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { if (!is_zero(group_index)) { - return ir::Simplify(group_index * reduce_extent + reduce_index); + return tir::Simplify(group_index * reduce_extent + reduce_index); } else { return reduce_index; } @@ -342,5 +342,5 @@ LowerThreadAllreduce(LoweredFunc f, int warp_size) { n->body = ThreadAllreduceBuilder(warp_size)(n->body); return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/lower_tvm_builtin.cc b/src/tir/pass/lower_tvm_builtin.cc similarity index 98% rename from src/pass/lower_tvm_builtin.cc rename to src/tir/pass/lower_tvm_builtin.cc index 13e2504a7bf0..106a604abce0 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/tir/pass/lower_tvm_builtin.cc @@ -21,15 +21,15 @@ * Lower TVM related buildin intrinsics such as packed call. * \file lower_tvm_buildin.cc */ -#include -#include -#include +#include +#include +#include #include #include "ir_util.h" -#include "../arith/compute_expr.h" +#include "../../arith/compute_expr.h" namespace tvm { -namespace ir { +namespace tir { inline PrimExpr ConstInt32(size_t index) { CHECK_LE(index, std::numeric_limits::max()); @@ -374,5 +374,5 @@ LoweredFunc LowerTVMBuiltin(LoweredFunc f) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/lower_warp_memory.cc b/src/tir/pass/lower_warp_memory.cc similarity index 98% rename from src/pass/lower_warp_memory.cc rename to src/tir/pass/lower_warp_memory.cc index 8da07f02ca29..385a5b454ca5 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/tir/pass/lower_warp_memory.cc @@ -28,16 +28,16 @@ #include #include -#include -#include -#include +#include +#include +#include #include #include "ir_util.h" -#include "../arith/compute_expr.h" -#include "../runtime/thread_storage_scope.h" +#include "../../arith/compute_expr.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { // Rewrite Rule // @@ -388,5 +388,5 @@ LowerWarpMemory(LoweredFunc f, int warp_size) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/make_api.cc b/src/tir/pass/make_api.cc similarity index 98% rename from src/pass/make_api.cc rename to src/tir/pass/make_api.cc index fa2965168c85..70ea2a21a869 100644 --- a/src/pass/make_api.cc +++ b/src/tir/pass/make_api.cc @@ -20,10 +20,10 @@ /*! * \file make_api.cc Build API function. */ -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include @@ -33,7 +33,7 @@ #include "arg_binder.h" namespace tvm { -namespace ir { +namespace tir { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0)); @@ -255,7 +255,7 @@ class DeviceTypeBinder: public StmtExprMutator { // eager check NE for device check PrimExpr res = StmtExprMutator::VisitExpr_(op); op = res.as(); - if (ir::Equal(op->a, op->b)) { + if (tir::Equal(op->a, op->b)) { return make_const(op->dtype, false); } return res; @@ -281,5 +281,5 @@ LoweredFunc BindDeviceType(LoweredFunc f, return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/remap_thread_axis.cc b/src/tir/pass/remap_thread_axis.cc similarity index 95% rename from src/pass/remap_thread_axis.cc rename to src/tir/pass/remap_thread_axis.cc index 4201e785064a..4fa5dd3cbe9b 100644 --- a/src/pass/remap_thread_axis.cc +++ b/src/tir/pass/remap_thread_axis.cc @@ -20,14 +20,14 @@ /*! * \file remap_thread_axis.cc */ -#include -#include -#include +#include +#include +#include #include namespace tvm { -namespace ir { +namespace tir { // Mutator to change the read pattern class ThreadAxisRewriter : private StmtExprMutator { @@ -96,5 +96,5 @@ RemapThreadAxis(LoweredFunc f, Map thread_map) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/remove_no_op.cc b/src/tir/pass/remove_no_op.cc similarity index 97% rename from src/pass/remove_no_op.cc rename to src/tir/pass/remove_no_op.cc index eecbe30828d1..3b9f823517ac 100644 --- a/src/pass/remove_no_op.cc +++ b/src/tir/pass/remove_no_op.cc @@ -21,13 +21,13 @@ * \file remove_no_op.cc * \brief Remove no op from the stmt */ -#include -#include -#include +#include +#include +#include #include namespace tvm { -namespace ir { +namespace tir { // Mark the statment of each stage. class NoOpRemover : public StmtMutator { @@ -151,5 +151,5 @@ class NoOpRemover : public StmtMutator { Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/rewrite_unsafe_select.cc b/src/tir/pass/rewrite_unsafe_select.cc similarity index 97% rename from src/pass/rewrite_unsafe_select.cc rename to src/tir/pass/rewrite_unsafe_select.cc index 9fb19cc4b308..501649237090 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/tir/pass/rewrite_unsafe_select.cc @@ -21,12 +21,12 @@ * \file unsafe_select_rewrite.cc * \brief Rewrite uinsafe select expression. */ -#include -#include -#include +#include +#include +#include namespace tvm { -namespace ir { +namespace tir { // For now, rewrite unsafe select expression to if_then_else @@ -132,5 +132,5 @@ Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/simple_passes.cc b/src/tir/pass/simple_passes.cc similarity index 96% rename from src/pass/simple_passes.cc rename to src/tir/pass/simple_passes.cc index 9737f7047ab6..81145c89b6a0 100644 --- a/src/pass/simple_passes.cc +++ b/src/tir/pass/simple_passes.cc @@ -21,12 +21,12 @@ * \file simple_passes.cc * \brief Implementation of simple passes */ -#include -#include -#include +#include +#include +#include namespace tvm { -namespace ir { +namespace tir { class IRSideEffect : public ExprVisitor { public: @@ -159,5 +159,5 @@ bool ExprUseVar(const PrimExpr& e, return visitor.use_var_; } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/skip_assert.cc b/src/tir/pass/skip_assert.cc similarity index 91% rename from src/pass/skip_assert.cc rename to src/tir/pass/skip_assert.cc index d2c4dc70e9f6..14f59f090cac 100644 --- a/src/pass/skip_assert.cc +++ b/src/tir/pass/skip_assert.cc @@ -17,12 +17,12 @@ * under the License. */ -#include -#include -#include +#include +#include +#include namespace tvm { -namespace ir { +namespace tir { class AssertSkipper : public StmtMutator { public: @@ -43,5 +43,5 @@ LoweredFunc SkipAssert(LoweredFunc f) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/split_host_device.cc b/src/tir/pass/split_host_device.cc similarity index 97% rename from src/pass/split_host_device.cc rename to src/tir/pass/split_host_device.cc index 7309c724099b..519101fe49ac 100644 --- a/src/pass/split_host_device.cc +++ b/src/tir/pass/split_host_device.cc @@ -21,15 +21,15 @@ * \file split_host_device.cc * \brief Split device function from host. */ -#include -#include -#include -#include +#include +#include +#include +#include #include #include namespace tvm { -namespace ir { +namespace tir { // use/def analysis, also delete unreferenced lets class IRUseDefAnalysis : public StmtExprMutator { @@ -253,5 +253,5 @@ Array SplitHostDevice(LoweredFunc func) { return HostDeviceSplitter().Split(func); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/ssa.cc b/src/tir/pass/ssa.cc similarity index 98% rename from src/pass/ssa.cc rename to src/tir/pass/ssa.cc index 50cdc528f207..833702e8202e 100644 --- a/src/pass/ssa.cc +++ b/src/tir/pass/ssa.cc @@ -23,15 +23,15 @@ * SSA requires each varaible to be only defined once. * \file ssa.cc */ -#include -#include -#include +#include +#include +#include #include #include #include namespace tvm { -namespace ir { +namespace tir { namespace { class IRVerifySSA final : public StmtExprVisitor { public: @@ -207,5 +207,5 @@ Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/storage_access.cc b/src/tir/pass/storage_access.cc similarity index 98% rename from src/pass/storage_access.cc rename to src/tir/pass/storage_access.cc index 43acbafc8554..aaee58234688 100644 --- a/src/pass/storage_access.cc +++ b/src/tir/pass/storage_access.cc @@ -20,16 +20,16 @@ /*! * \file storage_access.cc */ -#include +#include #include #include #include -#include "ir_util.h" #include "storage_access.h" -#include "../arith/compute_expr.h" +#include "ir_util.h" +#include "../../arith/compute_expr.h" namespace tvm { -namespace ir { +namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); @@ -320,7 +320,7 @@ class StorageAccessInfoLower : public StmtExprMutator { int dtype_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(info->unit_bits % dtype_bits, 0); return cast(ptr_type, - ir::Simplify(offset / make_const( + tir::Simplify(offset / make_const( offset.dtype(), info->unit_bits / dtype_bits))); } // The storage entry. @@ -346,5 +346,5 @@ LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/storage_access.h b/src/tir/pass/storage_access.h similarity index 94% rename from src/pass/storage_access.h rename to src/tir/pass/storage_access.h index aea9f1e595a5..d3614b8fff4e 100644 --- a/src/pass/storage_access.h +++ b/src/tir/pass/storage_access.h @@ -21,19 +21,19 @@ * \file storage_access.h * \brief Common data structure for storage access analysis. */ -#ifndef TVM_PASS_STORAGE_ACCESS_H_ -#define TVM_PASS_STORAGE_ACCESS_H_ +#ifndef TVM_TIR_PASS_STORAGE_ACCESS_H_ +#define TVM_TIR_PASS_STORAGE_ACCESS_H_ -#include +#include #include -#include -#include +#include +#include #include #include -#include "../runtime/thread_storage_scope.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { using runtime::StorageScope; using runtime::StorageRank; @@ -148,6 +148,6 @@ class StorageAccessVisitor : public StmtExprVisitor { std::unordered_map storage_scope_; }; -} // namespace ir +} // namespace tir } // namespace tvm -#endif // TVM_PASS_STORAGE_ACCESS_H_ +#endif // TVM_TIR_PASS_STORAGE_ACCESS_H_ diff --git a/src/pass/storage_flatten.cc b/src/tir/pass/storage_flatten.cc similarity index 97% rename from src/pass/storage_flatten.cc rename to src/tir/pass/storage_flatten.cc index b5067658c993..bed487942462 100644 --- a/src/pass/storage_flatten.cc +++ b/src/tir/pass/storage_flatten.cc @@ -23,24 +23,24 @@ // Flattens storage from multi-dimensional array to 1D // buffer access as in Halide pipeline. #include -#include -#include +#include +#include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include #include "ir_util.h" #include "arg_binder.h" -#include "../arith/compute_expr.h" -#include "../arith/ir_visitor_with_analyzer.h" -#include "../runtime/thread_storage_scope.h" +#include "../../arith/compute_expr.h" +#include "../../arith/ir_visitor_with_analyzer.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { using runtime::StorageRank; using runtime::StorageScope; @@ -150,7 +150,7 @@ class StorageFlattener : public StmtExprMutator { if (create_bound_attributes_ && shape_collector_.size()) { for (size_t i = 0; i < shape_collector_.size(); ++i) { body = AttrStmtNode::make( - shape_collector_[i].first, ir::attr::buffer_bound, + shape_collector_[i].first, tir::attr::buffer_bound, MakeBound(e.buffer->dtype, shape_collector_[i].second), body); } } @@ -210,7 +210,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = ir::Simplify(stride); + stride = tir::Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -255,7 +255,7 @@ class StorageFlattener : public StmtExprMutator { StringImmNode::make(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - ret = AttrStmtNode::make(e.buffer->data, ir::attr::buffer_bound, + ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound, MakeBound(e.buffer->dtype, e.buffer->shape), ret); } return ret; @@ -539,5 +539,5 @@ Stmt StorageFlatten(Stmt stmt, Map extern_buffer, return stmt; } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/storage_rewrite.cc b/src/tir/pass/storage_rewrite.cc similarity index 99% rename from src/pass/storage_rewrite.cc rename to src/tir/pass/storage_rewrite.cc index 8b55dab31e9c..98410336db61 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/tir/pass/storage_rewrite.cc @@ -23,19 +23,19 @@ * Re-write data access to enable memory sharing when possible. */ #include -#include -#include -#include +#include +#include +#include #include #include #include #include #include "ir_util.h" -#include "../arith/compute_expr.h" -#include "../runtime/thread_storage_scope.h" +#include "../../arith/compute_expr.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { using runtime::StorageRank; using runtime::StorageScope; @@ -311,7 +311,7 @@ class InplaceOpVerifier : public StmtExprVisitor { if (src_ == buf) { if (store_ == nullptr || store_->value.dtype() != op->dtype || - !ir::Equal(store_->index, op->index)) { + !tir::Equal(store_->index, op->index)) { result_ = false; return; } } @@ -622,7 +622,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (!divided) { combo_size = combo_size + make_const(DataType::Int(32), 1); } - combo_size = ir::Simplify(combo_size); + combo_size = tir::Simplify(combo_size); e->new_alloc = AllocateNode::make( e->alloc_var, alloc_type, {combo_size}, const_true(), EvaluateNode::make(0)); @@ -1020,5 +1020,5 @@ Stmt StorageRewrite(Stmt stmt) { stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); return VectorAllocRewriter()(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/storage_sync.cc b/src/tir/pass/storage_sync.cc similarity index 98% rename from src/pass/storage_sync.cc rename to src/tir/pass/storage_sync.cc index 2358ce999231..0f9af3ca48db 100644 --- a/src/pass/storage_sync.cc +++ b/src/tir/pass/storage_sync.cc @@ -20,17 +20,17 @@ /*! * \file storage_sync.cc */ -#include -#include -#include +#include +#include +#include #include #include #include "ir_util.h" #include "storage_access.h" -#include "../runtime/thread_storage_scope.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { class ThreadSyncPlanner : public StorageAccessVisitor { public: @@ -375,5 +375,5 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { return LoweredFunc(n); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/tensor_core.cc b/src/tir/pass/tensor_core.cc similarity index 98% rename from src/pass/tensor_core.cc rename to src/tir/pass/tensor_core.cc index bf36b0a8ffdb..edf1400aa7e3 100644 --- a/src/pass/tensor_core.cc +++ b/src/tir/pass/tensor_core.cc @@ -21,23 +21,23 @@ * \file tensor_core.cc */ // IR Passes for TensorCore CodeGen -#include -#include +#include +#include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include #include #include "ir_util.h" -#include "../arith/compute_expr.h" -#include "../runtime/thread_storage_scope.h" +#include "../../arith/compute_expr.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace ir { +namespace tir { using namespace top; using runtime::StorageRank; @@ -498,7 +498,7 @@ class BufferAnalyser : public StmtExprVisitor { return; } auto index = rel_index[i]; - auto simplified_index = ir::Simplify(index); + auto simplified_index = tir::Simplify(index); index_visitor(simplified_index); } @@ -601,7 +601,7 @@ class BufferAnalyser : public StmtExprVisitor { index_visitor.scaling_factor_ = shape->value; } auto index = rel_index[i]; - auto simplified_index = ir::Simplify(index); + auto simplified_index = tir::Simplify(index); index_visitor(simplified_index); } } @@ -635,7 +635,7 @@ class BufferAnalyser : public StmtExprVisitor { PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + \ indexmod(factor + offset - indexmod(stride, factor), factor); - stride = ir::Simplify(stride); + stride = tir::Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -1193,5 +1193,5 @@ Stmt RewriteForTensorCore(Stmt stmt, return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/unroll_loop.cc b/src/tir/pass/unroll_loop.cc similarity index 96% rename from src/pass/unroll_loop.cc rename to src/tir/pass/unroll_loop.cc index 2658e76880a2..3d669d03f1d6 100644 --- a/src/pass/unroll_loop.cc +++ b/src/tir/pass/unroll_loop.cc @@ -22,16 +22,16 @@ * \file unroll_loop.cc */ // Unrolls the loop as in Halide pipeline. -#include -#include -#include +#include +#include +#include #include #include #include -#include "../arith/compute_expr.h" +#include "../../arith/compute_expr.h" namespace tvm { -namespace ir { +namespace tir { class LoopUnroller : public StmtExprMutator { public: @@ -157,7 +157,7 @@ class LoopUnroller : public StmtExprMutator { // returns the extent of the loop if it's a constant integer, otherwise return -1 int GetExtent(const ForNode* op) { // constant folding. - PrimExpr extent = ir::Simplify(op->extent); + PrimExpr extent = tir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); int value = -1; if (v1 != nullptr) { @@ -207,5 +207,5 @@ Stmt UnrollLoopExplicitly(Stmt stmt) { return LoopUnroller(0, 0, 0, false).Unroll(op); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/vectorize_loop.cc b/src/tir/pass/vectorize_loop.cc similarity index 99% rename from src/pass/vectorize_loop.cc rename to src/tir/pass/vectorize_loop.cc index 5e5e427b126e..d62bd1f2584e 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/tir/pass/vectorize_loop.cc @@ -21,17 +21,17 @@ * \file vectorize_loop.cc */ // Loop vectorizer as in Halide pipeline. -#include -#include -#include +#include +#include +#include #include #include #include #include -#include "../arith/compute_expr.h" +#include "../../arith/compute_expr.h" namespace tvm { -namespace ir { +namespace tir { inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { if (e.dtype().lanes() == lanes) return e; @@ -556,5 +556,5 @@ Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/verify_compact_buffer.cc b/src/tir/pass/verify_compact_buffer.cc similarity index 90% rename from src/pass/verify_compact_buffer.cc rename to src/tir/pass/verify_compact_buffer.cc index 95dcbddedddc..0ca2e1470058 100644 --- a/src/pass/verify_compact_buffer.cc +++ b/src/tir/pass/verify_compact_buffer.cc @@ -21,16 +21,16 @@ * \file verify_compact_buffer.cc * \brief Verify if there was any compact buffer bound to a statement. */ -#include -#include -#include -#include +#include +#include +#include +#include #include #include namespace tvm { -namespace ir { +namespace tir { class VerifyBuffer : public StmtVisitor { public: @@ -55,5 +55,5 @@ bool VerifyCompactBuffer(Stmt stmt) { return verifier.Verify(stmt); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/verify_gpu_code.cc b/src/tir/pass/verify_gpu_code.cc similarity index 93% rename from src/pass/verify_gpu_code.cc rename to src/tir/pass/verify_gpu_code.cc index f9c183eb830e..f05423b7dca5 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/tir/pass/verify_gpu_code.cc @@ -26,15 +26,16 @@ #include -#include -#include +#include +#include +#include namespace tvm { -namespace ir { +namespace tir { class GPUCodeVerifier : public StmtVisitor { public: - bool Verify(tvm::Stmt stmt, + bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, @@ -94,12 +95,12 @@ class GPUCodeVerifier : public StmtVisitor { if (op->attr_key == attr::storage_scope) { std::string op_value = op->value.as()->value; if (op_value == "local") { - visited_local_buffers_.insert(op->node.as()); + visited_local_buffers_.insert(op->node.as()); } else if (op_value == "shared") { - visited_shared_buffers_.insert(op->node.as()); + visited_shared_buffers_.insert(op->node.as()); } } else if (op->attr_key == attr::thread_extent) { - Var var = op->node.as()->var; + Var var = op->node.as()->var; const auto *extent = op->value.as(); CHECK(extent); @@ -139,8 +140,8 @@ class GPUCodeVerifier : public StmtVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; + std::unordered_set visited_local_buffers_; + std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -205,5 +206,5 @@ bool VerifyGPUCode(Stmt stmt, max_thread_z); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/pass/verify_memory.cc b/src/tir/pass/verify_memory.cc similarity index 98% rename from src/pass/verify_memory.cc rename to src/tir/pass/verify_memory.cc index 899e9bc3c435..5e805f8f9560 100644 --- a/src/pass/verify_memory.cc +++ b/src/tir/pass/verify_memory.cc @@ -21,13 +21,13 @@ * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ -#include -#include -#include +#include +#include +#include namespace tvm { -namespace ir { +namespace tir { namespace { /*! @@ -192,5 +192,5 @@ bool VerifyMemory(LoweredFunc func, int device_type) { return !v.Failed(); } -} // namespace ir +} // namespace tir } // namespace tvm diff --git a/src/top/operation/compute_op.cc b/src/top/operation/compute_op.cc index a8c232824346..f325ae85002c 100644 --- a/src/top/operation/compute_op.cc +++ b/src/top/operation/compute_op.cc @@ -23,9 +23,9 @@ */ #include #include -#include -#include -#include +#include +#include +#include #include #include #include @@ -37,7 +37,7 @@ namespace tvm { namespace top { -using namespace ir; +using namespace tir; TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -50,7 +50,7 @@ TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. static void VerifyComputeOp(const ComputeOpNode *op); -inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) { +inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && (a->axis.same_as(b->axis)) && @@ -148,8 +148,8 @@ Operation ComputeOpNode::make(std::string name, n->attrs = std::move(attrs); n->axis = std::move(axis); n->body = std::move(body); - if (n->body[0]->IsInstance()) { - const ir::ReduceNode* reduce = n->body[0].as(); + if (n->body[0]->IsInstance()) { + const tir::ReduceNode* reduce = n->body[0].as(); n->reduce_axis = reduce->axis; } VerifyComputeOp(n.get()); @@ -161,8 +161,8 @@ Array ComputeOpNode::InputTensors() const { Array ret; std::unordered_set visited; for (auto& e : body) { - ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { - const ir::CallNode *call = n.as(); + tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { + const tir::CallNode *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); if (!visited.count(t)) { @@ -181,14 +181,14 @@ Operation ComputeOpNode::ReplaceInputs( CHECK_EQ(self.operator->(), this); VerifyComputeOp(this); Array arr; - if (this->body[0]->IsInstance()) { + if (this->body[0]->IsInstance()) { // Specially handle reduce so the replaced op // still share all the components PrimExpr new_reduce = top::ReplaceTensor(this->body[0], rmap); if (!new_reduce.same_as(this->body[0])) { - const ir::ReduceNode* r = new_reduce.as(); + const tir::ReduceNode* r = new_reduce.as(); for (size_t k = 0; k < this->body.size(); ++k) { - auto n = make_object(*r); + auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); arr.push_back(PrimExpr(n)); @@ -216,7 +216,7 @@ void ComputeOpNode::PropBoundToInputs( std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { - auto *call = n.as(); + auto *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); if (t->op.defined() && out_dom_map->count(t)) { @@ -250,7 +250,7 @@ void ComputeOpNode::PropBoundToInputs( } } }; - for (auto& e : body) ir::PostOrderVisit(e, fvisit); + for (auto& e : body) tir::PostOrderVisit(e, fvisit); } void BaseComputeOpNode::GatherBound( @@ -282,7 +282,7 @@ Stmt BaseComputeOpNode::BuildRealize( Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i-1); - realize = ir::RealizeNode::make(t->op, t->value_index, + realize = tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { @@ -293,10 +293,10 @@ Stmt BaseComputeOpNode::BuildRealize( Array tuple = {static_cast(i), attr->dim_align_factor, attr->dim_align_offset}; - realize = ir::AttrStmtNode::make( - t, ir::attr::buffer_dim_align, + realize = tir::AttrStmtNode::make( + t, tir::attr::buffer_dim_align, CallNode::make(DataType::Handle(), - ir::intrinsic::tvm_tuple, + tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), realize); } @@ -540,12 +540,12 @@ namespace { * must be Reduce as well; and their inputs should have the * same attribute except value_index. */ -class ComputeVerifier final : protected ir::ExprVisitor { +class ComputeVerifier final : protected tir::ExprVisitor { public: /// Special member functions //@{ explicit ComputeVerifier(const ComputeOpNode* compute) - : compute_(compute), reduce_(compute->body[0].as()) {} + : compute_(compute), reduce_(compute->body[0].as()) {} virtual ~ComputeVerifier() = default; ComputeVerifier(const ComputeVerifier&) = delete; ComputeVerifier(ComputeVerifier&&) = delete; @@ -557,7 +557,7 @@ class ComputeVerifier final : protected ir::ExprVisitor { void Run() { for (const PrimExpr e : compute_->body) { // Check for consistency of top level reductions - const ir::ReduceNode* reduce = e.as(); + const tir::ReduceNode* reduce = e.as(); CHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent " << "with being Reduce operation or not."; @@ -582,7 +582,7 @@ class ComputeVerifier final : protected ir::ExprVisitor { --level_; } - void VisitExpr_(const ir::ReduceNode* op) final { + void VisitExpr_(const tir::ReduceNode* op) final { // Check for non top level reductions CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " @@ -592,7 +592,7 @@ class ComputeVerifier final : protected ir::ExprVisitor { private: const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify - const ir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation + const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation int level_{0}; ///< Level of op being processed }; } // namespace @@ -628,13 +628,13 @@ Stmt TransformUpdate(const Stage& stage, } } for (const PrimExpr& pred : n.main_predicates) { - if (ir::ExprUseVar(pred, banned)) { + if (tir::ExprUseVar(pred, banned)) { LOG(FATAL) << "Tensorize update transform failed, the condition " << pred << " has a conflict with the reset condition"; } } - return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), + return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), update, body); } diff --git a/src/top/operation/compute_op.h b/src/top/operation/compute_op.h index 093dd221d8eb..bbe97313dfba 100644 --- a/src/top/operation/compute_op.h +++ b/src/top/operation/compute_op.h @@ -24,8 +24,8 @@ #ifndef TVM_TOP_OPERATION_COMPUTE_OP_H_ #define TVM_TOP_OPERATION_COMPUTE_OP_H_ -#include -#include +#include +#include #include #include #include diff --git a/src/top/operation/cross_thread_reduction.cc b/src/top/operation/cross_thread_reduction.cc index bf5c9b167b3a..30ee7b8cda47 100644 --- a/src/top/operation/cross_thread_reduction.cc +++ b/src/top/operation/cross_thread_reduction.cc @@ -21,13 +21,13 @@ * \brief Logics related to cross thread reduction, used by ComputeOpNode. * \file cross_thread_reduction.cc */ -#include +#include #include "compute_op.h" #include "op_util.h" namespace tvm { namespace top { -using namespace ir; +using namespace tir; Stmt MakeCrossThreadReduction( const ComputeOpNode* self, @@ -87,7 +87,7 @@ Stmt MakeCrossThreadReduction( Stmt reduce_body = EvaluateNode::make(CallNode::make( DataType::Handle(), - ir::intrinsic::tvm_thread_allreduce, + tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic)); reduce_body = AttrStmtNode::make( reduces[0]->combiner, diff --git a/src/top/operation/extern_op.cc b/src/top/operation/extern_op.cc index 3fc73dc38b99..276b5ebf6bc7 100644 --- a/src/top/operation/extern_op.cc +++ b/src/top/operation/extern_op.cc @@ -23,13 +23,13 @@ */ #include #include -#include +#include #include #include "op_util.h" namespace tvm { namespace top { -using namespace ir; +using namespace tir; // ExternOpNode TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -148,7 +148,7 @@ Stmt ExternOpNode::BuildRealize( Range::make_by_min_extent( make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = ir::RealizeNode::make( + realize_body = tir::RealizeNode::make( t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); } diff --git a/src/top/operation/hybrid_op.cc b/src/top/operation/hybrid_op.cc index d959826772f3..f4e3850650a3 100644 --- a/src/top/operation/hybrid_op.cc +++ b/src/top/operation/hybrid_op.cc @@ -23,10 +23,10 @@ */ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include @@ -35,7 +35,7 @@ namespace tvm { namespace top { -using namespace ir; +using namespace tir; // HybridOpNode TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -92,8 +92,8 @@ Array HybridOpNode::InputTensors() const { } std::unordered_set visited; Array curr_inputs; - ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { - const ir::CallNode *call = n.as(); + tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { + const tir::CallNode *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); if (orig_inputs.count(t) && !visited.count(t)) { @@ -169,7 +169,7 @@ Stmt HybridOpNode::BuildRealize( Range::make_by_min_extent( make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = ir::RealizeNode::make( + realize_body = tir::RealizeNode::make( t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); } @@ -249,7 +249,7 @@ Stmt ApplyLoopShapes(const Stage &stage, if (op->loop_var.get() == parent) { std::unordered_map rmap; rmap[op->loop_var.get()] = inner + outer * factor; - Stmt ret = ir::Substitute(op->body, rmap); + Stmt ret = tir::Substitute(op->body, rmap); PrimExpr cond = likely(outer * factor < (op->extent - inner)); ret = IfThenElseNode::make(cond, ret); ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent, @@ -285,13 +285,13 @@ Stmt ApplyLoopShapes(const Stage &stage, rmap[op->loop_var.get()] = indexmod(parent, op->extent); extent = op->extent; fused = true; - return ir::Substitute(op->body, rmap); + return tir::Substitute(op->body, rmap); } else if (op->loop_var.get() == outer) { under_outer = true; Stmt body = this->VisitStmt(op->body); std::unordered_map rmap; rmap[op->loop_var.get()] = indexdiv(parent, extent); - body = ir::Substitute(body, rmap); + body = tir::Substitute(body, rmap); under_outer = false; return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api, body); @@ -299,7 +299,7 @@ Stmt ApplyLoopShapes(const Stage &stage, Stmt body = this->VisitStmt(op->body); std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); - body = ir::Substitute(body, rmap); + body = tir::Substitute(body, rmap); extent = extent * op->extent; return body; } @@ -342,7 +342,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, } std::unordered_map rmap; rmap[op->loop_var.get()] = iter_var; - Stmt body = ir::Substitute(op->body, rmap); + Stmt body = tir::Substitute(op->body, rmap); return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body); } else { return ForNode::make(op->loop_var, op->min, op->extent, @@ -476,16 +476,16 @@ std::vector GatherLoopVars(Stmt stmt) { } // replacer to replace tensors' usage in Provide -class ProviderReplacer : public ir::StmtMutator { +class ProviderReplacer : public tir::StmtMutator { public: explicit ProviderReplacer(const std::unordered_map &vmap) : vmap_(vmap) {} - Stmt VisitStmt_(const ir::ProvideNode* op) final { + Stmt VisitStmt_(const tir::ProvideNode* op) final { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { - Stmt ret = ir::ProvideNode::make( + Stmt ret = tir::ProvideNode::make( it->second->op, it->second->value_index, op->value, op->args); found = true; return this->VisitStmt(ret); diff --git a/src/top/operation/hybrid_op.h b/src/top/operation/hybrid_op.h index c4586cb880d3..d0de706c42a3 100644 --- a/src/top/operation/hybrid_op.h +++ b/src/top/operation/hybrid_op.h @@ -24,7 +24,7 @@ #ifndef TVM_TOP_OPERATION_HYBRID_OP_H_ #define TVM_TOP_OPERATION_HYBRID_OP_H_ -#include +#include #include #include @@ -32,8 +32,8 @@ #include #include "../schedule/message_passing.h" -#include "../../pass/ir_util.h" -#include "../../pass/arg_binder.h" +#include "../../tir/pass/ir_util.h" +#include "../../tir/pass/arg_binder.h" namespace tvm { namespace top { diff --git a/src/top/operation/op_util.cc b/src/top/operation/op_util.cc index fcf8318d21b8..47ad82f305df 100644 --- a/src/top/operation/op_util.cc +++ b/src/top/operation/op_util.cc @@ -21,9 +21,9 @@ * \brief Utility to make loop nest. * \file op_util.cc */ -#include -#include -#include +#include +#include +#include #include #include #include "op_util.h" @@ -34,7 +34,7 @@ namespace tvm { namespace top { using namespace arith; -using namespace ir; +using namespace tir; std::vector > MakeLoopNest(const Stage& stage, @@ -101,7 +101,7 @@ MakeLoopNest(const Stage& stage, pvalue = make_const(DataType::Int(32), 1); } nest[i + 1].emplace_back( - AttrStmtNode::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); + AttrStmtNode::make(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { @@ -131,7 +131,7 @@ MakeLoopNest(const Stage& stage, for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { nest[i + 1].emplace_back( AttrStmtNode::make(it_attr->prefetch_data[j], - ir::attr::prefetch_scope, + tir::attr::prefetch_scope, it_attr->prefetch_offset[j], no_op)); } } @@ -143,7 +143,7 @@ MakeLoopNest(const Stage& stage, CHECK(is_positive_const(dom->extent)); // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op)); + AttrStmtNode::make(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); value_map[iv] = var; } else if (bind_iv->thread_tag == "pipeline") { // pipeline marker. @@ -151,14 +151,14 @@ MakeLoopNest(const Stage& stage, CHECK(is_one(dom->extent)); // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op)); + AttrStmtNode::make(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); value_map[iv] = dom->min; } else { // Always restrict threaded IterVar to starts from 0. CHECK(is_zero(dom->min)); // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op)); + AttrStmtNode::make(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else { @@ -186,17 +186,17 @@ std::vector MakeIfNest(const std::vector& predicates) { } // replacer to replace tensors -class TensorReplacer : public ir::StmtExprMutator { +class TensorReplacer : public tir::StmtExprMutator { public: explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - PrimExpr VisitExpr_(const ir::CallNode* op) final { - if (op->call_type == ir::CallNode::Halide) { + PrimExpr VisitExpr_(const tir::CallNode* op) final { + if (op->call_type == tir::CallNode::Halide) { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { - PrimExpr ret = ir::CallNode::make( + PrimExpr ret = tir::CallNode::make( op->dtype, it->second->op->name, op->args, op->call_type, it->second->op, it->second->value_index); found = true; @@ -233,10 +233,10 @@ Stmt Substitute(Stmt s, for (const auto& kv : value_map) { init[kv.first->var.get()] = kv.second; } - return ir::Substitute(s, init); + return tir::Substitute(s, init); } -IterVarType ForTypeToIterVarType(ir::ForType for_type) { +IterVarType ForTypeToIterVarType(tir::ForType for_type) { switch (for_type) { case ForType::Serial: return kDataPar; @@ -251,7 +251,7 @@ IterVarType ForTypeToIterVarType(ir::ForType for_type) { } } -ir::ForType IterVarTypeToForType(IterVarType iter_type) { +tir::ForType IterVarTypeToForType(IterVarType iter_type) { switch (iter_type) { case kDataPar: return ForType::Serial; diff --git a/src/top/operation/op_util.h b/src/top/operation/op_util.h index babdabc2c46a..bc6e49ec22d4 100644 --- a/src/top/operation/op_util.h +++ b/src/top/operation/op_util.h @@ -24,19 +24,19 @@ #ifndef TVM_TOP_OPERATION_OP_UTIL_H_ #define TVM_TOP_OPERATION_OP_UTIL_H_ -#include +#include #include #include #include #include -#include "../../pass/ir_util.h" -#include "../../pass/arg_binder.h" +#include "../../tir/pass/ir_util.h" +#include "../../tir/pass/arg_binder.h" #include "../schedule/message_passing.h" namespace tvm { namespace top { -using ir::MergeNest; +using tir::MergeNest; /*! * \brief Build loop nest for stage. @@ -94,13 +94,13 @@ Stmt Substitute(Stmt stmt, * \brief Converts Halide ForType to its corresponding IterVarType * \param for_type The ForType to be converted */ -IterVarType ForTypeToIterVarType(ir::ForType for_type); +IterVarType ForTypeToIterVarType(tir::ForType for_type); /*! * \brief Converts IterVarType to its corresponding Halide ForType * \param iter_type The IterVarType to be converted */ -ir::ForType IterVarTypeToForType(IterVarType iter_type); +tir::ForType IterVarTypeToForType(IterVarType iter_type); } // namespace top } // namespace tvm diff --git a/src/top/operation/scan_op.cc b/src/top/operation/scan_op.cc index 8f54872bc7ac..2ddb6bd11cc8 100644 --- a/src/top/operation/scan_op.cc +++ b/src/top/operation/scan_op.cc @@ -22,14 +22,14 @@ * \file scan_op.cc */ #include -#include -#include +#include +#include #include "op_util.h" #include "../schedule/graph.h" namespace tvm { namespace top { -using namespace ir; +using namespace tir; TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -39,7 +39,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) TVM_REGISTER_NODE_TYPE(ScanOpNode); inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) { - return is_zero(ir::Simplify(lhs - rhs)); + return is_zero(tir::Simplify(lhs - rhs)); } int ScanOpNode::num_outputs() const { @@ -230,7 +230,7 @@ void ScanOpNode::GatherBound( Range sdom = this->scan_axis->dom; Range r = arith::Union(time_dom).cover_range(sdom); (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent( - sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); + sdom->min, tir::Simplify(r->extent + r->min - sdom->min)); Map fix_pt = ScanFixPointAnalysis(self); // Update for spatial axis. size_t sp_idx = 0; @@ -240,7 +240,7 @@ void ScanOpNode::GatherBound( IterVar sp_ax = this->spatial_axis_[sp_idx]; CHECK(!out_dom_map->count(sp_ax)); CHECK(fix_pt.count(sp_ax)); - if (fix_pt[sp_ax].as()->value) { + if (fix_pt[sp_ax].as()->value) { // fix point, we can slice it. (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom); } else { @@ -258,7 +258,7 @@ Stmt ScanOpNode::BuildRealize( CHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); Range tdom = Range::make_by_min_extent( - 0, ir::Simplify(sdom->extent + sdom->min)); + 0, tir::Simplify(sdom->extent + sdom->min)); Stmt ret = body; size_t sp_idx = 0; for (size_t i = 0; i < update.size(); ++i) { @@ -270,7 +270,7 @@ Stmt ScanOpNode::BuildRealize( IterVar sp_ax = this->spatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = ir::RealizeNode::make(t->op, t->value_index, t->dtype, + ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), ret); } return ret; diff --git a/src/top/operation/tensor_compute_op.cc b/src/top/operation/tensor_compute_op.cc index 49b00fc5d533..2cc821928809 100644 --- a/src/top/operation/tensor_compute_op.cc +++ b/src/top/operation/tensor_compute_op.cc @@ -23,8 +23,8 @@ */ #include #include -#include -#include +#include +#include #include #include "./op_util.h" #include "./compute_op.h" @@ -32,7 +32,7 @@ namespace tvm { namespace top { -using namespace ir; +using namespace tir; // TensorComputeOpNode TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -154,9 +154,9 @@ Stmt TensorComputeOpNode::BuildProvide( tuple.push_back(region[i]->extent); } input_bind_nest.emplace_back(AttrStmtNode::make( - bind_spec, ir::attr::buffer_bind_scope, + bind_spec, tir::attr::buffer_bind_scope, CallNode::make(DataType::Handle(), - ir::intrinsic::tvm_tuple, + tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } @@ -180,15 +180,15 @@ Stmt TensorComputeOpNode::BuildProvide( } output_bind_nest.emplace_back(AttrStmtNode::make( - bind_spec, ir::attr::buffer_bind_scope, + bind_spec, tir::attr::buffer_bind_scope, CallNode::make(DataType::Handle(), - ir::intrinsic::tvm_tuple, + tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; - ir::ArgBinder binder(&vmap); + tir::ArgBinder binder(&vmap); // Map the expressions passed in the call to the TensorIntrin, to the placeholder // variables @@ -215,7 +215,7 @@ Stmt TensorComputeOpNode::BuildProvide( << "Normal store op for intrin " << this << " is not defined"; Stmt body = MergeNest(output_bind_nest, this->intrin->body); body = MergeNest(input_bind_nest, body); - body = ir::Substitute(body, vmap); + body = tir::Substitute(body, vmap); body = MergeNest(binder.asserts(), body); body = top::Substitute(body, n.main_vmap); Stmt ret = MergeNest(nest, body); @@ -243,7 +243,7 @@ Stmt TensorComputeOpNode::BuildProvide( // The update Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update); update = MergeNest(input_bind_nest, update); - update = ir::Substitute(update, vmap); + update = tir::Substitute(update, vmap); update = MergeNest(binder.asserts(), update); update = top::Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); @@ -257,7 +257,7 @@ Stmt TensorComputeOpNode::BuildProvide( this->intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); - update = ir::Substitute(update, vmap); + update = tir::Substitute(update, vmap); update = MergeNest(binder.asserts(), update); update = top::Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); diff --git a/src/top/operation/tensorize.cc b/src/top/operation/tensorize.cc index e7f6b33608cb..b096ee6c776e 100644 --- a/src/top/operation/tensorize.cc +++ b/src/top/operation/tensorize.cc @@ -21,9 +21,9 @@ * \brief Logics related to tensorize, used by ComputeOpNode. * \file tensorize.cc */ -#include -#include -#include +#include +#include +#include #include #include "op_util.h" @@ -33,7 +33,7 @@ namespace tvm { namespace top { -using namespace ir; +using namespace tir; // Detect the region of input and output to be tensrized. // out_dom: the domain of root iter vars in output op @@ -144,13 +144,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, } } for (const PrimExpr& pred : n.main_predicates) { - if (ir::ExprUseVar(pred, banned)) { + if (tir::ExprUseVar(pred, banned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } } for (const PrimExpr& pred : n.init_predicates) { - if (ir::ExprUseVar(pred, banned)) { + if (tir::ExprUseVar(pred, banned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } @@ -390,9 +390,9 @@ Stmt MakeTensorize(const ComputeOpNode* self, tuple.push_back(r->extent); } input_bind_nest.emplace_back(AttrStmtNode::make( - bind_spec, ir::attr::buffer_bind_scope, + bind_spec, tir::attr::buffer_bind_scope, CallNode::make(DataType::Handle(), - ir::intrinsic::tvm_tuple, + tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // output binding @@ -412,14 +412,14 @@ Stmt MakeTensorize(const ComputeOpNode* self, Buffer buffer = intrin->buffers[i]; Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmtNode::make( - bind_spec, ir::attr::buffer_bind_scope, + bind_spec, tir::attr::buffer_bind_scope, CallNode::make(DataType::Handle(), - ir::intrinsic::tvm_tuple, + tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; - ir::ArgBinder binder(&vmap); + tir::ArgBinder binder(&vmap); CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) << "Tensorization fail: reduction axis size do not match"; size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); @@ -450,7 +450,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, << "Normal store op for intrin " << intrin << " is not defined"; Stmt body = MergeNest(output_bind_nest, intrin->body); body = MergeNest(input_bind_nest, body); - body = ir::Substitute(body, vmap); + body = tir::Substitute(body, vmap); body = MergeNest(binder.asserts(), body); body = top::Substitute(body, n.main_vmap); return MergeNest(nest, body); @@ -477,7 +477,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, // The update Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); update = MergeNest(input_bind_nest, update); - update = ir::Substitute(update, vmap); + update = tir::Substitute(update, vmap); update = MergeNest(binder.asserts(), update); update = top::Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); @@ -491,7 +491,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); - update = ir::Substitute(update, vmap); + update = tir::Substitute(update, vmap); update = MergeNest(binder.asserts(), update); update = top::Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); diff --git a/src/top/schedule/auto_inline_elem_wise.cc b/src/top/schedule/auto_inline_elem_wise.cc index 9b088132ae1b..d889a174108c 100644 --- a/src/top/schedule/auto_inline_elem_wise.cc +++ b/src/top/schedule/auto_inline_elem_wise.cc @@ -22,14 +22,14 @@ */ #include #include -#include +#include namespace tvm { namespace top { -using namespace ir; +using namespace tir; -class ElemWiseDetector : public ir::ExprVisitor { +class ElemWiseDetector : public tir::ExprVisitor { public: explicit ElemWiseDetector(Array axis) : axis_(axis) {} diff --git a/src/top/schedule/bound.cc b/src/top/schedule/bound.cc index 8fffc5398411..97b99301dc95 100644 --- a/src/top/schedule/bound.cc +++ b/src/top/schedule/bound.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include "graph.h" diff --git a/src/top/schedule/graph.cc b/src/top/schedule/graph.cc index 62df8422afdb..1e44face1e2b 100644 --- a/src/top/schedule/graph.cc +++ b/src/top/schedule/graph.cc @@ -21,8 +21,8 @@ * \file graph.cc * \brief Utilities to get information about schedule graph. */ -#include -#include +#include +#include #include #include #include @@ -33,11 +33,11 @@ namespace tvm { namespace top { // key to specific tensor dimension. struct TensorDimKey { - ir::FunctionRef f; + tir::FunctionRef f; int value_index; int dim; TensorDimKey() {} - TensorDimKey(const ir::CallNode* op, int dim) + TensorDimKey(const tir::CallNode* op, int dim) : f(op->func), value_index(op->value_index), dim(dim) { } TensorDimKey(const Tensor& t, int dim) @@ -263,7 +263,7 @@ ReachGraph GetReachGraph(const Array& ops) { reach[TensorDimKey(t, i)] = {}; } auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { - const ir::CallNode *call = n.as(); + const tir::CallNode *call = n.as(); if (call != nullptr && call->func.defined()) { if (!bset.count(call->func.get())) return; for (size_t i = 0; i < call->args.size(); ++i) { @@ -275,12 +275,12 @@ ReachGraph GetReachGraph(const Array& ops) { reach[it->second].push_back(dkey); } }; - ir::PostOrderVisit(call->args[i], fpush); + tir::PostOrderVisit(call->args[i], fpush); } } }; for (auto& e : compute_op->body) { - ir::PostOrderVisit(e, fvisit); + tir::PostOrderVisit(e, fvisit); } } } @@ -353,7 +353,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( const ObjectRef& n) { - const ir::CallNode *call = n.as(); + const tir::CallNode *call = n.as(); if (call != nullptr && call->func.defined()) { for (size_t i = 0; i < call->args.size(); ++i) { auto it = vmap.find(call->args[i].get()); @@ -372,7 +372,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } }; for (auto& e : compute_op->body) { - ir::PostOrderVisit(e, fvisit); + tir::PostOrderVisit(e, fvisit); } } } diff --git a/src/top/schedule/graph.h b/src/top/schedule/graph.h index f379f98ee4eb..8d2d8566d432 100644 --- a/src/top/schedule/graph.h +++ b/src/top/schedule/graph.h @@ -24,7 +24,7 @@ #ifndef TVM_TOP_SCHEDULE_GRAPH_H_ #define TVM_TOP_SCHEDULE_GRAPH_H_ -#include +#include #include #include #include diff --git a/src/top/schedule/message_passing.cc b/src/top/schedule/message_passing.cc index 47326816acb6..a979df42c636 100644 --- a/src/top/schedule/message_passing.cc +++ b/src/top/schedule/message_passing.cc @@ -22,15 +22,15 @@ * \brief The message passing domain. */ #include -#include -#include +#include +#include #include "message_passing.h" #include "../../arith/compute_expr.h" namespace tvm { namespace top { -using namespace ir; +using namespace tir; void Update(std::unordered_map* p_state, const IterVar& iv, diff --git a/src/top/schedule/message_passing.h b/src/top/schedule/message_passing.h index 42b72a729871..beaf30139def 100644 --- a/src/top/schedule/message_passing.h +++ b/src/top/schedule/message_passing.h @@ -25,7 +25,7 @@ #ifndef TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_ #define TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_ -#include +#include #include #include #include diff --git a/src/top/schedule/schedule_dataflow_rewrite.cc b/src/top/schedule/schedule_dataflow_rewrite.cc index 5f9ba3984309..9ffb5119d737 100644 --- a/src/top/schedule/schedule_dataflow_rewrite.cc +++ b/src/top/schedule/schedule_dataflow_rewrite.cc @@ -22,11 +22,11 @@ */ #include #include -#include -#include +#include +#include #include #include "message_passing.h" -#include "../../pass/ir_util.h" +#include "../../tir/pass/ir_util.h" #include "../../arith/compute_expr.h" namespace tvm { @@ -42,7 +42,7 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { } // The replacer of cache. -class VarReplacer : public ir::StmtExprMutator { +class VarReplacer : public tir::StmtExprMutator { public: explicit VarReplacer( const std::unordered_map& vsub) @@ -53,12 +53,12 @@ class VarReplacer : public ir::StmtExprMutator { return GetRef(op); } - ir::CommReducer MutateCommReducer(ir::CommReducer combiner) { + tir::CommReducer MutateCommReducer(tir::CommReducer combiner) { // Replace free variables in combiner - auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) { + auto new_identity = tir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) { return this->VisitExpr(e); }); - auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) { + auto new_result = tir::UpdateArray(combiner->result, [this] (const PrimExpr& e) { return this->VisitExpr(e); }); @@ -66,19 +66,19 @@ class VarReplacer : public ir::StmtExprMutator { combiner->identity_element.same_as(new_result)) { return combiner; } else { - return ir::CommReducerNode::make( + return tir::CommReducerNode::make( combiner->lhs, combiner->rhs, new_result, new_identity); } } - PrimExpr VisitExpr_(const ir::ReduceNode* op) final { + PrimExpr VisitExpr_(const tir::ReduceNode* op) final { PrimExpr new_e = StmtExprMutator::VisitExpr_(op); - const ir::ReduceNode* new_reduce = new_e.as(); - ir::CommReducer new_combiner = MutateCommReducer(op->combiner); + const tir::ReduceNode* new_reduce = new_e.as(); + tir::CommReducer new_combiner = MutateCommReducer(op->combiner); if (op->combiner.same_as(new_combiner)) { return new_e; } else { - return ir::ReduceNode::make( + return tir::ReduceNode::make( new_combiner, new_reduce->source, new_reduce->axis, @@ -93,16 +93,16 @@ class VarReplacer : public ir::StmtExprMutator { PrimExpr InjectPredicate(const Array& predicates, PrimExpr body) { - using ir::ReduceNode; - using ir::SelectNode; + using tir::ReduceNode; + using tir::SelectNode; if (predicates.size() == 0) return body; const ReduceNode* reduce = body.as(); if (reduce) { auto n = make_object(*reduce); - n->condition = n->condition && arith::ComputeReduce(predicates, PrimExpr()); + n->condition = n->condition && arith::ComputeReduce(predicates, PrimExpr()); return PrimExpr(n); } - return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), + return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), body, make_zero(body.dtype())); } @@ -130,7 +130,7 @@ void ReplaceDataFlow(const Array& stages, } } -inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) { +inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && (a->axis.same_as(b->axis)) && @@ -314,18 +314,18 @@ Array CacheWriteWithReLayout(Schedule sch, PrimExpr body; Array body_list; - const ir::ReduceNode* first_reduce = nullptr; + const tir::ReduceNode* first_reduce = nullptr; for (auto cbody : compute->body) { body = VarReplacer(vsub)(cbody); body = InjectPredicate(predicates, body); body = VarReplacer(vsub2newvar)(body); // Reduce nodes in ONE computeOp must be the same except value_index // This is right only if the original body ensures Reduce nodes are the same - if (body->IsInstance()) { - const ir::ReduceNode* reduce_body = body.as(); + if (body->IsInstance()) { + const tir::ReduceNode* reduce_body = body.as(); if (first_reduce != nullptr) { CHECK(ReduceEqual(reduce_body, first_reduce)); - body = ir::ReduceNode::make(first_reduce->combiner, + body = tir::ReduceNode::make(first_reduce->combiner, first_reduce->source, first_reduce->axis, first_reduce->condition, @@ -573,25 +573,25 @@ void InjectInline(ScheduleNode* sch) { if (!new_body[j].size()) { new_body[j] = compute->body; } - if (new_body[j][0]->IsInstance()) { + if (new_body[j][0]->IsInstance()) { // specially handle reduction inline for multiplre reductions. - const ir::ReduceNode* reduce = new_body[j][0].as(); + const tir::ReduceNode* reduce = new_body[j][0].as(); for (size_t k = 1; k < new_body[j].size(); ++k) { - const ir::ReduceNode* reduce_ = new_body[j][k].as(); + const tir::ReduceNode* reduce_ = new_body[j][k].as(); CHECK(reduce_); CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should " << "have the same attribute except value_index"; } - PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]), - stage->op, args, body).as()->value; + PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][0]), + stage->op, args, body).as()->value; if (!new_value.same_as(new_body[j][0])) { changed[j] = true; - const ir::ReduceNode* r = new_value.as(); + const tir::ReduceNode* r = new_value.as(); CHECK_EQ(new_body[j].size(), r->source.size()); CHECK(r != nullptr); for (size_t k = 0; k < new_body[j].size(); ++k) { - auto n = make_object(*r); + auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); new_body[j].Set(k, PrimExpr(n)); @@ -599,8 +599,8 @@ void InjectInline(ScheduleNode* sch) { } } else { for (size_t k = 0; k < new_body[j].size(); ++k) { - PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]), - stage->op, args, body).as()->value; + PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][k]), + stage->op, args, body).as()->value; if (!new_value.same_as(new_body[j][k])) { new_body[j].Set(k, new_value); changed[j] = true; @@ -611,7 +611,7 @@ void InjectInline(ScheduleNode* sch) { if (!new_hybrid_body[j].defined()) { new_hybrid_body[j] = hybrid->body; } - Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body); + Stmt new_stmt = tir::Inline(new_hybrid_body[j], stage->op, args, body); if (!new_stmt.same_as(new_hybrid_body[j])) { new_hybrid_body[j] = new_stmt; hybrid_changed[j] = true; @@ -677,7 +677,7 @@ Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) { (*this)->InvalidateCache(); - using ir::ReduceNode; + using tir::ReduceNode; CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis"; Stage reduce_stage = operator[](tensor->op); @@ -761,7 +761,7 @@ Array Schedule::rfactor(const Tensor& tensor, const ReduceNode* reduce = compute_op->body[idx].as(); CHECK(reduce) << "Can only rfactor non-inline reductions"; predicates.push_back(reduce->condition); - PrimExpr predicate = likely(arith::ComputeReduce(predicates, PrimExpr())); + PrimExpr predicate = likely(arith::ComputeReduce(predicates, PrimExpr())); std::unordered_map vsub; @@ -785,7 +785,7 @@ Array Schedule::rfactor(const Tensor& tensor, } } VarReplacer replacer(vsub); - Array new_source = ir::UpdateArray(reduce->source, + Array new_source = tir::UpdateArray(reduce->source, [&replacer] (const PrimExpr& e) { return replacer(e); }); PrimExpr new_pred = replacer(predicate); diff --git a/src/top/schedule/schedule_lang.cc b/src/top/schedule/schedule_lang.cc index 55235305d4c1..10d5ddc48b7b 100644 --- a/src/top/schedule/schedule_lang.cc +++ b/src/top/schedule/schedule_lang.cc @@ -405,7 +405,7 @@ Stage& Stage::pragma(IterVar var, } else { UpdateIterVarAttr( operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { - n->pragma_keys.push_back(ir::StringImmNode::make(pragma_type)); + n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type)); n->pragma_values.push_back(pragma_value); }); } diff --git a/src/top/schedule/schedule_ops.cc b/src/top/schedule/schedule_ops.cc index 1176d82eaead..2bfe96bfccd0 100644 --- a/src/top/schedule/schedule_ops.cc +++ b/src/top/schedule/schedule_ops.cc @@ -20,9 +20,9 @@ /*! * \file schedule_ops.cc */ -#include -#include -#include +#include +#include +#include #include #include #include @@ -30,12 +30,12 @@ #include #include "graph.h" #include "../operation/op_util.h" -#include "../../pass/ir_util.h" +#include "../../tir/pass/ir_util.h" namespace tvm { namespace top { -using namespace ir; +using namespace tir; Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_map, @@ -47,7 +47,7 @@ Stmt MakePipeline(const Stage& s, } if (s->double_buffer) { producer = AttrStmtNode::make( - s->op, ir::attr::double_buffer_scope, 1, producer); + s->op, tir::attr::double_buffer_scope, 1, producer); } Stmt pipeline = producer; @@ -58,13 +58,13 @@ Stmt MakePipeline(const Stage& s, pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. pipeline = AttrStmtNode::make( - s->op, ir::attr::realize_scope, + s->op, tir::attr::realize_scope, StringImmNode::make(s->scope), pipeline); if (s->is_opengl) { pipeline = AttrStmtNode::make( - s->op, ir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline); + s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline); } return pipeline; } @@ -198,7 +198,7 @@ class SchedulePostProc : public StmtExprMutator { // delete duplicated thread extent attr auto it = thread_extent_scope_.find(op->node.get()); if (it != thread_extent_scope_.end()) { - CHECK(is_zero(ir::Simplify(it->second - op->value))); + CHECK(is_zero(tir::Simplify(it->second - op->value))); return this->VisitStmt(op->body); } else { thread_extent_scope_[op->node.get()] = op->value; @@ -206,8 +206,8 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == ir::attr::realize_scope || - op->attr_key == ir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::realize_scope || + op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { @@ -218,7 +218,7 @@ class SchedulePostProc : public StmtExprMutator { return this->VisitStmt(op->body); } } - } else if (op->attr_key == ir::attr::buffer_bind_scope) { + } else if (op->attr_key == tir::attr::buffer_bind_scope) { Array tuple = Downcast >(op->node); Tensor tensor = Downcast(tuple[1]); auto it = replace_op_.find(tensor->op.get()); @@ -231,7 +231,7 @@ class SchedulePostProc : public StmtExprMutator { return this->VisitStmt(op->body); } } - } else if (op->attr_key == ir::attr::buffer_dim_align) { + } else if (op->attr_key == tir::attr::buffer_dim_align) { Tensor tensor = Downcast(op->node); auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { diff --git a/src/top/tensor.cc b/src/top/tensor.cc index c8e3aeaa4e4c..c848cc4b4367 100644 --- a/src/top/tensor.cc +++ b/src/top/tensor.cc @@ -27,6 +27,21 @@ namespace tvm { namespace top { + +IterVar thread_axis(Range dom, std::string tag) { + return IterVarNode::make( + dom, Var(tag), kThreadIndex, tag); +} + +IterVar reduce_axis(Range dom, std::string name) { + return IterVarNode::make( + dom, Var(name), kCommReduce); +} + +Var var(std::string name_hint, DataType t) { + return Var(name_hint, t); +} + // Tensor PrimExpr Tensor::operator()(Array indices) const { Array arr(indices.begin(), indices.end()); @@ -34,7 +49,7 @@ PrimExpr Tensor::operator()(Array indices) const { } PrimExpr Tensor::operator()(Array indices) const { - using ir::CallNode; + using tir::CallNode; if (ndim() != 0) { CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index b87576e75094..ccf1b251482f 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -20,8 +20,8 @@ #include #include #include -#include -#include +#include +#include namespace tvm { namespace test { @@ -42,7 +42,7 @@ struct TestAttrs : public AttrsNode { .describe("name of the field"); TVM_ATTR_FIELD(expr) .describe("expression field") - .set_default(make_const(DataType::Int(32), 1)); + .set_default(tir::make_const(DataType::Int(32), 1)); TVM_ATTR_FIELD(learning_rate) .describe("learning_rate") .set_default(0.1); @@ -80,7 +80,7 @@ TEST(Attrs, Basic) { n->InitBySeq("name", "xxx", "expr", 128); CHECK_EQ(n->name, "xxx"); CHECK_EQ(n->axis, 10); - CHECK_EQ(n->expr.as()->value, 128); + CHECK_EQ(n->expr.as()->value, 128); // Check docstring std::ostringstream os; n->PrintDocString(os); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 5988b2ad2b0a..3e6ef2138625 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -19,13 +19,14 @@ #include #include -#include +#include #include #include #include #include using namespace tvm; +using namespace tvm::tir; using namespace tvm::runtime; class TestErrorSwitch { diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index af8ede396edc..61fd726061f2 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -23,6 +23,7 @@ TEST(Expr, Basic) { using namespace tvm; + using namespace tvm::tir; Var x("x"); auto z = max(x + 1 + 2, 100); ObjectRef tmp = z; @@ -36,9 +37,10 @@ TEST(Expr, Basic) { TEST(ExprNodeRef, Basic) { using namespace tvm; + using namespace tvm::tir; Var x("x"); PrimExpr z = max(x + 1 + 2, 100); - const ir::MaxNode* op = z.as(); + const tir::MaxNode* op = z.as(); CHECK(GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 178f582b94f1..3941de5eef17 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -19,14 +19,15 @@ #include #include -#include -#include +#include +#include #include -#include +#include +#include TEST(IRF, Basic) { using namespace tvm; - using namespace tvm::ir; + using namespace tvm::tir; Var x("x"); auto z = x + 1; @@ -43,12 +44,13 @@ TEST(IRF, Basic) { TEST(IRF, CountVar) { using namespace tvm; + using namespace tvm::tir; int n_var = 0; Var x("x"), y; auto z = x + 1 + y + y; - ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { - if (n.as()) ++n_var; + tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { + if (n.as()) ++n_var; }); CHECK_EQ(n_var, 2); } @@ -56,12 +58,12 @@ TEST(IRF, CountVar) { TEST(IRF, ExprTransform) { using namespace tvm; - using namespace tvm::ir; + using namespace tvm::tir; Var x("x"); auto z = x + 1; class MyExprFunctor - : public ir::ExprFunctor { + : public tir::ExprFunctor { public: int VisitExpr_(const VarNode* op, int b) final { return b; @@ -85,13 +87,13 @@ TEST(IRF, ExprTransform) { TEST(IRF, ExprVisit) { using namespace tvm; - using namespace tvm::ir; + using namespace tvm::tir; Var x("x"); auto z = x + 1; class MyVisitor - : public ir::ExprFunctor, - public ir::StmtFunctor { + : public tir::ExprFunctor, + public tir::StmtFunctor { public: int count = 0; // implementation @@ -116,7 +118,7 @@ TEST(IRF, ExprVisit) { TEST(IRF, StmtVisitor) { using namespace tvm; - using namespace tvm::ir; + using namespace tvm::tir; Var x("x"); class MyVisitor : public StmtExprVisitor { @@ -140,12 +142,12 @@ TEST(IRF, StmtVisitor) { TEST(IRF, StmtMutator) { using namespace tvm; - using namespace tvm::ir; + using namespace tvm::tir; Var x("x"); class MyVisitor - : public ir::StmtMutator, - public ir::ExprMutator { + : public tir::StmtMutator, + public tir::ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/ir_simplify_test.cc index e9f0df6493a6..f4d1a467b241 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/ir_simplify_test.cc @@ -19,25 +19,25 @@ #include #include -#include +#include #include TEST(IRSIMPLIFY, MinMax) { - auto x = tvm::var("x"); + auto x = tvm::top::var("x"); auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ; - auto e1s = tvm::ir::CanonicalSimplify(e1); - CHECK(is_zero(e1s)); + auto e1s = tvm::tir::CanonicalSimplify(e1); + CHECK(tvm::tir::is_zero(e1s)); auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); - auto e2s = tvm::ir::CanonicalSimplify(e2); - CHECK(is_zero(e2s)); + auto e2s = tvm::tir::CanonicalSimplify(e2); + CHECK(tvm::tir::is_zero(e2s)); } TEST(IRSIMPLIFY, Mul) { - auto x = tvm::var("x"); + auto x = tvm::top::var("x"); auto e = (x * x) - (x * x) ; - auto es = tvm::ir::CanonicalSimplify(e); - CHECK(is_zero(es)); + auto es = tvm::tir::CanonicalSimplify(e); + CHECK(tvm::tir::is_zero(es)); } TEST(IRSIMPLIFY, Mod) { @@ -46,9 +46,9 @@ TEST(IRSIMPLIFY, Mod) { // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = tvm::ir::CanonicalSimplify(tvm::ir::ModNode::make(x, y)); - auto es = tvm::ir::CanonicalSimplify(mod - x); - CHECK(is_zero(es)); + auto mod = tvm::tir::CanonicalSimplify(tvm::tir::ModNode::make(x, y)); + auto es = tvm::tir::CanonicalSimplify(mod - x); + CHECK(tvm::tir::is_zero(es)); } int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/tests/cpp/ir_ssa_test.cc b/tests/cpp/ir_ssa_test.cc index d1316dec7121..56f178dbcf4e 100644 --- a/tests/cpp/ir_ssa_test.cc +++ b/tests/cpp/ir_ssa_test.cc @@ -19,27 +19,27 @@ #include #include -#include +#include TEST(IRSSA, Convert) { using namespace tvm; - using namespace tvm::ir; + using namespace tvm::tir; Var x("x"), y; PrimExpr let = LetNode::make(x, 1, x + 1); auto z = EvaluateNode::make(let + let); - CHECK(!ir::VerifySSA(z)); - auto z_ssa = ir::ConvertSSA(z); - CHECK(ir::VerifySSA(z_ssa)); + CHECK(!tir::VerifySSA(z)); + auto z_ssa = tir::ConvertSSA(z); + CHECK(tir::VerifySSA(z_ssa)); } TEST(IRSSA, Basic) { - using namespace tvm::ir; + using namespace tvm::tir; using namespace tvm; Var x("x"), y; auto z = EvaluateNode::make(x + y); - CHECK(ir::VerifySSA(z)); + CHECK(tir::VerifySSA(z)); } int main(int argc, char ** argv) { diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 550c93eea3a5..349e493881af 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -21,10 +21,11 @@ #include #include #include -#include +#include TEST(PackedFunc, Basic) { using namespace tvm; + using namespace tvm::tir; using namespace tvm::runtime; int x = 0; void* handle = &x; @@ -45,6 +46,7 @@ TEST(PackedFunc, Basic) { TEST(PackedFunc, Node) { using namespace tvm; + using namespace tvm::tir; using namespace tvm::runtime; Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { @@ -129,7 +131,7 @@ TEST(PackedFunc, Expr) { // automatic conversion of int to expr PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { PrimExpr x = args[0]; - *rv = x.as()->value + 1; + *rv = x.as()->value + 1; }); int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { PackedFunc f = args[0]; @@ -198,6 +200,7 @@ TEST(TypedPackedFunc, Deduce) { TEST(PackedFunc, ObjectConversion) { using namespace tvm; + using namespace tvm::tir; using namespace tvm::runtime; TVMRetValue rv; auto x = NDArray::Empty( diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index ffd710fd83bb..5176a5d6f6f6 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -22,8 +22,9 @@ TEST(Pattern, Basic) { using namespace tvm; + using namespace tvm::tir; using namespace tvm::arith; - Var x("x"), y("y"), z("z"); + tvm::tir::Var x("x"), y("y"), z("z"); arith::PVar px, py, pz; arith::PVar pt; arith::PVar planes; @@ -38,12 +39,12 @@ TEST(Pattern, Basic) { { CHECK((px + (py + px)).Match(r)); auto rr = (px + py).Eval(); - CHECK(ir::Equal(rr, 1 + y)); - CHECK(ir::Equal(px.Eval() + py.Eval(), 1 + y)); + CHECK(tir::Equal(rr, 1 + y)); + CHECK(tir::Equal(px.Eval() + py.Eval(), 1 + y)); } { CHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1)))); - CHECK(ir::Equal(px.Eval(), x + 1)); + CHECK(tir::Equal(px.Eval(), x + 1)); } CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); CHECK((px + min(py, px)).Match(z + min(y, z))); @@ -62,8 +63,8 @@ TEST(Pattern, Basic) { CHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { CHECK(select(px >= pz, py, py + pz).Match( - ir::SelectNode::make((x + 1) >= 1, y, y + 1))); - CHECK(ir::Equal(px.Eval(), x + 1)); + tir::SelectNode::make((x + 1) >= 1, y, y + 1))); + CHECK(tir::Equal(px.Eval(), x + 1)); } // bit intrinsics { @@ -79,17 +80,17 @@ TEST(Pattern, Basic) { // select { CHECK(select(px > pz, py, py + pz).Match( - ir::SelectNode::make(x > 1, y, y + 1))); + tir::SelectNode::make(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } CHECK(!select(px > pz, py, py + pz).Match( - ir::SelectNode::make(x > 2, y, y + 1))); + tir::SelectNode::make(x > 2, y, y + 1))); CHECK(!select(px > pz, py, py).Match( - ir::SelectNode::make(x > 2, y, y + 1))); + tir::SelectNode::make(x > 2, y, y + 1))); { CHECK(select(px, py, pz).Match( - ir::SelectNode::make(x > 2, y, y + 1))); - CHECK(ir::Equal(pz.Eval(), y + 1)); + tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(tir::Equal(pz.Eval(), y + 1)); } // if_then_else { @@ -100,38 +101,38 @@ TEST(Pattern, Basic) { // cast pattern { CHECK(!cast(PConst( - DataType::Int(32)), px).Match(ir::CastNode::make(DataType::Float(64), x))); - CHECK(cast(pt, px).Match(ir::CastNode::make(DataType::Float(64), x))); + DataType::Int(32)), px).Match(tir::CastNode::make(DataType::Float(64), x))); + CHECK(cast(pt, px).Match(tir::CastNode::make(DataType::Float(64), x))); CHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); CHECK((cast(pt, px) - cast(pt, py)).Match( - ir::CastNode::make(DataType::Float(64), x) - ir::CastNode::make(DataType::Int(64), x))); - auto expr = ir::CastNode::make(DataType::Int(32), ir::CastNode::make(DataType::Float(64), x)); + tir::CastNode::make(DataType::Float(64), x) - tir::CastNode::make(DataType::Int(64), x))); + auto expr = tir::CastNode::make(DataType::Int(32), tir::CastNode::make(DataType::Float(64), x)); CHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { CHECK(ramp(px, PConst(1), planes).Match( - ir::RampNode::make(x, 1, 10))); + tir::RampNode::make(x, 1, 10))); CHECK(planes.Eval() == 10); CHECK(!ramp(px, PConst(1), planes).Match( - ir::RampNode::make(x, 2, 10))); + tir::RampNode::make(x, 2, 10))); } // broadcast pattern { CHECK(broadcast(px, planes).Match( - ir::BroadcastNode::make(x, 10))); + tir::BroadcastNode::make(x, 10))); CHECK(planes.Eval() == 10); CHECK(broadcast(px * py , planes).Match( - ir::BroadcastNode::make(x * 10, 10))); + tir::BroadcastNode::make(x * 10, 10))); } } TEST(Pattern, IntImm) { using namespace tvm; - tvm::Var tx, ty; + tir::Var tx, ty; arith::PVar c; - arith::PVar v; + arith::PVar v; { // We can match integer and Var, both of which are // special case container of Expr diff --git a/tests/cpp/simple_passes_test.cc b/tests/cpp/simple_passes_test.cc index e41b881faa43..cff0a562e085 100644 --- a/tests/cpp/simple_passes_test.cc +++ b/tests/cpp/simple_passes_test.cc @@ -19,18 +19,18 @@ #include #include -#include +#include #include TEST(SimplePasses, HasSideEffect) { using namespace tvm; - auto n = var("n"); + auto n = top::var("n"); Array shape; shape.push_back(n); auto A = top::placeholder(shape, DataType::Float(32), "A"); - CHECK(!tvm::ir::HasSideEffect(A[0])); + CHECK(!tvm::tir::HasSideEffect(A[0])); } diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index 47dfc3e77f09..33034a1b6c4a 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -55,7 +55,7 @@ inline tvm::top::Tensor broadcast_to(const tvm::top::Tensor& t, for (size_t i = 0; i < output_shape.size(); ++i) { CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); } - auto l = [&](tvm::Array ovars) { + auto l = [&](tvm::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; return tvm::top::compute( @@ -82,7 +82,7 @@ inline tvm::top::Tensor broadcast_to(const tvm::top::Tensor& t, std::string name = "T_" #Name, \ std::string tag = kElementWise) { \ auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::top::compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \ + return tvm::top::compute(A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ return l(A(i), B); \ }, name, tag); \ } \ @@ -91,7 +91,7 @@ inline tvm::top::Tensor broadcast_to(const tvm::top::Tensor& t, std::string name = "T_" #Name, \ std::string tag = kElementWise) { \ auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::top::compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \ + return tvm::top::compute(B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ return l(A, B(i)); \ }, name, tag); \ } diff --git a/topi/include/topi/cuda/dense.h b/topi/include/topi/cuda/dense.h index 637a861cd989..45bc72ffc75e 100644 --- a/topi/include/topi/cuda/dense.h +++ b/topi/include/topi/cuda/dense.h @@ -113,11 +113,13 @@ inline Schedule schedule_dense(const Target &target, const Array& outs) out = outs[0]->op.output(0); s[dense].compute_at(s[out], s[out]->op.as()->axis[1]); } - s[out].bind(s[out]->op.as()->axis[0], tvm::thread_axis(Range(), "blockIdx.y")); - s[out].bind(s[out]->op.as()->axis[1], tvm::thread_axis(Range(), "blockIdx.x")); + s[out].bind(s[out]->op.as()->axis[0], + tvm::top::thread_axis(Range(), "blockIdx.y")); + s[out].bind(s[out]->op.as()->axis[1], + tvm::top::thread_axis(Range(), "blockIdx.x")); auto tx = s[dense]->op.as()->reduce_axis[0]; - auto thread_x = tvm::thread_axis(Range(), "threadIdx.x"); + auto thread_x = tvm::top::thread_axis(Range(), "threadIdx.x"); s[dense].bind(tx, thread_x); s[dense_f].compute_at(s[dense], tx); s[dense].set_store_predicate(static_cast(thread_x) == 0); diff --git a/topi/include/topi/cuda/normalization.h b/topi/include/topi/cuda/normalization.h index 708f8d5e7bbf..f420787d7d3f 100644 --- a/topi/include/topi/cuda/normalization.h +++ b/topi/include/topi/cuda/normalization.h @@ -47,8 +47,8 @@ inline Schedule schedule_lrn(const Target &target, const Array& outs) { } Schedule s = create_schedule(out_ops); int num_thread = 64; - IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x"); - IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); + IterVar block_x = tvm::top::thread_axis(Range(), "blockIdx.x"); + IterVar thread_x = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.x"); Tensor lrn = outs[0]; Tensor sqr_sum_up = lrn->op->InputTensors()[1]; Tensor sqr_sum = sqr_sum_up->op->InputTensors()[0]; @@ -110,8 +110,8 @@ inline Schedule schedule_l2_normalize(const Target &target, const Array& traverse(outs[0]->op); int num_thread = 64; Tensor l2_normalize = outs[0]; - IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x"); - IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); + IterVar block_x = tvm::top::thread_axis(Range(), "blockIdx.x"); + IterVar thread_x = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.x"); IterVar xto, xti; s[l2_normalize].split_by_nparts(l2_normalize->op.as()->axis[1], num_thread, &xto, &xti); diff --git a/topi/include/topi/cuda/pooling.h b/topi/include/topi/cuda/pooling.h index d2a5c1f4511f..c4edadc116ed 100644 --- a/topi/include/topi/cuda/pooling.h +++ b/topi/include/topi/cuda/pooling.h @@ -68,8 +68,8 @@ inline Schedule schedule_pool(const Target &target, const Array& outs) { auto fused = detail::Fuse(s[out], s[out]->op.as()->axis); IterVar bx, tx; s[out].split(fused, num_thread, &bx, &tx); - s[out].bind(bx, tvm::thread_axis(Range(), "blockIdx.x")); - s[out].bind(tx, tvm::thread_axis(Range(), "threadIdx.x")); + s[out].bind(bx, tvm::top::thread_axis(Range(), "blockIdx.x")); + s[out].bind(tx, tvm::top::thread_axis(Range(), "threadIdx.x")); if (detail::contains(s->outputs, pool->op)) { s[OL].compute_at(s[out], tx); } else { @@ -120,10 +120,10 @@ inline Schedule schedule_global_pool(const Target &target, const Array& auto _schedule = [&](const Tensor& pool) { auto num_thread = 8; - auto block_x = tvm::thread_axis(Range(), "blockIdx.x"); - auto block_y = tvm::thread_axis(Range(), "blockIdx.y"); - auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); - auto thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y"); + auto block_x = tvm::top::thread_axis(Range(), "blockIdx.x"); + auto block_y = tvm::top::thread_axis(Range(), "blockIdx.y"); + auto thread_x = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.x"); + auto thread_y = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.y"); Tensor out; Tensor OL; if (detail::contains(s->outputs, pool->op)) { diff --git a/topi/include/topi/cuda/reduction.h b/topi/include/topi/cuda/reduction.h index 244567499d3f..9d019991f4fc 100644 --- a/topi/include/topi/cuda/reduction.h +++ b/topi/include/topi/cuda/reduction.h @@ -75,13 +75,13 @@ Schedule ScheduleReduce(const Target& target, // Don't know why. num_thread = 16; } - block_x = tvm::thread_axis(Range(), "blockIdx.x"); - thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); - thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y"); + block_x = tvm::top::thread_axis(Range(), "blockIdx.x"); + thread_x = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.x"); + thread_y = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.y"); } else { all_reduce = true; num_thread = target->max_num_threads; - thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); + thread_x = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.x"); } auto fused_reduce = detail::Fuse(out_stage, out_stage->op.as()->reduce_axis); diff --git a/topi/include/topi/cuda/softmax.h b/topi/include/topi/cuda/softmax.h index 6f12de000bf8..f3368b114310 100644 --- a/topi/include/topi/cuda/softmax.h +++ b/topi/include/topi/cuda/softmax.h @@ -70,8 +70,8 @@ inline Schedule schedule_softmax(const Target &target, const Array& outs } int num_thread = 64; - auto block_x = tvm::thread_axis(Range(), "blockIdx.x"); - auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); + auto block_x = tvm::top::thread_axis(Range(), "blockIdx.x"); + auto thread_x = tvm::top::thread_axis(Range(0, num_thread), "threadIdx.x"); if (has_exp) { s[exp].bind(exp->op.as()->axis[0], block_x); diff --git a/topi/include/topi/detail/broadcast.h b/topi/include/topi/detail/broadcast.h index 2e644eebdc8c..17524b192620 100644 --- a/topi/include/topi/detail/broadcast.h +++ b/topi/include/topi/detail/broadcast.h @@ -28,9 +28,9 @@ #include #include -#include "tvm/ir_pass.h" +#include "tvm/tir/ir_pass.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" #include "topi/detail/constant_utils.h" namespace topi { @@ -38,9 +38,9 @@ namespace detail { struct BroadcastHelper { std::deque common_shape; - std::deque all_vars; - std::deque vars1; - std::deque vars2; + std::deque all_vars; + std::deque vars1; + std::deque vars2; }; inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, @@ -54,7 +54,7 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, // TODO(@icemelon9): Need to revisit this part const VarNode* var1 = shape1[s1_size - i].as(); const VarNode* var2 = shape2[s2_size - i].as(); - bh.all_vars.push_front(tvm::Var()); + bh.all_vars.push_front(tvm::tir::Var()); if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { bh.common_shape.push_front(shape1[s1_size - i]); bh.vars1.push_front(bh.all_vars[0]); @@ -91,7 +91,7 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, auto& shape = (s1_size > s2_size) ? shape1 : shape2; auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2; for (; i <= max_size; ++i) { - bh.all_vars.push_front(tvm::Var()); + bh.all_vars.push_front(tvm::tir::Var()); bh.common_shape.push_front(shape[max_size - i]); vars.push_front(bh.all_vars[0]); } @@ -99,10 +99,10 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } inline tvm::Array InputIndexFromBroadcast( - const tvm::Array& ovars, + const tvm::Array& ovars, const tvm::top::Tensor& T, - const std::deque& my_vars, - const std::deque& all_vars) { + const std::deque& my_vars, + const std::deque& all_vars) { tvm::Array ivars; CHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. @@ -119,7 +119,7 @@ inline tvm::Array InputIndexFromBroadcast( // Only inject 0 here if we have not yet reached the dimension of I // (i.e. this must be a 1) if (!found && (ovars.size() - i) <= expected_dims) { - ivars.push_back(tvm::make_zero(ovars[i].dtype())); + ivars.push_back(tvm::tir::make_zero(ovars[i].dtype())); } } CHECK(expected_dims == ivars.size()); @@ -133,7 +133,7 @@ inline tvm::top::Tensor WithBroadcast(FBinaryExpr op, const std::string& name = "tensor", const std::string& tag = "") { auto bh = BroadcastShape(A->shape, B->shape); - auto l = [&](tvm::Array ovars) { + auto l = [&](tvm::Array ovars) { return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 210049344404..081495e891f7 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -24,12 +24,12 @@ #ifndef TOPI_DETAIL_CONSTANT_UTILS_H_ #define TOPI_DETAIL_CONSTANT_UTILS_H_ +#include +#include + #include #include -#include "tvm/expr.h" -#include "tvm/ir_pass.h" - namespace topi { namespace detail { using namespace tvm; @@ -44,7 +44,7 @@ using namespace tvm::top; */ inline bool IsConstInt(PrimExpr expr) { return - expr->IsInstance(); + expr->IsInstance(); } /*! @@ -106,7 +106,7 @@ inline std::vector GetConstInt64Values( /*! * \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again - * \note This is stronger equality check than tvm::ir::Equal + * \note This is stronger equality check than tvm::tir::Equal * * \param lhs First expreesion * \param rhs Second expreesion @@ -114,10 +114,10 @@ inline std::vector GetConstInt64Values( * \return result True if both expressions are equal, else false */ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { - bool result = tvm::ir::Equal(lhs, rhs); + bool result = tvm::tir::Equal(lhs, rhs); if (!result) { PrimExpr zero(0); - result = tvm::ir::Equal(tvm::ir::CanonicalSimplify(lhs-rhs), zero); + result = tvm::tir::Equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero); } return result; } diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 717ce4d46d33..c8db4e1ec287 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -96,7 +96,7 @@ inline Array make_extern(const Array< Array >& out_shapes, } auto body = fextern(input_placeholders, output_placeholders); - auto body_stmt = tvm::ir::EvaluateNode::make(body); + auto body_stmt = tvm::tir::EvaluateNode::make(body); auto op = ExternOpNode::make( name, tag, attrs, inputs, @@ -119,12 +119,14 @@ inline Array make_extern(const Array< Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::ir::CallNode::CallType::Intrinsic); + auto shape = tvm::tir::CallNode::make( + DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::ir::CallNode::CallType::Intrinsic); + strides = tvm::tir::CallNode::make( + DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; } @@ -136,8 +138,8 @@ inline PrimExpr pack_buffer(Buffer buf) { make_const(buf->dtype, 0), buf->elem_offset }; - return tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_array, - pack_args, tvm::ir::CallNode::CallType::Intrinsic); + return tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, + pack_args, tvm::tir::CallNode::CallType::Intrinsic); } /*! @@ -150,8 +152,8 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::ir::CallNode::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed, - args, tvm::ir::CallNode::CallType::Intrinsic); + return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, + args, tvm::tir::CallNode::CallType::Intrinsic); } } // namespace detail diff --git a/topi/include/topi/detail/pad_utils.h b/topi/include/topi/detail/pad_utils.h index 12b15413e72a..a3f82de5fc79 100644 --- a/topi/include/topi/detail/pad_utils.h +++ b/topi/include/topi/detail/pad_utils.h @@ -26,8 +26,8 @@ #include -#include "tvm/expr.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/op.h" namespace topi { namespace detail { diff --git a/topi/include/topi/detail/ravel_unravel.h b/topi/include/topi/detail/ravel_unravel.h index c8da45d918b2..bd78e22ddedd 100644 --- a/topi/include/topi/detail/ravel_unravel.h +++ b/topi/include/topi/detail/ravel_unravel.h @@ -27,7 +27,7 @@ #include #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { namespace detail { diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 46515e763226..3762e3fe99e7 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -24,11 +24,10 @@ #ifndef TOPI_ELEMWISE_H_ #define TOPI_ELEMWISE_H_ +#include +#include +#include #include - -#include "topi/tags.h" -#include "tvm/ir.h" -#include "tvm/ir_pass.h" #include "broadcast.h" namespace topi { @@ -195,8 +194,8 @@ inline Tensor sign(const Tensor& x, PrimExpr zero = make_zero(x->dtype); PrimExpr one = make_const(x->dtype, 1); PrimExpr minus_one = make_const(x->dtype, -1); - auto s1 = tvm::ir::SelectNode::make((x(i) < zero), minus_one, zero); - auto s2 = tvm::ir::SelectNode::make((x(i) > zero), one, s1); + auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero); + auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1); return s2; }, name, tag); } @@ -265,7 +264,7 @@ inline Tensor cast(const Tensor& x, if (expr.dtype().lanes() == type.lanes()) { return expr; } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { - return tvm::ir::BroadcastNode::make(expr, type.lanes()); + return tvm::tir::BroadcastNode::make(expr, type.lanes()); } } @@ -287,8 +286,8 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te std::string tag = kElementWise) { return compute(x->shape, [&](const Array& i) { - return tvm::ir::CallNode::make(type, "reinterpret", {x(i)}, - tvm::ir::CallNode::PureIntrinsic); + return tvm::tir::CallNode::make(type, "reinterpret", {x(i)}, + tvm::tir::CallNode::PureIntrinsic); }, name, tag); } diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index 7c1bad3dc9a7..cdc3a4287619 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -34,7 +34,7 @@ #include "topi/detail/ravel_unravel.h" #include "topi/detail/constant_utils.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { namespace image { @@ -260,8 +260,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, static_cast(*out_width - 1))); } - PrimExpr other_y = tvm::ir::Simplify(input->shape[1] - cone); - PrimExpr other_x = tvm::ir::Simplify(input->shape[2] - cone); + PrimExpr other_y = tvm::tir::Simplify(input->shape[1] - cone); + PrimExpr other_x = tvm::tir::Simplify(input->shape[2] - cone); return compute( out_shape, [&](const Array& indices) { @@ -337,8 +337,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, static_cast(*out_width - 1))); } - PrimExpr other_y = tvm::ir::Simplify(input->shape[2] - cone); - PrimExpr other_x = tvm::ir::Simplify(input->shape[3] - cone); + PrimExpr other_y = tvm::tir::Simplify(input->shape[2] - cone); + PrimExpr other_x = tvm::tir::Simplify(input->shape[3] - cone); return compute( out_shape, [&](const Array& indices) { diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index b86c00c60ae3..16bcaef43234 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -29,10 +29,10 @@ #include "topi/tags.h" #include "topi/detail/constant_utils.h" -#include "tvm/ir.h" -#include "tvm/ir_pass.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/ir_pass.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { using namespace tvm; @@ -68,8 +68,8 @@ inline tvm::top::Tensor relu(const tvm::top::Tensor& t, std::string tag = kElementWise) { return tvm::top::compute( t->shape, - [&](const tvm::Array& i) { - auto threshold_const = tvm::make_const(t->dtype, threshold); + [&](const tvm::Array& i) { + auto threshold_const = tvm::tir::make_const(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, name, @@ -92,10 +92,10 @@ inline tvm::top::Tensor leaky_relu(const tvm::top::Tensor& t, std::string tag = kElementWise) { return tvm::top::compute( t->shape, - [&](const tvm::Array& i) { + [&](const tvm::Array& i) { auto value = t(i); - auto calpha = tvm::make_const(value.dtype(), alpha); - return tvm::ir::SelectNode::make(value > 0, value, value * calpha); + auto calpha = tvm::tir::make_const(value.dtype(), alpha); + return tvm::tir::SelectNode::make(value > 0, value, value * calpha); }, name, tag); @@ -124,9 +124,9 @@ inline tvm::top::Tensor prelu(const tvm::top::Tensor &x, << "Wrong slope shape received."; return tvm::top::compute(x->shape, - [&](const tvm::Array &indices) { + [&](const tvm::Array &indices) { auto xval = x(indices); - return tvm::ir::SelectNode::make( + return tvm::tir::SelectNode::make( xval > 0, xval, xval * slope(indices[axis])); @@ -200,14 +200,14 @@ inline tvm::top::Tensor pad(const tvm::top::Tensor& t, output_shape.push_back(t->shape[i]); } else { output_shape.push_back( - tvm::ir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); + tvm::tir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); } } if (!pad_value.defined()) { - pad_value = tvm::make_const(t->dtype, 0); + pad_value = tvm::tir::make_const(t->dtype, 0); } - auto l = [&](tvm::Array ovars) { + auto l = [&](tvm::Array ovars) { tvm::Array indices; tvm::Array sel; tvm::Array pad_idx; @@ -223,7 +223,7 @@ inline tvm::top::Tensor pad(const tvm::top::Tensor& t, indices.push_back(ovars[i]); } if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) { - sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); + sel.push_back(tvm::tir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); } if (pad_mode == "edge") { pad_idx.push_back(tvm::if_then_else( @@ -244,10 +244,10 @@ inline tvm::top::Tensor pad(const tvm::top::Tensor& t, if (sel.size() != 0) { if (pad_mode == "constant") { return tvm::if_then_else( - detail::Map(sel, tvm::ir::AndNode::make), t(indices), pad_value); + detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value); } else if (pad_mode == "edge" || pad_mode == "reflect") { return tvm::if_then_else( - detail::Map(sel, tvm::ir::AndNode::make), t(indices), t(pad_idx)); + detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx)); } } return t(indices); @@ -293,13 +293,13 @@ inline tvm::top::Tensor conv2d_nchw(const tvm::top::Tensor& I, indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; - auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); - auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); - auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); + auto i = tvm::top::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); + auto kh = tvm::top::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); + auto kw = tvm::top::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { + auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { return tvm::sum( T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw}); @@ -344,11 +344,11 @@ inline tvm::top::Tensor conv2d_hwcn(const tvm::top::Tensor& I, I->shape[2], // B W->shape[3] // O }; - auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); - auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); - auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); + auto i = tvm::top::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); + auto kh = tvm::top::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); + auto kw = tvm::top::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w}); - auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { + auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { return tvm::sum( T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw}); @@ -396,13 +396,13 @@ inline tvm::top::Tensor depthwise_conv2d_nchw(const tvm::top::Tensor& I, indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; - auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); - auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); - auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); + auto i = tvm::top::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); + auto kh = tvm::top::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); + auto kw = tvm::top::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { + auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) * W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), {i, kh, kw}); @@ -429,13 +429,13 @@ inline tvm::top::Tensor depthwise_conv2d_nhwc(const tvm::top::Tensor& I, indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W W->shape[3], // O }; - auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); - auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); - auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); + auto i = tvm::top::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); + auto kh = tvm::top::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); + auto kw = tvm::top::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); - auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) { + auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) { return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) * W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), {kh, kw, i}); @@ -482,19 +482,19 @@ inline tvm::top::Tensor group_conv2d_ngchw(const tvm::top::Tensor& I, indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W }; - auto i = tvm::reduce_axis(tvm::Range{0, I->shape[2]}, "i"); - auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kh"); - auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[4]}, "kw"); + auto i = tvm::top::reduce_axis(tvm::Range{0, I->shape[2]}, "i"); + auto kh = tvm::top::reduce_axis(tvm::Range{0, W->shape[3]}, "kh"); + auto kw = tvm::top::reduce_axis(tvm::Range{0, W->shape[4]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::Array args) { - tvm::Var b = args[0]; - tvm::Var g = args[1]; - tvm::Var o = args[2]; - tvm::Var h = args[3]; - tvm::Var w = args[4]; + auto l = [&](tvm::Array args) { + tvm::tir::Var b = args[0]; + tvm::tir::Var g = args[1]; + tvm::tir::Var o = args[2]; + tvm::tir::Var h = args[3]; + tvm::tir::Var w = args[4]; return tvm::sum( I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), {i, kh, kw}); diff --git a/topi/include/topi/nn/batch_matmul.h b/topi/include/topi/nn/batch_matmul.h index a3bd96df77d8..e124aff6af8d 100644 --- a/topi/include/topi/nn/batch_matmul.h +++ b/topi/include/topi/nn/batch_matmul.h @@ -52,7 +52,7 @@ inline tvm::top::Tensor batch_matmul(const tvm::top::Tensor& x, auto K = x->shape[2]; auto N = y->shape[1]; - auto k = tvm::reduce_axis(Range(0, K), "k"); + auto k = tvm::top::reduce_axis(Range(0, K), "k"); auto result = tvm::top::compute( { batch, M, N }, [&](Var b, Var i, Var j) { diff --git a/topi/include/topi/nn/bias_add.h b/topi/include/topi/nn/bias_add.h index 2d6f47ca8b3e..eca9989743ec 100644 --- a/topi/include/topi/nn/bias_add.h +++ b/topi/include/topi/nn/bias_add.h @@ -30,7 +30,7 @@ #include "topi/broadcast.h" #include "topi/transform.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { namespace nn { diff --git a/topi/include/topi/nn/bnn.h b/topi/include/topi/nn/bnn.h index 7c920347b68f..16e75f1b9dba 100644 --- a/topi/include/topi/nn/bnn.h +++ b/topi/include/topi/nn/bnn.h @@ -27,7 +27,7 @@ #include #include "tvm/top/operation.h" -#include "tvm/ir_pass.h" +#include "tvm/tir/ir_pass.h" #include "topi/tags.h" #include "topi/detail/constant_utils.h" @@ -59,7 +59,7 @@ inline tvm::top::Tensor binarize_pack(const tvm::top::Tensor& data, Array oshape; for (size_t i = 0; i < n; ++i) { oshape.push_back(i == static_cast(axis) ? - tvm::ir::Simplify(indexdiv(ishape[i], 32)) : + tvm::tir::Simplify(indexdiv(ishape[i], 32)) : ishape[i]); } @@ -110,7 +110,7 @@ inline tvm::top::Tensor binary_dense(const tvm::top::Tensor& data, auto in_dim = data->shape[1]; auto out_dim = weight->shape[0]; - auto k = tvm::reduce_axis(Range(0, in_dim), "k"); + auto k = tvm::top::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::top::compute( { batch, out_dim }, [&](Var i, Var j) { diff --git a/topi/include/topi/nn/dense.h b/topi/include/topi/nn/dense.h index 7cdc8d7b5be3..60e378a4ee03 100644 --- a/topi/include/topi/nn/dense.h +++ b/topi/include/topi/nn/dense.h @@ -58,7 +58,7 @@ inline tvm::top::Tensor dense(const tvm::top::Tensor& data, auto in_dim = data->shape[1]; auto out_dim = weight->shape[0]; - auto k = tvm::reduce_axis(Range(0, in_dim), "k"); + auto k = tvm::top::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::top::compute( { batch, out_dim }, [&](Var i, Var j) { diff --git a/topi/include/topi/nn/dilate.h b/topi/include/topi/nn/dilate.h index 6ffb3da25e52..afeeecb2bdb5 100644 --- a/topi/include/topi/nn/dilate.h +++ b/topi/include/topi/nn/dilate.h @@ -27,7 +27,7 @@ #include #include "tvm/top/operation.h" -#include "tvm/ir_pass.h" +#include "tvm/tir/ir_pass.h" #include "topi/tags.h" namespace topi { @@ -76,7 +76,7 @@ inline Tensor dilate(const Tensor& x, Array out_shape; for (size_t i = 0; i < n; ++i) { - out_shape.push_back(tvm::ir::Simplify( + out_shape.push_back(tvm::tir::Simplify( (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); } diff --git a/topi/include/topi/nn/flatten.h b/topi/include/topi/nn/flatten.h index de11b6dd797e..a3e47b73530d 100644 --- a/topi/include/topi/nn/flatten.h +++ b/topi/include/topi/nn/flatten.h @@ -30,7 +30,7 @@ #include "topi/tags.h" #include "topi/detail/constant_utils.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { diff --git a/topi/include/topi/nn/local_response_norm.h b/topi/include/topi/nn/local_response_norm.h index cd3b9b2456e1..4766ee28becb 100644 --- a/topi/include/topi/nn/local_response_norm.h +++ b/topi/include/topi/nn/local_response_norm.h @@ -65,7 +65,7 @@ inline Tensor lrn(const Tensor& data, pad_before.Set(axis, static_cast(size/2)); pad_after.Set(axis, static_cast(size/2)); auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); - auto rxs = tvm::reduce_axis(Range(0, size), "rxs"); + auto rxs = tvm::top::reduce_axis(Range(0, size), "rxs"); Tensor sqr_sum; if (axis == 1) { sqr_sum = tvm::top::compute(input_shape, diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index ac284a0e01fe..86f797f2e782 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -32,7 +32,7 @@ #include "topi/nn.h" #include "topi/reduction.h" #include "topi/tags.h" -#include "tvm/ir_pass.h" +#include "tvm/tir/ir_pass.h" namespace topi { namespace nn { @@ -103,13 +103,13 @@ inline Tensor pool_impl(const Tensor& x, pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); - auto out_height = tvm::ir::Simplify( + auto out_height = tvm::tir::Simplify( indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); - auto out_width = tvm::ir::Simplify( + auto out_width = tvm::tir::Simplify( indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); - auto dheight = tvm::reduce_axis(Range(0, kernel_height)); - auto dwidth = tvm::reduce_axis(Range(0, kernel_width)); + auto dheight = tvm::top::reduce_axis(Range(0, kernel_height)); + auto dwidth = tvm::top::reduce_axis(Range(0, kernel_width)); Array out_shape = x->shape; out_shape.Set(height_axis, out_height); @@ -156,11 +156,11 @@ inline Tensor pool_impl(const Tensor& x, } else { PrimExpr h_start = output[height_axis] * stride_height - pad_top; PrimExpr w_start = output[width_axis] * stride_width - pad_left; - PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width); - h_start = ir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = ir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); - PrimExpr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start), + PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); + h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); + PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), make_const(DataType::DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } @@ -214,12 +214,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, pad_after.Set(width_axis, pad_right); auto out_height = - tvm::ir::Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); + tvm::tir::Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); auto out_width = - tvm::ir::Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); + tvm::tir::Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); - auto dheight = tvm::reduce_axis(Range(0, kernel_height)); - auto dwidth = tvm::reduce_axis(Range(0, kernel_width)); + auto dheight = tvm::top::reduce_axis(Range(0, kernel_height)); + auto dwidth = tvm::top::reduce_axis(Range(0, kernel_width)); Array out_shape = x->shape; out_shape.Set(height_axis, out_height); @@ -237,23 +237,26 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); - auto windowh = tvm::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = tvm::top::reduce_axis( + Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::top::reduce_axis( + Range(0, (kernel_width + stride_width - 1) / stride_width)); auto argmax = MakeArgmaxReducer(); auto pad_x = do_pad ? pad( x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; auto mp_argmax = - tvm::top::compute(out_shape, - [&](const Array& inds) { - Array window_inds{inds.begin(), inds.end()}; - window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); - window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); - auto idx = detail::RavelIndex(window_inds, ravel_shape); - return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); - }, - "maxpool_grad_argmax", kCommReduceIdx); + tvm::top::compute( + out_shape, + [&](const Array& inds) { + Array window_inds{inds.begin(), inds.end()}; + window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); + window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); + auto idx = detail::RavelIndex(window_inds, ravel_shape); + return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); + }, + "maxpool_grad_argmax", kCommReduceIdx); auto mp_inds = mp_argmax[0]; @@ -269,16 +272,16 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); - PrimExpr out_idx_lower_h = ir::SelectNode::make( + PrimExpr out_idx_lower_h = tir::SelectNode::make( pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = ir::SelectNode::make( + PrimExpr out_idx_lower_w = tir::SelectNode::make( pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( - tvm::if_then_else(ir::AndNode::make( - ir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, + tvm::if_then_else(tir::AndNode::make( + tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, out_idx[width_axis] >= out_idx_lower_w), mp_inds(out_idx) == idx), out_grad(out_idx), make_const(x->dtype, 0)), @@ -286,8 +289,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, }, "T_pool_grad", "pool_grad_max"); } else if (pool_type == kAvgPool) { - auto windowh = tvm::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = tvm::top::reduce_axis( + Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::top::reduce_axis( + Range(0, (kernel_width + stride_width - 1) / stride_width)); return tvm::top::compute( x->shape, [&](const Array& inds) { @@ -299,10 +304,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); - PrimExpr out_idx_lower_h = ir::SelectNode::make( + PrimExpr out_idx_lower_h = tir::SelectNode::make( pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), (pad_h_idx - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = ir::SelectNode::make( + PrimExpr out_idx_lower_w = tir::SelectNode::make( pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), (pad_w_idx - kernel_width) / stride_width + 1); @@ -312,19 +317,19 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, } else { PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top; PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left; - PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width); - h_start = ir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); - w_start = ir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); + PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); + h_start = tir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); + w_start = tir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); divide_factor = - ir::MaxNode::make((h_end - h_start) * (w_end - w_start), + tir::MaxNode::make((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1)); } return tvm::sum(tvm::if_then_else( - ir::AndNode::make( - ir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, + tir::AndNode::make( + tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, out_idx[height_axis] < out_height), - ir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, + tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, out_idx[width_axis] < out_width)), out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), {windowh, windoww}); @@ -481,7 +486,7 @@ inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) { PrimExpr tmp = indexdiv((out_index + 1) * idim, odim); - return tvm::ir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0, + return tvm::tir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1); } @@ -520,8 +525,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x, auto i_end_h = end_index(output[height_axis], out_height, height); auto i_start_w = start_index(output[width_axis], out_width, width); auto i_end_w = end_index(output[width_axis], out_width, width); - auto dheight = tvm::reduce_axis(Range(0, i_end_h - i_start_h), "rv1"); - auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2"); + auto dheight = tvm::top::reduce_axis(Range(0, i_end_h - i_start_h), "rv1"); + auto dwidth = tvm::top::reduce_axis(Range(0, i_end_w - i_start_w), "rv2"); indices.Set(height_axis, i_start_h + dheight); indices.Set(width_axis, i_start_w + dwidth); return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*) @@ -536,8 +541,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x, auto i_end_w = end_index(output[width_axis], out_width, width); auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h) * (i_end_w - i_start_w)); - auto dheight = tvm::reduce_axis(Range(0, i_end_h - i_start_h), "rv1"); - auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2"); + auto dheight = tvm::top::reduce_axis(Range(0, i_end_h - i_start_h), "rv1"); + auto dwidth = tvm::top::reduce_axis(Range(0, i_end_w - i_start_w), "rv2"); indices.Set(height_axis, i_start_h + dheight); indices.Set(width_axis, i_start_w + dwidth); return tvm::sum(x(indices), { dheight, dwidth }); @@ -683,12 +688,12 @@ inline Tensor pool_impl_nd(const Tensor& x, pad_tail[i] += stride[i] - 1; } - daxis.push_back(tvm::reduce_axis(Range(0, kernel[i]))); + daxis.push_back(tvm::top::reduce_axis(Range(0, kernel[i]))); pad_before.Set(ii, pad_head[i]); pad_after.Set(ii, pad_tail[i]); - auto out_dim = tvm::ir::Simplify( + auto out_dim = tvm::tir::Simplify( indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); out_shape.Set(ii, out_dim); @@ -743,12 +748,12 @@ inline Tensor pool_impl_nd(const Tensor& x, for (int i = 0; i < k_size; i++) { int ii = axis[i]; start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = ir::MinNode::make(start[i] + kernel[i], x->shape[ii]); - start[i] = ir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); + end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); + start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); kernel_size *= (end[i] - start[i]); } - PrimExpr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); + PrimExpr divide_factor = tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } }, "tensor", kElementWise); diff --git a/topi/include/topi/nn/softmax.h b/topi/include/topi/nn/softmax.h index 72e17454b724..9cdc20de1ef7 100644 --- a/topi/include/topi/nn/softmax.h +++ b/topi/include/topi/nn/softmax.h @@ -30,7 +30,7 @@ #include "topi/reduction.h" #include "topi/tags.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { namespace nn { @@ -58,8 +58,8 @@ inline Tensor softmax(const Tensor &x, } CHECK_LT(axis, ndim) << "axis parameter should be less than input dim"; - auto k1 = tvm::reduce_axis(Range(0, input_shape[axis]), "k1"); - auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2"); + auto k1 = tvm::top::reduce_axis(Range(0, input_shape[axis]), "k1"); + auto k2 = tvm::top::reduce_axis(Range(0, input_shape[axis]), "k2"); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); tvm::Map attrs; @@ -139,11 +139,11 @@ inline Tensor log_softmax(const Tensor& x, PrimExpr m = x->shape[0]; PrimExpr n = x->shape[1]; - auto k = tvm::reduce_axis(Range(0, n), "k"); + auto k = tvm::top::reduce_axis(Range(0, n), "k"); auto max_elem = tvm::top::compute( { m }, [&](Var i) { return tvm::max(x(i, k), Array{ k }); }); - k = tvm::reduce_axis(Range(0, n), "k"); + k = tvm::top::reduce_axis(Range(0, n), "k"); auto expsum = tvm::top::compute( { m }, [&](Var i) { diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index 197ef2b152c4..cb09990c0aed 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -36,7 +36,7 @@ #include "topi/detail/ravel_unravel.h" #include "topi/detail/constant_utils.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { @@ -92,7 +92,7 @@ inline Array MakeReduceAxes(const std::vector& real_axis, const Te for (auto i : real_axis) { std::string name = "k" + std::to_string(i); reduce_axes.push_back( - tvm::reduce_axis(Range(0, data->shape[i]), name)); + tvm::top::reduce_axis(Range(0, data->shape[i]), name)); } return reduce_axes; } @@ -295,13 +295,13 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, auto result = fcombine(lhs, rhs); auto id_elem = fidentity(dtypes); - auto cond = condition != nullptr ? *condition : tvm::const_true(); + auto cond = condition != nullptr ? *condition : tir::const_true(); - auto combiner = tvm::ir::CommReducerNode::make(lhs, rhs, result, id_elem); + auto combiner = tvm::tir::CommReducerNode::make(lhs, rhs, result, id_elem); Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { outputs.push_back( - tvm::ir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); + tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); } return outputs; }; @@ -474,13 +474,13 @@ inline Tensor argmin(const Tensor& data, bool atleast1d = false) { auto fcombine = [](Array lhs, Array rhs) { Array result; - result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val + result.push_back(tvm::tir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { Array result; - result.push_back(tvm::make_const(types[0], -1)); // idx + result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::max_value(types[1])); // val return result; }; @@ -491,13 +491,13 @@ inline Tensor argmin(const Tensor& data, inline FCommReduce MakeArgmaxReducer() { auto fcombine = [](Array lhs, Array rhs) { Array result; - result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val + result.push_back(tvm::tir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { Array result; - result.push_back(tvm::make_const(types[0], -1)); // idx + result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val return result; }; diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 41a64ebb45ad..9a6d82aec6ca 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -36,8 +36,8 @@ #include "topi/detail/constant_utils.h" #include "topi/detail/tensor_utils.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" -#include "tvm/data_layout.h" +#include "tvm/tir/op.h" +#include "tvm/tir/data_layout.h" namespace topi { using namespace tvm; @@ -333,7 +333,7 @@ inline Tensor concatenate(const Array& inputs, for (size_t i = 1; i < axis_sizes.size(); ++i) { join_size += axis_sizes[i]; } - join_size = tvm::ir::Simplify(join_size); + join_size = tvm::tir::Simplify(join_size); Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); @@ -709,7 +709,7 @@ inline Tensor sequence_mask(const Tensor& data, len_index.push_back(bid); PrimExpr ret = tvm::if_then_else( tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::make_const(data->dtype, mask_value), data(out_index)); + tvm::tir::make_const(data->dtype, mask_value), data(out_index)); return ret; }, name, tag); return out; @@ -842,7 +842,7 @@ inline Tensor where(const Tensor& condition, << condition->shape.size() << " vs " << x->shape.size(); out = compute( oshape, [&](const Array& indices) { - return tvm::ir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); + return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); }, name, tag); } else { CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) @@ -851,7 +851,7 @@ inline Tensor where(const Tensor& condition, out = compute( oshape, [&](const Array& indices) { Array condition_idx{indices[0]}; - return tvm::ir::SelectNode::make(condition(condition_idx) != 0, + return tvm::tir::SelectNode::make(condition(condition_idx) != 0, x(indices), y(indices)); }, name, tag); } @@ -1050,8 +1050,8 @@ inline tvm::top::Tensor matmul(const tvm::top::Tensor& A, std::string tag = kMatMul) { tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; - auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); - auto l = [&](tvm::Var i, tvm::Var j) { + auto k = tvm::top::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); + auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); }; @@ -1318,7 +1318,7 @@ inline Tensor one_hot(const Tensor& indices, } auto idx = iter_vars[true_axis]; - return ir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast); + return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast); }, name, tag); } diff --git a/topi/include/topi/vision/reorg.h b/topi/include/topi/vision/reorg.h index c5ddea9fdeb0..4722c1f42dd2 100644 --- a/topi/include/topi/vision/reorg.h +++ b/topi/include/topi/vision/reorg.h @@ -32,7 +32,7 @@ #include "topi/tags.h" #include "topi/transform.h" #include "tvm/top/operation.h" -#include "tvm/expr_operator.h" +#include "tvm/tir/op.h" namespace topi { namespace vision {