Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement caffe old-style extractors with extractor extensions #3675

13 changes: 8 additions & 5 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,20 @@ extensions/front/caffe/accum_ext.py
extensions/front/caffe/argmax_ext.py
extensions/front/caffe/ArgMaxFlatten.py
extensions/front/caffe/axpy.py
extensions/front/caffe/batchnorm_ext.py
extensions/front/caffe/binarization.py
extensions/front/caffe/binary_conv_ext.py
extensions/front/caffe/bn.py
extensions/front/caffe/bn_ext.py
extensions/front/caffe/concat_ext.py
extensions/front/caffe/conv_ext.py
extensions/front/caffe/correlation_ext.py
extensions/front/caffe/crop_ext.py
extensions/front/caffe/ctcgreedydecoder_ext.py
extensions/front/caffe/CustomLayersMapping.xml.example
extensions/front/caffe/data_augmentation_ext.py
extensions/front/caffe/detection_output.py
extensions/front/caffe/dropout_ext.py
extensions/front/caffe/elementwise_ext.py
extensions/front/caffe/eltwise_add_normalize.py
extensions/front/caffe/elu.py
Expand All @@ -106,6 +111,8 @@ extensions/front/caffe/relu_ext.py
extensions/front/caffe/reorgyolo_ext.py
extensions/front/caffe/resample_ext.py
extensions/front/caffe/reshape.py
extensions/front/caffe/roipooling_ext.py
extensions/front/caffe/scale_ext.py
extensions/front/caffe/shufflechannel_ext.py
extensions/front/caffe/sigmoid.py
extensions/front/caffe/simplernms_ext.py
Expand Down Expand Up @@ -617,6 +624,7 @@ extensions/ops/axpy.py
extensions/ops/BatchNormInference.py
extensions/ops/binarization.py
extensions/ops/BlockLSTM.py
extensions/ops/BN.py
extensions/ops/box_nms.py
extensions/ops/bucketize.py
extensions/ops/Cast.py
Expand Down Expand Up @@ -757,12 +765,7 @@ mo/front/caffe/collect_attributes.py
mo/front/caffe/custom_layers_mapping.py
mo/front/caffe/extractor.py
mo/front/caffe/extractors/__init__.py
mo/front/caffe/extractors/batchnorm.py
mo/front/caffe/extractors/concat.py
mo/front/caffe/extractors/crop.py
mo/front/caffe/extractors/native_caffe.py
mo/front/caffe/extractors/roipooling.py
mo/front/caffe/extractors/scale.py
mo/front/caffe/extractors/tile.py
mo/front/caffe/extractors/utils.py
mo/front/caffe/loader.py
Expand Down
55 changes: 55 additions & 0 deletions model-optimizer/extensions/front/caffe/batchnorm_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
evolosen marked this conversation as resolved.
Show resolved Hide resolved

from extensions.ops.BatchNormInference import BatchNormInference
from mo.front.extractor import FrontExtractorOp
from mo.front.caffe.extractors.utils import embed_input
import numpy as np
evolosen marked this conversation as resolved.
Show resolved Hide resolved


class BatchNormalizationExtractor(FrontExtractorOp):
op = 'batchnorm'
enabled = True

@classmethod
def extract(cls, node):
eps = node.pb.batch_norm_param.eps
attrs = {
'eps': eps
}
pb_model = None if not node.has('model_pb') else node.model_pb
evolosen marked this conversation as resolved.
Show resolved Hide resolved
if pb_model:

evolosen marked this conversation as resolved.
Show resolved Hide resolved
blobs = pb_model.blobs
assert len(blobs) >= 2, 'BatchNorm accepts not less then two input blobs'
mean = np.array(blobs[0].data)
variance = np.array(blobs[1].data)

if len(blobs) == 3:
scale = blobs[2].data[0]
if scale != 0:
scale = 1.0 / scale
mean *= scale
variance *= scale

embed_input(attrs, 1, 'gamma', np.ones(mean.shape), 'gamma')
embed_input(attrs, 2, 'beta', np.zeros(variance.shape), 'beta')
embed_input(attrs, 3, 'mean', mean, 'biases')
embed_input(attrs, 4, 'variance', variance, 'weights')

BatchNormInference.update_node_stat(node, attrs)
return cls.enabled
6 changes: 4 additions & 2 deletions model-optimizer/extensions/front/caffe/bn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,14 +27,16 @@ class BNToScaleShift(FrontReplacementOp):
"""
Replaces BN layer with ScaleShift.
"""
op = "batchNormInference"
op = "BN"
enabled = True

def replace_op(self, graph: Graph, node: Node):
# This transformation does not work!!!
evolosen marked this conversation as resolved.
Show resolved Hide resolved
attrs = {'name': node.id + "/ScaleShift_"}

param = graph.node[node.id]['pb'].bn_param
pb_model = graph.node[node.id]['model_pb']

blobs = pb_model.blobs

if len(blobs) != 4:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -13,16 +13,19 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log

from mo.front.common.partial_infer.roipooling import roipooling_infer
from extensions.ops.BN import BN
from mo.front.extractor import FrontExtractorOp
from mo.front.caffe.extractors.utils import embed_input
import numpy as np


