Skip to content

Commit

Permalink
[PyOV] Extend Python API with Col2Im-15 (#24569)
Browse files Browse the repository at this point in the history
### Details:
 - Add Python API for `Col2Im-15`
- Requires #24548 to
work and pass CI

### Tickets:
 - CVS-138920

### Related PRs:
- #24548
- #24197
- #23947

---------

Co-authored-by: Michal Lukaszewski <[email protected]>
  • Loading branch information
p-wysocki and mlukasze authored May 23, 2024
1 parent 3bbb35c commit c4f33ce
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@

# TODO (ticket 138273): Add previous opset operators at the end of opset15 development
from openvino.runtime.opset1.ops import parameter
from openvino.runtime.opset15.ops import col2im
from openvino.runtime.opset15.ops import scatter_nd_update
46 changes: 45 additions & 1 deletion src/bindings/python/src/openvino/runtime/opset15/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""Factory functions for ops added to openvino opset15."""
from functools import partial
from typing import Optional, Literal
from typing import Optional, Literal, List

from openvino.runtime import Node, Type
from openvino.runtime.opset_utils import _get_node_factory
Expand Down Expand Up @@ -39,3 +39,47 @@ def scatter_nd_update(
if reduction:
attributes["reduction"] = reduction
return _get_node_factory_opset15().create("ScatterNDUpdate", inputs, attributes)


@nameable_op
def col2im(
data: NodeInput,
output_size: NodeInput,
kernel_size: NodeInput,
strides: Optional[List[int]] = None,
dilations: Optional[List[int]] = None,
pads_begin: Optional[List[int]] = None,
pads_end: Optional[List[int]] = None,
name: Optional[str] = None,
) -> Node:
"""Perform data movement operation which combines sliding blocks into an image tensor.
:param data: The node providing input data.
:param output_size: Shape of the spatial dimensions of the output image.
:param kernel_size: Size of the sliding blocks.
:param strides: Stride on the sliding blocks in the input spatial dimensions. Defaults to [1, 1].
:param dilations: The dilation of filter elements (distance between elements). Defaults to [1, 1].
:param pads_begin: The number of pixels added at the beginning along each axis. Defaults to [0, 0].
:param pads_end: The number of pixels added at the end along each axis. Defaults to [0, 0].
:param name: The optional name for the created output node.
:return: The new node performing Col2Im operation.
"""
if strides is None:
strides = [1, 1]
if dilations is None:
dilations = [1, 1]
if pads_begin is None:
pads_begin = [0, 0]
if pads_end is None:
pads_end = [0, 0]
return _get_node_factory_opset15().create(
"Col2Im",
as_nodes(data, output_size, kernel_size, name=name),
{
"strides": strides,
"dilations": dilations,
"pads_begin": pads_begin,
"pads_end": pads_end,
},
)
27 changes: 27 additions & 0 deletions src/bindings/python/tests/test_graph/test_col2im.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from openvino import Type
import openvino.runtime.opset15 as ov
import numpy as np
import pytest


@pytest.mark.parametrize(("input_shape", "output_size", "kernel_size", "expected_shape", "params"), [
((3, 4), [2, 2], [1, 1], [3, 2, 2], []),
((3, 4), [2, 2], [1, 1], [3, 2, 2], [[1, 1], [1, 1], [0, 0], [0, 0]]),
((12, 25), [4, 4], [2, 2], [3, 4, 4], [[2, 2], [1, 1], [3, 3], [3, 3]]),
((2, 8, 8), [5, 5], [2, 2], [2, 2, 5, 5], [[2, 2], [1, 1], [0, 2], [0, 2]]),
((2, 32, 12), [6, 6], [4, 4], [2, 2, 6, 6], [[2, 2], [2, 2], [4, 3], [4, 3]]),
])
def test_col2im(input_shape, output_size, kernel_size, expected_shape, params):
input_data = ov.parameter(input_shape, name="input_data", dtype=np.float32)
output_size = np.array(output_size, np.int32)
kernel_size = np.array(kernel_size, np.int32)

node = ov.col2im(input_data, output_size, kernel_size, *params)
assert node.get_type_name() == "Col2Im"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32

0 comments on commit c4f33ce

Please sign in to comment.