Skip to content

Commit

Permalink
[REFACTOR] top->te (apache#4759)
Browse files Browse the repository at this point in the history
Bring up namespace te -- Tensor expression language DSL.
  • Loading branch information
tqchen authored and alexwong committed Feb 26, 2020
1 parent b93ade6 commit 38fbe54
Show file tree
Hide file tree
Showing 123 changed files with 923 additions and 939 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ file(GLOB_RECURSE COMPILER_SRCS
src/node/*.cc
src/ir/*.cc
src/arith/*.cc
src/top/*.cc
src/te/*.cc
src/autotvm/*.cc
src/tir/*.cc
src/driver/*.cc
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/arith/bound.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

namespace tvm {
// forward delcare Tensor
namespace top {
namespace te {
class Tensor;
}
namespace arith {
Expand Down Expand Up @@ -84,7 +84,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body,
const top::Tensor &tensor,
const te::Tensor &tensor,
bool consider_calls,
bool consider_provides);

Expand Down
8 changes: 4 additions & 4 deletions include/tvm/driver/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/support/with.h>
#include <tvm/top/schedule_pass.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/lowered_func.h>

#include <string>
Expand All @@ -52,10 +52,10 @@ namespace tvm {
* \return The lowered function.
*/
TVM_DLL Array<tir::LoweredFunc> lower(
top::Schedule sch,
const Array<top::Tensor>& args,
te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<top::Tensor, tir::Buffer>& binds,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_

#include <tvm/top/tensor.h>
#include <tvm/top/schedule.h>
#include <tvm/te/tensor.h>
#include <tvm/te/schedule.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <tvm/target/target.h>
Expand Down Expand Up @@ -104,8 +104,8 @@ using TShapeDataDependant = bool;
* \return The output compute description of the operator.
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<top::Tensor>(const Attrs& attrs,
const Array<top::Tensor>& inputs,
Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target)>;

Expand All @@ -119,8 +119,8 @@ using FTVMCompute = runtime::TypedPackedFunc<
* \return schedule The computation schedule.
*/
using FTVMSchedule = runtime::TypedPackedFunc<
top::Schedule(const Attrs& attrs,
const Array<top::Tensor>& outs,
te::Schedule(const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target)>;

/*!
Expand All @@ -136,7 +136,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<top::Tensor>& tinfos)>;
const Array<te::Tensor>& tinfos)>;

/*!
* \brief Convert the layout of operators or replace the
Expand All @@ -152,7 +152,7 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
using FTVMConvertOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<top::Tensor>& tinfos,
const Array<te::Tensor>& tinfos,
const std::string& desired_layout)>;
/*!
* \brief Legalizes an expression with another expression. This function will be
Expand Down Expand Up @@ -211,8 +211,8 @@ enum AnyCodegenStrategy {
using Shape = Array<IndexExpr>;

using FShapeFunc = runtime::TypedPackedFunc<
Array<top::Tensor>(const Attrs& attrs,
const Array<top::Tensor>& inputs,
Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;

} // namespace relay
Expand Down
19 changes: 9 additions & 10 deletions include/tvm/top/operation.h → include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
*/

/*!
* \file tvm/top/operation.h
* \file tvm/te/operation.h
* \brief Operation node can generate one or multiple Tensors
*/
#ifndef TVM_TOP_OPERATION_H_
#define TVM_TOP_OPERATION_H_
#ifndef TVM_TE_OPERATION_H_
#define TVM_TE_OPERATION_H_

#include <tvm/arith/analyzer.h>
#include <tvm/top/tensor.h>
#include <tvm/top/schedule.h>
#include <tvm/te/tensor.h>
#include <tvm/te/schedule.h>

#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
Expand All @@ -36,10 +36,9 @@
#include <vector>
#include <unordered_map>



namespace tvm {
namespace top {
/*! \brief Tensor expression language DSL. */
namespace te {

/*!
* \brief Temporary data structure to store union
Expand Down Expand Up @@ -679,6 +678,6 @@ inline Tensor compute(Array<PrimExpr> shape,
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(get());
}
} // namespace top
} // namespace te
} // namespace tvm
#endif // TVM_TOP_OPERATION_H_
#endif // TVM_TE_OPERATION_H_
18 changes: 8 additions & 10 deletions include/tvm/top/schedule.h → include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,22 @@
*/

/*!
* \file tvm/top/schedule.h
* \file tvm/te/schedule.h
* \brief Define a schedule.
*/
// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
#ifndef TVM_TOP_SCHEDULE_H_
#define TVM_TOP_SCHEDULE_H_
#ifndef TVM_TE_SCHEDULE_H_
#define TVM_TE_SCHEDULE_H_

#include <tvm/tir/expr.h>
#include <tvm/top/tensor.h>
#include <tvm/top/tensor_intrin.h>

#include <tvm/te/tensor.h>
#include <tvm/te/tensor_intrin.h>

#include <string>
#include <unordered_map>


namespace tvm {
namespace top {
namespace te {
// Node container for Stage
class StageNode;
// Node container for Schedule
Expand Down Expand Up @@ -767,6 +765,6 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(get());
}
} // namespace top
} // namespace te
} // namespace tvm
#endif // TVM_TOP_SCHEDULE_H_
#endif // TVM_TE_SCHEDULE_H_
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@
*/

/*!
* \file tvm/top/schedule_pass.h
* \file tvm/te/schedule_pass.h
* \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/
#ifndef TVM_TOP_SCHEDULE_PASS_H_
#define TVM_TOP_SCHEDULE_PASS_H_
#ifndef TVM_TE_SCHEDULE_PASS_H_
#define TVM_TE_SCHEDULE_PASS_H_

#include <tvm/top/schedule.h>
#include <tvm/te/schedule.h>

namespace tvm {
namespace top {
namespace te {

/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
Expand Down Expand Up @@ -71,6 +71,6 @@ void AutoInlineElemWise(Schedule sch);
*/
TVM_DLL void AutoInlineInjective(Schedule sch);

} // namespace top
} // namespace te
} // namespace tvm
#endif // TVM_TOP_SCHEDULE_PASS_H_
#endif // TVM_TE_SCHEDULE_PASS_H_
20 changes: 9 additions & 11 deletions include/tvm/top/tensor.h → include/tvm/te/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
*/

/*!
* \file tvm/top/tensor.h
* \file tvm/te/tensor.h
* \brief Dataflow tensor object
*/
#ifndef TVM_TOP_TENSOR_H_
#define TVM_TOP_TENSOR_H_
#ifndef TVM_TE_TENSOR_H_
#define TVM_TE_TENSOR_H_

#include <tvm/node/container.h>
#include <tvm/arith/bound.h>
Expand All @@ -34,10 +34,8 @@
#include <utility>
#include <type_traits>



namespace tvm {
namespace top {
namespace te {

using arith::IntSet;
using namespace tvm::tir;
Expand Down Expand Up @@ -251,17 +249,17 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*)
DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)

} // namespace top
} // namespace te
} // namespace tvm

namespace std {
template <>
struct hash<::tvm::top::Operation> : public ::tvm::ObjectHash {
struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash {
};

template <>
struct hash<::tvm::top::Tensor> {
std::size_t operator()(const ::tvm::top::Tensor& k) const {
struct hash<::tvm::te::Tensor> {
std::size_t operator()(const ::tvm::te::Tensor& k) const {
::tvm::ObjectHash hasher;
if (k.defined() && k->op.defined()) {
return hasher(k->op);
Expand All @@ -271,4 +269,4 @@ struct hash<::tvm::top::Tensor> {
}
};
} // namespace std
#endif // TVM_TOP_TENSOR_H_
#endif // TVM_TE_TENSOR_H_
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
*/

/*!
* \file tvm/top/tensor_intrin.h
* \file tvm/te/tensor_intrin.h
* \brief Tensor intrinsic operations.
*/
#ifndef TVM_TOP_TENSOR_INTRIN_H_
#define TVM_TOP_TENSOR_INTRIN_H_
#ifndef TVM_TE_TENSOR_INTRIN_H_
#define TVM_TE_TENSOR_INTRIN_H_

#include <tvm/top/tensor.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/buffer.h>

#include <string>


namespace tvm {
namespace top {
namespace te {

// Internal node container of tensor intrinsics.
class TensorIntrinNode;
Expand Down Expand Up @@ -176,6 +175,6 @@ inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
return static_cast<const TensorIntrinCallNode*>(get());
}

} // namespace top
} // namespace te
} // namespace tvm
#endif // TVM_TOP_TENSOR_INTRIN_H_
#endif // TVM_TE_TENSOR_INTRIN_H_
8 changes: 4 additions & 4 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef TVM_TIR_IR_PASS_H_
#define TVM_TIR_IR_PASS_H_

#include <tvm/top/schedule.h>
#include <tvm/te/schedule.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/lowered_func.h>
Expand Down Expand Up @@ -205,7 +205,7 @@ Stmt Inline(Stmt stmt,
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<top::Tensor, Buffer> extern_buffer,
Map<te::Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);

Expand All @@ -219,8 +219,8 @@ Stmt StorageFlatten(Stmt stmt,
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
top::Schedule schedule,
Map<top::Tensor, Buffer> extern_buffer);
te::Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer);

/*!
* \brief Verify if there is any argument bound to compact buffer.
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>

#include <tvm/top/tensor.h>
#include <tvm/te/tensor.h>

namespace tvm {
namespace arith {
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
#include <dmlc/memory_io.h>
#include <tvm/tir/expr.h>
#include <tvm/top/tensor.h>
#include <tvm/te/tensor.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/serialization.h>

Expand Down
10 changes: 5 additions & 5 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/top/tensor.h>
#include <tvm/top/operation.h>
#include <tvm/te/tensor.h>
#include <tvm/te/operation.h>
#include <tvm/tir/buffer.h>
#include <tvm/top/schedule.h>
#include <tvm/te/schedule.h>
#include <tvm/runtime/registry.h>

#include <tvm/driver/driver.h>
Expand Down Expand Up @@ -276,7 +276,7 @@ TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir

namespace top {
namespace te {
TVM_REGISTER_GLOBAL("_Tensor")
.set_body_typed(TensorNode::make);

Expand Down Expand Up @@ -444,7 +444,7 @@ TVM_REGISTER_GLOBAL("_ScheduleCacheWrite")

TVM_REGISTER_GLOBAL("_ScheduleRFactor")
.set_body_method(&Schedule::rfactor);
} // namespace top
} // namespace te

TVM_REGISTER_GLOBAL("_CommReducerCombine")
.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
Expand Down
Loading

0 comments on commit 38fbe54

Please sign in to comment.