diff --git a/tools/model_converters/pytorch2onnx.py b/tools/model_converters/pytorch2onnx.py index f0dabb00ad..a9760432df 100644 --- a/tools/model_converters/pytorch2onnx.py +++ b/tools/model_converters/pytorch2onnx.py @@ -54,7 +54,7 @@ def pytorch2onnx(model, data = torch.cat((merged, trimap), dim=1).float() data = model.resize_inputs(data) elif model_type == 'image_restorer': - data = input['inputs'].unsqueeze(0) + data = input['inputs'].unsqueeze(0).float() elif model_type == 'inpainting': masks = input['data_samples'].mask.data.unsqueeze(0) img = input['inputs'].unsqueeze(0)