diff --git a/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py index 1be5be9c8..52aaa3226 100644 --- a/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py +++ b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py @@ -673,12 +673,13 @@ def _predict(self, x, **kwargs): return pd.DataFrame(results) elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) + assert len(x.shape) == 4 and x.shape[1] == 3, f"Expect input to be (n, 3, h, w), given {x.shape}" with torch.no_grad(): input = x.to(self.ctx[0]) label = self.net(input) if self._problem_type in [MULTICLASS, BINARY]: topk = min(5, self.num_class) - probs = nn.functional.softmax(label, dim=0).cpu().numpy().flatten() + probs = nn.functional.softmax(label, dim=1).cpu().numpy().flatten() topk_inds = label.topk(topk)[1].cpu().numpy().flatten() if with_proba: df = pd.DataFrame([{'image_proba': probs.tolist()}])