From 3669205a444eb69712ae418f6d1ef8927b5b0e91 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Fri, 29 Jan 2021 15:48:33 +0300 Subject: [PATCH] Added support for the MxNet op take (#4071) --- .../Supported_Frameworks_Layers.md | 1 + model-optimizer/automation/package_BOM.txt | 1 + .../extensions/front/mxnet/take_ext.py | 33 +++++++++++++++++++ 3 files changed, 35 insertions(+) create mode 100644 model-optimizer/extensions/front/mxnet/take_ext.py diff --git a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md index 869cfa49d5e942..e938848a679444 100644 --- a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md +++ b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md @@ -108,6 +108,7 @@ Standard MXNet\* symbols: | SoftmaxActivation | No | | SoftmaxOutput | No | | SoftSign | No | +| Take | The attribute 'mode' is not supported | | Tile | No | | UpSampling | No | | Where | No | diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index e4080d168e1274..b488271b996c19 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -224,6 +224,7 @@ extensions/front/mxnet/ssd_pattern_remove_transpose.py extensions/front/mxnet/ssd_reorder_detection_out_inputs.py extensions/front/mxnet/stack_ext.py extensions/front/mxnet/swapaxis_ext.py +extensions/front/mxnet/take_ext.py extensions/front/mxnet/tile_ext.py extensions/front/mxnet/tile_replacer.py extensions/front/mxnet/transpose_ext.py diff --git a/model-optimizer/extensions/front/mxnet/take_ext.py b/model-optimizer/extensions/front/mxnet/take_ext.py new file mode 100644 index 00000000000000..590f0a3d68633d --- /dev/null +++ b/model-optimizer/extensions/front/mxnet/take_ext.py @@ -0,0 +1,33 @@ +""" + Copyright (C) 2017-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from extensions.ops.gather import AttributedGather +from mo.front.extractor import FrontExtractorOp +from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs +from mo.graph.graph import Node + + +class TakeExtractor(FrontExtractorOp): + op = 'take' + enabled = True + + @classmethod + def extract(cls, node: Node): + attrs = get_mxnet_layer_attrs(node.symbol_dict) + AttributedGather.update_node_stat(node, { + 'axis': attrs.int('axis', 0), + }) + return cls.enabled