Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge branch 'develop' into squeeze
Browse files Browse the repository at this point in the history
# Conflicts:
#	cinn/frontend/net_builder.cc
#	cinn/frontend/net_builder.h
#	cinn/pybind/frontend.cc
  • Loading branch information
zrr1999 committed Aug 29, 2022
2 parents e5cbe05 + fb25902 commit 20a2796
Show file tree
Hide file tree
Showing 152 changed files with 3,789 additions and 2,144 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function prepare_ci {
pip install clang-format==9.0
pip install wheel
pip install sphinx==3.3.1 sphinx_gallery==0.8.1 recommonmark==0.6.0 exhale scipy breathe==4.24.0 matplotlib sphinx_rtd_theme
pip install paddlepaddle-gpu==2.2.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
pip install paddlepaddle-gpu==2.3.1.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
}

function prepare_doc_model_file {
Expand Down
1 change: 1 addition & 0 deletions cinn/auto_schedule/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_subdirectory(analysis)
add_subdirectory(cost_model)
add_subdirectory(database)
add_subdirectory(measure)
add_subdirectory(search_space)
add_subdirectory(search_strategy)
Expand Down
2 changes: 2 additions & 0 deletions cinn/auto_schedule/analysis/analyze_ir.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <unordered_set>

Expand Down
4 changes: 4 additions & 0 deletions cinn/auto_schedule/auto_tuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "cinn/auto_schedule/task/task_creator.h"
#include "cinn/auto_schedule/task/tune_task.h"
#include "cinn/auto_schedule/task_scheduler/task_scheduler.h"
#include "cinn/common/type.h"

namespace cinn {
namespace auto_schedule {
Expand All @@ -44,6 +45,9 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler*
for (TuneTask& task : tasks_) {
task.SetGraphCompiler(graph_compiler);
task.TaskGraphToUnoptLoweredFunc();
task.SerializeToString(graph_->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"),
graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"));
VLOG(3) << "Add a task with serialized_key:\n" << task.serialized_key;
}

// create task optimizers
Expand Down
7 changes: 4 additions & 3 deletions cinn/auto_schedule/cost_model/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ set(Python_VIRTUALENV FIRST)
find_package(PythonInterp ${PY_VERSION} REQUIRED)
find_package(PythonLibs ${PY_VERSION} REQUIRED)

if (WITH_TESTING)
cc_test(test_cost_model SRCS cost_model_test.cc cost_model.cc DEPS pybind gtest_main)

cc_test(test_cost_model SRCS cost_model_test.cc cost_model.cc DEPS pybind gtest_main)

target_link_libraries(test_cost_model ${PYTHON_LIBRARIES})
target_link_libraries(test_cost_model ${PYTHON_LIBRARIES})
endif()
5 changes: 5 additions & 0 deletions cinn/auto_schedule/database/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS database.cc)

cc_test(test_database SRCS database_test.cc DEPS cinncore)
90 changes: 90 additions & 0 deletions cinn/auto_schedule/database/database.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/auto_schedule/database/database.h"

namespace cinn {
namespace auto_schedule {

bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningRecord& rhs) const {
return lhs.execution_cost < rhs.execution_cost;
}

Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) {
CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0";
}

bool Database::AddRecord(TuningRecord&& record) {
CHECK(!record.task_key.empty()) << "task_key of TuningRecord can't be empty";
Commit(record);

auto& records = key2record_[record.task_key];
records.emplace(record);
if (records.size() > capacity_per_task_) {
records.erase(std::prev(records.end()));
}
return true;
}

std::vector<TuningRecord> Database::LookUp(const std::string& task_key) {
auto fit = key2record_.find(task_key);
if (fit == key2record_.end()) {
return {};
}

std::vector<TuningRecord> results;
results.reserve(fit->second.size());
results.assign(fit->second.begin(), fit->second.end());
return results;
}

std::vector<SearchState> Database::GetTopK(const std::string& task_key, int k) {
auto fit = key2record_.find(task_key);
if (fit == key2record_.end() || k <= 0) {
return {};
}
if (k > capacity_per_task_) {
LOG(WARNING) << "Input k:" << k << " is greater than the capacity";
k = capacity_per_task_;
}

std::vector<SearchState> results;
results.reserve(k);
for (const TuningRecord& record : fit->second) {
results.emplace_back(record.state);
if (results.size() == k) {
break;
}
}
return results;
}

size_t Database::Size() {
auto res =
std::accumulate(key2record_.begin(), key2record_.end(), size_t(0), [](size_t res, const auto& kv) -> size_t {
return std::move(res) + kv.second.size();
});
return res;
}

size_t Database::Count(const std::string& task_key) {
auto fit = key2record_.find(task_key);
if (fit == key2record_.end()) {
return 0;
}
return fit->second.size();
}

} // namespace auto_schedule
} // namespace cinn
70 changes: 70 additions & 0 deletions cinn/auto_schedule/database/database.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "cinn/auto_schedule/measure/measure.h"
#include "cinn/auto_schedule/search_space/search_state.h"

namespace cinn {
namespace auto_schedule {

// Record related data about tuning process of a measure candidate
struct TuningRecord {
// the unique key to identify a task
std::string task_key;
// the cost time of the candidate executed during measure
double execution_cost; // unit: us
// the searched candidate to be saved
SearchState state;

// a binary compare function that denotes when the left
// will be sorted in the front of the right
struct Compare {
bool operator()(const TuningRecord& lhs, const TuningRecord& rhs) const;
};
};

// A database supports insert or lookup historial tuning result with sepecified traits.
// It can be implemented with a concrete storage to save/load underlying data,
// such as memory, file, database server and so on, this base class can be regarded as
// one using memory as its underlying storage medium.
class Database {
public:
explicit Database(int capacity_per_task);
~Database() = default;

// add a record into the database
bool AddRecord(TuningRecord&& record);
// return all records whose task_keys are equal to the specified key
std::vector<TuningRecord> LookUp(const std::string& task_key);
// return the states of the top k in sorted candidates
std::vector<SearchState> GetTopK(const std::string& task_key, int k);
// return the total number of stored candidates
size_t Size();
// return the number of stored candidates with specified key
size_t Count(const std::string& task_key);

protected:
// commit the newly added record into underlying storage
virtual bool Commit(const TuningRecord& record) { return true; }

// map task_key to its records
std::unordered_map<std::string, std::multiset<TuningRecord, TuningRecord::Compare>> key2record_;
// the max number of candidates stored
const int capacity_per_task_;
};

} // namespace auto_schedule
} // namespace cinn
72 changes: 72 additions & 0 deletions cinn/auto_schedule/database/database_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/auto_schedule/database/database.h"

#include <gtest/gtest.h>

#include <vector>

#include "cinn/auto_schedule/search_space/search_state.h"
#include "cinn/ir/ir_schedule.h"

namespace cinn {
namespace auto_schedule {

class TestDatabase : public ::testing::Test {
public:
TestDatabase() : test_db(2) {
test_db.AddRecord(TuningRecord({"k1", 1.0, SearchState(ir::ModuleExpr())}));
test_db.AddRecord(TuningRecord({"k2", 2.0, SearchState(ir::ModuleExpr())}));
test_db.AddRecord(TuningRecord({"k2", 3.0, SearchState(ir::ModuleExpr())}));
test_db.AddRecord(TuningRecord({"k3", 3.0, SearchState(ir::ModuleExpr())}));
test_db.AddRecord(TuningRecord({"k3", 4.0, SearchState(ir::ModuleExpr())}));
test_db.AddRecord(TuningRecord({"k3", 5.0, SearchState(ir::ModuleExpr())}));
test_db.AddRecord(TuningRecord({"k4", 4.0, SearchState(ir::ModuleExpr())}));
}

void SetUp() override {}
Database test_db;
};

TEST_F(TestDatabase, Basic) {
ASSERT_EQ(test_db.Size(), 6);
auto records = test_db.LookUp("k3");
// check the max number of stored candidates will
// be restricted to capacity_per_task
ASSERT_EQ(test_db.Count("k3"), 2);
ASSERT_EQ(records.size(), 2);
EXPECT_EQ(records[0].execution_cost, 3.0);
EXPECT_EQ(records[1].execution_cost, 4.0);
}

TEST_F(TestDatabase, GetTopK) {
ASSERT_TRUE(test_db.GetTopK("k5", 2).empty());
ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1);

SearchState state1(std::move(ir::ModuleExpr()));
SearchState state2(std::move(ir::ModuleExpr()));
state1.predicted_cost = 1.2;
state2.predicted_cost = 1.0;
test_db.AddRecord(TuningRecord({"k4", 2.0, state1}));
test_db.AddRecord(TuningRecord({"k4", 3.0, state2}));

auto states = test_db.GetTopK("k4", 3);
ASSERT_EQ(states.size(), 2);
EXPECT_FLOAT_EQ(states[0].predicted_cost, 1.2);
EXPECT_FLOAT_EQ(states[1].predicted_cost, 1.0);
}

} // namespace auto_schedule
} // namespace cinn
8 changes: 6 additions & 2 deletions cinn/auto_schedule/measure/measure.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "cinn/auto_schedule/task/tune_task.h"
Expand All @@ -41,10 +42,13 @@ struct MeasureInput {
struct MeasureResult {
// The time cost of execution in average of running
// with a specific repeated times.
double execution_cost; // unit: us
double execution_cost = 0.0; // unit: us
// The time cost of the whole measurement process including
// building and running
double elapsed_time; // unit: us
double elapsed_time = 0.0; // unit: us
// used to return detail messages once an error occurr during measurement,
// empty if nothing goes wrong
std::string error_msg;
};

// The result of building with input schedule
Expand Down
Loading

0 comments on commit 20a2796

Please sign in to comment.