Skip to content

Commit

Permalink
remove CreateProgram and keep only declaration in header
Browse files Browse the repository at this point in the history
  • Loading branch information
kazum committed Jul 8, 2018
1 parent 0b7ddcf commit e224cdc
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 120 deletions.
115 changes: 8 additions & 107 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,128 +233,29 @@ class OpenCLModuleNode : public ModuleNode {
std::string source)
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
// destructor
~OpenCLModuleNode() {
{
// free the kernel ids in global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (auto& kv : kid_map_) {
workspace_->free_kernel_ids.push_back(kv.second.kernel_id);
}
}
// free the kernels
for (cl_kernel k : kernels_) {
OPENCL_CALL(clReleaseKernel(k));
}
if (program_) {
OPENCL_CALL(clReleaseProgram(program_));
}
}
~OpenCLModuleNode();

/*!
* \brief Get the global workspace
*/
virtual std::shared_ptr<cl::OpenCLWorkspace> GetGlobalWorkspace() {
return cl::OpenCLWorkspace::Global();
}
virtual std::shared_ptr<cl::OpenCLWorkspace> GetGlobalWorkspace();

virtual const char* type_key() const {
return "opencl";
}

virtual cl_program CreateProgram() {
const char* s = data_.c_str();
size_t len = data_.length();
cl_int err;
cl_program program = clCreateProgramWithSource(workspace_->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
return program;
}
virtual const char* type_key() const;

PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;

void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}

void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}

std::string GetSource(const std::string& format) final {
if (format == fmt_) return data_;
if (fmt_ == "cl") {
return data_;
} else {
return source_;
}
}

const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
std::string GetSource(const std::string& format) final;
// Initialize the programs
void Init() {
workspace_ = GetGlobalWorkspace();
workspace_->Init();
CHECK(workspace_->context != nullptr) << "No OpenCL device";
program_ = CreateProgram();
device_built_flag_.resize(workspace_->devices.size(), false);
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
const std::string& key = kv.first;
KTRefEntry e;
if (workspace_->free_kernel_ids.size() != 0) {
e.kernel_id = workspace_->free_kernel_ids.back();
workspace_->free_kernel_ids.pop_back();
} else {
e.kernel_id = workspace_->num_registered_kernels++;
}
e.version = workspace_->timestamp++;
kid_map_[key] = e;
}
}
void Init();
// install a new kernel to thread local entry
cl_kernel InstallKernel(cl::OpenCLWorkspace* w,
cl::OpenCLThreadEntry* t,
const std::string& func_name,
const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->context.device_id;
if (!device_built_flag_[device_id]) {
// build program
cl_int err;
cl_device_id dev = w->devices[device_id];
err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
if (err != CL_SUCCESS) {
size_t len;
std::string log;
clGetProgramBuildInfo(
program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
log.resize(len);
clGetProgramBuildInfo(
program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << log;
}
device_built_flag_[device_id] = true;
}
// build kernel
cl_int err;
cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
OPENCL_CHECK_ERROR(err);
t->kernel_table[e.kernel_id].kernel = kernel;
t->kernel_table[e.kernel_id].version = e.version;
kernels_.push_back(kernel);
return kernel;
}

const KTRefEntry& e);
protected:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
Expand Down
123 changes: 123 additions & 0 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,31 @@ class OpenCLWrappedFunc {
ThreadAxisConfig thread_axis_cfg_;
};

OpenCLModuleNode::~OpenCLModuleNode() {
{
// free the kernel ids in global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (auto& kv : kid_map_) {
workspace_->free_kernel_ids.push_back(kv.second.kernel_id);
}
}
// free the kernels
for (cl_kernel k : kernels_) {
OPENCL_CALL(clReleaseKernel(k));
}
if (program_) {
OPENCL_CALL(clReleaseProgram(program_));
}
}

std::shared_ptr<cl::OpenCLWorkspace> OpenCLModuleNode::GetGlobalWorkspace() {
return cl::OpenCLWorkspace::Global();
}

