diff --git a/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py index 6fe9559c9..387cfb111 100644 --- a/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py +++ b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py @@ -549,7 +549,9 @@ def _evaluate(self, val_data): return self.validate(self.net, val_data, validate_loss_fn, amp_autocast=self._amp_autocast) def _predict(self, x, **kwargs): - if isinstance(x, pd.DataFrame): + if isinstance(x, str): + return self._predict((x,)) + elif isinstance(x, pd.DataFrame): assert 'image' in x.columns, "Expect column `image` for input images" df = self._predict(tuple(x['image'])) return df.reset_index(drop=True) @@ -576,7 +578,6 @@ def _predict(self, x, **kwargs): input = input.to(self.ctx[0]) labels = self.net(input) for l in labels: - print(l) probs = nn.functional.softmax(l, dim=0).cpu().numpy().flatten() topk_inds = l.topk(topk)[1].cpu().numpy().flatten() results.extend([{'class': self.classes[topk_inds[k]], @@ -586,7 +587,7 @@ def _predict(self, x, **kwargs): for k in range(topk)]) idx += 1 return pd.DataFrame(results) - elif not isinstance(x, torch.tensor): + elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) with torch.no_grad(): input = x.to(self.ctx[0]) @@ -602,7 +603,9 @@ def _predict(self, x, **kwargs): def _predict_feature(self, x, **kwargs): - if isinstance(x, pd.DataFrame): + if isinstance(x, str): + return self._predict_feature((x,)) + elif isinstance(x, pd.DataFrame): assert 'image' in x.columns, "Expect column `image` for input images" df = self._predict_feature(tuple(x['image'])) df = df.set_index(x.index) @@ -638,7 +641,7 @@ def _predict_feature(self, x, **kwargs): df = pd.DataFrame(results) df['image'] = x return df - elif not isinstance(x, torch.tensor): + elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) with torch.no_grad(): input = x.to(self.ctx[0]) diff --git a/tests/auto/test_torch_auto_estimators.py b/tests/auto/test_torch_auto_estimators.py index 458cef7b7..a3e47a254 100644 --- a/tests/auto/test_torch_auto_estimators.py +++ b/tests/auto/test_torch_auto_estimators.py @@ -30,6 +30,7 @@ def test_image_classification_estimator(): est = TorchImageClassificationEstimator({'img_cls': {'model': 'resnet18'}, 'train': {'epochs': 1}, 'gpus': list(range(get_gpu_count()))}) res = est.fit(IMAGE_CLASS_DATASET) est.predict(IMAGE_CLASS_TEST) + est.predict(IMAGE_CLASS_TEST.iloc[0]['image']) est.predict_feature(IMAGE_CLASS_TEST) _save_load_test(est, 'test.pkl')