diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 0d1ed1d6..429abaaf 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -420,7 +420,8 @@ def predict_image(self, result = utilities.read_file(result, root_dir=root_dir) return result - + + label_dict: dict = {"Tree": 0} def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1): """Create a dataset and predict entire annotation file Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" @@ -446,6 +447,10 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1 df = csv_file else: df = utilities.read_file(csv_file) + + # Map cropmodel_label to string labels using label_dict + df['label'] = df['label'].map({v: k for k, v in self.label_dict.items()}) + ds = dataset.TreeDataset(csv_file=df, root_dir=root_dir, transforms=None, diff --git a/tests/test_main.py b/tests/test_main.py index a8c3543e..d573be43 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -650,7 +650,6 @@ def test_predict_tile_with_crop_model(m, config): "cropmodel_score", "image_path" } - def test_predict_tile_with_crop_model_empty(): """If the model return is empty, the crop model should return an empty dataframe""" raster_path = get_data("SOAP_061.png") @@ -659,9 +658,11 @@ def test_predict_tile_with_crop_model_empty(): patch_overlap = 0.05 iou_threshold = 0.15 mosaic = True - # Set up the crop model crop_model = model.CropModel() + # Configure the label dictionary + m.label_dict = {"Tree": 0, "Bush": 1} + # Call the predict_tile method with the crop_model m.config["train"]["fast_dev_run"] = False m.create_trainer() @@ -672,5 +673,8 @@ def test_predict_tile_with_crop_model_empty(): mosaic=mosaic, crop_model=crop_model) - # Assert the result + # If result is not None, map cropmodel_label to string + if result is not None and not result.empty: + result['cropmodel_label'] = result['cropmodel_label'].map({v: k for k, v in m.label_dict.items()}) + assert result is None