Skip to content

Commit

Permalink
codegen fix wip
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 17, 2024
1 parent 19d56dd commit d891914
Showing 1 changed file with 84 additions and 44 deletions.
128 changes: 84 additions & 44 deletions ark/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CodeGenerator::Impl {
~Impl() = default;

private:
std::string def_op(const Json &op_json, size_t task_id, size_t op_idx);
std::pair<std::string, size_t> def_op(const Json &op_json);

std::string def_task(const Json &task_json);

Expand All @@ -80,6 +80,8 @@ class CodeGenerator::Impl {
protected:
friend class CodeGenerator;

std::set<size_t> op_hashes_;
std::set<size_t> task_hashes_;
std::map<size_t, size_t> buffer_id_to_offset_;
std::set<size_t> extra_buffer_ids_;
std::string name_;
Expand Down Expand Up @@ -183,7 +185,10 @@ CodeGenerator::Impl::Impl(const PlanJson &plan,
const std::string &template_path =
ark_root + "/include/kernels/kernel_template.in";
if (!is_file(template_path)) {
ERR(InternalError, "kernel template file not found: ", template_path);
ERR(InvalidUsageError,

Check warning on line 188 in ark/codegen.cpp

View check run for this annotation

Codecov / codecov/patch

ark/codegen.cpp#L188

Added line #L188 was not covered by tests
"kernel template file not found: ", template_path,
". Please make sure the ARK_ROOT environment variable is set "
"correctly.");
}

// Generate the global arguments
Expand Down Expand Up @@ -224,92 +229,126 @@ CodeGenerator::Impl::Impl(const PlanJson &plan,
code_ = replace(template_code, replacements);
}

std::string CodeGenerator::Impl::def_op(const Json &op_json, size_t task_id,
size_t op_idx) {
std::pair<std::string, size_t> CodeGenerator::Impl::def_op(
const Json &op_json) {
auto op = ModelOp::deserialize(op_json);
auto impl_name = op->impl_name(op_json["Config"]);
auto impl_args = op->impl_args(op_json["Config"]);
std::stringstream ss;
ss << "__forceinline__ __device__ void t" << task_id << "_o" << op_idx
<< "(";
std::stringstream ss_desc;
size_t arg_idx = 0;
for (auto &arg : impl_args) {
if (arg.type_name() == "TENSOR") {
auto tns = arg.value<ModelTensorRef>();
ss << tns->data_type()->type_str() << "*";
ss_desc << tns->data_type()->type_str() << "*";
} else if (arg.type_name() == "OFFSET") {
ss << "uint64_t";
ss_desc << "uint64_t";
} else {
ss << arg.type_str();
ss_desc << arg.type_str();
}
ss << " _" << arg_idx++ << ", ";
ss_desc << " _" << arg_idx++ << ", ";
}
ss << "int _idx, int _spw) {\n " << impl_name << "(";
ss_desc << "int _idx, int _spw) {\n " << impl_name << "(";
for (size_t i = 0; i < impl_args.size(); ++i) {
ss << "_" << i << ", ";
ss_desc << "_" << i << ", ";
}
ss << "_idx, _spw);\n}\n";
return ss.str();
ss_desc << "_idx, _spw);\n}\n";
auto desc_str = ss_desc.str();
size_t op_hash = std::hash<std::string>{}(desc_str);
std::stringstream ss;
ss << "__forceinline__ __device__ void __op_" << std::hex << op_hash
<< std::dec << "(";
ss << desc_str;
return {ss.str(), op_hash};
}

std::string CodeGenerator::Impl::def_task(const Json &task_json) {
std::stringstream ss;
size_t op_idx = 0;
std::stringstream ss_hash_concat;
std::vector<size_t> op_hash_list;
for (auto &op_json : task_json["Ops"]) {
ss << this->def_op(op_json, task_json["Id"], op_idx++);
auto [def_str, hash] = this->def_op(op_json);
if (op_hashes_.find(hash) == op_hashes_.end()) {
ss << def_str;
op_hashes_.insert(hash);
}
ss_hash_concat << std::hex << hash;
op_hash_list.push_back(hash);
}
ss << "__device__ void t" << task_json["Id"]
<< "(char *_buf, int _idx, int _spw, @GLOBAL_ARGS@) {\n";
size_t task_hash = std::hash<std::string>{}(ss_hash_concat.str());
std::stringstream ss_desc;
auto &buf_reg = BufferRegistry::get_instance();
op_idx = 0;
size_t op_idx = 0;
std::map<std::string, size_t> ptr_str_to_index;
std::vector<std::string> ptr_str_list;
for (auto &op_json : task_json["Ops"]) {
auto op = ModelOp::deserialize(op_json);
auto impl_args = op->impl_args(op_json["Config"]);
ss << " t" << task_json["Id"] << "_o" << op_idx++ << "(";
for (size_t i = 0; i < impl_args.size(); ++i) {
auto &arg = impl_args[i];
ss_desc << " __op_" << std::hex << op_hash_list[op_idx++] << std::dec
<< "(";
for (auto &arg : impl_args) {
if (arg.type_name() == "TENSOR") {
auto tns = arg.value<ModelTensorRef>();
size_t buffer_id = tns->buffer()->id();
auto it = buffer_id_to_offset_.find(buffer_id);
auto buf_info = buf_reg.get(buffer_id);
std::string ptr_str;
if ((buf_info && buf_info->is_external) ||
(it == buffer_id_to_offset_.end())) {
ss << "(" << tns->data_type()->type_str() << "*)_ext_buf_"
<< buffer_id;
ptr_str = "_ext_buf_" + std::to_string(buffer_id);
} else {
size_t buffer_offset;
buffer_offset = it->second;
size_t offset = buffer_offset + ModelOffset(tns).value();
ss << "(" << tns->data_type()->type_str() << "*)&_buf["
<< offset << "]";
ptr_str = "&_buf[" + std::to_string(offset) + "]";
}
size_t ptr_idx;
if (ptr_str_to_index.find(ptr_str) == ptr_str_to_index.end()) {
ptr_idx = ptr_str_to_index.size();
ptr_str_to_index[ptr_str] = ptr_idx;
ptr_str_list.push_back(ptr_str);
} else {
ptr_idx = ptr_str_to_index[ptr_str];
}
ss_desc << "(" << tns->data_type()->type_str() << "*)_"
<< ptr_idx;
} else if (arg.type_name() == "OFFSET") {
auto moff = arg.value<ModelOffset>();
size_t buffer_id = moff.buffer_id();
auto buf_info = buf_reg.get(buffer_id);
if (buf_info && buf_info->is_external) {
size_t offset = moff.value();
ss << "(uint64_t)((char*)_ext_buf_" << buffer_id << " + "
<< offset << ")";
} else {
size_t buffer_offset;
auto it = buffer_id_to_offset_.find(buffer_id);
if (it == buffer_id_to_offset_.end()) {
ERR(InternalError, "buffer ID not found: ", buffer_id);
}
buffer_offset = it->second;
size_t offset = buffer_offset + moff.value();
ss << offset;
ERR(InternalError, "cannot offset external buffer");

Check warning on line 319 in ark/codegen.cpp

View check run for this annotation

Codecov / codecov/patch

ark/codegen.cpp#L319

Added line #L319 was not covered by tests
}
size_t buffer_offset;
auto it = buffer_id_to_offset_.find(buffer_id);
if (it == buffer_id_to_offset_.end()) {
ERR(InternalError, "buffer ID not found: ", buffer_id);

Check warning on line 324 in ark/codegen.cpp

View check run for this annotation

Codecov / codecov/patch

ark/codegen.cpp#L324

Added line #L324 was not covered by tests
}
buffer_offset = it->second;
size_t offset = buffer_offset + moff.value();
ss_desc << offset;
} else {
ss << arg.serialize().begin().value();
ss_desc << arg.serialize().begin().value();
}
ss << ", ";
ss_desc << ", ";
}
ss << "_idx, _spw);\n";
ss_desc << "_idx, _spw);\n";
}
ss << "}\n";
if (task_hashes_.find(task_hash) == task_hashes_.end()) {
ss << "__device__ void __task_" << std::hex << task_hash << std::dec
<< "(";
for (size_t i = 0; i < ptr_str_list.size(); ++i) {
ss << "void *_" << i << ", ";
}
ss << "int _idx, int _spw) {\n" << ss_desc.str() << "}\n";
task_hashes_.insert(task_hash);
}
ss << "__forceinline__ __device__ void __t" << task_json["Id"]
<< "(char *_buf, int _idx, int _spw, @GLOBAL_ARGS@) {\n";
ss << " __task_" << std::hex << task_hash << std::dec << "(";
for (auto &ptr_str : ptr_str_list) {
ss << ptr_str << ", ";
}
ss << "_idx, _spw);\n}\n";
return ss.str();
}

Expand All @@ -332,7 +371,8 @@ std::string CodeGenerator::Impl::task_seq(
ss << "task_seq<" << proc_b << ", " << proc_e << ", " << proc_s << ", "
<< proc_cur << ", " << task_b << ", " << task_e << ", " << task_s << ", "
<< task_gran << ", " << num_slots << ", " << slot_num_warps << ", "
<< slot_sram_bytes << ", t" << task_id << ">(_buf, @FUNCTION_ARGS@);\n";
<< slot_sram_bytes << ", __t" << task_id
<< ">(_buf, @FUNCTION_ARGS@);\n";
return ss.str();
}

Expand Down

0 comments on commit d891914

Please sign in to comment.