Skip to content

Commit

Permalink
Numpy order (#575)
Browse files Browse the repository at this point in the history
* Use pytest.mark.parametrize for multiple test instances
  • Loading branch information
henrykironde authored Dec 12, 2023
1 parent bd84184 commit fb3f5ee
Showing 1 changed file with 57 additions and 52 deletions.
109 changes: 57 additions & 52 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import rasterio


@pytest.fixture()
def config():
config = utilities.read_config("deepforest_config.yml")
Expand All @@ -25,8 +26,7 @@ def config():
config["path_to_raster"] = get_data("OSBS_029.tif")

# Create a clean config test data
annotations = utilities.xml_to_annotations(
xml_path=config["annotations_xml"])
annotations = utilities.xml_to_annotations(xml_path=config["annotations_xml"])
annotations.to_csv("tests/data/OSBS_029.csv", index=False)

return config
Expand Down Expand Up @@ -74,60 +74,69 @@ def test_select_annotations_tile(config, image):
assert selected_annotations.xmax.max() <= config["patch_size"]
assert selected_annotations.ymax.max() <= config["patch_size"]

@pytest.mark.parametrize("input_type",["path","dataframe"])

@pytest.mark.parametrize("input_type", ["path", "dataframe"])
def test_split_raster(config, tmpdir, input_type):
"""Split raster into crops with overlaps to maintain all annotations"""
raster = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
annotations = utilities.xml_to_annotations(get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations = utilities.xml_to_annotations(
get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations.to_csv("{}/example.csv".format(tmpdir), index=False)
#annotations.label = 0

if input_type =="path":
if input_type == "path":
annotations_file = "{}/example.csv".format(tmpdir)
else:
annotations_file = annotations
annotations_file = annotations

output_annotations = preprocess.split_raster(path_to_raster=raster,
annotations_file=annotations_file,
base_dir=tmpdir,
patch_size=500,
patch_overlap=0)
annotations_file=annotations_file,
base_dir=tmpdir,
patch_size=500,
patch_overlap=0)

# Returns a 6 column pandas array
assert not output_annotations.empty
assert output_annotations.shape[1] == 6


def test_split_raster_empty_crops(config, tmpdir):
"""Split raster into crops with overlaps to maintain all annotations, allow empty crops"""
raster = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
annotations = utilities.xml_to_annotations(get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations = utilities.xml_to_annotations(
get_data("2019_YELL_2_528000_4978000_image_crop2.xml"))
annotations.to_csv("{}/example.csv".format(tmpdir), index=False)
#annotations.label = 0
#visualize.plot_prediction_dataframe(df=annotations, root_dir=os.path.dirname(get_data(".")), show=True)

annotations_file = preprocess.split_raster(path_to_raster=raster,
annotations_file="{}/example.csv".format(tmpdir),
base_dir=tmpdir,
patch_size=100,
patch_overlap=0,
allow_empty=True)

annotations_file = preprocess.split_raster(
path_to_raster=raster,
annotations_file="{}/example.csv".format(tmpdir),
base_dir=tmpdir,
patch_size=100,
patch_overlap=0,
allow_empty=True)

# Returns a 6 column pandas array
assert not annotations_file[(annotations_file.xmin == 0) & (annotations_file.xmax == 0)].empty

assert not annotations_file[(annotations_file.xmin == 0) &
(annotations_file.xmax == 0)].empty


def test_split_raster_from_image(config, tmpdir):
r = rasterio.open(config["path_to_raster"]).read()
r = np.rollaxis(r,0,3)
annotations_file = preprocess.split_raster(numpy_image=r,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")
r = np.rollaxis(r, 0, 3)
annotations_file = preprocess.split_raster(
numpy_image=r,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")

# Returns a 6 column pandas array
assert annotations_file.shape[1] == 6


def test_split_raster_empty(config):
# Clean output folder
for f in glob.glob("tests/output/empty/*"):
Expand Down Expand Up @@ -170,33 +179,29 @@ def test_split_raster_empty(config):

def test_split_size_error(config, tmpdir):
with pytest.raises(ValueError):
annotations_file = preprocess.split_raster(path_to_raster=config["path_to_raster"],
annotations_file=config["annotations_file"],
base_dir=tmpdir,
patch_size=2000,
patch_overlap=config["patch_overlap"])
annotations_file = preprocess.split_raster(
path_to_raster=config["path_to_raster"],
annotations_file=config["annotations_file"],
base_dir=tmpdir,
patch_size=2000,
patch_overlap=config["patch_overlap"])


@pytest.mark.parametrize("orders", [(4, 400, 400), (400, 400, 4)])
def test_split_raster_4_band_warns(config, tmpdir, orders):
"""Test rasterio channel order
(400, 400, 4) C x H x W
(4, 400, 400) wrong channel order, H x W x C
"""

def test_split_raster_4_band_warns(config, tmpdir):
# Confirm that the rasterio channel order is C x H x W
assert rasterio.open(get_data("OSBS_029.tif")).read().shape[0] == 3

# Create a 4 band image in the wrong channel order, it should be H x W x C
numpy_image = np.zeros((4, 400, 400), dtype=np.uint8)

with pytest.warns(UserWarning):
preprocess.split_raster(numpy_image=numpy_image,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")

numpy_image = np.zeros((400, 400, 4), dtype=np.uint8)
numpy_image = np.zeros(orders, dtype=np.uint8)

with pytest.warns(UserWarning):
preprocess.split_raster(numpy_image=numpy_image,
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")
annotations_file=config["annotations_file"],
save_dir=tmpdir,
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"],
image_name="OSBS_029.tif")

0 comments on commit fb3f5ee

Please sign in to comment.