Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ViTReciproCAM for VisionTransformer backbone (cls task) #2403

Merged
merged 10 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
#
from functools import partial

import torch
from mmcls.models.backbones.vision_transformer import VisionTransformer
from mmcls.models.builder import CLASSIFIERS
from mmcls.models.classifiers.image import ImageClassifier
from mmcls.models.utils import resize_pos_embed

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ViTReciproCAMHook
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names
Expand Down Expand Up @@ -300,6 +304,56 @@ def extract_feat(self, img):
ReciproCAMHook,
)

def _extract_vit_feat(model, x):
negvet marked this conversation as resolved.
Show resolved Hide resolved
"""Modified forward from mmcls.models.backbones.vision_transformer.VisionTransformer.forward()."""
B = x.shape[0]
x, patch_resolution = model.backbone.patch_embed(x)

# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = model.backbone.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + resize_pos_embed(
model.backbone.pos_embed,
model.backbone.patch_resolution,
patch_resolution,
mode=model.backbone.interpolate_mode,
num_extra_tokens=model.backbone.num_extra_tokens,
)
x = model.backbone.drop_after_pos(x)

if not model.backbone.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]

feat = None
layernorm_feat = None
for i, layer in enumerate(model.backbone.layers):
if i == len(model.backbone.layers) - 1:
layernorm_feat = layer.norm1(x)

x = layer(x)

if i == len(model.backbone.layers) - 1 and model.backbone.final_norm:
x = model.backbone.norm1(x)

if i in model.backbone.out_indices:
B, _, C = x.shape
if model.backbone.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if model.backbone.output_cls_token:
feat = [patch_token, cls_token]
else:
feat = patch_token
if model.with_neck:
feat = model.neck(feat)
return feat, layernorm_feat

@FUNCTION_REWRITER.register_rewriter(
"otx.algorithms.classification.adapters.mmcls.models.classifiers.CustomImageClassifier.extract_feat"
)
Expand All @@ -320,12 +374,22 @@ def sam_image_classifier__extract_feat(ctx, self, img): # pylint: disable=unuse
)
def sam_image_classifier__simple_test(ctx, self, img, img_metas): # pylint: disable=unused-argument
"""Simple test function used for inference for SAMClassifier with mmdeploy."""
feat, backbone_feat = self.extract_feat(img)
vit_backbone = isinstance(self.backbone, VisionTransformer)
if vit_backbone:
feat, layernorm_feat = _extract_vit_feat(self, img)
else:
feat, backbone_feat = self.extract_feat(img)
logit = self.head.simple_test(feat)

if ctx.cfg["dump_features"]:
saliency_map = ReciproCAMHook(self).func(backbone_feat)
feature_vector = FeatureVectorHook.func(backbone_feat)
if vit_backbone:
assert self.backbone.with_cls_token
_, cls_token = feat
feature_vector = cls_token
saliency_map = ViTReciproCAMHook(self).func(layernorm_feat)
else:
saliency_map = ReciproCAMHook(self).func(backbone_feat)
feature_vector = FeatureVectorHook.func(backbone_feat)
return logit, feature_vector, saliency_map

return logit
29 changes: 19 additions & 10 deletions src/otx/algorithms/classification/adapters/mmcls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Type, Union

