forked from alexriedel1/detectron2-GradCAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
detectron2_gradcam.py
100 lines (87 loc) · 3.67 KB
/
detectron2_gradcam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from gradcam import GradCAM, GradCamPlusPlus
import detectron2.data.transforms as T
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.detection_utils import read_image
from detectron2.modeling import build_model
from detectron2.data.datasets import register_coco_instances
class Detectron2GradCAM():
"""
Attributes
----------
config_file : str
detectron2 model config file path
cfg_list : list
List of additional model configurations
root_dir : str [optional]
directory of coco.josn and dataset images for custom dataset registration
custom_dataset : str [optional]
Name of the custom dataset to register
"""
def __init__(self, config_file, cfg_list, img_path, root_dir=None, custom_dataset=None):
# load config from file
cfg = get_cfg()
cfg.merge_from_file(config_file)
if custom_dataset:
register_coco_instances(custom_dataset, {}, root_dir + "coco.json", root_dir)
cfg.DATASETS.TRAIN = (custom_dataset,)
MetadataCatalog.get(custom_dataset)
DatasetCatalog.get(custom_dataset)
if torch.cuda.is_available():
cfg.MODEL.DEVICE = "cuda"
else:
cfg.MODEL.DEVICE = "cpu"
cfg.merge_from_list(cfg_list)
cfg.freeze()
self.cfg = cfg
self._set_input_image(img_path)
def _set_input_image(self, img_path):
self.image = read_image(img_path, format="BGR")
self.image_height, self.image_width = self.image.shape[:2]
transform_gen = T.ResizeShortestEdge(
[self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MIN_SIZE_TEST], self.cfg.INPUT.MAX_SIZE_TEST
)
transformed_img = transform_gen.get_transform(self.image).apply_image(self.image)
self.input_tensor = torch.as_tensor(transformed_img.astype("float32").transpose(2, 0, 1)).requires_grad_(True)
def get_cam(self, target_instance, layer_name, grad_cam_instance):
"""
Calls the GradCAM instance
Parameters
----------
img : str
Path to inference image
target_instance : int
The target instance index
layer_name : str
Convolutional layer to perform GradCAM on
grad_cam_type : str
GradCAM or GradCAM++ (for multiple instances of the same object, GradCAM++ can be favorable)
Returns
-------
image_dict : dict
{"image" : <image>, "cam" : <cam>, "output" : <output>, "label" : <label>}
<image> original input image
<cam> class activation map resized to original image shape
<output> instances object generated by the model
<label> label of the
cam_orig : numpy.ndarray
unprocessed raw cam
"""
model = build_model(self.cfg)
checkpointer = DetectionCheckpointer(model)
checkpointer.load(self.cfg.MODEL.WEIGHTS)
input_image_dict = {"image": self.input_tensor, "height": self.image_height, "width": self.image_width}
grad_cam = grad_cam_instance(model, layer_name)
with grad_cam as cam:
cam, cam_orig, output = cam(input_image_dict, target_instance=target_instance)
output_dict = self.get_output_dict(cam, output, target_instance)
return output_dict, cam_orig
def get_output_dict(self, cam, output, target_instance):
image_dict = {}
image_dict["image"] = self.image
image_dict["cam"] = cam
image_dict["output"] = output
image_dict["label"] = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]).thing_classes[output["instances"].pred_classes[target_instance]]
return image_dict