Skip to content

Commit

Permalink
Merge pull request #550 from weecology/predict_file_dataloader
Browse files Browse the repository at this point in the history
Move main.predict_file to predict.predict_file and uses trainer.predict() for predict_file(). Speeds up main.evaluate!
  • Loading branch information
ethanwhite authored Nov 10, 2023
2 parents 79e024d + 6050220 commit a028d71
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 52 deletions.
2 changes: 1 addition & 1 deletion deepforest/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def log_images(self, pl_module):
"skipping upload, images were saved to {}, "
"error was rasied {}".format(self.savedir, e))

def on_validation_epoch_end(self, trainer, pl_module):
def on_validation_end(self, trainer, pl_module):
if trainer.sanity_checking: # optional skip
return

Expand Down
57 changes: 7 additions & 50 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def predict_image(self,

def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1):
"""Create a dataset and predict entire annotation file
Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
Image_path is the relative filename, not absolute path, which is in the root_dir directory. One bounding box per line.
Expand All @@ -383,55 +382,13 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
Returns:
df: pandas dataframe with bounding boxes, label and scores for each image in the csv file
"""
self.model.eval()
df = pd.read_csv(csv_file)
paths = df.image_path.unique()
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
transforms=None,
train=False)

batched_results = []
for i, batch in enumerate(self.predict_dataloader(ds)):
batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
out = self.predict_step(batch, i)
batched_results.append(out)

# Flatten list from batched prediction
prediction_list = []
for batch in batched_results:
for boxes in batch:
prediction_list.append(boxes)

results = []
for index, prediction in enumerate(prediction_list):
# If there is more than one class, apply NMS Loop through images and apply cross
if len(prediction.label.unique()) > 1:
prediction = predict.across_class_nms(
prediction, iou_threshold=self.config["nms_thresh"])

if savedir:
# Just predict the images, even though we have the annotations
image = np.array(Image.open("{}/{}".format(root_dir,
paths[index])))[:, :, ::-1]
image = visualize.plot_predictions(image, prediction)

# Plot annotations if they exist
annotations = df[df.image_path == paths[index]]

image = visualize.plot_predictions(image,
annotations,
color=color,
thickness=thickness)
cv2.imwrite(
"{}/{}.png".format(savedir,
os.path.splitext(paths[index])[0]), image)

prediction["image_path"] = paths[index]
results.append(prediction)

results = pd.concat(results, ignore_index=True)

results = predict.predict_file(model=self,
csv_file=csv_file,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
thickness=thickness)
return results

def predict_tile(self,
Expand Down
76 changes: 76 additions & 0 deletions deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import cv2
import pandas as pd
import numpy as np
import os
from PIL import Image
import warnings

import torch
from torchvision.ops import nms

from deepforest import preprocess
from deepforest import visualize
from deepforest import dataset


def predict_image(model,
Expand Down Expand Up @@ -135,3 +138,76 @@ def across_class_nms(predicted_boxes, iou_threshold=0.15):
columns=["xmin", "ymin", "xmax", "ymax", "label", "score"])

return new_df


def predict_file(model,
csv_file,
root_dir,
nms_thresh,
savedir=None,
color=None,
thickness=1):
"""Create a dataset and predict entire annotation file
Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
Image_path is the relative filename, not absolute path, which is in the root_dir directory. One bounding box per line.
Args:
model: deepforest.main object
csv_file: path to csv file
root_dir: directory of images. If none, uses "image_dir" in config
nms_thresh: Non-max supression threshold, see config["nms_thresh"]
savedir: Optional. Directory to save image plots.
color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
thickness: thickness of the rectangle border line in px
Returns:
df: pandas dataframe with bounding boxes, label and scores for each image in the csv file
"""
df = pd.read_csv(csv_file)
paths = df.image_path.unique()
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
transforms=None,
train=False)

dataloader = model.predict_dataloader(ds)

#Make sure the latest trainer is used.
model.create_trainer()
trainer = model.trainer
batched_results = trainer.predict(model, dataloader)

# Flatten list from batched prediction
prediction_list = []
for batch in batched_results:
for images in batch:
prediction_list.append(images)

results = []
for index, prediction in enumerate(prediction_list):
# If there is more than one class, apply NMS Loop through images and apply cross
if len(prediction.label.unique()) > 1:
prediction = across_class_nms(prediction, iou_threshold=nms_thresh)

if savedir:
# Just predict the images, even though we have the annotations
image = np.array(Image.open("{}/{}".format(root_dir,
paths[index])))[:, :, ::-1]
image = visualize.plot_predictions(image, prediction)

# Plot annotations if they exist
annotations = df[df.image_path == paths[index]]

image = visualize.plot_predictions(image,
annotations,
color=color,
thickness=thickness)
cv2.imwrite("{}/{}.png".format(savedir,
os.path.splitext(paths[index])[0]), image)

prediction["image_path"] = paths[index]
results.append(prediction)

results = pd.concat(results, ignore_index=True)

return results
4 changes: 3 additions & 1 deletion tests/profile_predict_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ def run(m, csv_file, root_dir):
if __name__ == "__main__":
m = main.deepforest()
m.use_release()
m.config["workers"] = 0
m.config["batch_size"] = 5

csv_file = get_data("OSBS_029.csv")
image_path = get_data("OSBS_029.png")
tmpdir = tempfile.gettempdir()
df = pd.read_csv(csv_file)

big_frame = []
for x in range(10):
for x in range(100):
img = Image.open("{}/{}".format(os.path.dirname(csv_file), df.image_path.unique()[0]))
cv2.imwrite("{}/{}.png".format(tmpdir, x), np.array(img))
new_df = df.copy()
Expand Down

0 comments on commit a028d71

Please sign in to comment.