const char* OpenCLModuleNode::type_key() const {
return "opencl";
}

PackedFunc OpenCLModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
Expand Down Expand Up @@ -108,6 +133,104 @@ PackedFunc OpenCLModuleNode::GetFunction(
return PackFuncVoidAddr(f, info.arg_types);
}

void OpenCLModuleNode::SaveToFile(const std::string& file_name,
const std::string& format) {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}

void OpenCLModuleNode::SaveToBinary(dmlc::Stream* stream) {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}

std::string OpenCLModuleNode::GetSource(const std::string& format) {
if (format == fmt_) return data_;
if (fmt_ == "cl") {
return data_;
} else {
return source_;
}
}

void OpenCLModuleNode::Init() {
workspace_ = GetGlobalWorkspace();
workspace_->Init();
CHECK(workspace_->context != nullptr) << "No OpenCL device";
if (fmt_ == "cl") {
const char* s = data_.c_str();
size_t len = data_.length();
cl_int err;
program_ = clCreateProgramWithSource(
workspace_->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
} else if (fmt_ == "xclbin" || fmt_ == "awsxclbin") {
const unsigned char* s = (const unsigned char *)data_.c_str();
size_t len = data_.length();
cl_int err;
program_ = clCreateProgramWithBinary(
workspace_->context, 1, &(workspace_->devices[0]), &len, &s, NULL, &err);
if (err != CL_SUCCESS) {
LOG(ERROR) << "OpenCL Error: " << cl::CLGetErrorString(err);
}
} else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_;
}
device_built_flag_.resize(workspace_->devices.size(), false);
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
const std::string& key = kv.first;
KTRefEntry e;
if (workspace_->free_kernel_ids.size() != 0) {
e.kernel_id = workspace_->free_kernel_ids.back();
workspace_->free_kernel_ids.pop_back();
} else {
e.kernel_id = workspace_->num_registered_kernels++;
}
e.version = workspace_->timestamp++;
kid_map_[key] = e;
}
}

cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w,
cl::OpenCLThreadEntry* t,
const std::string& func_name,
const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->context.device_id;
if (!device_built_flag_[device_id]) {
// build program
cl_int err;
cl_device_id dev = w->devices[device_id];
err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
if (err != CL_SUCCESS) {
size_t len;
std::string log;
clGetProgramBuildInfo(
program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
log.resize(len);
clGetProgramBuildInfo(
program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << log;
}
device_built_flag_[device_id] = true;
}
// build kernel
cl_int err;
cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
OPENCL_CHECK_ERROR(err);
t->kernel_table[e.kernel_id].kernel = kernel;
t->kernel_table[e.kernel_id].version = e.version;
kernels_.push_back(kernel);
return kernel;
}

Module OpenCLModuleCreate(
std::string data,
std::string fmt,
Expand Down
13 changes: 0 additions & 13 deletions src/runtime/opencl/sdaccel/sdaccel_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class SDAccelModuleNode : public OpenCLModuleNode {
: OpenCLModuleNode(data, fmt, fmap, source) {}
std::shared_ptr<cl::OpenCLWorkspace> GetGlobalWorkspace() final;
const char* type_key() const final;
cl_program CreateProgram() final;
};

std::shared_ptr<cl::OpenCLWorkspace> SDAccelModuleNode::GetGlobalWorkspace() {
Expand All @@ -33,18 +32,6 @@ const char* SDAccelModuleNode::type_key() const {
return "sdaccel";
}

cl_program SDAccelModuleNode::CreateProgram() {
const unsigned char* s = (const unsigned char *)data_.c_str();
size_t len = data_.length();
cl_int err;
cl_program program = clCreateProgramWithBinary(workspace_->context, 1, &(workspace_->devices[0]),
&len, &s, NULL, &err);
if (err != CL_SUCCESS) {
LOG(ERROR) << "OpenCL Error: " << cl::CLGetErrorString(err);
}
return program;
}

Module SDAccelModuleCreate(
std::string data,
std::string fmt,
Expand Down

0 comments on commit e224cdc

Please sign in to comment.