Skip to content

Commit

Permalink
Fix lint.
Browse files Browse the repository at this point in the history
  • Loading branch information
Min Chen committed Apr 22, 2023
1 parent ac9749d commit 3c6d01e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 52 deletions.
2 changes: 2 additions & 0 deletions include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

#include <limits>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>

namespace tvm {
Expand Down
59 changes: 28 additions & 31 deletions src/arith/presburger_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,22 @@
* \file presburger_set.cc
* \brief The presburger set functions
*/
#include "presburger_set.h"

#include <tvm/arith/int_set.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/pattern.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/pattern.h>
#include <tvm/arith/int_solver.h>

#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>

#include "constraint_extract.h"
#include "presburger_set.h"
#include "interval_set.h"

namespace tvm {
Expand All @@ -43,14 +45,12 @@ namespace arith {
#ifdef TVM_MLIR_VERSION
using namespace tir;


void Update(const PrimExpr& constraint,
PresburgerSetNode& intset) {
auto& space = intset.space;
void Update(const PrimExpr& constraint, PresburgerSetNode* intset) {
auto& space = intset->space;
auto constraints_union = ExtractComponents(constraint);
for (const PrimExpr& subconstraint : constraints_union) {
auto entries = ExtractConstraints(subconstraint, false);
auto vars = intset.GetVars();
auto vars = intset->GetVars();
IntegerRelation disjunct(entries.size(), 0, vars.size() + 1, space);
for (const PrimExpr& entry : entries) {
// The expression is expect to be simplified to only contain ==, <= or <
Expand Down Expand Up @@ -83,19 +83,18 @@ void Update(const PrimExpr& constraint,
LOG(FATAL) << "Unsupported constraint expression: " << entry->GetTypeKey();
}
}
intset.unionInPlace(disjunct);
intset->unionInPlace(disjunct);
}
}

PresburgerSet::PresburgerSet(const PrimExpr& constraint) {
Array<Var> vars;
PostOrderVisit(constraint, [&vars](const ObjectRef& obj) {
if (const VarNode* new_var = obj.as<VarNode>()) {
auto var = GetRef<Var>(new_var);
if (!std::any_of(vars.begin(), vars.end(),
[&var](const Var& v) { return v.same_as(var); })) {
vars.push_back(var);
}
auto var = GetRef<Var>(new_var);
if (!std::any_of(vars.begin(), vars.end(), [&var](const Var& v) { return v.same_as(var); })) {
vars.push_back(var);
}
}
});
auto constraints_union = ExtractComponents(constraint);
Expand All @@ -104,25 +103,26 @@ PresburgerSet::PresburgerSet(const PrimExpr& constraint) {
auto space = PresburgerSpace::getRelationSpace(vars.size(), 0, 0, 0);
auto node = make_object<PresburgerSetNode>(std::move(space), vars);
node->SetVars(vars);
Update(simplified_constraint, *node);
Update(simplified_constraint, node.get());
data_ = std::move(node);
}

PresburgerSet::PresburgerSet(const std::vector<IntegerRelation>& disjuncts, const Array<Var>& vars) {
PresburgerSet::PresburgerSet(const std::vector<IntegerRelation>& disjuncts,
const Array<Var>& vars) {
auto node = make_object<PresburgerSetNode>(disjuncts, disjuncts[0].getSpace(), vars);
data_ = std::move(node);
}

void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const Array<Var>& vars) {
Analyzer analyzer;
PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite);
Update(simplified_constraint, *this);
Update(simplified_constraint, this);
SetVars(vars);
}

PrimExpr PresburgerSetNode::GenerateConstraint() const {
PrimExpr constraint = Bool(0);
for (const IntegerRelation &disjunct : disjuncts) {
for (const IntegerRelation& disjunct : disjuncts) {
PrimExpr union_entry = Bool(1);
for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
Expand Down Expand Up @@ -171,8 +171,9 @@ PresburgerSet Union(Array<PresburgerSet> sets) {
if (sets.size() == 1) return sets[0];
auto relations = sets[0]->disjuncts;
for (size_t i = 1; i < sets.size(); ++i) {
for (const auto rel : sets[i]->disjuncts)
for (const IntegerRelation& rel : sets[i]->disjuncts) {
relations.push_back(rel);
}
}
return PresburgerSet(std::move(relations), sets[0]->GetVars());
}
Expand All @@ -185,31 +186,29 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) {

for (size_t i = 1; i < sets.size(); ++i) {
ICHECK(space.isCompatible(sets[i]->space)) << "Spaces should match";
for (const IntegerRelation &relA : sets[i]->disjuncts) {
for (const IntegerRelation &relB : relations) {
for (const IntegerRelation& relA : sets[i]->disjuncts) {
for (const IntegerRelation& relB : relations) {
IntegerRelation intersection = relA.intersect(relB);
if (!intersection.isEmpty())
relations.push_back(intersection);
if (!intersection.isEmpty()) relations.push_back(intersection);
}
}
}
return PresburgerSet(std::move(relations), sets[0]->GetVars());
}

IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
auto tvm_coeffs = DetectLinearEquation(e, set->GetVars());
Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
SmallVector<int64_t> coeffs;
coeffs.reserve(tvm_coeffs.size());
for (auto &it : tvm_coeffs) {
for (const PrimExpr& it : tvm_coeffs) {
coeffs.push_back(*as_const_int(it));
}

IntSet result = IntSet().Nothing();
for (auto &it : set->disjuncts) {
for (const IntegerRelation& it : set->disjuncts) {
Simplex simplex(it);
auto range = simplex.computeIntegerBounds(coeffs);
auto maxRoundedDown(
simplex.computeOptimum(Simplex::Direction::Up, coeffs));
auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs));
auto opt = range.first.getOptimumIfBounded();
auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf();
opt = range.second.getOptimumIfBounded();
Expand All @@ -232,9 +231,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

#endif

PresburgerSet MakePresburgerSet(const PrimExpr& constraint) {
return PresburgerSet(constraint);
}
PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); }

TVM_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet);

Expand Down
33 changes: 14 additions & 19 deletions src/arith/presburger_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@
#define TVM_ARITH_PRESBURGER_SET_H_

#ifdef TVM_MLIR_VERSION
#include <mlir/Analysis/Presburger/PresburgerRelation.h>
#include <mlir/Analysis/Presburger/IntegerRelation.h>
#include <mlir/Analysis/Presburger/PresburgerRelation.h>
#include <mlir/Analysis/Presburger/Simplex.h>
#endif

#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>

#include <limits>
#include <vector>

#include "const_fold.h"

Expand All @@ -54,13 +55,12 @@ using namespace presburger;
*/
class PresburgerSetNode : public IntSetNode {
public:
explicit PresburgerSetNode(const PresburgerSpace &space, const Array<Var> &vars)
: disjuncts({}), space(space), vars(vars) {};
explicit PresburgerSetNode() : space(PresburgerSpace::getRelationSpace()) {};
explicit PresburgerSetNode(const std::vector<IntegerRelation> &disjuncts,
const PresburgerSpace &space,
const Array<Var> &vars)
: disjuncts(disjuncts), space(space), vars(vars) {}
PresburgerSetNode() : space(PresburgerSpace::getRelationSpace()) {}
explicit PresburgerSetNode(const PresburgerSpace& space, const Array<Var>& vars)
: disjuncts({}), space(space), vars(vars) {}
explicit PresburgerSetNode(const std::vector<IntegerRelation>& disjuncts,
const PresburgerSpace& space, const Array<Var>& vars)
: disjuncts(disjuncts), space(space), vars(vars) {}

/*! \brief Represent the union of multiple IntegerRelation */
std::vector<IntegerRelation> disjuncts;
Expand All @@ -83,7 +83,7 @@ class PresburgerSetNode : public IntSetNode {
* \brief Do inplace union with given disjunct
* \param disjunct The given disjunct to be union with
*/
void unionInPlace(const IntegerRelation &disjunct) {
void unionInPlace(const IntegerRelation& disjunct) {
assert(space.isCompatible(disjunct.getSpace()) && "Spaces should match");
disjuncts.push_back(disjunct);
}
Expand All @@ -105,7 +105,7 @@ class PresburgerSetNode : public IntSetNode {
* \brief Set domain vars
* \param new_vars Vars that will be taken as the domain vars
*/
void SetVars(const Array<Var> &new_vars) { vars = new_vars; }
void SetVars(const Array<Var>& new_vars) { vars = new_vars; }

/*!
* \brief Get the current domain vars
Expand All @@ -115,8 +115,7 @@ class PresburgerSetNode : public IntSetNode {

/*! \return whether integer set is empty */
bool IsEmpty() const {
return std::all_of(disjuncts.begin(),
disjuncts.end(),
return std::all_of(disjuncts.begin(), disjuncts.end(),
std::mem_fn(&IntegerRelation::isIntegerEmpty));
}

Expand Down Expand Up @@ -156,24 +155,20 @@ class PresburgerSet : public IntSet {
class PresburgerSetNode : public IntSetNode {
public:
// dummy visitor overload.
void VisitAttrs(tvm::AttrVisitor* v) {
LOG(FATAL) << "MLIR is not enabled!";
}
void VisitAttrs(tvm::AttrVisitor* v) { LOG(FATAL) << "MLIR is not enabled!"; }

static constexpr const char* _type_key = "arith.PresburgerSet";
TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode);
};

class PresburgerSet : public IntSet {
public:
/*!
/*!
* \brief Constructor interface to prompt when MLIR is not enabled.
* \param constraint The constraint to construct the set.
* \return The created set.
*/
TVM_DLL PresburgerSet(const PrimExpr& constraint) {
LOG(FATAL) << "MLIR is not enabled!";
}
TVM_DLL PresburgerSet(const PrimExpr& constraint) { LOG(FATAL) << "MLIR is not enabled!"; }
};
#endif
/*!
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/arith_integer_set_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ TEST(PresburgerSet, eval) {
auto x = tvm::tir::Var("x");
auto y = tvm::tir::Var("y");
auto sub_constraint0 = (x + y < 20) && (x - y <= 0);
auto sub_constraint1 = x >= 0 && x < 20 && y >=0 && y < 20;
auto sub_constraint1 = x >= 0 && x < 20 && y >= 0 && y < 20;
auto constraint = sub_constraint0 && sub_constraint1;
auto set = tvm::arith::PresburgerSet(constraint);

auto target = x + 2*y;
auto target = x + 2 * y;
auto result = EvalSet(target, set);
ASSERT_TRUE(tvm::tir::is_zero(result.min()));
ASSERT_TRUE(tvm::tir::is_const_int(result.max(), 38));
Expand Down

0 comments on commit 3c6d01e

Please sign in to comment.