Skip to content

Commit

Permalink
generalize conv2d implementation for conv1d and conv3d
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 28, 2022
1 parent 92eaa52 commit 7e24381
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_conv2d(NodeContext& context) {
OutputVector translate_conv(NodeContext& context) {
auto strides = context.const_input<Strides>(3);
// In torch pads at beginning are same as at end
auto pads = CoordinateDiff(strides.size(), 0);
Expand Down Expand Up @@ -49,8 +49,16 @@ OutputVector translate_conv2d(NodeContext& context) {
dilations,
pad_type);
}
if (!context.input_is_none(2)) {
auto bias = context.get_input(2);
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_conv_bias(context, bias, conv);
}
conv = context.mark_node(std::make_shared<opset8::Add>(conv, bias));
}

return {context.mark_output(make_optional_bias(conv, context, 2, {-2, -1}))};
return {conv};
};

} // namespace op
Expand Down
53 changes: 0 additions & 53 deletions tests/layer_tests/pytorch_tests/test_conv2d.py

This file was deleted.

151 changes: 151 additions & 0 deletions tests/layer_tests/pytorch_tests/test_convnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
from pytorch_layer_test_class import PytorchLayerTest


class TestConv2D(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 3, 25, 25).astype(np.float32),)

def create_model(self, weights_shape, strides, pads, dilations, groups, bias):

import torch
import torch.nn.functional as F

class aten_conv2d(torch.nn.Module):
def __init__(self):
super(aten_conv2d, self).__init__()
self.weight = torch.randn(weights_shape)
self.bias = None
if bias:
self.bias = torch.randn(weights_shape[0])
self.strides = strides
self.pads = pads
self.dilations = dilations
self.groups = groups

def forward(self, x):
return F.conv2d(x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.groups)

ref_net = None

return aten_conv2d(), ref_net, "aten::conv2d"

@pytest.mark.parametrize("params",
[{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3], 'strides': 2, 'pads': 0, 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 1, 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 2, 'groups': 1},
{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': [0, 1], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': [1, 0], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 'same', 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 'valid', 'dilations': 1, 'groups': 1},
# doesn't work because input shape is dynamic which makes kernel shape dynamic
# {'weights_shape': [2, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 2},
])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.nightly
def test_conv2d(self, params, bias, ie_device, precision, ir_version):
self._test(*self.create_model(**params, bias=bias),
ie_device, precision, ir_version)


class TestConv1D(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 3, 25).astype(np.float32),)

def create_model(self, weights_shape, strides, pads, dilations, groups, bias):

import torch
import torch.nn.functional as F

class aten_conv1d(torch.nn.Module):
def __init__(self):
super(aten_conv1d, self).__init__()
self.weight = torch.randn(weights_shape)
self.bias = None
if bias:
self.bias = torch.randn(weights_shape[0])
self.strides = strides
self.pads = pads
self.dilations = dilations
self.groups = groups

def forward(self, x):
return F.conv1d(x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.groups)

ref_net = None

return aten_conv1d(), ref_net, "aten::conv1d"

@pytest.mark.parametrize("params",
[{'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1},
{'weights_shape': [3, 3, 3], 'strides': 2, 'pads': 0, 'dilations': 1, 'groups': 1},
{'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 1, 'dilations': 1, 'groups': 1},
{'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 2, 'groups': 1},
{'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 'same', 'dilations': 1, 'groups': 1},
{'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 'valid', 'dilations': 1, 'groups': 1},
# doesn't work because input shape is dynamic which makes kernel shape dynamic
# {'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 2},
])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.nightly
def test_conv1d(self, params, bias, ie_device, precision, ir_version):
self._test(*self.create_model(**params, bias=bias),
ie_device, precision, ir_version)


class TestConv3D(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 3, 25, 25, 25).astype(np.float32),)

def create_model(self, weights_shape, strides, pads, dilations, groups, bias):

import torch
import torch.nn.functional as F

class aten_conv3d(torch.nn.Module):
def __init__(self):
super(aten_conv3d, self).__init__()
self.weight = torch.randn(weights_shape)
self.bias = None
if bias:
self.bias = torch.randn(weights_shape[0])
self.strides = strides
self.pads = pads
self.dilations = dilations
self.groups = groups

def forward(self, x):
return F.conv3d(x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.groups)

ref_net = None

return aten_conv3d(), ref_net, "aten::conv3d"

@pytest.mark.parametrize("params",
[{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 2, 'pads': 0, 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 1, 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 2, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [0, 1, 0], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [1, 0, 0], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [0, 0, 1], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [1, 1, 0], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [0, 1, 1], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [1, 0, 1], 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 'same', 'dilations': 1, 'groups': 1},
{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 'valid', 'dilations': 1, 'groups': 1},
# doesn't work because input shape is dynamic which makes kernel shape dynamic
# {'weights_shape': [2, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 2},
])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.nightly
def test_conv3d(self, params, bias, ie_device, precision, ir_version):
self._test(*self.create_model(**params, bias=bias),
ie_device, precision, ir_version)

0 comments on commit 7e24381

Please sign in to comment.