Skip to content

Commit

Permalink
Expose list of PassContext configurations to the Python APIs (apache#…
Browse files Browse the repository at this point in the history
…8212)

* Expose C++ PassContext::ListAllConfigs via its Python counterpart
   tvm.ir.transform.PassContext.list_configs()

 * Add unit tests for the C++ and Python layers
  • Loading branch information
leandron authored Jun 9, 2021
1 parent 53e4c60 commit 1f2ca06
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ class PassContext : public ObjectRef {
*/
TVM_DLL static PassContext Current();

/*!
* \brief Get all supported configuration names, registered within the PassContext.
* \return List of all configuration names.
*/
TVM_DLL static Array<String> ListConfigNames();

/*!
* \brief Call instrument implementations' callbacks when entering PassContext.
* The callbacks are called in order, and if one raises an exception, the rest will not be
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ def current():
"""Return the current pass context."""
return _ffi_transform_api.GetCurrentPassContext()

@staticmethod
def list_config_names():
"""List all registered `PassContext` configuration names"""
return list(_ffi_transform_api.ListConfigNames())


@tvm._ffi.register_object("transform.Pass")
class Pass(tvm.runtime.Object):
Expand Down
14 changes: 14 additions & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ class PassConfigManager {
}
}

Array<String> ListConfigNames() {
Array<String> config_keys;
for (const auto& kv : key2vtype_) {
config_keys.push_back(kv.first);
}
return config_keys;
}

static PassConfigManager* Global() {
static auto* inst = new PassConfigManager();
return inst;
Expand All @@ -163,6 +171,10 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde
PassConfigManager::Global()->Register(key, value_type_index);
}

Array<String> PassContext::ListConfigNames() {
return PassConfigManager::Global()->ListConfigNames();
}

PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }

void PassContext::InstrumentEnterPassContext() {
Expand Down Expand Up @@ -607,5 +619,7 @@ Pass PrintIR(String header, bool show_meta_data) {

TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR);

TVM_REGISTER_GLOBAL("transform.ListConfigNames").set_body_typed(PassContext::ListConfigNames);

} // namespace transform
} // namespace tvm
7 changes: 7 additions & 0 deletions tests/cpp/relay_transform_sequential_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ TEST(Relay, Sequential) {
ICHECK(tvm::StructuralEqual()(f, expected));
}

TEST(PassContextListConfigNames, Basic) {
Array<String> configs = relay::transform::PassContext::ListConfigNames();
ICHECK_EQ(configs.empty(), false);
ICHECK_EQ(std::count(std::begin(configs), std::end(configs), "relay.backend.use_auto_scheduler"),
1);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
Expand Down
7 changes: 7 additions & 0 deletions tests/python/relay/test_pass_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ def run_after_pass(self, mod, info):
assert passes_counter.run_after_count == 0


def test_list_pass_configs():
config_names = tvm.transform.PassContext.list_config_names()

assert len(config_names) > 0
assert "relay.backend.use_auto_scheduler" in config_names


def test_enter_pass_ctx_exception():
events = []

Expand Down

0 comments on commit 1f2ca06

Please sign in to comment.