Skip to content

Commit

Permalink
fix unit test of inpainter inferencer
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Dec 15, 2022
1 parent 5b96a30 commit ac7c142
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmedit/apis/inferencers/inpainting_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def preprocess(self, img: InputsType, mask: InputsType) -> Dict:
# prepare data
_data = infer_pipeline(dict(gt_path=img, mask_path=mask))
data = dict()
data['inputs'] = _data['inputs'] / 255.0
data['inputs'] = dict(img=(_data['inputs'] / 255.0))
data = collate([data])
data['data_samples'] = [_data['data_samples']]
if 'cuda' in str(self.device):
Expand All @@ -58,7 +58,7 @@ def preprocess(self, img: InputsType, mask: InputsType) -> Dict:

# save masks and masked_imgs to visualize
self.masks = data['data_samples'][0].mask.data * 255
self.masked_imgs = data['inputs'][0]
self.masked_imgs = data['inputs']['img'][0]

return data

Expand Down
2 changes: 1 addition & 1 deletion mmedit/models/editors/aotgan/aot_inpaintor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def forward_tensor(self, inputs, data_samples):
masks = torch.stack(
list(d.mask.data for d in data_samples), dim=0) # N,1,H,W

masked_imgs = inputs['img'] # N,3,H,W
masked_imgs = inputs # N,3,H,W
masked_imgs = masked_imgs.float() + masks

input_xs = torch.cat([masked_imgs, masks], dim=1) # N,4,H,W
Expand Down

0 comments on commit ac7c142

Please sign in to comment.