def roipooling_ext(proto_layer, model_layer):
param = proto_layer.roi_pooling_param
return {
'type': 'ROIPooling',
'pooled_h': param.pooled_h,
'pooled_w': param.pooled_w,
'spatial_scale': param.spatial_scale,
'infer': roipooling_infer
}
class BNExtractor(FrontExtractorOp):
op = 'BN'
enabled = True

@classmethod
def extract(cls, node):
BN.update_node_stat(node, {})
return cls.enabled
2 changes: 1 addition & 1 deletion model-optimizer/extensions/front/caffe/bn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_bn(self):
FakeParam('data', shift)])
nodes = [
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}),
('bn', {'type': None, 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}),
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
]
edges = [
Expand Down
33 changes: 33 additions & 0 deletions model-optimizer/extensions/front/caffe/concat_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from mo.front.onnx.extractors.utils import onnx_attr
evolosen marked this conversation as resolved.
Show resolved Hide resolved
from mo.front.extractor import FrontExtractorOp
from mo.ops.concat import Concat


class ConcatFrontExtractor(FrontExtractorOp):
op = 'concat'
enabled = True

@classmethod
def extract(cls, node):
pb = node.pb
mapping_rule = {
'axis': pb.concat_param.axis,
}
Concat.update_node_stat(node, mapping_rule)
return cls.enabled
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
import unittest
from unittest.mock import patch

from mo.front.caffe.extractors.crop import CropFrontExtractor
from extensions.front.caffe.crop_ext import CropFrontExtractor
from mo.front.common.partial_infer.crop import crop_infer
from mo.ops.crop import Crop
from mo.ops.op import Op
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,12 +14,16 @@
limitations under the License.
"""

from mo.front.common.partial_infer.concat import concat_infer
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
from mo.graph.graph import Node


def concat_ext(pb_layer, pb_model):
return {
'type': "Concat",
'axis': pb_layer.concat_param.axis,
'infer': concat_infer
}
class DropoutFrontExtractor(FrontExtractorOp):
op = 'dropout'
enabled = True

@classmethod
def extract(cls, node: Node):
Identity.update_node_stat(node, {})
return cls.enabled
36 changes: 36 additions & 0 deletions model-optimizer/extensions/front/caffe/roipooling_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from mo.ops.roipooling import ROIPooling
evolosen marked this conversation as resolved.
Show resolved Hide resolved

from mo.front.extractor import FrontExtractorOp


class ROIPoolingFrontExtractor(FrontExtractorOp):
op = 'roipooling'
enabled = True

@classmethod
def extract(cls, node):
param = node.pb.roi_pooling_param
attrs = {
'pooled_h': param.pooled_h,
'pooled_w': param.pooled_w,
'spatial_scale': param.spatial_scale,
}

ROIPooling.update_node_stat(node, attrs)
return cls.enabled
55 changes: 55 additions & 0 deletions model-optimizer/extensions/front/caffe/scale_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from mo.front.caffe.extractors.utils import embed_input, weights_biases
from mo.front.extractor import FrontExtractorOp
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.utils.utils import NamedAttrsClass
evolosen marked this conversation as resolved.
Show resolved Hide resolved
from mo.ops.scale_shift import ScaleShiftOp


class ScaleFrontExtractor(FrontExtractorOp):
op = 'scale'
enabled = True

@classmethod
def extract(cls, node):
pb = node.pb
model = node.model_pb
param = pb.scale_param
attrs = {
'axis': param.axis,
}

if model is None and len(pb.bottom) == 1:
# default weights and biases for scale layer if the caffemodel file doesn't contain them
model = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([1])}),
NamedAttrsClass({'data': np.array([0])})])})
# scale with 1 input and 1 or 2 blobs
if model and len(model.blobs) != 0 and len(pb.bottom) == 1:
attrs.update(weights_biases(param.bias_term, model))
# 2 inputs + bias
elif len(pb.bottom) == 2 and param.bias_term:
if model is None or len(model.blobs) == 0:
# default bias for scale layer with 2 inputs if the caffemodel file doesn't contain them
model = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([0])})])})

embed_input(attrs, 1, 'biases', model.blobs[0].data)
ScaleShiftOp.update_node_stat(node, attrs)
return cls.enabled

Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,25 @@
limitations under the License.
"""

import unittest

from mo.front.caffe.extractors.concat import concat_ext
from mo.front.common.partial_infer.concat import concat_infer
from mo.utils.unittest.extractors import FakeParam


class FakeProtoLayer:
def __init__(self, axis):
self.concat_param = FakeParam('axis', axis)


class TestConcat(unittest.TestCase):
def test_concat(self):
res = concat_ext(FakeProtoLayer(10), None)
exp_res = {
'axis': 10,
'infer': concat_infer,
'type': 'Concat'
}
self.assertEqual(res, exp_res)
from mo.graph.graph import Graph
from mo.ops.op import Op


class BN(Op):
"""
BN operation will be replaced by BNToScaleShift FrontReplacer.
evolosen marked this conversation as resolved.
Show resolved Hide resolved
"""
op = 'BN'
enabled = False

def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': None,
'op': self.op,
'in_ports_count': 5,
'out_ports_count': 1,
'infer': self.infer
}, attrs)
@staticmethod
evolosen marked this conversation as resolved.
Show resolved Hide resolved
def infer(node):
evolosen marked this conversation as resolved.
Show resolved Hide resolved
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
Loading