Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate IntImm & FloatImm ObjectRef to not-null #5788

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class IntImm : public PrimExpr {
*/
TVM_DLL IntImm(DataType dtype, int64_t value);

TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
};

/*!
Expand Down Expand Up @@ -310,7 +310,7 @@ class FloatImm : public PrimExpr {
*/
TVM_DLL FloatImm(DataType dtype, double value);

TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};

/*!
Expand Down Expand Up @@ -350,7 +350,6 @@ inline Bool operator&&(const Bool& a, const Bool& b) {
*/
class Integer : public IntImm {
public:
Integer() {}
/*!
* \brief constructor from node.
*/
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,20 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}; // struct ReshapeAttrs

struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
Integer axis;
Integer axis = Integer(0);

TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
}
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
Optional<Integer> axis;
std::string mode;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.set_default(NullValue<Optional<Integer>>())
.describe("The axis over which to select values.");
TVM_ATTR_FIELD(mode).set_default("clip").describe(
"Specify how out-of-bound indices will behave."
Expand Down Expand Up @@ -145,7 +145,7 @@ struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {

/*! \brief Attributes used in stack operators */
struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
Integer axis;
Integer axis = Integer(0);
TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe(
"The axis in the result array along which the input arrays are stacked.");
Expand All @@ -154,12 +154,12 @@ struct StackAttrs : public tvm::AttrsNode<StackAttrs> {

/*! \brief Attributes used in repeat operators */
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
Integer repeats;
Integer axis;
Integer repeats = Integer(0);
Optional<Integer> axis;
TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
TVM_ATTR_FIELD(repeats).describe("The number of repetitions for each element.");
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.set_default(NullValue<Optional<Integer>>())
.describe(" The axis along which to repeat values.");
}
}; // struct RepeatAttrs
Expand All @@ -176,10 +176,10 @@ struct TileAttrs : public tvm::AttrsNode<TileAttrs> {

/*! \brief Attributes used in reverse operators */
struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
Integer axis;
Optional<Integer> axis;
TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.set_default(NullValue<Optional<Integer>>())
.describe("The axis along which to reverse elements.");
}
}; // struct ReverseAttrs
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ struct ROIPoolAttrs : public tvm::AttrsNode<ROIPoolAttrs> {

/*! \brief Attributes used in yolo reorg operators */
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
Integer stride;
Integer stride = Integer(1);

TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
TVM_ATTR_FIELD(stride).set_default(1).describe("Stride value for yolo reorg");
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,7 @@ template <typename T>
class Optional : public ObjectRef {
public:
using ContainerType = typename T::ContainerType;
using RefType = T;
static_assert(std::is_base_of<ObjectRef, T>::value, "Optional is only defined for ObjectRef.");
// default constructors.
Optional() = default;
Expand All @@ -1474,6 +1475,7 @@ class Optional : public ObjectRef {
data_ = nullptr;
return *this;
}

// normal value handling.
Optional(T other) // NOLINT(*)
: ObjectRef(std::move(other)) {}
Expand Down Expand Up @@ -1546,6 +1548,12 @@ class Optional : public ObjectRef {
if (*this == nullptr) return RetType(true);
return value() != other;
}

const ContainerType* operator->() const {
CHECK(data_ != nullptr);
return static_cast<const ContainerType*>(data_.get());
}

static constexpr bool _type_is_nullable = true;
};

Expand All @@ -1561,6 +1569,7 @@ struct PackedFuncValueConverter<Optional<T>> {
}
};


} // namespace runtime

// expose the functions to the root namespace.
Expand Down
8 changes: 4 additions & 4 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
// const folding
PrimExpr const_res = TryConstFold<Div>(a, b);
if (const_res.defined()) return const_res;
PVar<IntImm> c1;
PVarOpt<Optional<IntImm>> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
Expand Down Expand Up @@ -764,7 +764,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
// const folding
PrimExpr const_res = TryConstFold<FloorDiv>(a, b);
if (const_res.defined()) return const_res;
PVar<IntImm> c1;
PVarOpt<Optional<IntImm>> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
Expand Down Expand Up @@ -868,7 +868,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
PrimExpr const_res = TryConstFold<Mod>(a, b);
if (const_res.defined()) return const_res;

PVar<IntImm> c1;
PVarOpt<Optional<IntImm>> c1;
// x % c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
Expand Down Expand Up @@ -936,7 +936,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
PrimExpr const_res = TryConstFold<FloorMod>(a, b);
if (const_res.defined()) return const_res;

