Skip to content

Commit

Permalink
[Arith] Add tvm::arith::PresburgerSetNode to work with Presburger Set…
Browse files Browse the repository at this point in the history
… in MLIR (#14690)

[Arith] Add IntegerSetNode to represent Presburger Set

Co-authored-by: MinChen <[email protected]>
  • Loading branch information
multiverstack-intellif and MinChen authored Aug 23, 2023
1 parent 32658f8 commit 483a8c3
Show file tree
Hide file tree
Showing 8 changed files with 504 additions and 0 deletions.
4 changes: 4 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ set(USE_MICRO_STANDALONE_RUNTIME OFF)
# - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is available.
set(USE_LLVM OFF)

# Whether use MLIR to help analyze, requires USE_LLVM is enabled
# Possible values: ON/OFF
set(USE_MLIR OFF)

#---------------------------------------------
# Contrib libraries
#---------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions cmake/modules/LLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN})
message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
# Set flags that are only needed for LLVM target
add_definitions(-DTVM_LLVM_VERSION=${TVM_LLVM_VERSION})
if (${TVM_MLIR_VERSION})
add_definitions(-DTVM_MLIR_VERSION=${TVM_MLIR_VERSION})
endif()
tvm_file_glob(GLOB COMPILER_LLVM_SRCS src/target/llvm/*.cc)
list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS})
list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS})
Expand Down
10 changes: 10 additions & 0 deletions cmake/utils/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ macro(find_llvm use_llvm)
string(REPLACE "$" ${__llvm_prefix} __lib_with_prefix "${__flag}")
list(APPEND LLVM_LIBS "${__lib_with_prefix}")
endforeach()
if (${USE_MLIR})
if (EXISTS "${__llvm_libdir}/libMLIRPresburger.a")
if (EXISTS "${__llvm_libdir}/libMLIRSupport.a")
message(STATUS "Found MLIR")
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
endif()
endif()
endif()
separate_arguments(__llvm_system_libs)
foreach(__flag IN ITEMS ${__llvm_system_libs})
if("${__flag}" STREQUAL "-lm")
Expand Down
1 change: 1 addition & 0 deletions python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .int_set import (
IntSet,
IntervalSet,
PresburgerSet,
estimate_region_lower_bound,
estimate_region_strict_bound,
estimate_region_upper_bound,
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)


@tvm._ffi.register_object("arith.PresburgerSet")
class PresburgerSet(IntSet):
"""Represent of Presburger Set"""

def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.PresburgerSet)


def estimate_region_lower_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
Some subregion may be discarded during the lower-bound analysis.
Expand Down
243 changes: 243 additions & 0 deletions src/arith/presburger_set.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file presburger_set.cc
* \brief The presburger set functions
*/
#include "presburger_set.h"

#include <tvm/arith/int_set.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/pattern.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>

#include "constraint_extract.h"
#include "interval_set.h"

namespace tvm {
namespace arith {

#ifdef TVM_MLIR_VERSION
#if TVM_MLIR_VERSION >= 150
using namespace tir;

static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) {
auto& space = intset->space;
auto constraints_union = ExtractComponents(constraint);
for (const PrimExpr& subconstraint : constraints_union) {
auto entries = ExtractConstraints(subconstraint, false);
auto vars = intset->GetVars();
IntegerRelation disjunct(entries.size(), 0, vars.size() + 1, space);
for (const PrimExpr& entry : entries) {
// The expression is expect to be simplified to only contain ==, <= or <
if (entry.as<LENode>()) {
auto coeffs_a = DetectLinearEquation(entry.as<LENode>()->a, vars);
auto coeffs_b = DetectLinearEquation(entry.as<LENode>()->b, vars);
std::vector<int64_t> int_coeffs;
for (size_t i = 0; i < coeffs_a.size(); i++) {
int_coeffs.push_back(*as_const_int(coeffs_b[i]) - *as_const_int(coeffs_a[i]));
}
disjunct.addInequality(int_coeffs);
} else if (entry.as<LTNode>()) {
auto coeffs_a = DetectLinearEquation(entry.as<LTNode>()->a, vars);
auto coeffs_b = DetectLinearEquation(entry.as<LTNode>()->b, vars);
std::vector<int64_t> int_coeffs;
for (size_t i = 0; i < coeffs_a.size(); i++) {
int_coeffs.push_back(*as_const_int(coeffs_b[i]) - *as_const_int(coeffs_a[i]));
}
int_coeffs[int_coeffs.size() - 1] -= 1;
disjunct.addInequality(int_coeffs);
} else if (entry.as<EQNode>()) {
auto coeffs_a = DetectLinearEquation(entry.as<EQNode>()->a, vars);
auto coeffs_b = DetectLinearEquation(entry.as<EQNode>()->b, vars);
std::vector<int64_t> int_coeffs;
for (size_t i = 0; i < coeffs_a.size(); i++) {
int_coeffs.push_back(*as_const_int(coeffs_a[i]) - *as_const_int(coeffs_b[i]));
}
disjunct.addEquality(int_coeffs);
} else {
LOG(FATAL) << "Unsupported constraint expression: " << entry->GetTypeKey();
}
}
intset->unionInPlace(disjunct);
}
}

PresburgerSet::PresburgerSet(const PrimExpr& constraint) {
Array<Var> vars;
PostOrderVisit(constraint, [&vars](const ObjectRef& obj) {
if (const VarNode* new_var = obj.as<VarNode>()) {
auto var = GetRef<Var>(new_var);
if (!std::any_of(vars.begin(), vars.end(), [&var](const Var& v) { return v.same_as(var); })) {
vars.push_back(var);
}
}
});
auto constraints_union = ExtractComponents(constraint);
Analyzer analyzer;
PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite);
auto space = PresburgerSpace::getRelationSpace(vars.size(), 0, 0, 0);
auto node = make_object<PresburgerSetNode>(std::move(space), vars);
node->SetVars(vars);
Update(simplified_constraint, node.get());
data_ = std::move(node);
}

PresburgerSet::PresburgerSet(const std::vector<IntegerRelation>& disjuncts,
const Array<Var>& vars) {
auto node = make_object<PresburgerSetNode>(disjuncts, disjuncts[0].getSpace(), vars);
data_ = std::move(node);
}

void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const Array<Var>& vars) {
Analyzer analyzer;
PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite);
Update(simplified_constraint, this);
SetVars(vars);
}

