From 483a8c3234cc104e59f3c178ebc77c4814ee153d Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Wed, 23 Aug 2023 15:13:28 +0800 Subject: [PATCH] [Arith] Add tvm::arith::PresburgerSetNode to work with Presburger Set in MLIR (#14690) [Arith] Add IntegerSetNode to represent Presburger Set Co-authored-by: MinChen --- cmake/config.cmake | 4 + cmake/modules/LLVM.cmake | 3 + cmake/utils/FindLLVM.cmake | 10 ++ python/tvm/arith/__init__.py | 1 + python/tvm/arith/int_set.py | 8 + src/arith/presburger_set.cc | 243 ++++++++++++++++++++++++++++ src/arith/presburger_set.h | 194 ++++++++++++++++++++++ tests/cpp/arith_integer_set_test.cc | 41 +++++ 8 files changed, 504 insertions(+) create mode 100644 src/arith/presburger_set.cc create mode 100644 src/arith/presburger_set.h create mode 100644 tests/cpp/arith_integer_set_test.cc diff --git a/cmake/config.cmake b/cmake/config.cmake index 8a7a0f1fdd29..35d1c23fc800 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -144,6 +144,10 @@ set(USE_MICRO_STANDALONE_RUNTIME OFF) # - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is available. set(USE_LLVM OFF) +# Whether use MLIR to help analyze, requires USE_LLVM is enabled +# Possible values: ON/OFF +set(USE_MLIR OFF) + #--------------------------------------------- # Contrib libraries #--------------------------------------------- diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index de5bcb669671..6c21356ae880 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -35,6 +35,9 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION}) # Set flags that are only needed for LLVM target add_definitions(-DTVM_LLVM_VERSION=${TVM_LLVM_VERSION}) + if (${TVM_MLIR_VERSION}) + add_definitions(-DTVM_MLIR_VERSION=${TVM_MLIR_VERSION}) + endif() tvm_file_glob(GLOB COMPILER_LLVM_SRCS src/target/llvm/*.cc) list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS}) list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS}) diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index 2c70a3ae4adf..f10e5f1eb8da 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -143,6 +143,16 @@ macro(find_llvm use_llvm) string(REPLACE "$" ${__llvm_prefix} __lib_with_prefix "${__flag}") list(APPEND LLVM_LIBS "${__lib_with_prefix}") endforeach() + if (${USE_MLIR}) + if (EXISTS "${__llvm_libdir}/libMLIRPresburger.a") + if (EXISTS "${__llvm_libdir}/libMLIRSupport.a") + message(STATUS "Found MLIR") + list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a") + list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a") + set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION}) + endif() + endif() + endif() separate_arguments(__llvm_system_libs) foreach(__flag IN ITEMS ${__llvm_system_libs}) if("${__flag}" STREQUAL "-lm") diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 30fd86b0375d..87801fd781b1 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -19,6 +19,7 @@ from .int_set import ( IntSet, IntervalSet, + PresburgerSet, estimate_region_lower_bound, estimate_region_strict_bound, estimate_region_upper_bound, diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index 151461bcaf9f..d38f5e805f39 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -81,6 +81,14 @@ def __init__(self, min_value, max_value): self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value) +@tvm._ffi.register_object("arith.PresburgerSet") +class PresburgerSet(IntSet): + """Represent of Presburger Set""" + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.PresburgerSet) + + def estimate_region_lower_bound(region, var_dom, predicate): """Analyze the region with affine map, given the domain of variables and their predicate Some subregion may be discarded during the lower-bound analysis. diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc new file mode 100644 index 000000000000..f1d86c861a59 --- /dev/null +++ b/src/arith/presburger_set.cc @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file presburger_set.cc + * \brief The presburger set functions + */ +#include "presburger_set.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "constraint_extract.h" +#include "interval_set.h" + +namespace tvm { +namespace arith { + +#ifdef TVM_MLIR_VERSION +#if TVM_MLIR_VERSION >= 150 +using namespace tir; + +static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) { + auto& space = intset->space; + auto constraints_union = ExtractComponents(constraint); + for (const PrimExpr& subconstraint : constraints_union) { + auto entries = ExtractConstraints(subconstraint, false); + auto vars = intset->GetVars(); + IntegerRelation disjunct(entries.size(), 0, vars.size() + 1, space); + for (const PrimExpr& entry : entries) { + // The expression is expect to be simplified to only contain ==, <= or < + if (entry.as()) { + auto coeffs_a = DetectLinearEquation(entry.as()->a, vars); + auto coeffs_b = DetectLinearEquation(entry.as()->b, vars); + std::vector int_coeffs; + for (size_t i = 0; i < coeffs_a.size(); i++) { + int_coeffs.push_back(*as_const_int(coeffs_b[i]) - *as_const_int(coeffs_a[i])); + } + disjunct.addInequality(int_coeffs); + } else if (entry.as()) { + auto coeffs_a = DetectLinearEquation(entry.as()->a, vars); + auto coeffs_b = DetectLinearEquation(entry.as()->b, vars); + std::vector int_coeffs; + for (size_t i = 0; i < coeffs_a.size(); i++) { + int_coeffs.push_back(*as_const_int(coeffs_b[i]) - *as_const_int(coeffs_a[i])); + } + int_coeffs[int_coeffs.size() - 1] -= 1; + disjunct.addInequality(int_coeffs); + } else if (entry.as()) { + auto coeffs_a = DetectLinearEquation(entry.as()->a, vars); + auto coeffs_b = DetectLinearEquation(entry.as()->b, vars); + std::vector int_coeffs; + for (size_t i = 0; i < coeffs_a.size(); i++) { + int_coeffs.push_back(*as_const_int(coeffs_a[i]) - *as_const_int(coeffs_b[i])); + } + disjunct.addEquality(int_coeffs); + } else { + LOG(FATAL) << "Unsupported constraint expression: " << entry->GetTypeKey(); + } + } + intset->unionInPlace(disjunct); + } +} + +PresburgerSet::PresburgerSet(const PrimExpr& constraint) { + Array vars; + PostOrderVisit(constraint, [&vars](const ObjectRef& obj) { + if (const VarNode* new_var = obj.as()) { + auto var = GetRef(new_var); + if (!std::any_of(vars.begin(), vars.end(), [&var](const Var& v) { return v.same_as(var); })) { + vars.push_back(var); + } + } + }); + auto constraints_union = ExtractComponents(constraint); + Analyzer analyzer; + PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); + auto space = PresburgerSpace::getRelationSpace(vars.size(), 0, 0, 0); + auto node = make_object(std::move(space), vars); + node->SetVars(vars); + Update(simplified_constraint, node.get()); + data_ = std::move(node); +} + +PresburgerSet::PresburgerSet(const std::vector& disjuncts, + const Array& vars) { + auto node = make_object(disjuncts, disjuncts[0].getSpace(), vars); + data_ = std::move(node); +} + +void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const Array& vars) { + Analyzer analyzer; + PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); + Update(simplified_constraint, this); + SetVars(vars); +} + +PrimExpr PresburgerSetNode::GenerateConstraint() const { + PrimExpr constraint = Bool(0); + for (const IntegerRelation& disjunct : disjuncts) { + PrimExpr union_entry = Bool(1); + for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) { + PrimExpr linear_eq = IntImm(DataType::Int(32), 0); + if (disjunct.getNumCols() > 1) { + for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) { + auto coeff = disjunct.atEq(i, j); + if (coeff >= 0 || is_zero(linear_eq)) { + linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j]; + } else { + linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j]; + } + } + } + auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1); + linear_eq = linear_eq + IntImm(DataType::Int(32), c0); + union_entry = (union_entry && (linear_eq == 0)); + } + for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) { + PrimExpr linear_eq = IntImm(DataType::Int(32), 0); + if (disjunct.getNumCols() > 1) { + for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) { + auto coeff = disjunct.atIneq(i, j); + if (coeff >= 0 || is_zero(linear_eq)) { + linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j]; + } else { + linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j]; + } + } + } + auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1); + if (c0 >= 0) { + linear_eq = linear_eq + IntImm(DataType::Int(32), c0); + } else { + linear_eq = linear_eq - IntImm(DataType::Int(32), -c0); + } + union_entry = (union_entry && (linear_eq >= 0)); + } + constraint = constraint || union_entry; + } + + return constraint; +} + +PresburgerSet Union(const Array& sets) { + CHECK_GT(sets.size(), 0); + if (sets.size() == 1) return sets[0]; + auto relations = sets[0]->disjuncts; + for (size_t i = 1; i < sets.size(); ++i) { + for (const IntegerRelation& rel : sets[i]->disjuncts) { + relations.push_back(rel); + } + } + return PresburgerSet(std::move(relations), sets[0]->GetVars()); +} + +PresburgerSet Intersect(const Array& sets) { + CHECK_GT(sets.size(), 0); + if (sets.size() == 1) return sets[0]; + auto relations = sets[0]->disjuncts; + const auto& space = sets[0]->space; + + for (size_t i = 1; i < sets.size(); ++i) { + ICHECK(space.isCompatible(sets[i]->space)) << "Spaces should match"; + for (const IntegerRelation& relA : sets[i]->disjuncts) { + for (const IntegerRelation& relB : relations) { + IntegerRelation intersection = relA.intersect(relB); + if (!intersection.isEmpty()) relations.push_back(intersection); + } + } + } + return PresburgerSet(std::move(relations), sets[0]->GetVars()); +} + +IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { + Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); + SmallVector coeffs; + coeffs.reserve(tvm_coeffs.size()); + for (const PrimExpr& it : tvm_coeffs) { + coeffs.push_back(*as_const_int(it)); + } + + IntSet result = IntSet().Nothing(); + for (const IntegerRelation& it : set->disjuncts) { + Simplex simplex(it); + auto range = simplex.computeIntegerBounds(coeffs); + auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs)); + auto opt = range.first.getOptimumIfBounded(); + auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf(); + opt = range.second.getOptimumIfBounded(); + auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : pos_inf(); + auto interval = IntervalSet(min, max); + result = Union({result, interval}); + } + return result; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto set = node.as(); + ICHECK(ret) << "Unknown type:" << node->GetTypeKey(); + p->stream << "{"; + p->stream << set->GetVars() << ": "; + p->stream << node.as()->GenerateConstraint(); + p->stream << "}"; + }); + +#endif // TVM_MLIR_VERSION >= 150 +#endif // TVM_MLIR_VERSION + +PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); } + +TVM_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet); + +TVM_REGISTER_NODE_TYPE(PresburgerSetNode); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/presburger_set.h b/src/arith/presburger_set.h new file mode 100644 index 000000000000..d580e23a6d5a --- /dev/null +++ b/src/arith/presburger_set.h @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file presburger_set.h + * \brief Integer set based on MLIR Presburger set + */ +#ifndef TVM_ARITH_PRESBURGER_SET_H_ +#define TVM_ARITH_PRESBURGER_SET_H_ + +#ifdef TVM_MLIR_VERSION +#if TVM_MLIR_VERSION >= 150 +#include +#include +#include +#endif +#endif + +#include +#include + +#include +#include + +#include "const_fold.h" + +namespace tvm { +namespace arith { + +#ifdef TVM_MLIR_VERSION +#if TVM_MLIR_VERSION >= 150 +using namespace mlir; +using namespace presburger; + +// Acknowledgement: PresburgerSet is based on Presburger set of MLIR. +/*! + * \brief Symbolic integer set. + * + * \note PresburgerSet aims to provide compatible APIs with IntSet, + * and some additional APIs that analyze and solve + * multi-dimension interger set problems + */ +class PresburgerSetNode : public IntSetNode { + public: + PresburgerSetNode() : space(PresburgerSpace::getRelationSpace()) {} + explicit PresburgerSetNode(const PresburgerSpace& space, const Array& vars) + : disjuncts({}), space(space), vars(vars) {} + explicit PresburgerSetNode(const std::vector& disjuncts, + const PresburgerSpace& space, const Array& vars) + : disjuncts(disjuncts), space(space), vars(vars) {} + + /*! \brief Represent the union of multiple IntegerRelation */ + std::vector disjuncts; + /*! \brief The space of all the disjuncts */ + PresburgerSpace space; + + // visitor overload. + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Do inplace union with given disjunct + * \param disjunct The given disjunct to be union with + */ + void unionInPlace(const IntegerRelation& disjunct) { + assert(space.isCompatible(disjunct.getSpace()) && "Spaces should match"); + disjuncts.push_back(disjunct); + } + + /*! + * \brief Update int set with given constraint + * \param constraint The added constraint to the PresburgerSet. + * \param vars The specified domain vars in constraint expression. + */ + void UpdateConstraint(const PrimExpr& constraint, const Array& vars); + + /*! + * \brief Generate expression that represents the constraint + * \return The generated expression + */ + PrimExpr GenerateConstraint() const; + + /*! + * \brief Set domain vars + * \param new_vars Vars that will be taken as the domain vars + */ + void SetVars(const Array& new_vars) { vars = new_vars; } + + /*! + * \brief Get the current domain vars + * \return The current doamin vars + */ + Array GetVars() const { return vars; } + + /*! \return whether integer set is empty */ + bool IsEmpty() const { + return std::all_of(disjuncts.begin(), disjuncts.end(), + std::mem_fn(&IntegerRelation::isIntegerEmpty)); + } + + static constexpr const char* _type_key = "arith.PresburgerSet"; + TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode); + + private: + Array vars; +}; + +/*! + * \brief Integer set used for multi-dimension integer analysis. + * \sa PresburgerSetNode + */ +class PresburgerSet : public IntSet { + public: + /*! + * \brief Make a new instance of PresburgerSet. + * \param disjuncts The disjunts to construct the set. + * \param vars The variables that the constraint describes about. + * \return The created PresburgerSet. + */ + TVM_DLL PresburgerSet(const std::vector& disjuncts, const Array& vars); + + /*! + * \brief Make a new instance of PresburgerSet, collect all vars as space vars. + * \param constraint The constraint to construct the set. + * \return The created PresburgerSet. + */ + TVM_DLL PresburgerSet(const PrimExpr& constraint); + + TVM_DEFINE_OBJECT_REF_COW_METHOD(PresburgerSetNode); + TVM_DEFINE_OBJECT_REF_METHODS(PresburgerSet, IntSet, PresburgerSetNode); +}; +#endif // TVM_MLIR_VERSION >= 150 +#else // TVM_MLIR_VERSION +// Class definition when MLIR is not enabled, to prevent compile-time error. +class PresburgerSetNode : public IntSetNode { + public: + // dummy visitor overload. + void VisitAttrs(tvm::AttrVisitor* v) { LOG(FATAL) << "MLIR is not enabled!"; } + + static constexpr const char* _type_key = "arith.PresburgerSet"; + TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode); +}; + +class PresburgerSet : public IntSet { + public: + /*! + * \brief Constructor interface to prompt when MLIR is not enabled. + * \param constraint The constraint to construct the set. + * \return The created set. + */ + TVM_DLL PresburgerSet(const PrimExpr& constraint) { LOG(FATAL) << "MLIR is not enabled!"; } +}; +#endif // TVM_MLIR_VERSION +/*! + * \brief Create a union set of all sets + * \param sets The sets to be combined + * \return the set after union + */ +PresburgerSet Union(const Array& sets); + +/*! + * \brief Create an intersected set of all sets + * \param sets The sets to be intersected + * \return The intersect set + */ +PresburgerSet Intersect(const Array& sets); + +/*! + * \brief Evaluate the range of given expression based on the constraint + * in given PresburgerSet + * \param e The target expresision to be evaluated. + * \param set The PresburgerSet defining the constraint. + */ +IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITH_PRESBURGER_SET_H_ diff --git a/tests/cpp/arith_integer_set_test.cc b/tests/cpp/arith_integer_set_test.cc new file mode 100644 index 000000000000..04546abba9a6 --- /dev/null +++ b/tests/cpp/arith_integer_set_test.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if TVM_MLIR_VERSION >= 150 +#include +#include +#include +#include + +#include "../src/arith/presburger_set.h" + +TEST(PresburgerSet, eval) { + auto x = tvm::tir::Var("x"); + auto y = tvm::tir::Var("y"); + auto sub_constraint0 = (x + y < 20) && (x - y <= 0); + auto sub_constraint1 = x >= 0 && x < 20 && y >= 0 && y < 20; + auto constraint = sub_constraint0 && sub_constraint1; + auto set = tvm::arith::PresburgerSet(constraint); + + auto target = x + 2 * y; + auto result = EvalSet(target, set); + ASSERT_TRUE(tvm::tir::is_zero(result.min())); + ASSERT_TRUE(tvm::tir::is_const_int(result.max(), 38)); +} +#endif // TVM_MLIR_VERSION