Skip to content

Commit

Permalink
[Object] Unify StrMapNode and MapNode (#5687)
Browse files Browse the repository at this point in the history
* Pass cpptest and py unittest

* fix graph runtime

* right fix

* fix a bug that runtime::String's operator < is actually compare by address

* Update container.py

* Renaming

* Address comments

* lint

* Replace ObjectHash in object.py
  • Loading branch information
junrushao authored Jun 2, 2020
1 parent 062a244 commit 4347b41
Show file tree
Hide file tree
Showing 98 changed files with 395 additions and 471 deletions.
2 changes: 1 addition & 1 deletion include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ConstIntBound : public ObjectRef {
*/
class ConstIntBoundAnalyzer {
public:
using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectHash, ObjectEqual>;
using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
/*!
* \brief analyze the expr
* \param expr The expression of interest.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_m
*/
IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target&
* pass Target().
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input, const Target& target_host);
TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host);
} // namespace tvm

#endif // TVM_DRIVER_DRIVER_API_H_
4 changes: 2 additions & 2 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class Attrs : public ObjectRef {
class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
Map<String, ObjectRef> dict;

bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
return equal(dict, other->dict);
Expand Down Expand Up @@ -230,7 +230,7 @@ class DictAttrs : public Attrs {
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict);
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);

TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ class ErrorReporter {

private:
std::vector<Error> errors_;
std::unordered_map<ObjectRef, std::vector<size_t>, ObjectHash, ObjectEqual> node_to_error_;
std::unordered_map<ObjectRef, GlobalVar, ObjectHash, ObjectEqual> node_to_gv_;
std::unordered_map<ObjectRef, std::vector<size_t>, ObjectPtrHash, ObjectPtrEqual> node_to_error_;
std::unordered_map<ObjectRef, GlobalVar, ObjectPtrHash, ObjectPtrEqual> node_to_gv_;
};

} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_va
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,12 @@ class IRModuleNode : public Object {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Map<std::string, GlobalVar> global_var_map_;
Map<String, GlobalVar> global_var_map_;

/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
Map<std::string, GlobalTypeVar> global_type_var_map_;
Map<String, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class PassContextNode : public Object {
TraceFunc trace_func;

/*! \brief Pass specific configurations. */
Map<std::string, ObjectRef> config;
Map<String, ObjectRef> config;

PassContextNode() = default;

Expand Down
146 changes: 26 additions & 120 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,40 @@ namespace tvm {

using runtime::Array;
using runtime::ArrayNode;
using runtime::Downcast;
using runtime::IterAdapter;
using runtime::make_object;
using runtime::Object;
using runtime::ObjectEqual;
using runtime::ObjectHash;
using runtime::ObjectPtr;
using runtime::ObjectPtrEqual;
using runtime::ObjectPtrHash;
using runtime::ObjectRef;
using runtime::String;
using runtime::StringObj;

struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
if (const auto* str = a.as<StringObj>()) {
return String::HashBytes(str->data, str->size);
}
return ObjectPtrHash()(a);
}
};

struct ObjectEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const {
if (a.same_as(b)) {
return true;
}
if (const auto* str_a = a.as<StringObj>()) {
if (const auto* str_b = b.as<StringObj>()) {
return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0;
}
}
return false;
}
};

/*! \brief map node content */
class MapNode : public Object {
public:
Expand All @@ -62,19 +86,6 @@ class MapNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
};

/*! \brief specialized map node with string as key */
class StrMapNode : public Object {
public:
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>;

/*! \brief the data content */
ContainerType data;

static constexpr const char* _type_key = "StrMap";
TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object);
};

/*!
* \brief Map container of NodeRef->NodeRef in DSL graph.
* Map implements copy on write semantics, which means map is mutable
Expand Down Expand Up @@ -249,97 +260,6 @@ class Map : public ObjectRef {
}
};

// specialize of string map
template <typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public ObjectRef {
public:
// for code reuse
Map() { data_ = make_object<StrMapNode>(); }
Map(Map<std::string, V>&& other) { // NOLINT(*)
data_ = std::move(other.data_);
}
Map(const Map<std::string, V>& other) : ObjectRef(other.data_) { // NOLINT(*)
}
explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}

template <typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
Map<std::string, V>& operator=(Map<std::string, V>&& other) {
data_ = std::move(other.data_);
return *this;
}
Map<std::string, V>& operator=(const Map<std::string, V>& other) {
data_ = other.data_;
return *this;
}
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_object<StrMapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first, i->second));
}
data_ = std::move(n);
}
inline const V operator[](const std::string& key) const {
return DowncastNoCheck<V>(static_cast<const StrMapNode*>(data_.get())->data.at(key));
}
inline const V at(const std::string& key) const {
return DowncastNoCheck<V>(static_cast<const StrMapNode*>(data_.get())->data.at(key));
}
inline size_t size() const {
if (data_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(data_.get())->data.size();
}
inline size_t count(const std::string& key) const {
if (data_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(data_.get())->data.count(key);
}
inline StrMapNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
ObjectPtr<StrMapNode> n = make_object<StrMapNode>();
n->data = static_cast<const StrMapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
return static_cast<StrMapNode*>(data_.get());
}
inline void Set(const std::string& key, const V& value) {
StrMapNode* n = this->CopyOnWrite();
n->data[key] = value;
}
inline bool empty() const { return size() == 0; }
using ContainerType = StrMapNode;

struct ValueConverter {
using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair<std::string, ObjectRef>& n) {
return std::make_pair(n.first, DowncastNoCheck<V>(n.second));
}
};

using iterator = IterAdapter<ValueConverter, StrMapNode::ContainerType::const_iterator>;

/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key));
}
};
} // namespace tvm

namespace tvm {
Expand All @@ -361,20 +281,6 @@ struct ObjectTypeChecker<Array<T> > {
static std::string TypeName() { return "List[" + ObjectTypeChecker<T>::TypeName() + "]"; }
};

template <typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<StrMapNode>()) return false;
const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() { return "Map[str, " + ObjectTypeChecker<V>::TypeName() + ']'; }
};

template <typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static bool Check(const Object* ptr) {
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ using runtime::Downcast;
using runtime::GetRef;
using runtime::make_object;
using runtime::Object;
using runtime::ObjectEqual;
using runtime::ObjectHash;
using runtime::ObjectPtr;
using runtime::ObjectPtrEqual;
using runtime::ObjectPtrHash;
using runtime::ObjectRef;
using runtime::PackedFunc;
using runtime::TVMArgs;
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ class ReflectionVTable {
* \param kwargs The field arguments.
* \return The created object.
*/
TVM_DLL ObjectRef CreateObject(const std::string& type_key,
const Map<std::string, ObjectRef>& kwargs);
TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map<String, ObjectRef>& kwargs);
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr);
*
* \return Return the paritioned Expr.
*/
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, PackedFunc check);

} // namespace relay
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {

protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo_;
};

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class PatternMutator : public ::tvm::relay::PatternFunctor<Pattern(const Pattern
virtual Constructor VisitConstructor(const Constructor& c);

private:
std::unordered_map<Var, Var, ObjectHash, ObjectEqual> var_map_;
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_map_;
};

} // namespace relay
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ TVM_DLL Pass AlterOpLayout();
* this specifies the desired layout for data then kernel for nn.conv2d.
* \return The pass.
*/
TVM_DLL Pass ConvertLayout(const Map<std::string, Array<String>>& desired_layouts);
TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);

/*!
* \brief Legalizes an expr with another expression.
Expand Down
Loading

0 comments on commit 4347b41

Please sign in to comment.