PrimExpr PresburgerSetNode::GenerateConstraint() const {
PrimExpr constraint = Bool(0);
for (const IntegerRelation& disjunct : disjuncts) {
PrimExpr union_entry = Bool(1);
for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
auto coeff = disjunct.atEq(i, j);
if (coeff >= 0 || is_zero(linear_eq)) {
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
}
}
}
auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1);
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
union_entry = (union_entry && (linear_eq == 0));
}
for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) {
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
auto coeff = disjunct.atIneq(i, j);
if (coeff >= 0 || is_zero(linear_eq)) {
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
}
}
}
auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1);
if (c0 >= 0) {
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -c0);
}
union_entry = (union_entry && (linear_eq >= 0));
}
constraint = constraint || union_entry;
}

return constraint;
}

PresburgerSet Union(const Array<PresburgerSet>& sets) {
CHECK_GT(sets.size(), 0);
if (sets.size() == 1) return sets[0];
auto relations = sets[0]->disjuncts;
for (size_t i = 1; i < sets.size(); ++i) {
for (const IntegerRelation& rel : sets[i]->disjuncts) {
relations.push_back(rel);
}
}
return PresburgerSet(std::move(relations), sets[0]->GetVars());
}

PresburgerSet Intersect(const Array<PresburgerSet>& sets) {
CHECK_GT(sets.size(), 0);
if (sets.size() == 1) return sets[0];
auto relations = sets[0]->disjuncts;
const auto& space = sets[0]->space;

for (size_t i = 1; i < sets.size(); ++i) {
ICHECK(space.isCompatible(sets[i]->space)) << "Spaces should match";
for (const IntegerRelation& relA : sets[i]->disjuncts) {
for (const IntegerRelation& relB : relations) {
IntegerRelation intersection = relA.intersect(relB);
if (!intersection.isEmpty()) relations.push_back(intersection);
}
}
}
return PresburgerSet(std::move(relations), sets[0]->GetVars());
}

IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
SmallVector<int64_t> coeffs;
coeffs.reserve(tvm_coeffs.size());
for (const PrimExpr& it : tvm_coeffs) {
coeffs.push_back(*as_const_int(it));
}

IntSet result = IntSet().Nothing();
for (const IntegerRelation& it : set->disjuncts) {
Simplex simplex(it);
auto range = simplex.computeIntegerBounds(coeffs);
auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs));
auto opt = range.first.getOptimumIfBounded();
auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf();
opt = range.second.getOptimumIfBounded();
auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : pos_inf();
auto interval = IntervalSet(min, max);
result = Union({result, interval});
}
return result;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PresburgerSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto set = node.as<PresburgerSetNode>();
ICHECK(ret) << "Unknown type:" << node->GetTypeKey();
p->stream << "{";
p->stream << set->GetVars() << ": ";
p->stream << node.as<PresburgerSetNode>()->GenerateConstraint();
p->stream << "}";
});

#endif // TVM_MLIR_VERSION >= 150
#endif // TVM_MLIR_VERSION

PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); }

TVM_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet);

TVM_REGISTER_NODE_TYPE(PresburgerSetNode);

} // namespace arith
} // namespace tvm
Loading

0 comments on commit 483a8c3

Please sign in to comment.