Skip to content

Commit

Permalink
[Target] Migrate data structure of TargetNode (#5960)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jul 2, 2020
1 parent 78a006c commit 6ce8a1c
Show file tree
Hide file tree
Showing 57 changed files with 639 additions and 345 deletions.
75 changes: 46 additions & 29 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/ir/transform.h>
#include <tvm/node/container.h>
#include <tvm/support/with.h>
#include <tvm/target/target_id.h>

#include <string>
#include <unordered_set>
Expand All @@ -42,52 +43,58 @@ namespace tvm {
*/
class TargetNode : public Object {
public:
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The name of the target device */
std::string device_name;
/*! \brief The type of the target device */
int device_type;
/*! \brief The maximum threads that a schedule should use for this device */
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief The id of the target device */
TargetId id;
/*! \brief Tag of the the target, can be empty */
String tag;
/*! \brief Keys for this target */
Array<runtime::String> keys_array;
/*! \brief Options for this target */
Array<runtime::String> options_array;
/*! \brief Collection of imported libs */
Array<runtime::String> libs_array;
Array<String> keys;
/*! \brief Collection of attributes */
Map<String, ObjectRef> attrs;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array);
v->Visit("id", &id);
v->Visit("tag", &tag);
v->Visit("keys_", &keys);
v->Visit("attrs", &attrs);
v->Visit("_str_repr_", &str_repr_);
}

/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> keys() const;
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
auto it = attrs.find(attr_key);
if (it != attrs.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}

/*! \brief Get the options for this target as a vector of string */
TVM_DLL std::vector<std::string> options() const;
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> GetKeys() const;

/*! \brief Get the keys for this target as an unordered_set of string */
TVM_DLL std::unordered_set<std::string> libs() const;
TVM_DLL std::unordered_set<std::string> GetLibs() const;

static constexpr const char* _type_key = "Target";
TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);

private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
friend class Target;
};

/*!
Expand All @@ -102,7 +109,17 @@ class Target : public ObjectRef {
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL static Target Create(const std::string& target_str);
TVM_DLL static Target Create(const String& target_str);
/*!
* \brief Construct a Target node from the given name and options.
* \param name The major target name. Should be one of
* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm",
* "metal", "nvptx", "opencl", "rocm", "sdaccel", "stackvm", "vulkan"}
* \param options Additional options appended to the target
* \return The constructed Target
*/
TVM_DLL static Target CreateTarget(const std::string& name,
const std::vector<std::string>& options);
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
Expand Down
55 changes: 55 additions & 0 deletions include/tvm/target/target_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ template <typename, typename, typename>
struct ValueTypeInfoMaker;
}

class Target;

/*! \brief Perform schema validation */
TVM_DLL void TargetValidateSchema(const Map<String, ObjectRef>& config);

