Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][API-Change] Migrate all Object construction to constructor. #5784

Merged
merged 1 commit into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.
::

inline Schedule create_schedule(Array<Operation> ops) {
return ScheduleNode::make(ops);
return Schedule(ops);
}

``Schedule`` consists of collections of ``Stage`` and output ``Operation``.
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/relay_add_pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ is shown below.
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
return TupleGetItem(t, g->index);
}
}
Expand Down
10 changes: 5 additions & 5 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,13 @@ registration.
.. code:: c++

// Create a simple Relay program.
auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
auto tensor_type = relay::TensorType({}, tvm::Bool());
auto x = relay::Var("x", relay::Type());
auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});

auto y = relay::VarNode::make("y", tensor_type);
auto y = relay::Var("y", tensor_type);
auto call = relay::Call(f, tvm::Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});

// Create a module for optimization.
auto mod = IRModule::FromExpr(fx);
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ class SpanNode : public Object {
equal(col_offset, other->col_offset);
}

TVM_DLL static Span make(SourceName source, int lineno, int col_offset);

static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};

class Span : public ObjectRef {
public:
TVM_DLL Span(SourceName source, int lineno, int col_offset);

TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};

Expand Down
90 changes: 75 additions & 15 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,22 @@ class PlaceholderOpNode : public OperationNode {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name, Array<PrimExpr> shape, DataType dtype);

static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};

/*!
* \brief Managed reference to PlaceholderOpNode
* \sa PlaceholderOpNode
*/
class PlaceholderOp : public Operation {
public:
TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
};

/*!
* \brief A Compute op that compute a tensor on certain domain.
* This is the base class for ComputeOp (operating on a scalar at a time) and
Expand Down Expand Up @@ -237,13 +247,23 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);

static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
};

/*!
* \brief Managed reference to ComputeOpNode
* \sa ComputeOpNode
*/
class ComputeOp : public Operation {
public:
TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);

TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
};

/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
Expand Down Expand Up @@ -285,15 +305,25 @@ class TensorComputeOpNode : public BaseComputeOpNode {
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static Operation make(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
};

/*!
* \brief Managed reference to TensorComputeOpNode
* \sa TensorComputeOpNode
*/
class TensorComputeOp : public Operation {
public:
TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);

TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
};

/*!
* \brief Symbolic scan.
*/
Expand Down Expand Up @@ -353,14 +383,24 @@ class ScanOpNode : public OperationNode {
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
IterVar axis, Array<Tensor> init, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> input);

static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
};

/*!
* \brief Managed reference to ScanOpNode
* \sa ScanOpNode
*/
class ScanOp : public Operation {
public:
TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
Array<Tensor> input);

TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
};

/*!
* \brief External computation that cannot be splitted.
*/
Expand Down Expand Up @@ -404,14 +444,24 @@ class ExternOpNode : public OperationNode {
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);

static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
};

/*!
* \brief Managed reference to ExternOpNode
* \sa ExternOpNode
*/
class ExternOp : public Operation {
public:
TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
};

/*!
* \brief A computation operator that generated by hybrid script.
*/
Expand Down Expand Up @@ -459,13 +509,23 @@ class HybridOpNode : public OperationNode {
v->Visit("axis", &axis);
v->Visit("body", &body);
}
TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);

static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
};

/*!
* \brief Managed reference to HybridOpNode
* \sa HybridOpNode
*/
class HybridOp : public Operation {
public:
TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
};

/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
Expand Down
68 changes: 51 additions & 17 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ class Schedule : public ObjectRef {
public:
Schedule() {}
explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
TVM_DLL explicit Schedule(Array<Operation> ops);
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
Expand Down Expand Up @@ -553,13 +559,6 @@ class ScheduleNode : public Object {
*/
TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }

/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
TVM_DLL static Schedule make(Array<Operation> ops);

static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
};
Expand All @@ -569,7 +568,7 @@ class ScheduleNode : public Object {
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
inline Schedule create_schedule(Array<Operation> ops) { return ScheduleNode::make(ops); }
inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }

/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Object {
Expand Down Expand Up @@ -648,13 +647,21 @@ class SplitNode : public IterVarRelationNode {
v->Visit("nparts", &nparts);
}

static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
PrimExpr nparts);

static constexpr const char* _type_key = "Split";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to SplitNode
* \sa SplitNode
*/
class Split : public IterVarRelation {
public:
TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);

TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
};

/*!
* \brief Fuse two domains into one domain.
*/
Expand All @@ -673,12 +680,21 @@ class FuseNode : public IterVarRelationNode {
v->Visit("fused", &fused);
}

static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused);

static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to FuseNode
* \sa FuseNode
*/
class Fuse : public IterVarRelation {
public:
TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);

TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode);
};

/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
Expand All @@ -696,12 +712,21 @@ class RebaseNode : public IterVarRelationNode {
v->Visit("rebased", &rebased);
}

static IterVarRelation make(IterVar parent, IterVar rebased);

static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to RebaseNode
* \sa RebaseNode
*/
class Rebase : public IterVarRelation {
public:
TVM_DLL Rebase(IterVar parent, IterVar rebased);

TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode);
};

/*!
* \brief Singleton iterator [0, 1)
*/
Expand All @@ -712,12 +737,21 @@ class SingletonNode : public IterVarRelationNode {

void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }

static IterVarRelation make(IterVar iter);

static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to SingletonNode
* \sa SingletonNode
*/
class Singleton : public IterVarRelation {
public:
TVM_DLL explicit Singleton(IterVar iter);

TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
};

/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
Expand Down
Loading