Skip to content

Commit

Permalink
[BATCH] use ov::device::properties to replace batch size of BATCH:CPU(4)
Browse files Browse the repository at this point in the history
  • Loading branch information
riverlijunjie committed Mar 11, 2024
1 parent 5872549 commit f1bd8c1
Show file tree
Hide file tree
Showing 31 changed files with 210 additions and 123 deletions.
10 changes: 1 addition & 9 deletions src/inference/src/dev/core_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -57,13 +57,6 @@ void allowNotImplemented(F&& f) {
}
}

void stripDeviceName(std::string& device, const std::string& substr) {
auto pos = device.find(substr);
if (pos == 0) {
device.erase(pos, substr.length());
}
}

bool is_virtual_device(const std::string& device_name) {
return (device_name.find("AUTO") != std::string::npos || device_name.find("MULTI") != std::string::npos ||
device_name.find("HETERO") != std::string::npos || device_name.find("BATCH") != std::string::npos);
Expand Down Expand Up @@ -539,7 +532,6 @@ ov::Plugin ov::CoreImpl::get_plugin(const std::string& pluginName) const {
auto deviceName = pluginName;
if (deviceName == ov::DEFAULT_DEVICE_NAME)
deviceName = "AUTO";
stripDeviceName(deviceName, "-");
std::map<std::string, PluginDescriptor>::const_iterator it;
{
// Global lock to find plugin.
Expand Down
8 changes: 1 addition & 7 deletions src/inference/src/dev/device_id_parser.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -84,12 +84,6 @@ std::vector<std::string> DeviceIDParser::get_multi_devices(const std::string& de
}

std::string DeviceIDParser::get_batch_device(const std::string& device) {
if (device.find(",") != std::string::npos) {
OPENVINO_THROW("BATCH accepts only one device in list but got '", device, "'");
}
if (device.find("-") != std::string::npos) {
OPENVINO_THROW("Invalid device name '", device, "' for BATCH");
}
auto trim_request_info = [](const std::string& device_with_requests) {
auto opening_bracket = device_with_requests.find_first_of('(');
return device_with_requests.substr(0, opening_bracket);
Expand Down
18 changes: 0 additions & 18 deletions src/inference/tests/unit/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,24 +384,6 @@ TEST(CoreTests_parse_device_config, get_device_config) {
ov::device::properties(ov::AnyMap{{"MULTI", ov::AnyMap{ov::device::priorities("DEVICE")}}})});
}

TEST(CoreTests_parse_device_config, get_batch_device_name) {
EXPECT_STREQ(ov::DeviceIDParser::get_batch_device("CPU").c_str(), "CPU");
EXPECT_STREQ(ov::DeviceIDParser::get_batch_device("GPU(4)").c_str(), "GPU");

OV_EXPECT_THROW(ov::DeviceIDParser::get_batch_device("-CPU"),
ov::Exception,
::testing::HasSubstr("Invalid device name '-CPU' for BATCH"));
OV_EXPECT_THROW(ov::DeviceIDParser::get_batch_device("CPU(0)-"),
ov::Exception,
::testing::HasSubstr("Invalid device name 'CPU(0)-' for BATCH"));
OV_EXPECT_THROW(ov::DeviceIDParser::get_batch_device("GPU(4),CPU"),
ov::Exception,
::testing::HasSubstr("BATCH accepts only one device in list but got 'GPU(4),CPU'"));
OV_EXPECT_THROW(ov::DeviceIDParser::get_batch_device("CPU,GPU"),
ov::Exception,
::testing::HasSubstr("BATCH accepts only one device in list but got 'CPU,GPU'"));
}

class ApplyAutoBatchThreading : public testing::Test {
public:
static void runParallel(std::function<void(void)> func,
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/auto/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ std::vector<DeviceInformation> Plugin::parse_meta_devices(const std::string& pri
auto opening_bracket = d.find_first_of('(');
auto closing_bracket = d.find_first_of(')', opening_bracket);
auto device_name = d.substr(0, opening_bracket);
if (closing_bracket != std::string::npos && closing_bracket < d.length() - 1) {
OPENVINO_THROW("Device list with \"", d, "\" name is illegal in the AUTO plugin.");
}

int num_requests = -1;
if (closing_bracket != std::string::npos && opening_bracket < closing_bracket) {
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/auto/tests/unit/parse_meta_device_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -159,6 +159,8 @@ const std::vector<ConfigParams> testConfigs = {

ConfigParams{"CPU(-1),GPU,OTHER", {}, true, 0},
ConfigParams{"CPU(NA),GPU,OTHER", {}, true, 0},
ConfigParams{"CPU(4)a", {}, true, 0},
ConfigParams{"CPU(4)a,GPU,OTHER", {}, true, 0},
ConfigParams{"INVALID_DEVICE", {}, false, 0},
ConfigParams{"INVALID_DEVICE,CPU", {{"CPU", {}, -1, "", "CPU_", 1}}, false, 2},

Expand Down
50 changes: 47 additions & 3 deletions src/plugins/auto_batch/src/plugin.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -19,7 +19,9 @@
namespace ov {
namespace autobatch_plugin {

std::vector<std::string> supported_configKeys = {ov::device::priorities.name(), ov::auto_batch_timeout.name()};
std::vector<std::string> supported_configKeys = {ov::device::priorities.name(),
ov::auto_batch_timeout.name(),
ov::device::properties.name()};

inline ov::AnyMap merge_properties(ov::AnyMap config, const ov::AnyMap& user_config) {
for (auto&& kvp : user_config) {
Expand All @@ -29,6 +31,9 @@ inline ov::AnyMap merge_properties(ov::AnyMap config, const ov::AnyMap& user_con
}

DeviceInformation Plugin::parse_batch_device(const std::string& device_with_batch) {
if (device_with_batch.find("-") != std::string::npos) {
OPENVINO_THROW("Invalide Batch device name ", device_with_batch);
}
auto openingBracket = device_with_batch.find_first_of('(');
auto closingBracket = device_with_batch.find_first_of(')', openingBracket);
auto deviceName = device_with_batch.substr(0, openingBracket);
Expand All @@ -44,9 +49,48 @@ DeviceInformation Plugin::parse_batch_device(const std::string& device_with_batc
return {std::move(deviceName), {{}}, static_cast<uint32_t>(batch)};
}

uint32_t Plugin::parse_batch_size(const std::string& device_name, const ov::AnyMap& properties) {
uint32_t num_requests = 0;
// Parse batch_size from ov::device::properties
auto item = properties.find(ov::device::properties.name());
if (item != properties.end()) {
ov::AnyMap devices_properties = item->second.as<ov::AnyMap>();
auto it = devices_properties.find(device_name);

if (it != devices_properties.end()) {
auto props = it->second.as<ov::AnyMap>();
if (props.find(ov::hint::num_requests.name()) != props.end()) {
try {
num_requests = props.at(ov::hint::num_requests.name()).as<uint32_t>();
if ((num_requests == 0) || num_requests == (uint32_t)-1) {
OPENVINO_THROW("BATCH can got valid num_request: ",
props.at(ov::hint::num_requests.name()).as<std::string>());
}
} catch (...) {
OPENVINO_THROW("BATCH can got valid num_request: ",
props.at(ov::hint::num_requests.name()).as<std::string>());
}
}
}
}
return num_requests;
}

DeviceInformation Plugin::parse_meta_device(const std::string& devices_batch_config,
const ov::AnyMap& user_config) const {
if (devices_batch_config.find(",") != std::string::npos) {
OPENVINO_THROW("BATCH accepts only one device in list but got '", devices_batch_config, "'");
}
// Batch_size will be got from ov::device::properties, while devices_name_with_batch_config will be deprecated
// after 25.0.
// For example:
// DeviceName = "BATCH:GPU", ov::device::properties("GPU",ov::hint::num_requests(8)),
// while similar DeviceName = "BATCH:GPU(8)" will be deprecated.
auto meta_device = parse_batch_device(devices_batch_config);
auto batch_size = parse_batch_size(meta_device.device_name, user_config);
if (batch_size > 0) {
meta_device.device_batch_size = batch_size;
}
meta_device.device_config = get_core()->get_supported_property(meta_device.device_name, user_config);
// check that no irrelevant config-keys left
for (const auto& k : user_config) {
Expand Down Expand Up @@ -130,7 +174,7 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
auto full_properties = merge_properties(m_plugin_config, properties);
auto device_batch = full_properties.find(ov::device::priorities.name());
if (device_batch == full_properties.end()) {
OPENVINO_THROW("ov::device::priorities key for AUTO NATCH is not set for BATCH device");
OPENVINO_THROW("ov::device::priorities key for AUTO BATCH is not set for BATCH device");
}
auto meta_device = parse_meta_device(device_batch->second.as<std::string>(), properties);

Expand Down
4 changes: 3 additions & 1 deletion src/plugins/auto_batch/src/plugin.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -64,6 +64,8 @@ class Plugin : public ov::IPlugin {

static DeviceInformation parse_batch_device(const std::string& device_with_batch);

static uint32_t parse_batch_size(const std::string& device_name, const ov::AnyMap& properties);

private:
mutable ov::AnyMap m_plugin_config;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -9,9 +9,11 @@ namespace {
auto autoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
// no timeout to avoid increasing the test time
{ov::auto_batch_timeout.name(), "0"}}};
{ov::device::priorities(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::priorities(ov::test::utils::DEVICE_TEMPLATE),
ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
{ov::device::priorities(ov::test::utils::DEVICE_TEMPLATE), ov::auto_batch_timeout(0)}};
};

INSTANTIATE_TEST_SUITE_P(smoke_AutoBatch_BehaviorTests, OVCompiledModelBaseTest,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -20,10 +20,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_AutoBatch_BehaviorTests,
OVClassCompiledModelPropertiesIncorrectTests::getTestCaseName);

const std::vector<ov::AnyMap> auto_batch_properties = {
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"}},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))}},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
{ov::auto_batch_timeout(1)}},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
{ov::auto_batch_timeout(10)}},
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,7 +10,8 @@ namespace {
auto autoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
{ov::auto_batch_timeout(0)}}};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace {
auto autoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
{ov::auto_batch_timeout(0)}}};
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,7 +10,8 @@ namespace {
auto AutoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
{ov::auto_batch_timeout(0)}}};
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,7 +10,8 @@ namespace {
auto AutoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
{ov::auto_batch_timeout(0)}}};
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,7 +10,8 @@ namespace {
auto AutoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
{ov::auto_batch_timeout(0)}}};
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,7 +10,8 @@ namespace {
auto AutoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
{ov::auto_batch_timeout(0)}}};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ std::vector<std::pair<ov::AnyMap, ov::AnyMap>> generate_remote_params() {
auto AutoBatchConfigs = []() {
return std::vector<ov::AnyMap>{
// explicit batch size 4 to avoid fallback to no auto-batching
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE) + "(4)"},
{{ov::device::priorities.name(), std::string(ov::test::utils::DEVICE_TEMPLATE)},
{ov::device::properties(ov::test::utils::DEVICE_TEMPLATE, ov::hint::num_requests(4))},
// no timeout to avoid increasing the test time
ov::auto_batch_timeout(0)}};
};
Expand Down
Loading

0 comments on commit f1bd8c1

Please sign in to comment.