diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 9a0d2a2d7f43..208e6356355d 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -372,42 +372,8 @@ class LowerTensorExpr : public ExprMutator { "in the memory planner."; auto& device_context = this->device_context_map_[expr]; - auto call_dev_type = device_context.device_type; - + target = GetTargetFromInteger(device_context.device_type, targets_); // Non-External Relay Function - if (targets_.size() == 1) { - // The homogeneous execution case, we should only have one target - // so we just grab it. - const auto& it = targets_.begin(); - target = (*it).second; - } else { - // The heterogeneous execution case we have multiple targets - // in this case. - // - // We need to identify the target and translate. - std::string call_dev_name; - if (call_dev_type == 0) { - call_dev_name = "llvm"; - call_dev_type = kDLCPU; - } else { - call_dev_name = ::tvm::runtime::DeviceName(call_dev_type); - } - - if (targets_.count(call_dev_type) == 0) { - std::stringstream msg; - msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n"; - msg << call_dev_name << " mapped to device type (" << call_dev_type - << ") which was not found in the target map.\n"; - msg << "Availible targets: \n"; - for (auto target : targets_) { - msg << " " << target.first << "-> " << target.second << "\n"; - } - LOG(FATAL) << msg.str(); - } - - target = targets_[call_dev_type]; - } - CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); @@ -465,19 +431,29 @@ class LowerTensorExpr : public ExprMutator { */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { if (targets.size() == 1) { - // homogeneous execution. + // The homogeneous execution case, return the only target. const auto& it = targets.begin(); return (*it).second; } else { - // heterogeneous execution. - std::string call_dev_name; - if (dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(dev_type); + // The heterogeneous execution case, return the target associated with the + // given device type. + // If "dev_type" equals to 0, the device name only can be got from + // "targets", and it may not be "llvm", so here just set it to "unknown". + std::string dev_name = "unknown"; + if (dev_type != 0) { + dev_name = runtime::DeviceName(dev_type); } + if (targets.count(dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; + std::stringstream msg; + msg << "No target is specified for provided device name: `" << dev_name << "`\n\n" + << dev_name << " mapped to device type (" << dev_type + << ") which was not found in the target map.\n" + << "Availible targets: \n"; + for (auto target : targets) { + msg << " " << target.first << "-> " << target.second << "\n"; + } + LOG(FATAL) << msg.str(); } return targets[dev_type]; }