Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
opt load_params and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyue50 committed Aug 25, 2021
1 parent 1d4d085 commit 4b3bb59
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
1 change: 1 addition & 0 deletions cinn/hlir/pe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ core_gather_srcs(SRCS
cc_test(test_cinn_pe_elementwise SRCS pe_elementwise_test.cc DEPS cinncore)
cc_test(test_cinn_pe_broadcast SRCS pe_broadcast_test.cc DEPS cinncore)
cc_test(test_cinn_pe_transform SRCS pe_transform_test.cc DEPS cinncore)
cc_test(test_load_params SRCS load_params_test.cc DEPS cinncore)
20 changes: 10 additions & 10 deletions cinn/hlir/pe/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ void GetConv2dFactors(std::unordered_map<std::string, int> *factors,
const std::string &key,
bool import_params) {
if (import_params) {
auto &params = ScheduleParam::get_instance().GetParam();
auto &params = ScheduleParam::get_x86_instance().GetParam();
if (params.empty()) {
CreateX86SerialData();
LoadSerialData();
LoadSerialData(&params);
}
if (params.count(key)) {
CHECK(!params[key]["oc_bn"].empty());
Expand Down Expand Up @@ -1187,8 +1187,8 @@ void CreateCudaSerialData(const std::string &file_name) {

// winograd
InputCudaParam(model_data,
"CudaScheduleConv 1 64 58 58 64 64 3 3 1 64 56 56",
{{32, 2}, {1, 3}, {1, 3}, {4, 1, 8, 2}, {28, 1, 2, 1}, {1, 2, 7, 4}});
"CudaScheduleConv 1 64 58 58 64 64 3 3 1 64 56 56",
{{32, 2}, {1, 3}, {1, 3}, {4, 1, 8, 2}, {28, 1, 2, 1}, {1, 2, 7, 4}});
// winograd
InputCudaParam(model_data,
"CudaScheduleConv 1 512 9 9 512 512 3 3 1 512 7 7",
Expand All @@ -1212,7 +1212,8 @@ int GetMaxSplitter(int a, int b) {
return b;
}

void LoadSerialData(const std::string &file_name) {
void LoadSerialData(std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> *params,
const std::string &file_name) {
proto::ModelData read_model_data;
std::fstream input(file_name, std::ios::in | std::ios::binary);
if (!read_model_data.ParseFromIstream(&input)) {
Expand All @@ -1223,7 +1224,6 @@ void LoadSerialData(const std::string &file_name) {
std::string test_write3;
read_model_data.SerializeToString(&test_write3);
auto read_model_map = read_model_data.data();
auto &res = ScheduleParam::get_instance().GetParam();
for (auto &i : read_model_map) {
auto read_schedule_map = i.second.data();
std::unordered_map<std::string, std::vector<int>> param_data;
Expand All @@ -1234,7 +1234,7 @@ void LoadSerialData(const std::string &file_name) {
}
param_data[j.first] = temp_data;
}
res[i.first] = param_data;
(*params)[i.first] = param_data;
}
}

Expand Down Expand Up @@ -1272,10 +1272,10 @@ void CudaScheduleConv(poly::StageMap stages,
ir::Tensor &weights,
ir::Tensor &output,
const common::Target &target) {
auto &res = ScheduleParam::get_instance().GetParam();
auto &res = ScheduleParam::get_cuda_instance().GetParam();
if (res.empty()) {
CreateCudaSerialData();
LoadSerialData();
LoadSerialData(&res);
}

int n = output->shape[0].as_int32();
Expand Down Expand Up @@ -1344,7 +1344,7 @@ void CudaScheduleConv2(poly::StageMap stages,
ir::Tensor &output,
const common::Target &target,
const std::string &key) {
auto &res = ScheduleParam::get_instance().GetParam();
auto &res = ScheduleParam::get_cuda_instance().GetParam();
stages[input_pad]->ComputeInline();
optim::Simplify(&(output->shape[2]));
optim::Simplify(&(output->shape[3]));
Expand Down
13 changes: 9 additions & 4 deletions cinn/hlir/pe/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ class ScheduleParam {
~ScheduleParam();
ScheduleParam(const ScheduleParam &) = delete;
ScheduleParam &operator=(const ScheduleParam &) = delete;
static ScheduleParam &get_instance() {
static ScheduleParam instance;
return instance;
static ScheduleParam &get_cuda_instance() {
static ScheduleParam cuda_instance;
return cuda_instance;
}
static ScheduleParam &get_x86_instance() {
static ScheduleParam x86_instance;
return x86_instance;
}
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> &GetParam() { return param_data; }
std::unordered_map<std::string, std::vector<int>> &operator[](const std::string &key) { return param_data[key]; }
Expand Down Expand Up @@ -156,7 +160,8 @@ std::string GenerateX86ConvKey(const std::vector<int> &input_shape,
const std::vector<int> &dilations);
void CreateX86SerialData(const std::string &file_name = "default_serial.log");

void LoadSerialData(const std::string &file_name = "default_serial.log");
void LoadSerialData(std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> *params,
const std::string &file_name = "default_serial.log");

void SaveSerialData(
const std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> &model_data,
Expand Down

0 comments on commit 4b3bb59

Please sign in to comment.