Skip to content

Commit

Permalink
[Fix] fix dtype of pytorch2onnx (#1629)
Browse files Browse the repository at this point in the history
  • Loading branch information
Z-Fran authored Feb 10, 2023
1 parent 4db8c4f commit 78126da
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tools/model_converters/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def pytorch2onnx(model,
data = torch.cat((merged, trimap), dim=1).float()
data = model.resize_inputs(data)
elif model_type == 'image_restorer':
data = input['inputs'].unsqueeze(0)
data = input['inputs'].unsqueeze(0).float()
elif model_type == 'inpainting':
masks = input['data_samples'].mask.data.unsqueeze(0)
img = input['inputs'].unsqueeze(0)
Expand Down

0 comments on commit 78126da

Please sign in to comment.