-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][External Codegen] Support data types for CSourceModuleCodegen args and output #4934
Changes from 5 commits
1294b36
ccd3b6d
92b6b1a
ee7753c
43d682c
5e9ecf1
74d7a7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,13 @@ namespace tvm { | |
namespace relay { | ||
namespace contrib { | ||
|
||
struct Output { | ||
std::string name; | ||
std::string dtype; | ||
int size; | ||
bool need_copy; | ||
}; | ||
|
||
class CSourceModuleCodegenBase { | ||
public: | ||
CSourceModuleCodegenBase() = default; | ||
|
@@ -98,7 +105,7 @@ class CodegenCBase { | |
* \brief Gerenate C code for the external function. | ||
* | ||
* \param func_name The name of the external function. | ||
* \param arg_cnt The expected number of arguments. | ||
* \param args arguments to the external function. | ||
* | ||
* \code | ||
* | ||
|
@@ -116,29 +123,30 @@ class CodegenCBase { | |
* | ||
* \endcode | ||
*/ | ||
void GenerateBackendCFunc(const std::string& func_name, int arg_cnt) { | ||
void GenerateBackendCFunc(const std::string& func_name, Array<Var> args, const Output& out) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const Array& args There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
// Print signature | ||
code_stream_ << "\n"; | ||
code_stream_ << "extern \"C\" int " << func_name << "_wrapper_("; | ||
for (int i = 0; i < arg_cnt - 1; i++) { | ||
for (size_t i = 0; i < args.size(); i++) { | ||
code_stream_ << "DLTensor* arg" << i << ",\n"; | ||
code_stream_ << "\t"; | ||
} | ||
if (arg_cnt > 0) { | ||
code_stream_ << "DLTensor* arg" << arg_cnt - 1 << ") {\n"; | ||
if (args.size() > 0) { | ||
code_stream_ << "DLTensor* arg" << args.size() << ") {\n"; | ||
} | ||
|
||
EnterScope(); | ||
|
||
// Generate the internal call. | ||
PrintIndents(); | ||
code_stream_ << func_name << "_("; | ||
for (int i = 0; i < arg_cnt - 1; i++) { | ||
code_stream_ << "static_cast<float*>(arg" << i << "->data),\n"; | ||
for (size_t i = 0; i < args.size(); i++) { | ||
const auto& dtype_str = GetDtypeString(args[i]); | ||
code_stream_ << "static_cast<" << dtype_str << "*>(arg" << i << "->data),\n"; | ||
PrintIndents(); | ||
} | ||
if (arg_cnt > 0) { | ||
code_stream_ << "static_cast<float*>(arg" << arg_cnt - 1 << "->data)"; | ||
if (args.size() > 0) { | ||
code_stream_ << "static_cast<" << out.dtype << "*>(arg" << args.size() << "->data)"; | ||
} | ||
code_stream_ << ");\n"; | ||
PrintIndents(); | ||
|
@@ -207,17 +215,20 @@ class CodegenCBase { | |
* | ||
* \return The emitted code string. | ||
*/ | ||
std::string JitImpl(std::string ext_func_id, std::vector<std::string> args, | ||
std::string JitImpl(std::string ext_func_id, Array<Var> args, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please also use const reference here for all parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
std::vector<std::string> buf_decl, std::vector<std::string> body, | ||
std::vector<std::pair<std::string, int>> out) { | ||
std::vector<Output> out) { | ||
// Create the signature. For example, it could be: | ||
// extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {} | ||
code_stream_ << "extern \"C\" void " << ext_func_id << "_("; | ||
|
||
CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support."; | ||
|
||
for (const auto& arg : args) { | ||
code_stream_ << "float* " << arg << ", "; | ||
const auto& dtype_str = GetDtypeString(arg); | ||
code_stream_ << dtype_str << "* " << arg->name_hint() << ", "; | ||
} | ||
code_stream_ << "float* out) {\n"; | ||
code_stream_ << out[0].dtype << "* out) {\n"; | ||
this->EnterScope(); | ||
|
||
// Function body | ||
|
@@ -232,24 +243,60 @@ class CodegenCBase { | |
} | ||
|
||
// Copy output | ||
CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support."; | ||
this->PrintIndents(); | ||
code_stream_ << "std::memcpy(out, " << out[0].first << ", 4 * " << out[0].second << ");\n"; | ||
|
||
// Free buffers | ||
for (size_t i = 0; i < buf_decl.size(); i++) { | ||
if (out[0].need_copy) { | ||
this->PrintIndents(); | ||
code_stream_ << "std::free(buf_" << i << ");\n"; | ||
code_stream_ << "std::memcpy(out, " << out[0].name << ", 4 * " << out[0].size << ");\n"; | ||
|
||
// Free buffers | ||
for (size_t i = 0; i < buf_decl.size(); i++) { | ||
this->PrintIndents(); | ||
code_stream_ << "std::free(buf_" << i << ");\n"; | ||
} | ||
} | ||
|
||
this->ExitScope(); | ||
code_stream_ << "}\n"; | ||
|
||
// Create the wrapper to call the ext_func | ||
this->GenerateBackendCFunc(ext_func_id, args.size() + 1 /* output */); | ||
this->GenerateBackendCFunc(ext_func_id, args, out[0]); | ||
return code_stream_.str(); | ||
} | ||
|
||
/*! | ||
* \brief Returns dtype string | ||
* | ||
* \param var Var to get the dtype of | ||
* | ||
* \return The dtype string. | ||
*/ | ||
std::string GetDtypeString(Var var) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const Var& There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
auto ttype = var->checked_type().as<TensorTypeNode>(); | ||
CHECK(ttype) << "Expect TensorTypeNode"; | ||
return GetDtypeString(ttype); | ||
} | ||
|
||
/*! | ||
* \brief Returns dtype string | ||
* | ||
* \param ttype TensorTypeNode* to get the dtype of | ||
* | ||
* \return The dtype string. | ||
*/ | ||
std::string GetDtypeString(const TensorTypeNode* ttype) { | ||
std::string dtype; | ||
if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) { | ||
dtype = "float"; | ||
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) { | ||
dtype = "int"; | ||
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { | ||
dtype = "int64_t"; | ||
} else { | ||
LOG(FATAL) << "Unsupported dtype " << ttype->dtype; | ||
} | ||
|
||
return dtype; | ||
} | ||
|
||
/*! \brief The external function source code stream. */ | ||
std::ostringstream code_stream_; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto* type_node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated