Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#50 from eaidova/ea/more_mm
Browse files Browse the repository at this point in the history
more matmul operations
  • Loading branch information
slyalin authored Dec 6, 2022
2 parents 88b12d9 + 1f1c176 commit fc0bc93
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/frontends/pytorch/src/op/addmm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_addmm(NodeContext& context) {
auto input = context.get_input(0);
auto m1 = context.get_input(1);
auto m2 = context.get_input(2);
auto beta = context.get_input(3);
auto alpha = context.get_input(4);
auto mm = context.mark_node(std::make_shared<opset8::MatMul>(m1, m2));
auto input_beta = context.mark_node(std::make_shared<opset8::Multiply>(input, beta));
auto mm_alpha = context.mark_node(std::make_shared<opset8::Multiply>(mm, alpha));
return {context.mark_node(std::make_shared<opset8::Add>(input_beta, mm_alpha))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ OP_CONVERTER(translate_adaptive_avg_pool3d);
OP_CONVERTER(translate_adaptive_max_pool2d);
OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_pool2d);
OP_CONVERTER(translate_batch_norm);
Expand Down Expand Up @@ -91,6 +92,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::add", op::translate_add},
{"aten::add_", op::inplace_op<op::translate_add>},
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::as_tensor", op::translate_as_tensor},
{"aten::avg_pool2d", op::translate_avg_pool2d},
{"aten::batch_norm", op::translate_batch_norm},
Expand Down Expand Up @@ -137,6 +139,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>},
{"aten::mean", op::translate_mean},
{"aten::mm", op::translate_1to1_match_2_inputs<opset8::MatMul>},
{"aten::bmm", op::translate_1to1_match_2_inputs<opset8::MatMul>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset8::MatMul>},
{"aten::mul", op::translate_1to1_match_2_inputs<opset8::Multiply>},
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs<opset8::Multiply>>},
{"aten::ne", op::translate_1to1_match_2_inputs<opset8::NotEqual>},
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/pytorch_tests/test_addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
from pytorch_layer_test_class import PytorchLayerTest


class TestAddMM(PytorchLayerTest):
def _prepare_input(self, input_shape=(2,2), matrix1_shape=(2, 2), matrix2_shape=(2, 2)):
import numpy as np
return (
np.random.randn(*input_shape).astype(np.float32),
np.random.randn(*matrix1_shape).astype(np.float32),
np.random.randn(*matrix2_shape).astype(np.float32)
)

def create_model(self, alpha, beta):

import torch

class aten_addmm(torch.nn.Module):
def __init__(self, alpha, beta):
super(aten_addmm, self).__init__()
self.alpha = alpha
self.beta = beta

def forward(self, m0, m1, m2):
return torch.addmm(m0, m1, m2, alpha=self.alpha, beta=self.beta)

ref_net = None

return aten_addmm(alpha, beta), ref_net, 'aten::addmm'

@pytest.mark.parametrize("kwargs_to_prepare_input", [
{"input_shape": (3, 3), 'matrix1_shape': (3, 3), 'matrix2_shape': (3, 3)},
{"input_shape": (2, 2), 'matrix1_shape': (2, 3), 'matrix2_shape': (3, 2)},
{"input_shape": (10, 1), 'matrix1_shape': (10, 5), 'matrix2_shape': (5, 1)},
{"input_shape": (1, 2), 'matrix1_shape': (1, 10), 'matrix2_shape': (10, 2)},
{"input_shape": (1, 1), 'matrix1_shape': (1, 10), 'matrix2_shape': (10, 1)},
])
@pytest.mark.parametrize("alpha,beta", [(1., 1.), (0., 1.), (1., 0.), (1., 2.), (2., 1.), (-5., -6.), (3., 4.), (0.5, 0.75)])
@pytest.mark.nightly
def test_addmm(self, kwargs_to_prepare_input, alpha, beta, ie_device, precision, ir_version):
self._test(*self.create_model(alpha, beta), ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input)
79 changes: 79 additions & 0 deletions tests/layer_tests/pytorch_tests/test_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
from pytorch_layer_test_class import PytorchLayerTest


