Skip to content

Commit

Permalink
fix predict dim (#1739)
Browse files Browse the repository at this point in the history
Co-authored-by: Weisu Yin <[email protected]>
  • Loading branch information
yinweisu and yinweisu authored Mar 29, 2022
1 parent dae504a commit 0baa50a
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -673,12 +673,13 @@ def _predict(self, x, **kwargs):
return pd.DataFrame(results)
elif not isinstance(x, torch.Tensor):
raise ValueError('Input is not supported: {}'.format(type(x)))
assert len(x.shape) == 4 and x.shape[1] == 3, f"Expect input to be (n, 3, h, w), given {x.shape}"
with torch.no_grad():
input = x.to(self.ctx[0])
label = self.net(input)
if self._problem_type in [MULTICLASS, BINARY]:
topk = min(5, self.num_class)
probs = nn.functional.softmax(label, dim=0).cpu().numpy().flatten()
probs = nn.functional.softmax(label, dim=1).cpu().numpy().flatten()
topk_inds = label.topk(topk)[1].cpu().numpy().flatten()
if with_proba:
df = pd.DataFrame([{'image_proba': probs.tolist()}])
Expand Down

0 comments on commit 0baa50a

Please sign in to comment.