Skip to content

Commit

Permalink
Added Tensor.get_size() method to Node.js API (#23498)
Browse files Browse the repository at this point in the history
### This Fixes #23440
 
### Details:

Extended Tensor API , with tensor.getSize() method

-  Implemented tensor.getSize() in the js api
-  added parameter validation
-  updated Tensor interface with the getSize() method in addon.ts
-  added Tests

---------

Co-authored-by: Alicja Miloszewska <[email protected]>
Co-authored-by: Vishniakov Nikolai <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2024
1 parent 561d78f commit 496a5de
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/bindings/js/node/include/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class TensorWrap : public Napi::ObjectWrap<TensorWrap> {
Napi::Value get_shape(const Napi::CallbackInfo& info);
/** @return Napi::String containing ov::element type. */
Napi::Value get_element_type(const Napi::CallbackInfo& info);
/**@return Napi::Number containing tensor size as total number of elements.*/
Napi::Value get_size(const Napi::CallbackInfo& info);

private:
ov::Tensor _tensor;
Expand Down
1 change: 1 addition & 0 deletions src/bindings/js/node/lib/addon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ interface Tensor {
getElementType(): element;
getShape(): number[];
getData(): number[];
getSize(): number;
}
interface TensorConstructor {
new(type: element | elementTypeString,
Expand Down
13 changes: 12 additions & 1 deletion src/bindings/js/node/src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Napi::Function TensorWrap::get_class(Napi::Env env) {
{InstanceAccessor<&TensorWrap::get_data>("data"),
InstanceMethod("getData", &TensorWrap::get_data),
InstanceMethod("getShape", &TensorWrap::get_shape),
InstanceMethod("getElementType", &TensorWrap::get_element_type)});
InstanceMethod("getElementType", &TensorWrap::get_element_type),
InstanceMethod("getSize", &TensorWrap::get_size)});
}

ov::Tensor TensorWrap::get_tensor() const {
Expand Down Expand Up @@ -138,3 +139,13 @@ Napi::Value TensorWrap::get_shape(const Napi::CallbackInfo& info) {
Napi::Value TensorWrap::get_element_type(const Napi::CallbackInfo& info) {
return cpp_to_js<ov::element::Type_t, Napi::String>(info, _tensor.get_element_type());
}

Napi::Value TensorWrap::get_size(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (info.Length() > 0) {
reportError(env, "getSize() does not accept any arguments.");
return env.Undefined();
}
const auto size = static_cast<double>(_tensor.get_size());
return Napi::Number::New(env, size);
}
46 changes: 46 additions & 0 deletions src/bindings/js/node/tests/tensor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,49 @@ describe('Tensor element type', () => {
});
});
});


describe('Tensor getSize', () => {

it('getSize returns the correct total number of elements', () => {
const tensor = new ov.Tensor(ov.element.f32, shape, data);
const expectedSize = shape.reduce((acc, dim) => acc * dim, 1);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('getSize should throw an error if arguments are provided', () => {
const tensor = new ov.Tensor(ov.element.f32, shape, data);
assert.throws(
() => tensor.getSize(1),
{ message: 'getSize() does not accept any arguments.' }
);
});
});

describe('Tensor getSize for various shapes', () => {

it('calculates size correctly for a common image data shape [3, 224, 224]', () => {
const shape = [3, 224, 224];
const expectedSize = 3*224*224;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('calculates size correctly for a scalar wrapped in a tensor [1]', () => {
const shape = [1];
const expectedSize = 1;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});

it('calculates size correctly for a vector [10]', () => {
const shape = [10];
const expectedSize = 10;
const tensorData = new Float32Array(expectedSize).fill(0);
const tensor = new ov.Tensor(ov.element.f32, shape, tensorData);
assert.strictEqual(tensor.getSize(), expectedSize);
});
});

0 comments on commit 496a5de

Please sign in to comment.