Skip to content

Commit

Permalink
[NODE][PASS] Introduce config to PassContext. (apache#5631)
Browse files Browse the repository at this point in the history
This PR introduces a new config field to the PassContext
to allow it store arbitary config values.

To make sure that the config is validated, we allow each pass
to register the config key they would expect and the corresponding types.

We also introduce a CreateObject from Map<str, Object> to allow config creation
from a json-nest(like in vscode) in python.

We added an example of UnrollLoopConfig.

Followup PR should migrate the passes to use the new config field.
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 9, 2020
1 parent 1d5f63d commit b34f96a
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 46 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct AttrError : public dmlc::Error {
* \brief constructor
* \param msg error message
*/
explicit AttrError(const std::string& msg) : dmlc::Error(msg) {}
explicit AttrError(std::string msg) : dmlc::Error("AttributeError:" + msg) {}
};

/*!
Expand Down
73 changes: 68 additions & 5 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ namespace transform {
// Forward declare for TraceFunc.
class PassInfo;

/*! \brief A callback for tracing passes, useful for debugging and logging.
*
/*!
* \brief A callback for tracing passes, useful for debugging and logging.
*/
using TraceFunc =
runtime::TypedPackedFunc<void(const IRModule& ir_module, const PassInfo& ctx, bool is_before)>;
Expand All @@ -96,19 +96,53 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
Array<runtime::String> required_pass;
Array<String> required_pass;
/*! \brief The list of disabled passes. */
Array<runtime::String> disabled_pass;

Array<String> disabled_pass;
/*! \brief Trace function to be invoked before and after each pass. */
TraceFunc trace_func;

/*! \brief Pass specific configurations. */
Map<std::string, ObjectRef> config;

PassContextNode() = default;

/*!
* \brief Get a config value from the pass context.
*
* \param key The config key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef.
*/
template <typename TObjectRef>
Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value =
Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!config.defined()) return default_value;
auto it = config.find(key);
if (it != config.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
}

void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
v->Visit("config", &config);
}

static constexpr const char* _type_key = "transform.PassContext";
Expand Down Expand Up @@ -150,6 +184,7 @@ class PassContext : public ObjectRef {
CHECK(get() != nullptr);
return static_cast<PassContextNode*>(get_mutable());
}

/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
Expand All @@ -169,6 +204,20 @@ class PassContext : public ObjectRef {
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;

/*!
* \brief Register a valid configuration option and its ValueType for validation.
*
* \param key The configuration key.
* \tparam ValueNodeType The value type to be registered
*/
template <typename ValueNodeType>
static uint32_t RegisterConfigOption(const char* key) {
// NOTE: we could further update the function later.
uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
RegisterConfigOption(key, tindex);
return tindex;
}

// accessor.
using ContainerType = PassContextNode;
class Internal;
Expand All @@ -178,12 +227,26 @@ class PassContext : public ObjectRef {
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Register configuration key value type.
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);

// Classes to get the Python `with` like syntax.
friend class Internal;
friend class With<PassContext>;
};

#define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid

/*!
* \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated.
*
* Use this macro in the cc file for each terminal class.
*/
#define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)

