From 6a06453ed4020806cac20bd8df37660a559f915d Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Thu, 8 Apr 2021 13:06:54 -0400 Subject: [PATCH] Support training with the vanilla module --- test/test_engine.py | 1 - yolort/models/yolo_module.py | 28 +++++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/test/test_engine.py b/test/test_engine.py index 91773f27..846c18a8 100644 --- a/test/test_engine.py +++ b/test/test_engine.py @@ -95,7 +95,6 @@ def test_train_with_vanilla_module(self): self.assertIsInstance(out["bbox_regression"], torch.Tensor) self.assertIsInstance(out["objectness"], torch.Tensor) - @unittest.skip("Just ignore this.") def test_train_one_step(self): # Load model model = yolov5s() diff --git a/yolort/models/yolo_module.py b/yolort/models/yolo_module.py index 88fb8810..cfd2db6f 100644 --- a/yolort/models/yolo_module.py +++ b/yolort/models/yolo_module.py @@ -77,9 +77,31 @@ def forward( # Transform the input samples, targets = self.transform(inputs, targets) # Compute the detections - detections = self.model(samples.tensors, targets=targets) - # Rescale coordinate - detections = self.transform.postprocess(detections, samples.image_sizes, original_image_sizes) + outputs = self.model(samples.tensors, targets=targets) + + losses = {} + detections: List[Dict[str, Tensor]] = [] + + if self.training: + # compute the losses + losses = outputs + else: + # Rescale coordinate + detections = self.transform.postprocess(outputs, samples.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + return losses, detections + else: + return self.eager_outputs(losses, detections) + + @torch.jit.unused + def eager_outputs( + self, + losses: Dict[str, Tensor], + detections: List[Dict[str, Tensor]], + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: + return losses return detections