diff --git a/src/otx/algorithms/detection/adapters/mmdet/models/detectors/mean_teacher.py b/src/otx/algorithms/detection/adapters/mmdet/models/detectors/mean_teacher.py index ac8f99e5240..150140c9d7d 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/models/detectors/mean_teacher.py +++ b/src/otx/algorithms/detection/adapters/mmdet/models/detectors/mean_teacher.py @@ -14,6 +14,7 @@ from mmdet.core.mask.structures import BitmapMasks from mmdet.models import DETECTORS, build_detector from mmdet.models.detectors import BaseDetector +from torch import distributed as dist from otx.utils.logger import get_logger @@ -182,22 +183,29 @@ def forward_train( pseudo_bboxes, pseudo_labels, pseudo_masks, pseudo_ratio = self.generate_pseudo_labels( teacher_outputs, device=current_device, img_meta=ul_img_metas, **kwargs ) - if self.filter_empty_annotations: - non_empty = [bool(len(i)) for i in pseudo_labels] - pseudo_bboxes = [pb for i, pb in enumerate(pseudo_bboxes) if non_empty[i]] - pseudo_labels = [pl for i, pl in enumerate(pseudo_labels) if non_empty[i]] - pseudo_masks = [pm for i, pm in enumerate(pseudo_masks) if non_empty[i]] - ul_img_metas = [im for i, im in enumerate(ul_img_metas) if non_empty[i]] - ul_img = ul_img[non_empty] - else: - non_empty = [True] - if self.visualize: - self._visual_online(ul_img, pseudo_bboxes, pseudo_labels) + non_empty = [bool(len(i)) for i in pseudo_labels] if self.filter_empty_annotations else [True] + get_unlabeled_loss = pseudo_ratio >= self.min_pseudo_label_ratio and any(non_empty) + + if dist.is_initialized(): + reduced_get_unlabeled_loss = torch.tensor(int(get_unlabeled_loss)).to(current_device) + dist.all_reduce(reduced_get_unlabeled_loss) + if dont_have_to_train := not get_unlabeled_loss and reduced_get_unlabeled_loss > 0: + get_unlabeled_loss = True + non_empty[0] = True + losses.update(ps_ratio=torch.tensor([pseudo_ratio], device=current_device)) # Unsupervised loss # Compute only if min_pseudo_label_ratio is reached - if pseudo_ratio >= self.min_pseudo_label_ratio and any(non_empty): + if get_unlabeled_loss: + if self.filter_empty_annotations: + pseudo_bboxes = [pb for i, pb in enumerate(pseudo_bboxes) if non_empty[i]] + pseudo_labels = [pl for i, pl in enumerate(pseudo_labels) if non_empty[i]] + pseudo_masks = [pm for i, pm in enumerate(pseudo_masks) if non_empty[i]] + ul_img_metas = [im for i, im in enumerate(ul_img_metas) if non_empty[i]] + ul_img = ul_img[non_empty] + if self.visualize: + self._visual_online(ul_img, pseudo_bboxes, pseudo_labels) if self.bg_loss_weight >= 0.0: self.model_s.bbox_head.bg_loss_weight = self.bg_loss_weight if self.model_t.with_mask: @@ -214,7 +222,10 @@ def forward_train( if ul_loss_name.startswith("loss_"): ul_loss = ul_losses[ul_loss_name] target_loss = ul_loss_name.split("_")[-1] - if self.unlabeled_loss_weights[target_loss] == 0: + if dist.is_initialized(): + if dont_have_to_train: + self.unlabeled_loss_weights[target_loss] = 0 + elif self.unlabeled_loss_weights[target_loss] == 0: continue self._update_unlabeled_loss(losses, ul_loss, ul_loss_name, self.unlabeled_loss_weights[target_loss]) return losses diff --git a/tests/e2e/cli/detection/test_detection.py b/tests/e2e/cli/detection/test_detection.py index cc2efcd0b89..dd86fcf46d2 100644 --- a/tests/e2e/cli/detection/test_detection.py +++ b/tests/e2e/cli/detection/test_detection.py @@ -348,8 +348,6 @@ def test_otx_eval(self, template, tmp_dir_path): def test_otx_multi_gpu_train_semisl(self, template, tmp_dir_path): if not (Path(template.model_template_path).parent / "semisl").is_dir(): pytest.skip(f"Semi-SL training type isn't available for {template.name}") - if template.name == "ResNeXt101-ATSS": - pytest.skip(f"Issue#2705: multi-gpu training e2e test failure for {template.name}") tmp_dir_path = tmp_dir_path / "detection/test_multi_gpu_semisl" args_semisl_multigpu = copy.deepcopy(args_semisl) args_semisl_multigpu["--gpus"] = "0,1"