class TestMatMul(PytorchLayerTest):
def _prepare_input(self, matrix1_shape=(2, 2), matrix2_shape=(2, 2)):
import numpy as np
return (np.random.randn(*matrix1_shape).astype(np.float32), np.random.randn(*matrix2_shape).astype(np.float32))

def create_model(self, op_type="aten::mm"):

import torch
ops = {
"aten::mm": torch.mm,
"aten::bmm": torch.bmm,
"aten::matmul": torch.matmul
}

class aten_mm(torch.nn.Module):
def __init__(self, op):
super(aten_mm, self).__init__()
self.op = op

def forward(self, m1, m2):
return self.op(m1, m2)
ref_net = None

return aten_mm(ops[op_type]), ref_net, op_type

@pytest.mark.parametrize("kwargs_to_prepare_input", [
{'matrix1_shape': (3, 3), 'matrix2_shape': (3, 3)},
{'matrix1_shape': (2, 3), 'matrix2_shape': (3, 2)},
{'matrix1_shape': (10, 5), 'matrix2_shape': (5, 1)},
{'matrix1_shape': (1, 10), 'matrix2_shape': (10, 2)},
{'matrix1_shape': (1, 10), 'matrix2_shape': (10, 1)},
])
@pytest.mark.nightly
def test_mm(self, kwargs_to_prepare_input, ie_device, precision, ir_version):
self._test(*self.create_model('aten::mm'), ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input)

@pytest.mark.parametrize("kwargs_to_prepare_input", [
{'matrix1_shape': (10, 3, 3), 'matrix2_shape': (10, 3, 3)},
{'matrix1_shape': (1, 2, 3), 'matrix2_shape': (1, 3, 2)},
{'matrix1_shape': (2, 10, 5), 'matrix2_shape': (2, 5, 1)},
{'matrix1_shape': (3, 1, 10), 'matrix2_shape': (3, 10, 2)},
{'matrix1_shape': (4, 1, 10), 'matrix2_shape': (4, 10, 1)},
])
@pytest.mark.nightly
def test_bmm(self, kwargs_to_prepare_input, ie_device, precision, ir_version):
self._test(*self.create_model('aten::bmm'), ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input)

@pytest.mark.parametrize("kwargs_to_prepare_input", [
{'matrix1_shape': (10, 3, 3), 'matrix2_shape': (10, 3, 3)},
{'matrix1_shape': (1, 2, 3), 'matrix2_shape': (1, 3, 2)},
{'matrix1_shape': (2, 10, 5), 'matrix2_shape': (2, 5, 1)},
{'matrix1_shape': (3, 1, 10), 'matrix2_shape': (3, 10, 2)},
{'matrix1_shape': (4, 1, 10), 'matrix2_shape': (4, 10, 1)},
{'matrix1_shape': (3, 3), 'matrix2_shape': (3, 3)},
{'matrix1_shape': (2, 3), 'matrix2_shape': (3, 2)},
{'matrix1_shape': (10, 5), 'matrix2_shape': (5, 1)},
{'matrix1_shape': (1, 10), 'matrix2_shape': (10, 2)},
{'matrix1_shape': (1, 10), 'matrix2_shape': (10, 1)},
{'matrix1_shape': (10, 3, 3), 'matrix2_shape': (3, 3)},
{'matrix1_shape': (2, 3), 'matrix2_shape': (10, 3, 2)},
{'matrix1_shape': (1, 10, 5), 'matrix2_shape': (5, 1)},
{'matrix1_shape': (5, 1, 10), 'matrix2_shape': (10, 2)},
{'matrix1_shape': (1, 10), 'matrix2_shape': (4, 10, 2)},
{'matrix1_shape': (2, 1, 10), 'matrix2_shape': (10, 1)},
{'matrix1_shape': (1, 10), 'matrix2_shape': (2, 10, 1)},
])
@pytest.mark.nightly
def test_matmul(self, kwargs_to_prepare_input, ie_device, precision, ir_version):
self._test(*self.create_model('aten::matmul'), ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input)

0 comments on commit fc0bc93

Please sign in to comment.