Expand All @@ -54,6 +56,10 @@ class TargetIdNode : public Object {
public:
/*! \brief Name of the target id */
String name;
/*! \brief Device type of target id */
int device_type;
/*! \brief Default keys of the target */
Array<String> default_keys;
/*! \brief Stores the required type_key and type_index of a specific attr of a target */
struct ValueTypeInfo {
String type_key;
Expand All @@ -62,6 +68,14 @@ class TargetIdNode : public Object {
std::unique_ptr<ValueTypeInfo> val;
};

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("device_type", &device_type);
v->Visit("default_keys", &default_keys);
}

Map<String, ObjectRef> ParseAttrsFromRawString(const std::vector<std::string>& options);

static constexpr const char* _type_key = "TargetId";
TVM_DECLARE_FINAL_OBJECT_INFO(TargetIdNode, Object);

Expand All @@ -72,9 +86,12 @@ class TargetIdNode : public Object {
void ValidateSchema(const Map<String, ObjectRef>& config) const;
/*! \brief A hash table that stores the type information of each attr of the target key */
std::unordered_map<String, ValueTypeInfo> key2vtype_;
/*! \brief A hash table that stores the default value of each attr of the target key */
std::unordered_map<String, ObjectRef> key2default_;
/*! \brief Index used for internal lookup of attribute registry */
uint32_t index_;
friend void TargetValidateSchema(const Map<String, ObjectRef>&);
friend class Target;
friend class TargetId;
template <typename, typename>
friend class AttrRegistry;
Expand All @@ -91,6 +108,7 @@ class TargetIdNode : public Object {
*/
class TargetId : public ObjectRef {
public:
TargetId() = default;
/*! \brief Get the attribute map given the attribute name */
template <typename ValueType>
static inline TargetIdAttrMap<ValueType> GetAttrMap(const String& attr_name);
Expand All @@ -110,6 +128,7 @@ class TargetId : public ObjectRef {
template <typename, typename>
friend class AttrRegistry;
friend class TargetIdRegEntry;
friend class Target;
};

/*!
Expand Down Expand Up @@ -148,13 +167,31 @@ class TargetIdRegEntry {
template <typename ValueType>
inline TargetIdRegEntry& set_attr(const String& attr_name, const ValueType& value,
int plevel = 10);
/*!
* \brief Set DLPack's device_type the target
* \param device_type Device type
*/
inline TargetIdRegEntry& set_device_type(int device_type);
/*!
* \brief Set DLPack's device_type the target
* \param keys The default keys
*/
inline TargetIdRegEntry& set_default_keys(std::vector<String> keys);
/*!
* \brief Register a valid configuration option and its ValueType for validation
* \param key The configuration key
* \tparam ValueType The value type to be registered
*/
template <typename ValueType>
inline TargetIdRegEntry& add_attr_option(const String& key);
/*!
* \brief Register a valid configuration option and its ValueType for validation
* \param key The configuration key
* \param default_value The default value of the key
* \tparam ValueType The value type to be registered
*/
template <typename ValueType>
inline TargetIdRegEntry& add_attr_option(const String& key, ObjectRef default_value);
/*! \brief Set name of the TargetId to be the same as registry if it is empty */
inline TargetIdRegEntry& set_name();
/*!
Expand Down Expand Up @@ -286,6 +323,16 @@ inline TargetIdRegEntry& TargetIdRegEntry::set_attr(const String& attr_name, con
return *this;
}

inline TargetIdRegEntry& TargetIdRegEntry::set_device_type(int device_type) {
id_->device_type = device_type;
return *this;
}

inline TargetIdRegEntry& TargetIdRegEntry::set_default_keys(std::vector<String> keys) {
id_->default_keys = keys;
return *this;
}

template <typename ValueType>
inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) {
CHECK(!id_->key2vtype_.count(key))
Expand All @@ -294,6 +341,14 @@ inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) {
return *this;
}

template <typename ValueType>
inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key,
ObjectRef default_value) {
add_attr_option<ValueType>(key);
id_->key2default_[key] = default_value;
return *this;
}

inline TargetIdRegEntry& TargetIdRegEntry::set_name() {
if (id_->name.empty()) {
id_->name = name;
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ def context(target, extra_files=None):
tgt = _target.create(tgt)

possible_names = []
for opt in tgt.options:
if opt.startswith("-device"):
device = _alias(opt[8:])
possible_names.append(device)
possible_names.append(tgt.target_name)
device = tgt.attrs.get("device", "")
if device != "":
possible_names.append(_alias(device))
possible_names.append(tgt.id.name)

all_packages = list(PACKAGE_VERSION.keys())
for name in possible_names:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _build_for_device(input_mod, target, target_host):
"""
target = _target.create(target)
target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type
device_type = ndarray.context(target.id.name, 0).device_type

mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
Expand Down Expand Up @@ -402,7 +402,7 @@ def build(inputs,
if not target_host:
for tar, _ in target_input_mod.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
device_type = ndarray.context(tar.id.name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
Expand Down
22 changes: 11 additions & 11 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="softmax.cuda")
if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(topi.cuda.softmax_cudnn),
wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),
Expand Down Expand Up @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
dilation_h, dilation_w,
pre_flag=False)
if judge_winograd_shape:
if target.target_name == "cuda" and \
if target.id.name == "cuda" and \
nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
judge_winograd_tensorcore:
strategy.add_implementation(
Expand All @@ -162,7 +162,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
topi.cuda.schedule_conv2d_nhwc_winograd_direct),
name="conv2d_nhwc_winograd_direct.cuda",
plevel=5)
if target.target_name == "cuda":
if target.id.name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
Expand All @@ -181,7 +181,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add cudnn implementation
if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
Expand Down Expand Up @@ -209,7 +209,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
else: # group_conv2d
# add cudnn implementation, if any
cudnn_impl = False
if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
Expand Down Expand Up @@ -264,7 +264,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
padding, stride_h, stride_w,
dilation_h, dilation_w,
pre_flag=True)
if target.target_name == "cuda" and \
if target.id.name == "cuda" and \
nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
judge_winograd_tensorcore:
strategy.add_implementation(
Expand Down Expand Up @@ -362,7 +362,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
plevel=10)
N, _, _, _, _ = get_const_tuple(data.shape)
_, _, _, CI, CO = get_const_tuple(kernel.shape)
if target.target_name == "cuda":
if target.id.name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
Expand All @@ -373,7 +373,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
name="conv3d_ndhwc_tensorcore.cuda",
plevel=20)

if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
name="conv3d_cudnn.cuda",
Expand Down Expand Up @@ -458,7 +458,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
plevel=5)
if target.target_name == "cuda":
if target.id.name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
Expand All @@ -468,7 +468,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
name="dense_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda" and "cublas" in target.libs:
if target.id.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas),
wrap_topi_schedule(topi.cuda.schedule_dense_cublas),
Expand All @@ -485,7 +485,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10)
if target.target_name == "cuda" and "cublas" in target.libs:
if target.id.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
wrap_topi_schedule(topi.generic.schedule_extern),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
wrap_compute_dense(topi.rocm.dense),
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm")
if target.target_name == "rocm" and "rocblas" in target.libs:
if target.id.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense_rocblas),
Expand Down
Loading

0 comments on commit 6ce8a1c

Please sign in to comment.