Skip to content

Commit

Permalink
fixes main.predict_file for numeric labels
Browse files Browse the repository at this point in the history
  • Loading branch information
naxatra2 committed Dec 24, 2024
1 parent 8213e0c commit 4361fca
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
7 changes: 6 additions & 1 deletion src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ def predict_image(self,
result = utilities.read_file(result, root_dir=root_dir)

return result


label_dict: dict = {"Tree": 0}
def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1):
"""Create a dataset and predict entire annotation file Csv file format
is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax"
Expand All @@ -446,6 +447,10 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
df = csv_file
else:
df = utilities.read_file(csv_file)

# Map cropmodel_label to string labels using label_dict
df['label'] = df['label'].map({v: k for k, v in self.label_dict.items()})

ds = dataset.TreeDataset(csv_file=df,
root_dir=root_dir,
transforms=None,
Expand Down
10 changes: 7 additions & 3 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,6 @@ def test_predict_tile_with_crop_model(m, config):
"cropmodel_score", "image_path"
}


def test_predict_tile_with_crop_model_empty():
"""If the model return is empty, the crop model should return an empty dataframe"""
raster_path = get_data("SOAP_061.png")
Expand All @@ -659,9 +658,11 @@ def test_predict_tile_with_crop_model_empty():
patch_overlap = 0.05
iou_threshold = 0.15
mosaic = True
# Set up the crop model
crop_model = model.CropModel()

# Configure the label dictionary
m.label_dict = {"Tree": 0, "Bush": 1}

# Call the predict_tile method with the crop_model
m.config["train"]["fast_dev_run"] = False
m.create_trainer()
Expand All @@ -672,5 +673,8 @@ def test_predict_tile_with_crop_model_empty():
mosaic=mosaic,
crop_model=crop_model)

# Assert the result
# If result is not None, map cropmodel_label to string
if result is not None and not result.empty:
result['cropmodel_label'] = result['cropmodel_label'].map({v: k for k, v in m.label_dict.items()})

assert result is None

0 comments on commit 4361fca

Please sign in to comment.