Skip to content

Commit

Permalink
[REFACTOR][TIR][API-Change] Range/IntSet API style consistency. (apac…
Browse files Browse the repository at this point in the history
…he#5953)

- Range::make_by_min_extent -> Range::FromMinExtent
- Update the APIs in IntSet to use CamelCase
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 30, 2020
1 parent f3ecd56 commit c835f7d
Show file tree
Hide file tree
Showing 39 changed files with 175 additions and 184 deletions.
6 changes: 3 additions & 3 deletions docs/dev/inferbound.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,16 @@ The Ranges of the inner and outer IterVars of the split are set based on the par

.. code:: cpp
rmap[split->inner] = Range::make_by_min_extent(0, split->factor)
rmap[split->outer] = Range::make_by_min_extent(0, DivCeil(rmap[split->parent]->extent, split->factor))
rmap[split->inner] = Range::FromMinExtent(0, split->factor)
rmap[split->outer] = Range::FromMinExtent(0, DivCeil(rmap[split->parent]->extent, split->factor))
There is an opportunity here to tighten the bounds produced by InferBound, when ``split->factor`` does not evenly divide the parent's extent. Suppose the parent's extent is 20, and the split factor is 16. Then on the second iteration of the outer loop, the inner loop only needs to perform 4 iterations, not 16. If PassDownDomain could set the extent of ``split->inner`` to ``min(split->factor, rmap[split->parent]->extent - (split->outer * split->factor))``, then the extent of the inner variable would properly adapt, based on which iteration of the outer loop is being executed.

For Fuse relations, the Range of the fused IterVar is set based on the known Ranges of the inner and outer IterVars, as follows:

.. code:: cpp
rmap[fuse->fused] = Range::make_by_min_extent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent)
rmap[fuse->fused] = Range::FromMinExtent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent)
InferRootBound
Expand Down
51 changes: 22 additions & 29 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,82 +65,75 @@ class IntSetNode : public Object {
*/
class IntSet : public ObjectRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const IntSetNode* operator->() const { return static_cast<const IntSetNode*>(get()); }
/*!
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
Range CoverRange(Range max_range) const;
/*! \return Lower bound of the set */
PrimExpr min() const;
/*! \return upper bound of the set */
PrimExpr max() const;
/*! \return The sign of the elements in the integer set */
SignType GetSignType() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
bool IsNothing() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
bool IsEverything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
bool IsSinglePoint() const;
/*! \return Whether the set is proved to be bigger than 0 */
bool can_prove_positive() const;
bool CanProvePositive() const;
/*! \return Whether the set is proved to be smaller than 0 */
bool can_prove_negative() const;
bool CanProveNegative() const;
/*! \return Whether the set is proved to be smaller than or equal to 0 */
bool can_prove_non_positive() const;
bool CanProveNonPositive() const;
/*! \return Whether the set is proved to be larger than or equal to 0 */
bool can_prove_non_negative() const;
/*! \return The sign of the elements in the integer set */
SignType sign_type() const;
bool CanProveNonNegative() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \brief The single point value, call only if IsSinglePoint is true
* \return The point value.
*/
PrimExpr point_value() const;
PrimExpr PointValue() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \note It is guanrateed that IntSet::FromRange(r).MatchRange(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
bool MatchRange(const tvm::Range& r) const;
/*! \return The set contains nothing */
static IntSet nothing();
static IntSet Nothing();
/*! \return The set contains everything */
static IntSet everything();
static IntSet Everything();
/*!
* \brief construct a point set.
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet single_point(PrimExpr point);
static IntSet SinglePoint(PrimExpr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(PrimExpr vec);
static IntSet Vector(PrimExpr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
static IntSet range(Range r);
static IntSet FromRange(tvm::Range r);
/*!
* \brief Construct a set representing a interval.
* \param min The minimum value of the interval.
* \param max The maximum value of the interval.
* \return constructed set.
*/
static IntSet interval(PrimExpr min, PrimExpr max);
static IntSet Interval(PrimExpr min, PrimExpr max);

TVM_DEFINE_OBJECT_REF_METHODS(IntSet, ObjectRef, IntSetNode);
};

//-----------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class Range : public ObjectRef {
* \param min The minimum range.
* \param extent The extent of the range.
*/
static Range make_by_min_extent(PrimExpr min, PrimExpr extent);
static Range FromMinExtent(PrimExpr min, PrimExpr extent);
// declare range.
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, begin, end=None):
_ffi_api.Range, begin, end)

@staticmethod
def make_by_min_extent(min_value, extent):
def from_min_extent(min_value, extent):
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Expand All @@ -136,4 +136,4 @@ def make_by_min_extent(min_value, extent):
rng : Range
The constructed range.
"""
return _ffi_api.range_by_min_extent(min_value, extent)
return _ffi_api.Range_from_min_extent(min_value, extent)
2 changes: 1 addition & 1 deletion python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def wrap_up_realize(self, node, body):
if _scope == 'global':
body = self.wrap_up_binds(body)

_domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
_domain = [Range.from_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_true = tvm.runtime.convert(True)
body = tvm.tir.ProducerRealize(_buf, _domain, _true, body)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_region(tslice):
begin = idx.var
else:
begin = idx
region.append(Range.make_by_min_extent(begin, 1))
region.append(Range.from_min_extent(begin, 1))
return region


Expand Down
8 changes: 4 additions & 4 deletions src/arith/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class BoundDeducer : public ExprVisitor {
if (operand.dtype().is_uint()) {
sign_operand = kPositive;
} else {
sign_operand = expr_map_[operand].sign_type();
sign_operand = expr_map_[operand].GetSignType();
}

if (sign_operand == SignType::kNegative) {
Expand Down Expand Up @@ -315,7 +315,7 @@ void BoundDeducer::Deduce() {
void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result_, relax_map_);
if (a.is_everything() || b.is_everything()) {
if (a.IsEverything() || b.IsEverything()) {
success_ = false;
return;
}
Expand All @@ -336,7 +336,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success_) return IntSet::nothing();
if (!d.success_) return IntSet::Nothing();
PrimExpr min = neg_inf(), max = pos_inf();
if (d.comp_op == kEqual) {
min = d.result_;
Expand All @@ -346,7 +346,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
} else {
max = d.result_;
}
return IntSet::interval(min, max);
return IntSet::Interval(min, max);
}

// assuming e >= 0, deduce the bound of variable from it.
Expand Down
6 changes: 3 additions & 3 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ class BufferTouchedDomain final : public StmtExprVisitor {
Region ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).cover_range(none));
ret.push_back(arith::Union(bounds_[i]).CoverRange(none));
}
return ret;
}

void VisitStmt_(const ForNode* op) final {
const VarNode* var = op->loop_var.get();
dom_map_[var] = IntSet::range(Range::make_by_min_extent(op->min, op->extent));
dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
Expand All @@ -69,7 +69,7 @@ class BufferTouchedDomain final : public StmtExprVisitor {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const VarNode* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
} else {
Expand Down
Loading

0 comments on commit c835f7d

Please sign in to comment.