Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SkipVectorize IR pass #3222

Merged
merged 1 commit into from
May 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ tvm.ir_pass
tvm.ir_pass.CanonicalSimplify
tvm.ir_pass.StorageFlatten
tvm.ir_pass.VectorizeLoop
tvm.ir_pass.SkipVectorize
tvm.ir_pass.UnrollLoop
tvm.ir_pass.ThreadSync
tvm.ir_pass.StorageRewrite
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false;

/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
Expand All @@ -260,6 +263,7 @@ class BuildConfigNode : public Node {
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
}

static constexpr const char* _type_key = "BuildConfig";
Expand Down
35 changes: 21 additions & 14 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt,

/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
* \param stmt The statement to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief convert vectorized loops into serialized loops
* \param stmt The statement to skip vectorization on.
* \return Transformed stmt.
*/
Stmt SkipVectorize(Stmt stmt);

/*!
* \brief instruments bound checkers.
* \param stmt The statment to be instrumented.
* \return Instrumented Stmt.
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);

/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);

/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);

/*!
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \param split_loop Loop splitting factor.
* \return Transformed stmt.
*/
Expand All @@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statment to be transformed.
* \param stmt The statement to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
Expand All @@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt,
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt StorageRewrite(Stmt stmt);
Expand All @@ -324,23 +331,23 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop);
/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt CoProcSync(Stmt stmt);

/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);

/*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statment to be rewritten.
* \param stmt The statement to be rewritten.
* \return Transformed stmt.
*/
Stmt RewriteUnsafeSelect(Stmt stmt);
Expand All @@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);
Expand All @@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt);
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
* \param stmt The stmt to be trasnformed
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt DecorateDeviceScope(Stmt stmt);
Expand All @@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt);
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ class BuildConfig(NodeBase):
"double_buffer_split_loop": 1,
"dump_pass_ir": False,
"instrument_bound_checkers": False,
"disable_select_rewriting": False
"disable_select_rewriting": False,
"disable_vectorize": False
}
_dump_ir = DumpIR()

Expand Down Expand Up @@ -384,7 +385,10 @@ def lower(sch,
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
stmt = ir_pass.VectorizeLoop(stmt)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
Expand Down
7 changes: 6 additions & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch,
if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop);
}
stmt = ir::VectorizeLoop(stmt);
if (config->disable_vectorize) {
stmt = ir::SkipVectorize(stmt);
} else {
stmt = ir::VectorizeLoop(stmt);
}
stmt = ir::InjectVirtualThread(stmt);
stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
stmt = ir::StorageRewrite(stmt);
Expand Down Expand Up @@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << ")";
});

Expand Down
22 changes: 20 additions & 2 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) {
return LoopVectorizer().Mutate(stmt);
}

class VectorizeSkipper : public IRMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
if (op->for_type == ForType::Vectorized) {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
op->body);
} else {
return stmt;
}
}
};

Stmt SkipVectorize(Stmt stmt) {
return VectorizeSkipper().Mutate(stmt);
}

} // namespace ir
} // namespace tvm