Skip to content

Commit

Permalink
Fix for model transformer command (#2823)
Browse files Browse the repository at this point in the history
### Changes

- Fixed `model_extraction_command` issue introduced in #2307.

### Reason for changes

- Bugfix.

### Related tickets

- 147065

### Tests

- Updated
  • Loading branch information
KodiaqQ authored Jul 23, 2024
1 parent 60000e4 commit 5f5562d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
9 changes: 5 additions & 4 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,11 @@ def _apply_model_extraction_transformation(
output_node = name_to_node_mapping[output_name]

result_name = get_result_node_name(output_name, output_port_id)
if output_node.get_element_type() != outputs_type:
output_node = opset.convert(output_node, destination_type=outputs_type)
new_result = opset.result(output_node, name=result_name)
result_tensor_names = [result_name] + list(output_node.output(0).get_names())
output_port = output_node.output(output_port_id)
if output_port.get_element_type() != outputs_type:
output_port = opset.convert(output_node, destination_type=outputs_type).output(0)
new_result = opset.result(output_port, name=result_name)
result_tensor_names = [result_name] + list(output_port.get_names())
OVModelTransformer._update_tensor_names([new_result.get_output_tensor(0)], result_tensor_names)
results.append(new_result)

Expand Down
8 changes: 6 additions & 2 deletions tests/post_training/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def __init__(self):
self.conv_2 = self._build_conv(2, 3, 2)
self.conv_3 = self._build_conv(1, 2, 3)
self.conv_4 = self._build_conv(2, 3, 1)
self.conv_5 = self._build_conv(3, 2, 2)
self.conv_5 = self._build_conv(3, 2, 1)
self.max_pool = torch.nn.MaxPool2d((2, 2))
self.conv_6 = self._build_conv(2, 3, 1)

def _build_conv(self, in_channels=1, out_channels=2, kernel_size=2):
conv = create_conv(in_channels, out_channels, kernel_size)
Expand All @@ -198,7 +200,9 @@ def forward(self, x):
x_2 = self.conv_3(x)
x_2 = self.conv_4(F.relu(x_2))
x_1_2 = torch.concat([x_1, x_2])
return self.conv_5(F.relu(x_1_2))
x = self.conv_5(F.relu(x_1_2))
x = self.max_pool(x)
return self.conv_6(x)


class LinearMultiShapeModel(nn.Module):
Expand Down

0 comments on commit 5f5562d

Please sign in to comment.