Skip to content

Commit

Permalink
Merge branch 'DA-52-unbiased-pl' into DA-52-cos-score-weighting-poole…
Browse files Browse the repository at this point in the history
…d-refactor

# Conflicts:
#	pcdet/models/detectors/pv_rcnn_ssl.py
#	pcdet/models/roi_heads/pvrcnn_head.py
#	pcdet/models/roi_heads/roi_head_template.py
#	pcdet/utils/__init__.py
#	tools/cfgs/kitti_models/pv_rcnn_ssl_60.yaml
  • Loading branch information
fnozarian committed Sep 2, 2023
2 parents 3ba611a + 88aed6e commit 1885ef1
Show file tree
Hide file tree
Showing 9 changed files with 1,020 additions and 9 deletions.
88 changes: 83 additions & 5 deletions pcdet/models/detectors/pv_rcnn_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@
from pcdet.utils.prototype_utils import feature_bank_registry
from tools.visual_utils import open3d_vis_utils as V
from collections import defaultdict
from pcdet.utils.weighting_methods import build_thresholding_method
from visual_utils import open3d_vis_utils as V


class DynamicThreshRegistry(object):
def __init__(self, **kwargs):
self._tag_metrics = {}
self.dataset = kwargs.get('dataset', None)
self.model_cfg = kwargs.get('model_cfg', None)

def get(self, tag=None):
if tag is None:
tag = 'default'
if tag in self._tag_metrics.keys():
metric = self._tag_metrics[tag]
else:
metric = build_thresholding_method(tag=tag, dataset=self.dataset, config=self.model_cfg)
self._tag_metrics[tag] = metric
return metric

def tags(self):
return self._tag_metrics.keys()


class PVRCNN_SSL(Detector3DTemplate):
Expand All @@ -38,7 +60,7 @@ def __init__(self, model_cfg, num_class, dataset):
self.unlabeled_weight = model_cfg.UNLABELED_WEIGHT
self.no_nms = model_cfg.NO_NMS
self.supervise_mode = model_cfg.SUPERVISE_MODE

self.thresh_registry = DynamicThreshRegistry(dataset=self.dataset, model_cfg=model_cfg)
for bank_configs in model_cfg.get("FEATURE_BANK_LIST", []):
feature_bank_registry.register(tag=bank_configs["NAME"], **bank_configs)

Expand Down Expand Up @@ -110,7 +132,22 @@ def forward(self, batch_dict):

return pred_dicts, recall_dicts, {}

def _gen_pseudo_labels(self, batch_dict_ema):
def _rectify_pl_scores(self, batch_dict_ema, unlabeled_inds):
thresh_reg = self.thresh_registry.get(tag='pl_adaptive_thresh')
pred_weak_aug_before_nms = torch.sigmoid(batch_dict_ema['batch_cls_preds']).detach().clone()
# to be used later for updating the EMA (p_model/p_target)
pred_weak_aug_before_nms_org = pred_weak_aug_before_nms.clone()
if thresh_reg.iteration_count > 0:
pred_weak_aug_unlab_before_nms = pred_weak_aug_before_nms[unlabeled_inds, ...]
pred_weak_aug_unlab_before_nms_aligned = pred_weak_aug_unlab_before_nms * (thresh_reg.ema_pred_weak_aug_lab_before_nms + 1e-6) / (thresh_reg.ema_pred_weak_aug_unlab_before_nms + 1e-6)
pred_weak_aug_unlab_before_nms_aligned = thresh_reg.normalize_(pred_weak_aug_unlab_before_nms_aligned)
pred_weak_aug_before_nms[unlabeled_inds, ...] = pred_weak_aug_unlab_before_nms_aligned

batch_dict_ema['batch_cls_preds_org'] = pred_weak_aug_before_nms_org
batch_dict_ema['batch_cls_preds'] = pred_weak_aug_before_nms
batch_dict_ema['cls_preds_normalized'] = True

