Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#22 from yangwang201911/ywang2/enab…
Browse files Browse the repository at this point in the history
…le_selection_logic_with_specified_GPU_device_list

[GPU] fix the bug for TP device selection.
  • Loading branch information
WeldonWangwang authored Aug 30, 2024
2 parents 4bc34b2 + 800c81d commit cdeb24b
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/plugins/intel_gpu/src/plugin/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,10 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
if (dotPos != std::string::npos) {
auto target_id = device_with_id.substr(dotPos + 1);
if (m_device_map.count(target_id)) {
if (m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type)
if (target_id != device_id &&
m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type) {
ret.push_back(target_id);
}
} else {
OPENVINO_THROW("Invalid device found for TP: ", device_with_id);
}
Expand All @@ -230,8 +232,10 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
if (dotPos != std::string::npos) {
auto target_id = last.substr(dotPos + 1);
if (m_device_map.count(target_id)) {
if (m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type)
if (target_id != device_id &&
m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type) {
ret.push_back(target_id);
}
} else {
OPENVINO_THROW("Invalid device found for TP: ", last);
}
Expand All @@ -241,18 +245,26 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
devices_for_tp);
if (ret.size() == 1) {
auto id = ret.front();
for (const auto& item : m_device_map) {
if (item.first != id && item.second->get_info().dev_type == device_ptr->get_info().dev_type)
ret.push_back(item.first);
if (id != device_id && device_ptr->get_info().dev_type == m_device_map.at(id)->get_info().dev_type) {
ret.push_back(device_id);
} else {
for (const auto& item : m_device_map) {
if (item.first != id && item.second->get_info().dev_type == device_ptr->get_info().dev_type) {
ret.push_back(item.first);
}
}
}
}
if (ret.size() > 2) {
GPU_DEBUG_LOG << "Will only select 2 devices for TP." << std::endl;
std::cout << "[WY-DEBUG][" << __FILE__ << ":" << __LINE__ << "] will keep the first 2 device from list.";
ret = std::vector<std::string>(ret.begin(), ret.begin() + 2);
}
return ret;
};
auto devices_id_for_tp = parse_devices_id(devices_for_tp);
std::cout << "[WY-DEBUG][" << __FILE__ << ":" << __LINE__
<< "] device priorities after filtered: " << devices_for_tp << std::endl;
if (1) {
auto get_rank_table = [&]() {
std::vector<std::vector<int>> rank_table = {};
Expand Down

0 comments on commit cdeb24b

Please sign in to comment.