Skip to content

Commit

Permalink
[REFACTOR][IR] Streamline ir/op Registry
Browse files Browse the repository at this point in the history
This PR refactors the attrregistry mechanism in the ir/op into
a separate common base. The common base will provide a foundation
for other attr related registries such as target and pass.

We also streamlines the terminology of the registry API.

- Use AttrMap for the column maps returned by the registry
- Use RegEntry to refer to the registry entry.
  • Loading branch information
tqchen committed May 16, 2020
1 parent 63f84a1 commit 773453e
Show file tree
Hide file tree
Showing 22 changed files with 433 additions and 284 deletions.
233 changes: 65 additions & 168 deletions include/tvm/ir/op.h

Large diffs are not rendered by default.

132 changes: 132 additions & 0 deletions include/tvm/node/attr_registry_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/attr_registry_map.h
* \brief Attribute map used in registry.
*/
#ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_
#define TVM_NODE_ATTR_REGISTRY_MAP_H_

#include <utility>
#include <vector>

namespace tvm {

/*!
* \brief Generic attribute map.
* \tparam KeyType the type of the key.
*/
template <typename KeyType>
class GenericAttrRegistryMap {
public:
/*!
* \brief Check if the map has key.
* \param key The key to the map
* \return 1 if key is contained in map, 0 otherwise.
*/
int count(const KeyType& key) const {
if (key.defined()) {
const uint32_t idx = key->AttrRegistryIndex();
return idx < data_.size() ? (data_[idx].second != 0) : 0;
} else {
return 0;
}
}
/*!
* \brief get the corresponding value element at key.
* \param key The key to the map
* \return the const reference to the content value.
*/
const runtime::TVMRetValue& operator[](const KeyType& key) const {
CHECK(key.defined());
const uint32_t idx = key->AttrRegistryIndex();
CHECK(idx < data_.size() && data_[idx].second != 0)
<< "Attribute " << attr_name_ << " has not been registered for " << key->name;
return data_[idx].first;
}
/*!
* \brief get the corresponding value element at key with default value.
* \param key The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
ValueType get(const KeyType& key, ValueType def_value) const {
CHECK(key.defined());
const uint32_t idx = key->AttrRegistryIndex();
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return def_value;
}
}

private:
/*! \brief The name of the attr field */
String attr_name_;
/*! \brief The internal data. */
std::vector<std::pair<runtime::TVMRetValue, int>> data_;
/*! \brief The constructor */
GenericAttrRegistryMap() = default;
template <typename, typename>
friend class AttrRegistry;
friend class OpRegEntry;
};

/*!
* \brief Map<Key, ValueType> used to store meta-data.
* \tparam KeyType The type of the key
* \tparam ValueType The type of the value stored in map.
*/
template <typename KeyType, typename ValueType>
class AttrRegistryMap {
public:
/*!
* \brief constructor
* \param map The internal map.
*/
explicit AttrRegistryMap(const GenericAttrRegistryMap<KeyType>& map) : map_(map) {}
/*!
* \brief Check if the map has op as key.
* \param key The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
int count(const KeyType& key) const { return map_.count(key); }
/*!
* \brief get the corresponding value element at key.
* \param key The key to the map
* \return the const reference to the content value.
*/
ValueType operator[](const KeyType& key) const { return map_[key]; }
/*!
* \brief get the corresponding value element at key with default value.
* \param key The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); }

protected:
/*! \brief The internal map field */
const GenericAttrRegistryMap<KeyType>& map_;
};

} // namespace tvm
#endif // TVM_NODE_ATTR_REGISTRY_MAP_
111 changes: 25 additions & 86 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,117 +28,56 @@
#include <tvm/runtime/packed_func.h>

#include <memory>
#include <mutex>

namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
} // namespace dmlc
#include "../node/attr_registry.h"

