Skip to content

Commit

Permalink
Support training with the vanilla module
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 8, 2021
1 parent c72794f commit 6a06453
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
1 change: 0 additions & 1 deletion test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def test_train_with_vanilla_module(self):
self.assertIsInstance(out["bbox_regression"], torch.Tensor)
self.assertIsInstance(out["objectness"], torch.Tensor)

@unittest.skip("Just ignore this.")
def test_train_one_step(self):
# Load model
model = yolov5s()
Expand Down
28 changes: 25 additions & 3 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,31 @@ def forward(
# Transform the input
samples, targets = self.transform(inputs, targets)
# Compute the detections
detections = self.model(samples.tensors, targets=targets)
# Rescale coordinate
detections = self.transform.postprocess(detections, samples.image_sizes, original_image_sizes)
outputs = self.model(samples.tensors, targets=targets)

losses = {}
detections: List[Dict[str, Tensor]] = []

if self.training:
# compute the losses
losses = outputs
else:
# Rescale coordinate
detections = self.transform.postprocess(outputs, samples.image_sizes, original_image_sizes)

if torch.jit.is_scripting():
return losses, detections
else:
return self.eager_outputs(losses, detections)

@torch.jit.unused
def eager_outputs(
self,
losses: Dict[str, Tensor],
detections: List[Dict[str, Tensor]],
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training:
return losses

return detections

Expand Down

0 comments on commit 6a06453

Please sign in to comment.