From 800c81d505c12ff4b641af383afc91267a6a411a Mon Sep 17 00:00:00 2001 From: "Wang, Yang" Date: Fri, 30 Aug 2024 15:39:32 +0800 Subject: [PATCH] Update. --- src/plugins/intel_gpu/src/plugin/plugin.cpp | 22 ++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/plugin.cpp b/src/plugins/intel_gpu/src/plugin/plugin.cpp index 48fefc429b7c6c..534a12e3796a4b 100644 --- a/src/plugins/intel_gpu/src/plugin/plugin.cpp +++ b/src/plugins/intel_gpu/src/plugin/plugin.cpp @@ -211,8 +211,10 @@ std::shared_ptr Plugin::compile_model(const std::shared_ptr< if (dotPos != std::string::npos) { auto target_id = device_with_id.substr(dotPos + 1); if (m_device_map.count(target_id)) { - if (m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type) + if (target_id != device_id && + m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type) { ret.push_back(target_id); + } } else { OPENVINO_THROW("Invalid device found for TP: ", device_with_id); } @@ -225,8 +227,10 @@ std::shared_ptr Plugin::compile_model(const std::shared_ptr< if (dotPos != std::string::npos) { auto target_id = last.substr(dotPos + 1); if (m_device_map.count(target_id)) { - if (m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type) + if (target_id != device_id && + m_device_map.at(target_id)->get_info().dev_type == device_ptr->get_info().dev_type) { ret.push_back(target_id); + } } else { OPENVINO_THROW("Invalid device found for TP: ", last); } @@ -236,18 +240,26 @@ std::shared_ptr Plugin::compile_model(const std::shared_ptr< devices_for_tp); if (ret.size() == 1) { auto id = ret.front(); - for (const auto& item : m_device_map) { - if (item.first != id && item.second->get_info().dev_type == device_ptr->get_info().dev_type) - ret.push_back(item.first); + if (id != device_id && device_ptr->get_info().dev_type == m_device_map.at(id)->get_info().dev_type) { + ret.push_back(device_id); + } else { + for (const auto& item : m_device_map) { + if (item.first != id && item.second->get_info().dev_type == device_ptr->get_info().dev_type) { + ret.push_back(item.first); + } + } } } if (ret.size() > 2) { GPU_DEBUG_LOG << "Will only select 2 devices for TP." << std::endl; + std::cout << "[WY-DEBUG][" << __FILE__ << ":" << __LINE__ << "] will keep the first 2 device from list."; ret = std::vector(ret.begin(), ret.begin() + 2); } return ret; }; auto devices_id_for_tp = parse_devices_id(devices_for_tp); + std::cout << "[WY-DEBUG][" << __FILE__ << ":" << __LINE__ + << "] device priorities after filtered: " << devices_for_tp << std::endl; if (1) { auto get_rank_table = [&]() { std::vector> rank_table = {};