Skip to content

Commit

Permalink
Add testdata for networks with parametrized input dims
Browse files Browse the repository at this point in the history
  • Loading branch information
sl-sergei committed Aug 20, 2020
1 parent 9cbde3f commit 7dee77f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
Binary file not shown.
Binary file not shown.
45 changes: 45 additions & 0 deletions testdata/dnn/onnx/generate_onnx_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,48 @@ def forward(self, x, kernel, bias):
x = Variable(torch.randn(1, 2, 2))
model = Expand(shape=[2, -1, -1, -1])
save_data_and_model("expand_neg_batch", x, model)

def postprocess_model(model_path):
onnx_model = onnx.load(model_path)

def update_inputs_dims(model, input_dims):
"""
This function updates the sizes of dimensions of the model's inputs to the values
provided in input_dims. if the dim value provided is negative, a unique dim_param
will be set for that dimension.
"""
def update_dim(tensor, dim, i, j, dim_param_prefix):
dim_proto = tensor.type.tensor_type.shape.dim[j]
if isinstance(dim, int):
if dim >= 0:
dim_proto.dim_value = dim
else:
dim_proto.dim_param = dim_param_prefix + str(i) + '_' + str(j)
elif isinstance(dim, str):
dim_proto.dim_param = dim
else:
raise ValueError('Only int or str is accepted as dimension value, incorrect type: {}'.format(type(dim)))

for i, input_dim_arr in enumerate(input_dims):
for j, dim in enumerate(input_dim_arr):
update_dim(model.graph.input[i], dim, i, j, 'in_')

onnx.checker.check_model(model)
return model

onnx_model = update_inputs_dims(onnx_model, [[3, 'height', 'width']])
onnx.save(onnx_model, model_path)

class ReshapeAndConv(nn.Module):
def __init__(self):
super(ReshapeAndConv, self).__init__()
self.conv = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = x.unsqueeze(axis=0)
out = self.conv(x)
return out

x = Variable(torch.randn(3, 10, 10))
model = ReshapeAndConv()
save_data_and_model("reshape_and_conv_parameter_dims", x, model)
postprocess_model("models/reshape_and_conv_parameter_dims.onnx")
Binary file not shown.

0 comments on commit 7dee77f

Please sign in to comment.