From dfd675c7e9eae37c4380f8c5195eda0ff772639a Mon Sep 17 00:00:00 2001 From: Michael Nosov Date: Thu, 15 Apr 2021 20:35:32 +0300 Subject: [PATCH] SetPartialShape in MO Usage example: ./mo.py --framework pdpd --input_model ../ngraph/test/files/paddlepaddle/models/2in_2out/2in_2out.pdmodel --log_level=INFO --input inputX1,inputX2 --input_shape [1,1,4,4],[1,2,3,4] --- model-optimizer/mo/front/extractor.py | 10 +++++----- model-optimizer/mo/pipeline/unified.py | 5 ++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/model-optimizer/mo/front/extractor.py b/model-optimizer/mo/front/extractor.py index 9871b3550ddba3..bba9966c8652f7 100644 --- a/model-optimizer/mo/front/extractor.py +++ b/model-optimizer/mo/front/extractor.py @@ -591,11 +591,11 @@ def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, n if node is None: raise Error('Cannot find location {} in the graph'.format(input_name)) shape = None if isinstance(input_user_shapes, list) else input_user_shapes[input_name] - if input_name in input_user_data_types and input_user_data_types[input_name] is not None: - data_type = input_user_data_types[input_name] - _input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type}) - else: - _input_shapes.append({'node': node, 'shape': shape}) + if input_name in input_user_data_types and input_user_data_types[input_name] is not None: + data_type = input_user_data_types[input_name] + _input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type}) + else: + _input_shapes.append({'node': node, 'shape': shape}) elif isinstance(input_user_shapes, np.ndarray): model_inputs = inputModel.getInputs() assert len(model_inputs) == 1 diff --git a/model-optimizer/mo/pipeline/unified.py b/model-optimizer/mo/pipeline/unified.py index 21867a2941c137..4bef47ae480a38 100644 --- a/model-optimizer/mo/pipeline/unified.py +++ b/model-optimizer/mo/pipeline/unified.py @@ -58,9 +58,8 @@ def moc_pipeline(argv: argparse.Namespace): apply_replacements_list(graph, transforms) user_shapes = graph.graph['user_shapes'] if len(user_shapes) > 0: - assert len(inputModel.getInputs()) == 1 - assert len(user_shapes) == 1 - inputModel.setPartialShape(user_shapes[0]['node'], PartialShape(user_shapes[0]['shape'])) + for user_shape in user_shapes: + inputModel.setPartialShape(user_shape['node'], PartialShape(user_shape['shape'])) nGraphModel = fe.convert(inputModel) network = function_to_cnn(nGraphModel) graph.graph['network'] = network