Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix breaking tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 30, 2021
1 parent d1843d3 commit 403fc0c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
7 changes: 3 additions & 4 deletions tests/pointcloud/detection/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_pointcloud_object_detection_data(tmpdir):

download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_micro.zip", tmpdir)

dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train"), )
dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train"))

class MockModel(PointCloudObjectDetector):

Expand All @@ -43,8 +43,8 @@ def training_step(self, batch, batch_idx: int):
assert len(batch.point) == 2
assert batch.point[0][1].shape == torch.Size([4])
assert len(batch.bboxes) > 1
assert batch.attr[0]["name"] == '000000.bin'
assert batch.attr[1]["name"] == '000001.bin'
assert batch.attr[0]["name"] in ('000000.bin', '000001.bin')
assert batch.attr[1]["name"] in ('000000.bin', '000001.bin')

num_classes = 19
model = MockModel(backbone="pointpillars_kitti", num_classes=num_classes)
Expand All @@ -57,4 +57,3 @@ def training_step(self, batch, batch_idx: int):
predictions = model.predict([join(predict_path, "scans/000000.bin")])
assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape[1] == 4
assert len(predictions[0][DefaultDataKeys.PREDS]) == 158
assert predictions[0][DefaultDataKeys.PREDS][0].__dict__["identifier"] == 'box:1'
6 changes: 3 additions & 3 deletions tests/pointcloud/segmentation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_pointcloud_segmentation_data(tmpdir):

download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiMicro.zip", tmpdir)

dm = PointCloudSegmentationData.from_folders(train_folder=join(tmpdir, "SemanticKittiMicro", "train"), )
dm = PointCloudSegmentationData.from_folders(train_folder=join(tmpdir, "SemanticKittiMicro", "train"))

class MockModel(PointCloudSegmentation):

Expand All @@ -43,8 +43,8 @@ def training_step(self, batch, batch_idx: int):
assert batch[DefaultDataKeys.INPUT]["labels"].shape == torch.Size([2, 45056])
assert batch[DefaultDataKeys.INPUT]["labels"].max() == 19
assert batch[DefaultDataKeys.INPUT]["labels"].min() == 0
assert batch[DefaultDataKeys.METADATA][0]["name"] == '00_000000'
assert batch[DefaultDataKeys.METADATA][1]["name"] == '00_000001'
assert batch[DefaultDataKeys.METADATA][0]["name"] in ('00_000000', '00_000001')
assert batch[DefaultDataKeys.METADATA][1]["name"] in ('00_000000', '00_000001')

num_classes = 19
model = MockModel(backbone="randlanet", num_classes=num_classes)
Expand Down

0 comments on commit 403fc0c

Please sign in to comment.