This repository has been archived by the owner on Jan 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
# Conflicts: # cinn/frontend/net_builder.cc # cinn/frontend/net_builder.h # cinn/pybind/frontend.cc
- Loading branch information
Showing
152 changed files
with
3,789 additions
and
2,144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
core_gather_headers() | ||
|
||
gather_srcs(cinnapi_src SRCS database.cc) | ||
|
||
cc_test(test_database SRCS database_test.cc DEPS cinncore) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright (c) 2022 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "cinn/auto_schedule/database/database.h" | ||
|
||
namespace cinn { | ||
namespace auto_schedule { | ||
|
||
bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningRecord& rhs) const { | ||
return lhs.execution_cost < rhs.execution_cost; | ||
} | ||
|
||
Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) { | ||
CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0"; | ||
} | ||
|
||
bool Database::AddRecord(TuningRecord&& record) { | ||
CHECK(!record.task_key.empty()) << "task_key of TuningRecord can't be empty"; | ||
Commit(record); | ||
|
||
auto& records = key2record_[record.task_key]; | ||
records.emplace(record); | ||
if (records.size() > capacity_per_task_) { | ||
records.erase(std::prev(records.end())); | ||
} | ||
return true; | ||
} | ||
|
||
std::vector<TuningRecord> Database::LookUp(const std::string& task_key) { | ||
auto fit = key2record_.find(task_key); | ||
if (fit == key2record_.end()) { | ||
return {}; | ||
} | ||
|
||
std::vector<TuningRecord> results; | ||
results.reserve(fit->second.size()); | ||
results.assign(fit->second.begin(), fit->second.end()); | ||
return results; | ||
} | ||
|
||
std::vector<SearchState> Database::GetTopK(const std::string& task_key, int k) { | ||
auto fit = key2record_.find(task_key); | ||
if (fit == key2record_.end() || k <= 0) { | ||
return {}; | ||
} | ||
if (k > capacity_per_task_) { | ||
LOG(WARNING) << "Input k:" << k << " is greater than the capacity"; | ||
k = capacity_per_task_; | ||
} | ||
|
||
std::vector<SearchState> results; | ||
results.reserve(k); | ||
for (const TuningRecord& record : fit->second) { | ||
results.emplace_back(record.state); | ||
if (results.size() == k) { | ||
break; | ||
} | ||
} | ||
return results; | ||
} | ||
|
||
size_t Database::Size() { | ||
auto res = | ||
std::accumulate(key2record_.begin(), key2record_.end(), size_t(0), [](size_t res, const auto& kv) -> size_t { | ||
return std::move(res) + kv.second.size(); | ||
}); | ||
return res; | ||
} | ||
|
||
size_t Database::Count(const std::string& task_key) { | ||
auto fit = key2record_.find(task_key); | ||
if (fit == key2record_.end()) { | ||
return 0; | ||
} | ||
return fit->second.size(); | ||
} | ||
|
||
} // namespace auto_schedule | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) 2022 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "cinn/auto_schedule/measure/measure.h" | ||
#include "cinn/auto_schedule/search_space/search_state.h" | ||
|
||
namespace cinn { | ||
namespace auto_schedule { | ||
|
||
// Record related data about tuning process of a measure candidate | ||
struct TuningRecord { | ||
// the unique key to identify a task | ||
std::string task_key; | ||
// the cost time of the candidate executed during measure | ||
double execution_cost; // unit: us | ||
// the searched candidate to be saved | ||
SearchState state; | ||
|
||
// a binary compare function that denotes when the left | ||
// will be sorted in the front of the right | ||
struct Compare { | ||
bool operator()(const TuningRecord& lhs, const TuningRecord& rhs) const; | ||
}; | ||
}; | ||
|
||
// A database supports insert or lookup historial tuning result with sepecified traits. | ||
// It can be implemented with a concrete storage to save/load underlying data, | ||
// such as memory, file, database server and so on, this base class can be regarded as | ||
// one using memory as its underlying storage medium. | ||
class Database { | ||
public: | ||
explicit Database(int capacity_per_task); | ||
~Database() = default; | ||
|
||
// add a record into the database | ||
bool AddRecord(TuningRecord&& record); | ||
// return all records whose task_keys are equal to the specified key | ||
std::vector<TuningRecord> LookUp(const std::string& task_key); | ||
// return the states of the top k in sorted candidates | ||
std::vector<SearchState> GetTopK(const std::string& task_key, int k); | ||
// return the total number of stored candidates | ||
size_t Size(); | ||
// return the number of stored candidates with specified key | ||
size_t Count(const std::string& task_key); | ||
|
||
protected: | ||
// commit the newly added record into underlying storage | ||
virtual bool Commit(const TuningRecord& record) { return true; } | ||
|
||
// map task_key to its records | ||
std::unordered_map<std::string, std::multiset<TuningRecord, TuningRecord::Compare>> key2record_; | ||
// the max number of candidates stored | ||
const int capacity_per_task_; | ||
}; | ||
|
||
} // namespace auto_schedule | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (c) 2022 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "cinn/auto_schedule/database/database.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <vector> | ||
|
||
#include "cinn/auto_schedule/search_space/search_state.h" | ||
#include "cinn/ir/ir_schedule.h" | ||
|
||
namespace cinn { | ||
namespace auto_schedule { | ||
|
||
class TestDatabase : public ::testing::Test { | ||
public: | ||
TestDatabase() : test_db(2) { | ||
test_db.AddRecord(TuningRecord({"k1", 1.0, SearchState(ir::ModuleExpr())})); | ||
test_db.AddRecord(TuningRecord({"k2", 2.0, SearchState(ir::ModuleExpr())})); | ||
test_db.AddRecord(TuningRecord({"k2", 3.0, SearchState(ir::ModuleExpr())})); | ||
test_db.AddRecord(TuningRecord({"k3", 3.0, SearchState(ir::ModuleExpr())})); | ||
test_db.AddRecord(TuningRecord({"k3", 4.0, SearchState(ir::ModuleExpr())})); | ||
test_db.AddRecord(TuningRecord({"k3", 5.0, SearchState(ir::ModuleExpr())})); | ||
test_db.AddRecord(TuningRecord({"k4", 4.0, SearchState(ir::ModuleExpr())})); | ||
} | ||
|
||
void SetUp() override {} | ||
Database test_db; | ||
}; | ||
|
||
TEST_F(TestDatabase, Basic) { | ||
ASSERT_EQ(test_db.Size(), 6); | ||
auto records = test_db.LookUp("k3"); | ||
// check the max number of stored candidates will | ||
// be restricted to capacity_per_task | ||
ASSERT_EQ(test_db.Count("k3"), 2); | ||
ASSERT_EQ(records.size(), 2); | ||
EXPECT_EQ(records[0].execution_cost, 3.0); | ||
EXPECT_EQ(records[1].execution_cost, 4.0); | ||
} | ||
|
||
TEST_F(TestDatabase, GetTopK) { | ||
ASSERT_TRUE(test_db.GetTopK("k5", 2).empty()); | ||
ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1); | ||
|
||
SearchState state1(std::move(ir::ModuleExpr())); | ||
SearchState state2(std::move(ir::ModuleExpr())); | ||
state1.predicted_cost = 1.2; | ||
state2.predicted_cost = 1.0; | ||
test_db.AddRecord(TuningRecord({"k4", 2.0, state1})); | ||
test_db.AddRecord(TuningRecord({"k4", 3.0, state2})); | ||
|
||
auto states = test_db.GetTopK("k4", 3); | ||
ASSERT_EQ(states.size(), 2); | ||
EXPECT_FLOAT_EQ(states[0].predicted_cost, 1.2); | ||
EXPECT_FLOAT_EQ(states[1].predicted_cost, 1.0); | ||
} | ||
|
||
} // namespace auto_schedule | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.