namespace tvm {

using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

::dmlc::Registry<OpRegistry>* OpRegistry::Registry() { return ::dmlc::Registry<OpRegistry>::Get(); }

// single manager of operator information.
struct OpManager {
// mutex to avoid registration from multiple threads.
std::mutex mutex;
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
// frontend functions
std::vector<PackedFunc*> frontend_funcs;
// get singleton of the op manager
static OpManager* Global() {
static OpManager* inst = new OpManager();
return inst;
}
};
using OpRegistry = AttrRegistry<OpRegEntry, Op>;

// find operator by name
const Op& Op::Get(const String& name) {
const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
const OpRegEntry* reg = OpRegistry::Global()->Get(name);
CHECK(reg != nullptr) << "Operator " << name << " is not registered";
return reg->op();
}

OpRegistry::OpRegistry() {
OpManager* mgr = OpManager::Global();
OpRegEntry::OpRegEntry(uint32_t reg_index) {
ObjectPtr<OpNode> n = make_object<OpNode>();
n->index_ = mgr->op_counter++;
n->index_ = reg_index;
op_ = Op(n);
}

OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
return OpRegistry::Global()->RegisterOrGet(name);
}

// Get attribute map by key
const GenericOpMap& Op::GetGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
if (it == mgr->attr.end()) {
LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
}
return *it->second.get();
const GenericAttrRegistryMap<Op>& Op::GetGenericAttrMap(const String& attr_name) {
return OpRegistry::Global()->GetAttrMap(attr_name);
}

// Check if a key is present in the registry.
bool Op::HasGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
if (it == mgr->attr.end()) {
return false;
}
return true;
bool Op::HasGenericAttrMap(const String& attr_name) {
return OpRegistry::Global()->HasAttrMap(attr_name);
}

// Resets attr of the OpMap.
void OpRegistry::reset_attr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
return;
}
uint32_t index = op_->index_;
if (op_map->data_.size() > index) {
op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
}
// Resets attr of the OpAttrMap.
void OpRegEntry::reset_attr(const std::string& attr_name) {
OpRegistry::Global()->ResetAttr(attr_name, op_);
}

void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
op_map.reset(new GenericOpMap());
op_map->attr_name_ = key;
}
uint32_t index = op_->index_;
if (op_map->data_.size() <= index) {
op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
}
std::pair<TVMRetValue, int>& p = op_map->data_[index];
CHECK(p.second != plevel) << "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel;
CHECK(value.type_code() != kTVMNullptr)
<< "Registered packed_func is Null for " << key << " of operator " << this->name;
if (p.second < plevel && value.type_code() != kTVMNullptr) {
op_map->data_[index] = std::make_pair(value, plevel);
}
void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
OpRegistry::Global()->UpdateAttr(key, op_, value, plevel);
}

// Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() {
Array<runtime::String> ret;
for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
ret.push_back(name);
}
return ret;
return OpRegistry::Global()->ListAllNames();
});

TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
Expand All @@ -148,7 +87,7 @@ TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
*rv = op_map[op];
}
Expand All @@ -159,14 +98,14 @@ TVM_REGISTER_GLOBAL("relay.op._OpSetAttr").set_body([](TVMArgs args, TVMRetValue
std::string attr_name = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
reg.set_attr(attr_name, value, plevel);
});

TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name);
reg.reset_attr(attr_name);
});

Expand All @@ -175,7 +114,7 @@ TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue*
std::string attr_key = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
// enable resgiteration and override of certain properties
if (attr_key == "num_inputs" && plevel > 128) {
reg.set_num_inputs(value);
Expand All @@ -187,8 +126,8 @@ TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue*
// do an eager copy of the PackedFunc
PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it.
OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
reg.set_attr(attr_key, f, plevel);
auto* fcopy = new PackedFunc(f);
reg.set_attr(attr_key, *fcopy, plevel);
} else {
reg.set_attr(attr_key, args[2], plevel);
}
Expand Down
Loading

0 comments on commit 773453e

Please sign in to comment.