def _gen_pseudo_labels(self, batch_dict_ema, ulb_inds):
with torch.no_grad():
# self.pv_rcnn_ema.eval() # https://github.com/yezhen17/3DIoUMatch-PVRCNN/issues/6
for cur_module in self.pv_rcnn_ema.module_list:
Expand All @@ -119,6 +156,9 @@ def _gen_pseudo_labels(self, batch_dict_ema):
except TypeError as e:
batch_dict_ema = cur_module(batch_dict_ema)

if self.model_cfg.ROI_HEAD.ADAPTIVE_THRESH_CONFIG.get('ENABLE', False):
self._rectify_pl_scores(batch_dict_ema, ulb_inds)

pseudo_labels, _ = self.pv_rcnn_ema.post_processing(batch_dict_ema, no_recall_dict=True)

return pseudo_labels
Expand Down Expand Up @@ -150,7 +190,7 @@ def _forward_training(self, batch_dict):
lbl_inds, ulb_inds = self._prep_batch_dict(batch_dict)
batch_dict_ema = self._split_ema_batch(batch_dict)

pseudo_labels = self._gen_pseudo_labels(batch_dict_ema)
pseudo_labels = self._gen_pseudo_labels(batch_dict_ema, ulb_inds)

pseudo_boxes, pseudo_scores, pseudo_sem_scores = self._filter_pseudo_labels(pseudo_labels, ulb_inds)

Expand All @@ -162,6 +202,19 @@ def _forward_training(self, batch_dict):
for cur_module in self.pv_rcnn.module_list:
batch_dict = cur_module(batch_dict)

if self.model_cfg.ROI_HEAD.ADAPTIVE_THRESH_CONFIG.get('ENABLE', False):
pred_strong_aug_before_nms_org = torch.sigmoid(batch_dict['batch_cls_preds']).detach().clone()
pred_dicts_std, recall_dicts_std = self.pv_rcnn_ema.post_processing(batch_dict, no_recall_dict=True)

metrics_input = defaultdict(list)
for ind in range(len(pred_dicts_std)):
batch_type = 'unlab' if ind in ulb_inds else 'lab'
metrics_input[f'pred_weak_aug_{batch_type}_before_nms'].append(batch_dict_ema['batch_cls_preds_org'][ind])
metrics_input[f'pred_weak_aug_{batch_type}_after_nms'].append(pseudo_labels[ind]['pred_scores'].clone().detach())
metrics_input[f'pred_strong_aug_{batch_type}_before_nms'].append(pred_strong_aug_before_nms_org[ind])
metrics_input[f'pred_strong_aug_{batch_type}_after_nms'].append(pred_dicts_std[ind]['pred_scores'].clone().detach())
self.thresh_registry.get(tag='pl_adaptive_thresh').update(**metrics_input)

# Update the bank with student's features from augmented labeled data
bank = feature_bank_registry.get('gt_aug_lbl_prototypes')
sa_gt_lbl_inputs = self._prep_bank_inputs(batch_dict, lbl_inds, bank.num_points_thresh)
Expand Down Expand Up @@ -211,6 +264,13 @@ def _forward_training(self, batch_dict):
for tag in feature_bank_registry.tags():
feature_bank_registry.get(tag).compute()

# update dynamic thresh results
for tag in self.thresh_registry.tags():
results = self.thresh_registry.get(tag).compute()
if results:
tag = tag + "/" if tag else ''
tb_dict_.update({tag + key: val for key, val in results.items()})

for tag in metrics_registry.tags():
results = metrics_registry.get(tag).compute()
if results is not None:
Expand Down Expand Up @@ -441,7 +501,13 @@ def _filter_pseudo_labels(self, pred_dicts, unlabeled_inds):
pseudo_scores.append(pseudo_score)
continue