import torch
from mmcls.apis import train_model
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.models.backbones.vision_transformer import VisionTransformer
from mmcls.utils import collect_env
from mmcv.runner import wrap_fp16_model
from mmcv.utils import Config, ConfigDict
Expand All @@ -41,6 +42,7 @@
EigenCamHook,
FeatureVectorHook,
ReciproCAMHook,
ViTReciproCAMHook,
)
from otx.algorithms.common.adapters.mmcv.utils import (
adapt_batch_size,
Expand Down Expand Up @@ -290,10 +292,13 @@ def hook(module, inp, outp):
model.register_forward_hook(hook)

model_type = cfg.model.backbone.type.split(".")[-1] # mmcls.VisionTransformer => VisionTransformer
if (
forward_explainer_hook: Union[nullcontext, BaseRecordingForwardHook]
if model_type == "VisionTransformer":
forward_explainer_hook = ViTReciproCAMHook(feature_model)
elif (
not dump_saliency_map or model_type in TRANSFORMER_BACKBONES
): # TODO: remove latter "or" condition after resolving Issue#2098
forward_explainer_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext()
forward_explainer_hook = nullcontext()
else:
forward_explainer_hook = ReciproCAMHook(feature_model)
if (
Expand Down Expand Up @@ -473,12 +478,6 @@ def _get_mem_cache_size():

def _explain_model(self, dataset: DatasetEntity, explain_parameters: Optional[ExplainParameters]):
"""Explain function in MMClassificationTask."""
explainer_hook_selector = {
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
"classwisesaliencymap": ReciproCAMHook,
}

self._data_cfg = ConfigDict(
data=ConfigDict(
train=ConfigDict(
Expand Down Expand Up @@ -525,9 +524,19 @@ def hook(module, inp, outp): # pylint: disable=unused-argument
model.register_forward_pre_hook(pre_hook)
model.register_forward_hook(hook)

per_class_xai_algorithm: Union[Type[ViTReciproCAMHook], Type[ReciproCAMHook]]
if isinstance(model.module.backbone, VisionTransformer):
per_class_xai_algorithm = ViTReciproCAMHook
else:
per_class_xai_algorithm = ReciproCAMHook
explainer_hook_selector = {
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
"classwisesaliencymap": per_class_xai_algorithm,
}
explainer = explain_parameters.explainer if explain_parameters else None
if explainer is not None:
explainer_hook = explainer_hook_selector.get(explainer.lower(), None)
explainer_hook = explainer_hook_selector.get(explainer.lower())
else:
explainer_hook = None
if explainer_hook is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
from torch.nn import LayerNorm

from otx.algorithms.classification import MMCLS_AVAILABLE

Expand Down Expand Up @@ -238,3 +239,130 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w:
mosaic_feature_map_mask[k, :, i, j] = torch.ones(c).to(feature_map.device)
mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask
return mosaic_feature_map


class ViTReciproCAMHook(BaseRecordingForwardHook):
negvet marked this conversation as resolved.
Show resolved Hide resolved
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
"""Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers.

Args:
module (torch.nn.Module): The PyTorch module.
layer_index (int): Index of the target transformer_encoder layer.
use_gaussian (bool): Defines kernel type for mosaic feature map generation.
If True, use gaussian 3x3 kernel. If False, use 1x1 kernel.
cls_token (bool): If True, includes classification token into the mosaic feature map.
"""

def __init__(
self, module: torch.nn.Module, layer_index: int = -1, use_gaussian: bool = True, cls_token: bool = True
):
super().__init__(module)
self._layer_index = layer_index
self._target_layernorm = self._get_target_layernorm()
self._final_norm = module.backbone.norm1 if module.backbone.final_norm else None
self._neck = module.neck if module.with_neck else None
self._num_classes = module.head.num_classes
self._use_gaussian = use_gaussian
self._cls_token = cls_token

def _get_target_layernorm(self) -> torch.nn.Module:
"""Returns the first (out of two) layernorm layer from the layer_index backbone layer."""
assert self._layer_index < 0, "negative index expected, e.g. -1 for the last layer."
layernorm_layers = []
for module in self._module.backbone.modules():
if isinstance(module, LayerNorm):
layernorm_layers.append(module)
assert len(layernorm_layers) == self._module.backbone.num_layers * 2 + int(self._module.backbone.final_norm)
target_layernorm_index = self._layer_index * 2 - int(self._module.backbone.final_norm)
return layernorm_layers[target_layernorm_index]

def func(self, feature_map: torch.Tensor, _: int = -1) -> torch.Tensor:
"""Generate the class-wise saliency maps using ViTRecipro-CAM and then normalizing to (0, 255).

Args:
feature_map (torch.Tensor): feature maps from target layernorm layer.

Returns:
torch.Tensor: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
"""
batch_size, token_number, _ = feature_map.size()
h = w = int((token_number - 1) ** 0.5)
saliency_maps = torch.empty(batch_size, self._num_classes, h, w)
for i in range(batch_size):
mosaic_feature_map = self._get_mosaic_feature_map(feature_map[i])
mosaic_prediction = self._predict_from_feature_map(mosaic_feature_map)
saliency_maps[i] = mosaic_prediction.transpose(1, 0).reshape((self._num_classes, h, w))

if self._norm_saliency_maps:
saliency_maps = saliency_maps.reshape((batch_size, self._num_classes, h * w))
saliency_maps = self._normalize_map(saliency_maps)

saliency_maps = saliency_maps.reshape((batch_size, self._num_classes, h, w))
return saliency_maps

def _get_mosaic_feature_map(self, feature_map: torch.Tensor) -> torch.Tensor:
token_number, dim = feature_map.size()
mosaic_feature_map = torch.zeros(token_number - 1, token_number, dim).to(feature_map.device)
h = w = int((token_number - 1) ** 0.5)

if self._use_gaussian:
if self._cls_token:
mosaic_feature_map[:, 0, :] = feature_map[0, :]
feature_map_spacial = feature_map[1:, :].reshape(1, h, w, dim)
feature_map_spacial_repeated = feature_map_spacial.repeat(h * w, 1, 1, 1) # 196, 14, 14, 192

spacial_order = torch.arange(h * w).reshape(h, w)
gaussian = torch.tensor(
[[1 / 16.0, 1 / 8.0, 1 / 16.0], [1 / 8.0, 1 / 4.0, 1 / 8.0], [1 / 16.0, 1 / 8.0, 1 / 16.0]]
).to(feature_map.device)
mosaic_feature_map_mask_padded = torch.zeros(h * w, h + 2, w + 2).to(feature_map.device)
for i in range(h):
for j in range(w):
k = spacial_order[i, j]
i_pad = i + 1
j_pad = j + 1
mosaic_feature_map_mask_padded[k, i_pad - 1 : i_pad + 2, j_pad - 1 : j_pad + 2] = gaussian
mosaic_feature_map_mask = mosaic_feature_map_mask_padded[:, 1:-1, 1:-1]
mosaic_feature_map_mask = torch.tensor(mosaic_feature_map_mask.unsqueeze(3).repeat(1, 1, 1, dim))

mosaic_fm_wo_cls_token = feature_map_spacial_repeated * mosaic_feature_map_mask
mosaic_feature_map[:, 1:, :] = mosaic_fm_wo_cls_token.reshape(h * w, h * w, dim)
else:
feature_map_repeated = feature_map.unsqueeze(0).repeat(h * w, 1, 1)
mosaic_feature_map_mask = torch.zeros(h * w, token_number).to(feature_map.device)
for i in range(h * w):
mosaic_feature_map_mask[i, i + 1] = torch.ones(1).to(feature_map.device)
if self._cls_token:
mosaic_feature_map_mask[:, 0] = torch.ones(1).to(feature_map.device)
mosaic_feature_map_mask = torch.tensor(mosaic_feature_map_mask.unsqueeze(2).repeat(1, 1, dim))
mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask

return mosaic_feature_map

def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
# Part of the target transformer_encoder layer (except first LayerNorm)
target_layer = self._module.backbone.layers[self._layer_index]
x = x + target_layer.attn(x)
x = target_layer.ffn(target_layer.norm2(x), identity=x)

# Rest transformer_encoder layers, if not the last one picked as a target
if self._layer_index < -1:
for layer in self._module.backbone.layers[(self._layer_index + 1) :]:
x = layer(x)

if self._final_norm:
x = self._final_norm(x)
if self._neck:
x = self._neck(x)

cls_token = x[:, 0]
layer_output = [None, cls_token]
logit = self._module.head.simple_test(layer_output)
if isinstance(logit, list):
logit = torch.from_numpy(np.array(logit))
return logit

def __enter__(self) -> BaseRecordingForwardHook:
"""Enter."""
self._handle = self._target_layernorm.register_forward_hook(self._recording_forward)
return self
Loading