Skip to content

Commit

Permalink
implemented getOutputElementType
Browse files Browse the repository at this point in the history
Added Method on C++ side.
Updated typescript definitions.
Created unit tests.
For Issue openvinotoolkit#25406
  • Loading branch information
Pey-crypto committed Jul 14, 2024
1 parent dcdfdc5 commit 7d6a2c7
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/bindings/js/node/include/model_wrap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ class ModelWrap : public Napi::ObjectWrap<ModelWrap> {
*/
Napi::Value get_output_shape(const Napi::CallbackInfo& info);

/**
* @brief Helper function to access model output elements types.
* @param info Contains information about the environment and passed arguements
* @return Napi::Value wrapping a TypeWrap object representing the element type of the requested output.
*/

Napi::Value get_output_element_type(const Napi::CallbackInfo& info);

private:
std::shared_ptr<ov::Model> _model;
ov::Core _core;
Expand Down
4 changes: 4 additions & 0 deletions src/bindings/js/node/lib/addon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ interface Model {
* It gets the input of the model.
* If a model has more than one input, this method throws an exception.
*/
getOutputElementType(index: number): string;
/**
* Gets the element type of a specific output of the model.
*/
input(): Output;
/**
* It gets the input of the model identified by the tensor name.
Expand Down
33 changes: 32 additions & 1 deletion src/bindings/js/node/src/model_wrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ Napi::Function ModelWrap::get_class(Napi::Env env) {
InstanceMethod("setFriendlyName", &ModelWrap::set_friendly_name),
InstanceMethod("getFriendlyName", &ModelWrap::get_friendly_name),
InstanceMethod("getOutputShape", &ModelWrap::get_output_shape),
InstanceMethod("getOutputElementType", &ModelWrap::get_output_elememt_type),
InstanceAccessor<&ModelWrap::get_inputs>("inputs"),
InstanceAccessor<&ModelWrap::get_outputs>("outputs")});
InstanceAccessor<&ModelWrap::get_outputs>("outputs"),});
}

void ModelWrap::set_model(const std::shared_ptr<ov::Model>& model) {
Expand Down Expand Up @@ -171,3 +172,33 @@ Napi::Value ModelWrap::get_output_shape(const Napi::CallbackInfo& info) {
return info.Env().Undefined();
}
}

Napi::Value ModelWrap::get_output_element_type(const Napi::CallbackInfo& info) {
if(info.length() != 1 || !info[0].isNumber()) {
reportError(info.Env(), "Invalid arguement.Expected a single number for output index");
return info.Env().Undefined();
}

try {
auto idx = info[0].As<Napi::Number>().Int32Value();
auto output = _model->output(idx);
auto element_type = output.get_element_type();
std::string type_name = element_type.get_type_name();
std::unordered_map<std::string, std::string> type_map = {
{"float", "f32"},
{"float16", "f16"},
{"int32", "i32"},
{"int64", "i64"},
{"uint8", "u8"}
};
auto mapped_type = type_map.find(type_name);
if (mapped_type != type_map.end()) {
type_name = mapped_type->second;
}

return Napi::String::New(info.Env(), type_name);
} catch (const std::exception& e) {
reportError(info.Env(), e.what());
return info.Env().Undefined();
}
}
33 changes: 33 additions & 0 deletions src/bindings/js/node/tests/model.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,37 @@ describe('Model.getOutputSize()', () => {
it('should return 1 for the default model', () => {
assert.strictEqual(model.getOutputSize(), 1, 'Expected getOutputSize to return 1 for the default model');
});
});

describe('Model.getOutputElementType()', () => {
it('should return a string indicating the element type of the specified output', () => {
const result = model.getOutputElementType(0);
assert.strictEqual(typeof result, 'string', 'getOutputElementType() should return a string');
});

it('should accept a single number argument', () => {
assert.throws(() => {
model.getOutputElementType();
}, /^Error: Invalid argument. Expected a single number for output index\.$/, 'Expected getOutputElementType to throw an error when called without arguments');

assert.throws(() => {
model.getOutputElementType('unexpected argument');
}, /^Error: Invalid argument. Expected a single number for output index\.$/, 'Expected getOutputElementType to throw an error when called with a non-number argument');

assert.throws(() => {
model.getOutputElementType(0, 1);
}, /^Error: Invalid argument. Expected a single number for output index\.$/, 'Expected getOutputElementType to throw an error when called with more than one argument');
});

it('should return a valid element type for the default model', () => {
const elementType = model.getOutputElementType(0);
assert.ok(['f32', 'f16', 'i32', 'i64', 'u8'].includes(elementType), `Expected a valid element type, got ${elementType}`);
});

it('should throw an error for out-of-range index', () => {
const outputSize = model.getOutputSize();
assert.throws(() => {
model.getOutputElementType(outputSize);
}, /^Error:/, 'Expected getOutputElementType to throw an error for out-of-range index');
});
});

0 comments on commit 7d6a2c7

Please sign in to comment.