conf_thresh = torch.tensor(self.thresh, device=pseudo_label.device).unsqueeze(
pl_thresh = self.thresh
if self.model_cfg.ROI_HEAD.ADAPTIVE_THRESH_CONFIG.get('ENABLE', False):
thresh_reg = self.thresh_registry.get(tag='pl_adaptive_thresh')
if thresh_reg.relative_ema_threshold is not None:
pl_thresh = [thresh_reg.relative_ema_threshold.item()] * len(self.thresh)

conf_thresh = torch.tensor(pl_thresh, device=pseudo_label.device).unsqueeze(
0).repeat(len(pseudo_label), 1).gather(dim=1, index=(pseudo_label - 1).unsqueeze(-1))

sem_conf_thresh = torch.tensor(self.sem_thresh, device=pseudo_label.device).unsqueeze(
Expand Down Expand Up @@ -576,4 +642,16 @@ def load_params_from_file(self, filename, logger, to_cpu=False):
if key not in update_model_state:
logger.info('Not updated weight %s: %s' % (key, str(state_dict[key].shape)))

logger.info('==> Done (loaded %d/%d)' % (len(update_model_state), len(self.state_dict())))
logger.info('==> Done (loaded %d/%d)' % (len(update_model_state), len(self.state_dict())))

def update_adaptive_thresholding_metrics(self, pred_dicts, unlabeled_inds, tag = 'pl_adaptive_thresh'):
metrics_input = defaultdict(list)
for ind in unlabeled_inds:
pseudo_score = pred_dicts[ind]['pred_scores']
pseudo_label = pred_dicts[ind]['pred_labels']
pseudo_sem_score = pred_dicts[ind]['pred_sem_scores']
if len(pseudo_label):
metrics_input['pred_labels'].append(pseudo_label)
metrics_input['pseudo_score'].append(pseudo_score)
metrics_input['pseudo_sem_score'].append(pseudo_sem_score)
self.thresh_registry.get(tag).update(**metrics_input)
3 changes: 0 additions & 3 deletions pcdet/models/roi_heads/roi_head_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def proposal_layer(self, batch_dict, nms_config):
batch_cls_preds = batch_dict['batch_cls_preds']
rois = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE, batch_box_preds.shape[-1]))
roi_scores = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE))
roi_scores_multiclass = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE, 3))
roi_labels = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE), dtype=torch.long)

for index in range(batch_size):
Expand All @@ -142,12 +141,10 @@ def proposal_layer(self, batch_dict, nms_config):

rois[index, :len(selected), :] = box_preds[selected]
roi_scores[index, :len(selected)] = cur_roi_scores[selected]
roi_scores_multiclass[index, :len(selected), :] = cls_preds[selected]
roi_labels[index, :len(selected)] = cur_roi_labels[selected]

batch_dict['rois'] = rois
batch_dict['roi_scores'] = roi_scores
batch_dict['roi_scores_multiclass'] = torch.sigmoid(roi_scores_multiclass)
batch_dict['roi_labels'] = roi_labels + 1
batch_dict['has_class_labels'] = True if batch_cls_preds.shape[-1] > 1 else False
batch_dict.pop('batch_index', None)
Expand Down
19 changes: 19 additions & 0 deletions pcdet/utils/weighting_methods/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .freematch import FreeMatchThreshold
from .adamatch import AdaMatchThreshold
from .consistant_teacher import AdaptiveThresholdGMM
from .softmatch import SoftMatchThreshold

__all__ = {
'FreeMatchThreshold': FreeMatchThreshold,
'AdaMatchThreshold': AdaMatchThreshold,
'AdaptiveThresholdGMM': AdaptiveThresholdGMM,
# 'SoftMatchThreshold': SoftMatchThreshold # not finalised yet
}


def build_thresholding_method(tag, dataset, config):
model = __all__[config.ROI_HEAD.ADAPTIVE_THRESH_CONFIG.NAME](
tag=tag, dataset=dataset, config=config
)

return model
Loading

0 comments on commit 1885ef1

Please sign in to comment.