PVar<IntImm> c1;
PVarOpt<Optional<IntImm>> c1;
// x % c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
Expand Down
2 changes: 1 addition & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ class ConstIntBoundAnalyzer::Impl
*/
static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
PVar<PrimExpr> x, y;
PVar<IntImm> c;
PVarOpt<Optional<IntImm>> c;
// NOTE: canonical form always use <= or <
if ((c <= x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
Expand Down
2 changes: 1 addition & 1 deletion src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
IntervalSet VisitExpr_(const RampNode* op) final {
CHECK(eval_vec_);
IntervalSet base = Eval(op->base);
PVar<IntImm> stride;
PVarOpt<Optional<IntImm>> stride;
if (stride.Match(op->stride)) {
DataType t = op->base.dtype();
int64_t vstride = stride.Eval()->value;
Expand Down
2 changes: 1 addition & 1 deletion src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
// Detect useful constraints and use them in the analysis scope.
std::function<void()> EnterConstraint(const PrimExpr& constraint) {
PVar<Var> var;
PVar<IntImm> coeff, base;
PVarOpt<Optional<IntImm>> coeff, base;
// pattern match interesting constraints
if ((truncmod(var, coeff) == base).Match(constraint) ||
(floormod(var, coeff) == base).Match(constraint)) {
Expand Down
51 changes: 51 additions & 0 deletions src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,57 @@ class PVar : public Pattern<PVar<T>> {
mutable bool filled_{false};
};

/*!
* \brief Pattern variable container.
*
* PVarOpt is a variant of PVar to incorporate optional objectref.
*
* \tparam T the type of the hole.
*
* \note PVarOpt is not thread safe.
* Do not use the same PVarOpt in multiple threads.
*/
template <typename T>
class PVarOpt : public Pattern<PVarOpt<T>> {
public:
// Store PVars by reference in the expression.
using Nested = const PVarOpt<T>&;
using RefType = typename T::RefType;

void InitMatch_() const { filled_ = false; }

bool Match_(const RefType& value) const {
if (!filled_) {
value_ = value;
filled_ = true;
return true;
} else {
return PEqualChecker<RefType>()(value_.value(), value);
}
}

template <typename NodeRefType,
typename = typename std::enable_if<std::is_base_of<NodeRefType, RefType>::value>::type>
bool Match_(const NodeRefType& value) const {
if (const auto* ptr = value.template as<typename T::ContainerType>()) {
return Match_(GetRef<RefType>(ptr));
} else {
return false;
}
}

RefType Eval() const {
CHECK(filled_);
return value_.value();
}

protected:
/*! \brief The matched value */
mutable T value_;
/*! \brief whether the variable has been filled */
mutable bool filled_{false};
};

/*!
* \brief Constant Pattern variable container.
*
Expand Down
26 changes: 13 additions & 13 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
PVarOpt<Optional<IntImm>> c1, c2, c3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is right to use optional here, as the intention is to match the variable itself

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. It is only to accommodate absent of default constructor. However the behaviour don't violate with PVar I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree too...Is there any better to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have 2 possible approach here:

A0: Rename PVarOpt --> PVarExt(Extended) and hide Optional input format.
Look like --> PVarExt<IntImm> c1, c2, c3;

A1: Provide default constructor with default values.
Look like: PVar<IntImm> c1(IntImm(DataType::Int(32), 0)), c2(IntImm(DataType::Int(32), 0)), c3(IntImm(DataType::Int(32), 0));

Please let me know your thoughts on this. Thanks!

Copy link
Member

@tqchen tqchen Jun 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think either approaches are good here. Let us leave it as it is for now

// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
Expand Down Expand Up @@ -230,7 +230,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
PVarOpt<Optional<IntImm>> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
Expand Down Expand Up @@ -413,7 +413,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
Expand Down Expand Up @@ -446,7 +446,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
PVarOpt<Optional<IntImm>> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;

Expand Down Expand Up @@ -621,7 +621,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;

Expand Down Expand Up @@ -701,7 +701,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
PVarOpt<Optional<IntImm>> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;

Expand Down Expand Up @@ -819,7 +819,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;

Expand Down Expand Up @@ -884,7 +884,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
PVar<int> lanes;

// vector rule
Expand Down Expand Up @@ -1056,7 +1056,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
PVar<int> lanes;

// vector rule
Expand Down Expand Up @@ -1219,7 +1219,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1;
PVarOpt<Optional<IntImm>> c1;
PVar<int> lanes;

// vector rule
Expand Down Expand Up @@ -1267,7 +1267,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
PVar<int> lanes;

// vector rule
Expand Down Expand Up @@ -1422,7 +1422,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
Expand Down Expand Up @@ -1461,7 +1461,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVarOpt<Optional<IntImm>> c1, c2;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
Expand Down
Loading