Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto TensorCore CodeGen #4234

Merged
merged 21 commits into from
Nov 9, 2019
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,8 @@ constexpr const char* reduce_scope = "reduce_scope";
constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import llvm source or file into the final code gen module */
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*! \brief Try to modify the AST to support Tensor Core */
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);

/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def lower(sch,
binds, arg_list = get_binds(args, compact, binds)

# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def compile_cuda(code,
"Compilation error: empty result is generated")
return data

@register_func("tvm_find_cuda_path")
def find_cuda_path():
"""Utility function to find cuda path

Expand All @@ -125,7 +126,7 @@ def find_cuda_path():
return cuda_path
raise RuntimeError("Cannot find cuda path")


@register_func("tvm_get_cuda_version")
def get_cuda_version(cuda_path):
"""Utility function to get cuda version

Expand Down
7 changes: 7 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ TVM_REGISTER_API("ir_pass.StorageFlatten")
}
});

TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use set_body_typed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, fixed.

if (args.size() == 3) {
*ret = RewriteForTensorCore(args[0], args[1], args[2]);
}
});

TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqual()(lhs, rhs);
Expand Down
Loading