Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 10, 2021
1 parent 0226dd9 commit 57db1e8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 18 deletions.
8 changes: 3 additions & 5 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from yolort.data import DetectionDataModule, contains_any_tensor, _helper as data_helper

from typing import Dict


def test_contains_any_tensor():
dummy_numpy = np.random.randn(3, 6)
Expand All @@ -25,7 +23,7 @@ def test_get_dataset():
# Test the datasets
image, target = next(iter(train_dataset))
assert isinstance(image, Tensor)
assert isinstance(target, Dict)
assert isinstance(target, dict)


def test_get_dataloader():
Expand All @@ -38,7 +36,7 @@ def test_get_dataloader():
assert isinstance(images[0], Tensor)
assert len(images[0]) == 3
assert len(targets) == batch_size
assert isinstance(targets[0], Dict)
assert isinstance(targets[0], dict)
assert isinstance(targets[0]["image_id"], Tensor)
assert isinstance(targets[0]["boxes"], Tensor)
assert isinstance(targets[0]["labels"], Tensor)
Expand All @@ -58,7 +56,7 @@ def test_detection_data_module():
assert isinstance(images[0], Tensor)
assert len(images[0]) == 3
assert len(targets) == batch_size
assert isinstance(targets[0], Dict)
assert isinstance(targets[0], dict)
assert isinstance(targets[0]["image_id"], Tensor)
assert isinstance(targets[0]["boxes"], Tensor)
assert isinstance(targets[0]["labels"], Tensor)
Expand Down
17 changes: 7 additions & 10 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
import pytorch_lightning as pl

from yolort.data import COCOEvaluator, DetectionDataModule, _helper as data_helper

from yolort.models import yolov5s
from yolort.models.yolo import yolov5_darknet_pan_s_r31
from yolort.models.transform import nested_tensor_from_tensor_list

from typing import Dict


def default_loader(img_name, is_half=False):
"""
Expand Down Expand Up @@ -44,7 +41,7 @@ def test_train_with_vanilla_model():
model = yolov5_darknet_pan_s_r31(num_classes=12)
model.train()
out = model(images, targets)
assert isinstance(out, Dict)
assert isinstance(out, dict)
assert isinstance(out["cls_logits"], Tensor)
assert isinstance(out["bbox_regression"], Tensor)
assert isinstance(out["objectness"], Tensor)
Expand All @@ -68,7 +65,7 @@ def test_train_with_vanilla_module():
model.train()

out = model(images, targets)
assert isinstance(out, Dict)
assert isinstance(out, dict)
assert isinstance(out["cls_logits"], Tensor)
assert isinstance(out["bbox_regression"], Tensor)
assert isinstance(out["objectness"], Tensor)
Expand Down Expand Up @@ -139,7 +136,7 @@ def test_predict_with_vanilla_model():
out = model([img_input])
assert isinstance(out, list)
assert len(out) == 1
assert isinstance(out[0], Dict)
assert isinstance(out[0], dict)
assert isinstance(out[0]["boxes"], Tensor)
assert isinstance(out[0]["labels"], Tensor)
assert isinstance(out[0]["scores"], Tensor)
Expand All @@ -157,7 +154,7 @@ def test_predict_with_tensor():
predictions = model.predict(img_tensor)
assert isinstance(predictions, list)
assert len(predictions) == 1
assert isinstance(predictions[0], Dict)
assert isinstance(predictions[0], dict)
assert isinstance(predictions[0]["boxes"], Tensor)
assert isinstance(predictions[0]["labels"], Tensor)
assert isinstance(predictions[0]["scores"], Tensor)
Expand All @@ -177,7 +174,7 @@ def test_predict_with_tensors():
predictions = model.predict(img_tensors)
assert isinstance(predictions, list)
assert len(predictions) == 2
assert isinstance(predictions[0], Dict)
assert isinstance(predictions[0], dict)
assert isinstance(predictions[0]["boxes"], Tensor)
assert isinstance(predictions[0]["labels"], Tensor)
assert isinstance(predictions[0]["scores"], Tensor)
Expand All @@ -193,7 +190,7 @@ def test_predict_with_image_file():
predictions = model.predict(img_name)
assert isinstance(predictions, list)
assert len(predictions) == 1
assert isinstance(predictions[0], Dict)
assert isinstance(predictions[0], dict)
assert isinstance(predictions[0]["boxes"], Tensor)
assert isinstance(predictions[0]["labels"], Tensor)
assert isinstance(predictions[0]["scores"], Tensor)
Expand All @@ -211,7 +208,7 @@ def test_predict_with_image_files():
predictions = model.predict(img_names)
assert isinstance(predictions, list)
assert len(predictions) == 2
assert isinstance(predictions[0], Dict)
assert isinstance(predictions[0], dict)
assert isinstance(predictions[0]["boxes"], Tensor)
assert isinstance(predictions[0]["labels"], Tensor)
assert isinstance(predictions[0]["scores"], Tensor)
4 changes: 1 addition & 3 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
"""
Test for exporting model to ONNX and inference with ONNXRuntime
"""
from typing import List, Tuple

from pathlib import Path
import io
import pytest
Expand Down Expand Up @@ -96,7 +94,7 @@ def get_image(self, img_name, size) -> Tensor:

return transforms.ToTensor()(image)

def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]:
def get_test_images(self):
return ([self.get_image("bus.jpg", (416, 320))],
[self.get_image("zidane.jpg", (352, 480))])

Expand Down

0 comments on commit 57db1e8

Please sign in to comment.