/*!
* \brief Meta data that will be used to help optimization and analysis.
* \sa PassInfo
Expand Down
17 changes: 17 additions & 0 deletions include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,23 @@ class ReflectionVTable {
*/
TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
const std::string& repr_bytes = "") const;
/*!
* \brief Create an object by giving kwargs about its fields.
*
* \param type_key The type key.
* \param kwargs the arguments in format key1, value1, ..., key_n, value_n.
* \return The created object.
*/
TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs);
/*!
* \brief Create an object by giving kwargs about its fields.
*
* \param type_key The type key.
* \param kwargs The field arguments.
* \return The created object.
*/
TVM_DLL ObjectRef CreateObject(const std::string& type_key,
const Map<std::string, ObjectRef>& kwargs);
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,17 @@ class PassContext(tvm.runtime.Object):
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
config : Optional[Dict[str, Object]]
Additional configurations for specific passes.
"""
def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None,
trace=None):
trace=None,
config=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, tvm.runtime.TVMContext):
Expand All @@ -97,7 +101,7 @@ def __init__(self,

self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level,
fallback_device, required,
disabled, trace)
disabled, trace, config)

def __enter__(self):
_ffi_transform_api.EnterPassContext(self)
Expand Down
107 changes: 87 additions & 20 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

// TODO(tqchen): Update to use String container after it is merged.
#include <tvm/tir/expr.h>

#include <stack>
#include <unordered_set>

#include "../runtime/object_internal.h"

namespace tvm {
namespace transform {

Expand Down Expand Up @@ -75,6 +74,70 @@ PassContext PassContext::Current() {
}
}

class PassConfigManager {
public:
void Register(std::string key, uint32_t value_type_index) {
CHECK_EQ(key2vtype_.count(key), 0U);
ValueTypeInfo info;
info.type_index = value_type_index;
info.type_key = runtime::Object::TypeIndex2Key(value_type_index);
key2vtype_[key] = info;
}

// Trying to validate and legalize a config.
void Legalize(Map<std::string, ObjectRef>* config) {
std::vector<std::pair<std::string, ObjectRef>> update;
auto* reflection = ReflectionVTable::Global();

for (auto kv : *config) {
auto it = key2vtype_.find(kv.first);
if (it == key2vtype_.end()) {
std::ostringstream os;
os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:";
int counter = 0;
for (const auto& kv : key2vtype_) {
os << ' ';
if (counter++ != 0) os << ',';
os << kv.first;
}
LOG(FATAL) << os.str();
}
const auto& info = it->second;
CHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None";
if (kv.second->IsInstance<Map<std::string, ObjectRef>::ContainerType>()) {
ObjectRef converted = reflection->CreateObject(
info.type_key, Downcast<Map<std::string, ObjectRef>>(kv.second));
update.emplace_back(kv.first, converted);
} else {
if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) {
LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type "
<< info.type_key << " but get " << kv.second->GetTypeKey();
}
}
}
for (auto&& kv : update) {
config->Set(kv.first, kv.second);
}
}

static PassConfigManager* Global() {
static auto* inst = new PassConfigManager();
return inst;
}

private:
struct ValueTypeInfo {
std::string type_key;
uint32_t type_index;
};

std::unordered_map<std::string, ValueTypeInfo> key2vtype_;
};

void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) {
PassConfigManager::Global()->Register(key, value_type_index);
}

PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }

void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
Expand Down Expand Up @@ -390,20 +453,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_NODE_TYPE(PassContextNode);

TVM_REGISTER_GLOBAL("transform.PassContext").set_body([](TVMArgs args, TVMRetValue* ret) {
auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
tvm::Array<runtime::String> required = args[2];
tvm::Array<runtime::String> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
pctx->trace_func = std::move(trace_func);
*ret = pctx;
});
TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body_typed([](int opt_level, int fallback_device, Array<String> required,
Array<String> disabled, TraceFunc trace_func,
Optional<Map<std::string, ObjectRef>> config) {
auto pctx = PassContext::Create();
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;

pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
pctx->trace_func = std::move(trace_func);
if (config.defined()) {
pctx->config = config.value();
}
PassConfigManager::Global()->Legalize(&(pctx->config));
return pctx;
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand All @@ -413,17 +479,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n";

p->stream << "\trequired passes: [" << node->opt_level;
p->stream << "\trequired passes: [";
for (const auto& it : node->required_pass) {
p->stream << it << " ";
}
p->stream << "]\n";

p->stream << "\tdisabled passes: [" << node->opt_level;
p->stream << "\tdisabled passes: [";
for (const auto& it : node->disabled_pass) {
p->stream << it << " ";
}
p->stream << "]";
p->stream << "]\n";
p->stream << "\tconfig: " << node->config;
});

class PassContext::Internal {
Expand Down
41 changes: 31 additions & 10 deletions src/node/reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,13 @@ class NodeAttrSetter : public AttrVisitor {
}
};

void InitNodeByPackedArgs(Object* n, const TVMArgs& args) {
void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) {
NodeAttrSetter setter;
setter.type_key = n->GetTypeKey();
CHECK_EQ(args.size() % 2, 0);
for (int i = 0; i < args.size(); i += 2) {
setter.attrs.emplace(args[i].operator std::string(), args[i + 1]);
}
auto* reflection = ReflectionVTable::Global();
reflection->VisitAttrs(n, &setter);

if (setter.attrs.size() != 0) {
Expand All @@ -215,6 +214,35 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) {
}
}

ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) {
ObjectPtr<Object> n = this->CreateInitObject(type_key);
if (n->IsInstance<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(this, n.get(), kwargs);
}
return ObjectRef(n);
}

ObjectRef ReflectionVTable::CreateObject(const std::string& type_key,
const Map<std::string, ObjectRef>& kwargs) {
// Redirect to the TVMArgs version
// It is not the most efficient way, but CreateObject is not meant to be used
// in a fast code-path and is mainly reserved as a flexible API for frontends.
std::vector<TVMValue> values(kwargs.size() * 2);
std::vector<int32_t> tcodes(kwargs.size() * 2);
runtime::TVMArgsSetter setter(values.data(), tcodes.data());
int index = 0;

for (auto& kv : static_cast<const StrMapNode*>(kwargs.get())->data) {
setter(index, kv.first);
setter(index + 1, kv.second);
index += 2;
}

return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2));
}

// Expose to FFI APIs.
void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Expand Down Expand Up @@ -246,14 +274,7 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0];
std::string empty_str;
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
auto* reflection = ReflectionVTable::Global();
ObjectPtr<Object> n = reflection->CreateInitObject(type_key);
if (n->IsInstance<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(n.get(), kwargs);
}
*rv = ObjectRef(n);
*rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs);
}

TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
Expand Down
Loading

0 comments on commit b34f96a

Please sign in to comment.