-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathimage_classification.py
70 lines (60 loc) · 2.41 KB
/
image_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Any, Dict, List
import timm
import torch
from app.pipelines import Pipeline
from PIL import Image
from timm.data import (
CustomDatasetInfo,
ImageNetInfo,
create_transform,
infer_imagenet_subset,
resolve_model_data_config,
)
class ImageClassificationPipeline(Pipeline):
def __init__(self, model_id: str):
self.model = timm.create_model(f"hf_hub:{model_id}", pretrained=True)
self.transform = create_transform(
**resolve_model_data_config(self.model, use_test_size=True)
)
self.top_k = min(self.model.num_classes, 5)
self.model.eval()
self.dataset_info = None
label_names = self.model.pretrained_cfg.get("label_names", None)
label_descriptions = self.model.pretrained_cfg.get("label_descriptions", None)
if label_names is None:
# if no labels added to config, use imagenet labeller in timm
imagenet_subset = infer_imagenet_subset(self.model)
if imagenet_subset:
self.dataset_info = ImageNetInfo(imagenet_subset)
else:
# fallback label names
label_names = [f"LABEL_{i}" for i in range(self.model.num_classes)]
if self.dataset_info is None:
self.dataset_info = CustomDatasetInfo(
label_names=label_names,
label_descriptions=label_descriptions,
)
def __call__(self, inputs: Image.Image) -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
It is preferred if the returned list is in decreasing `score` order
"""
im = inputs.convert("RGB")
inputs = self.transform(im).unsqueeze(0)
with torch.no_grad():
out = self.model(inputs)
probabilities = out.squeeze(0).softmax(dim=0)
values, indices = torch.topk(probabilities, self.top_k)
labels = [
{
"label": self.dataset_info.index_to_description(i, detailed=True),
"score": v.item(),
}
for i, v in zip(indices, values)
]
return labels