From 7d6a2c7bff0a809cc3d8e376407890e153529c56 Mon Sep 17 00:00:00 2001 From: Peyara Nando Date: Sun, 14 Jul 2024 16:28:55 +0530 Subject: [PATCH] implemented getOutputElementType Added Method on C++ side. Updated typescript definitions. Created unit tests. For Issue #25406 --- src/bindings/js/node/include/model_wrap.hpp | 8 +++++ src/bindings/js/node/lib/addon.ts | 4 +++ src/bindings/js/node/src/model_wrap.cpp | 33 ++++++++++++++++++++- src/bindings/js/node/tests/model.test.js | 33 +++++++++++++++++++++ 4 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/bindings/js/node/include/model_wrap.hpp b/src/bindings/js/node/include/model_wrap.hpp index cda9ff8b6ee65a..b34cdc0480dad5 100644 --- a/src/bindings/js/node/include/model_wrap.hpp +++ b/src/bindings/js/node/include/model_wrap.hpp @@ -109,6 +109,14 @@ class ModelWrap : public Napi::ObjectWrap { */ Napi::Value get_output_shape(const Napi::CallbackInfo& info); + /** + * @brief Helper function to access model output elements types. + * @param info Contains information about the environment and passed arguements + * @return Napi::Value wrapping a TypeWrap object representing the element type of the requested output. + */ + + Napi::Value get_output_element_type(const Napi::CallbackInfo& info); + private: std::shared_ptr _model; ov::Core _core; diff --git a/src/bindings/js/node/lib/addon.ts b/src/bindings/js/node/lib/addon.ts index b5909ea9f3ae03..1415b0873b341a 100644 --- a/src/bindings/js/node/lib/addon.ts +++ b/src/bindings/js/node/lib/addon.ts @@ -238,6 +238,10 @@ interface Model { * It gets the input of the model. * If a model has more than one input, this method throws an exception. */ + getOutputElementType(index: number): string; + /** + * Gets the element type of a specific output of the model. + */ input(): Output; /** * It gets the input of the model identified by the tensor name. diff --git a/src/bindings/js/node/src/model_wrap.cpp b/src/bindings/js/node/src/model_wrap.cpp index e8359b83ff6da3..544a8727ddde12 100644 --- a/src/bindings/js/node/src/model_wrap.cpp +++ b/src/bindings/js/node/src/model_wrap.cpp @@ -25,8 +25,9 @@ Napi::Function ModelWrap::get_class(Napi::Env env) { InstanceMethod("setFriendlyName", &ModelWrap::set_friendly_name), InstanceMethod("getFriendlyName", &ModelWrap::get_friendly_name), InstanceMethod("getOutputShape", &ModelWrap::get_output_shape), + InstanceMethod("getOutputElementType", &ModelWrap::get_output_elememt_type), InstanceAccessor<&ModelWrap::get_inputs>("inputs"), - InstanceAccessor<&ModelWrap::get_outputs>("outputs")}); + InstanceAccessor<&ModelWrap::get_outputs>("outputs"),}); } void ModelWrap::set_model(const std::shared_ptr& model) { @@ -171,3 +172,33 @@ Napi::Value ModelWrap::get_output_shape(const Napi::CallbackInfo& info) { return info.Env().Undefined(); } } + +Napi::Value ModelWrap::get_output_element_type(const Napi::CallbackInfo& info) { + if(info.length() != 1 || !info[0].isNumber()) { + reportError(info.Env(), "Invalid arguement.Expected a single number for output index"); + return info.Env().Undefined(); + } + + try { + auto idx = info[0].As().Int32Value(); + auto output = _model->output(idx); + auto element_type = output.get_element_type(); + std::string type_name = element_type.get_type_name(); + std::unordered_map type_map = { + {"float", "f32"}, + {"float16", "f16"}, + {"int32", "i32"}, + {"int64", "i64"}, + {"uint8", "u8"} + }; + auto mapped_type = type_map.find(type_name); + if (mapped_type != type_map.end()) { + type_name = mapped_type->second; + } + + return Napi::String::New(info.Env(), type_name); + } catch (const std::exception& e) { + reportError(info.Env(), e.what()); + return info.Env().Undefined(); + } +} \ No newline at end of file diff --git a/src/bindings/js/node/tests/model.test.js b/src/bindings/js/node/tests/model.test.js index 0ed4340b7be66d..f8aa102dc4a9f0 100644 --- a/src/bindings/js/node/tests/model.test.js +++ b/src/bindings/js/node/tests/model.test.js @@ -111,4 +111,37 @@ describe('Model.getOutputSize()', () => { it('should return 1 for the default model', () => { assert.strictEqual(model.getOutputSize(), 1, 'Expected getOutputSize to return 1 for the default model'); }); +}); + +describe('Model.getOutputElementType()', () => { + it('should return a string indicating the element type of the specified output', () => { + const result = model.getOutputElementType(0); + assert.strictEqual(typeof result, 'string', 'getOutputElementType() should return a string'); + }); + + it('should accept a single number argument', () => { + assert.throws(() => { + model.getOutputElementType(); + }, /^Error: Invalid argument. Expected a single number for output index\.$/, 'Expected getOutputElementType to throw an error when called without arguments'); + + assert.throws(() => { + model.getOutputElementType('unexpected argument'); + }, /^Error: Invalid argument. Expected a single number for output index\.$/, 'Expected getOutputElementType to throw an error when called with a non-number argument'); + + assert.throws(() => { + model.getOutputElementType(0, 1); + }, /^Error: Invalid argument. Expected a single number for output index\.$/, 'Expected getOutputElementType to throw an error when called with more than one argument'); + }); + + it('should return a valid element type for the default model', () => { + const elementType = model.getOutputElementType(0); + assert.ok(['f32', 'f16', 'i32', 'i64', 'u8'].includes(elementType), `Expected a valid element type, got ${elementType}`); + }); + + it('should throw an error for out-of-range index', () => { + const outputSize = model.getOutputSize(); + assert.throws(() => { + model.getOutputElementType(outputSize); + }, /^Error:/, 'Expected getOutputElementType to throw an error for out-of-range index'); + }); }); \ No newline at end of file