Skip to content

Commit

Permalink
[Bugfix] Fix #8536 Get Target When Heterogeneous Execution (#8537)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnson9009 authored Jul 29, 2021
1 parent 83ce7fe commit 00ad44e
Showing 1 changed file with 19 additions and 43 deletions.
62 changes: 19 additions & 43 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,42 +372,8 @@ class LowerTensorExpr : public ExprMutator {
"in the memory planner.";

auto& device_context = this->device_context_map_[expr];
auto call_dev_type = device_context.device_type;

target = GetTargetFromInteger(device_context.device_type, targets_);
// Non-External Relay Function
if (targets_.size() == 1) {
// The homogeneous execution case, we should only have one target
// so we just grab it.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// The heterogeneous execution case we have multiple targets
// in this case.
//
// We need to identify the target and translate.
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
call_dev_type = kDLCPU;
} else {
call_dev_name = ::tvm::runtime::DeviceName(call_dev_type);
}

if (targets_.count(call_dev_type) == 0) {
std::stringstream msg;
msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n";
msg << call_dev_name << " mapped to device type (" << call_dev_type
<< ") which was not found in the target map.\n";
msg << "Availible targets: \n";
for (auto target : targets_) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}

target = targets_[call_dev_type];
}

CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compiler_->Lower(key, module_name_);

Expand Down Expand Up @@ -465,19 +431,29 @@ class LowerTensorExpr : public ExprMutator {
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
if (targets.size() == 1) {
// homogeneous execution.
// The homogeneous execution case, return the only target.
const auto& it = targets.begin();
return (*it).second;
} else {
// heterogeneous execution.
std::string call_dev_name;
if (dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(dev_type);
// The heterogeneous execution case, return the target associated with the
// given device type.
// If "dev_type" equals to 0, the device name only can be got from
// "targets", and it may not be "llvm", so here just set it to "unknown".
std::string dev_name = "unknown";
if (dev_type != 0) {
dev_name = runtime::DeviceName(dev_type);
}

if (targets.count(dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " << call_dev_name;
std::stringstream msg;
msg << "No target is specified for provided device name: `" << dev_name << "`\n\n"
<< dev_name << " mapped to device type (" << dev_type
<< ") which was not found in the target map.\n"
<< "Availible targets: \n";
for (auto target : targets) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}
return targets[dev_type];
}
Expand Down

0 comments on commit 00ad44e

Please sign in to comment.