From 67e896cd91447203155616c3bdd8867e178d06cd Mon Sep 17 00:00:00 2001 From: Alicja Miloszewska Date: Thu, 18 Apr 2024 11:59:03 +0200 Subject: [PATCH] [JS API] Add config param to Core.import_model() (#24023) ### Details: - add `Core.import_model(stream, device_name, config[optional}) - add ` ov::AnyMap to_anyMap(const Napi::Env&, const Napi::Value&)` helper for conversion from Napi::Value to ov::AnyMap and reuse it in other methods ### Tickets: - *136460* --- src/bindings/js/node/include/helper.hpp | 6 ++- src/bindings/js/node/lib/addon.ts | 5 ++ src/bindings/js/node/src/core_wrap.cpp | 59 +++++++++++---------- src/bindings/js/node/src/helper.cpp | 35 ++++++++----- src/bindings/js/node/tests/basic.test.js | 66 ++++++++++++++++++++---- 5 files changed, 119 insertions(+), 52 deletions(-) diff --git a/src/bindings/js/node/include/helper.hpp b/src/bindings/js/node/include/helper.hpp index ffd944b3589033..3649d863179360 100644 --- a/src/bindings/js/node/include/helper.hpp +++ b/src/bindings/js/node/include/helper.hpp @@ -172,6 +172,8 @@ bool acceptableType(const Napi::Value& val, const std::vector& accep Napi::Value any_to_js(const Napi::CallbackInfo& info, ov::Any value); -ov::Any js_to_any(const Napi::CallbackInfo& info, Napi::Value value); +ov::Any js_to_any(const Napi::Env& env, const Napi::Value& value); -bool is_napi_value_int(const Napi::CallbackInfo& info, Napi::Value& num); +bool is_napi_value_int(const Napi::Env& env, const Napi::Value& num); + +ov::AnyMap to_anyMap(const Napi::Env&, const Napi::Value&); diff --git a/src/bindings/js/node/lib/addon.ts b/src/bindings/js/node/lib/addon.ts index 030ea914e6c35e..d76f3cb78e89c7 100644 --- a/src/bindings/js/node/lib/addon.ts +++ b/src/bindings/js/node/lib/addon.ts @@ -37,6 +37,11 @@ interface Core { readModelSync(modelPath: string, weightsPath?: string): Model; readModelSync(modelBuffer: Uint8Array, weightsBuffer?: Uint8Array): Model; importModelSync(modelStream: Buffer, device: string): CompiledModel; + importModelSync( + modelStream: Buffer, + device: string, + props: { [key: string]: string | number | boolean } + ): CompiledModel; getAvailableDevices(): string[]; getVersions(deviceName: string): { [deviceName: string]: { diff --git a/src/bindings/js/node/src/core_wrap.cpp b/src/bindings/js/node/src/core_wrap.cpp index cbcf49281e248a..ff1186d3a519ab 100644 --- a/src/bindings/js/node/src/core_wrap.cpp +++ b/src/bindings/js/node/src/core_wrap.cpp @@ -27,7 +27,6 @@ std::tuple try_get_set_property_parameters(const Napi:: validate_set_property_args(info); std::string device_name; - ov::AnyMap properties; const size_t args_length = info.Length(); @@ -35,16 +34,7 @@ std::tuple try_get_set_property_parameters(const Napi:: device_name = info[0].ToString(); const size_t parameters_position_index = device_name.empty() ? 0 : 1; - Napi::Object parameters = info[parameters_position_index].ToObject(); - const auto& keys = parameters.GetPropertyNames(); - - for (uint32_t i = 0; i < keys.Length(); ++i) { - auto property_name = static_cast(keys[i]).ToString().Utf8Value(); - - ov::Any any_value = js_to_any(info, parameters.Get(property_name)); - - properties.insert(std::make_pair(property_name, any_value)); - } + const auto& properties = to_anyMap(info.Env(), info[parameters_position_index]); return std::make_tuple(properties, device_name); } @@ -301,25 +291,38 @@ Napi::Value CoreWrap::get_versions(const Napi::CallbackInfo& info) { } Napi::Value CoreWrap::import_model(const Napi::CallbackInfo& info) { - if (info.Length() != 2) { - reportError(info.Env(), "Invalid number of arguments -> " + std::to_string(info.Length())); - return info.Env().Undefined(); - } - if (!info[0].IsBuffer()) { - reportError(info.Env(), "The first argument must be of type Buffer."); - return info.Env().Undefined(); - } - if (!info[1].IsString()) { - reportError(info.Env(), "The second argument must be of type String."); + try { + if (!info[0].IsBuffer()) { + OPENVINO_THROW("The first argument must be of type Buffer."); + } + if (!info[1].IsString()) { + OPENVINO_THROW("The second argument must be of type String."); + } + const auto& model_data = info[0].As>(); + const auto model_stream = std::string(reinterpret_cast(model_data.Data()), model_data.Length()); + std::stringstream _stream; + _stream << model_stream; + + ov::CompiledModel compiled; + switch (info.Length()) { + case 2: { + compiled = _core.import_model(_stream, std::string(info[1].ToString())); + break; + } + case 3: { + compiled = _core.import_model(_stream, std::string(info[1].ToString()), to_anyMap(info.Env(), info[2])); + break; + } + default: { + OPENVINO_THROW("Invalid number of arguments -> " + std::to_string(info.Length())); + } + } + return CompiledModelWrap::wrap(info.Env(), compiled); + + } catch (std::exception& e) { + reportError(info.Env(), e.what()); return info.Env().Undefined(); } - const auto& model_data = info[0].As>(); - const auto model_stream = std::string(reinterpret_cast(model_data.Data()), model_data.Length()); - std::stringstream _stream; - _stream << model_stream; - - const auto& compiled = _core.import_model(_stream, std::string(info[1].ToString())); - return CompiledModelWrap::wrap(info.Env(), compiled); } Napi::Value CoreWrap::set_property(const Napi::CallbackInfo& info) { diff --git a/src/bindings/js/node/src/helper.cpp b/src/bindings/js/node/src/helper.cpp index b00fbb033c8447..34586140251e4b 100644 --- a/src/bindings/js/node/src/helper.cpp +++ b/src/bindings/js/node/src/helper.cpp @@ -510,7 +510,7 @@ Napi::Value any_to_js(const Napi::CallbackInfo& info, ov::Any value) { return info.Env().Undefined(); } -ov::Any js_to_any(const Napi::CallbackInfo& info, Napi::Value value) { +ov::Any js_to_any(const Napi::Env& env, const Napi::Value& value) { if (value.IsString()) { return ov::Any(value.ToString().Utf8Value()); } else if (value.IsBigInt()) { @@ -526,7 +526,7 @@ ov::Any js_to_any(const Napi::CallbackInfo& info, Napi::Value value) { } else if (value.IsNumber()) { Napi::Number num = value.ToNumber(); - if (is_napi_value_int(info, value)) { + if (is_napi_value_int(env, value)) { return ov::Any(num.Int32Value()); } else { return ov::Any(num.DoubleValue()); @@ -538,14 +538,25 @@ ov::Any js_to_any(const Napi::CallbackInfo& info, Napi::Value value) { } } -bool is_napi_value_int(const Napi::CallbackInfo& info, Napi::Value& num) { - return info.Env() - .Global() - .Get("Number") - .ToObject() - .Get("isInteger") - .As() - .Call({num}) - .ToBoolean() - .Value(); +bool is_napi_value_int(const Napi::Env& env, const Napi::Value& num) { + return env.Global().Get("Number").ToObject().Get("isInteger").As().Call({num}).ToBoolean().Value(); +} + +ov::AnyMap to_anyMap(const Napi::Env& env, const Napi::Value& val) { + ov::AnyMap properties; + if (!val.IsObject()) { + OPENVINO_THROW("Passed Napi::Value must be an object."); + } + const auto& parameters = val.ToObject(); + const auto& keys = parameters.GetPropertyNames(); + + for (uint32_t i = 0; i < keys.Length(); ++i) { + const auto& property_name = static_cast(keys[i]).ToString().Utf8Value(); + + ov::Any any_value = js_to_any(env, parameters.Get(property_name)); + + properties.insert(std::make_pair(property_name, any_value)); + } + + return properties; } diff --git a/src/bindings/js/node/tests/basic.test.js b/src/bindings/js/node/tests/basic.test.js index 1236bd9c553520..a657e698850ce2 100644 --- a/src/bindings/js/node/tests/basic.test.js +++ b/src/bindings/js/node/tests/basic.test.js @@ -14,10 +14,10 @@ const compiledModel = core.compileModelSync(model, 'CPU'); const modelLike = [[model], [compiledModel]]; -it('Core.getAvailableDevices()', () => { - const devices = core.getAvailableDevices(); - - assert.ok(devices.includes('CPU')); +it('Core.getAvailableDevices()', () => { + const devices = core.getAvailableDevices(); + + assert.ok(devices.includes('CPU')); }); describe('Core.getVersions()', () => { @@ -214,16 +214,62 @@ describe('Input class for ov::Input', () => { }); -it('Test exportModel()/importModel()', () => { +describe('Test exportModel()/importModel()', () => { const userStream = compiledModel.exportModelSync(); - const newCompiled = core.importModelSync(userStream, 'CPU'); const epsilon = 0.5; const tensor = Float32Array.from({ length: 3072 }, () => (Math.random() + epsilon)); - const inferRequest = compiledModel.createInferRequest(); const res1 = inferRequest.infer([tensor]); - const newInferRequest = newCompiled.createInferRequest(); - const res2 = newInferRequest.infer([tensor]); - assert.deepStrictEqual(res1['fc_out'].data[0], res2['fc_out'].data[0]); + it('Test importModel(stream, device)', () => { + const newCompiled = core.importModelSync(userStream, 'CPU'); + const newInferRequest = newCompiled.createInferRequest(); + const res2 = newInferRequest.infer([tensor]); + + assert.deepStrictEqual(res1['fc_out'].data[0], res2['fc_out'].data[0]); + }); + + it('Test importModel(stream, device, config)', () => { + const newCompiled = core.importModelSync(userStream, 'CPU', { 'NUM_STREAMS': 1 }); + const newInferRequest = newCompiled.createInferRequest(); + const res2 = newInferRequest.infer([tensor]); + + assert.deepStrictEqual(res1['fc_out'].data[0], res2['fc_out'].data[0]); + }); + + it('Test importModel(stream, device) throws', () => { + assert.throws( + () => core.importModelSync(epsilon, 'CPU'), + /The first argument must be of type Buffer./ + ); + }); + + it('Test importModel(stream, device) throws', () => { + assert.throws( + () => core.importModelSync(userStream, tensor), + /The second argument must be of type String./ + ); + }); + it('Test importModel(stream, device, config: tensor) throws', () => { + assert.throws( + () => core.importModelSync(userStream, 'CPU', tensor), + /NotFound: Unsupported property 0 by CPU plugin./ + ); + }); + + it('Test importModel(stream, device, config: string) throws', () => { + const testString = 'test'; + assert.throws( + () => core.importModelSync(userStream, 'CPU', testString), + /Passed Napi::Value must be an object./ + ); + }); + + it('Test importModel(stream, device, config: unsupported property) throws', () => { + const tmpDir = '/tmp'; + assert.throws( + () => core.importModelSync(userStream, 'CPU', {'CACHE_DIR': tmpDir}), + /Unsupported property CACHE_DIR by CPU plugin./ + ); + }); });