Skip to content

Commit

Permalink
SetPartialShape in MO
Browse files Browse the repository at this point in the history
Usage example:
./mo.py --framework pdpd --input_model ../ngraph/test/files/paddlepaddle/models/2in_2out/2in_2out.pdmodel --log_level=INFO --input inputX1,inputX2 --input_shape [1,1,4,4],[1,2,3,4]
  • Loading branch information
nosovmik committed Apr 15, 2021
1 parent aaedbdc commit dfd675c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
10 changes: 5 additions & 5 deletions model-optimizer/mo/front/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,11 @@ def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, n
if node is None:
raise Error('Cannot find location {} in the graph'.format(input_name))
shape = None if isinstance(input_user_shapes, list) else input_user_shapes[input_name]
if input_name in input_user_data_types and input_user_data_types[input_name] is not None:
data_type = input_user_data_types[input_name]
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
else:
_input_shapes.append({'node': node, 'shape': shape})
if input_name in input_user_data_types and input_user_data_types[input_name] is not None:
data_type = input_user_data_types[input_name]
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
else:
_input_shapes.append({'node': node, 'shape': shape})
elif isinstance(input_user_shapes, np.ndarray):
model_inputs = inputModel.getInputs()
assert len(model_inputs) == 1
Expand Down
5 changes: 2 additions & 3 deletions model-optimizer/mo/pipeline/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ def moc_pipeline(argv: argparse.Namespace):
apply_replacements_list(graph, transforms)
user_shapes = graph.graph['user_shapes']
if len(user_shapes) > 0:
assert len(inputModel.getInputs()) == 1
assert len(user_shapes) == 1
inputModel.setPartialShape(user_shapes[0]['node'], PartialShape(user_shapes[0]['shape']))
for user_shape in user_shapes:
inputModel.setPartialShape(user_shape['node'], PartialShape(user_shape['shape']))
nGraphModel = fe.convert(inputModel)
network = function_to_cnn(nGraphModel)
graph.graph['network'] = network
Expand Down

0 comments on commit dfd675c

Please sign in to comment.