Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move main.predict_file to predict.predict_file and uses trainer.predict() for predict_file(). Speeds up main.evaluate! #550

Merged
merged 3 commits into from
Nov 10, 2023

Conversation

bw4sz
Copy link
Collaborator

@bw4sz bw4sz commented Nov 9, 2023

  1. Migrate the code in main.predict_file to predict.predict_file with just a wrapper in main. main.py is too long and shouldn't have any complex logic in it.
  2. Used the trainer logic instead of manually moving batches to GPU, this causes significant speed up in evaluate, which should close main.evaluate is too slow #538.

Before was around 37 seconds.

After
Screenshot 2023-11-09 at 11 30 16 AM

Tested on hpc with 1 gpu

(base) [b.weinstein@login12 ~]$ cat tunnel.sh
#!/bin/bash
#SBATCH --job-name=tunnel   # Job name
#SBATCH --mail-type=END               # Mail events
#SBATCH [email protected]  # Where to send mail
#SBATCH --account=ewhite
#SBATCH --nodes=1                 # Number of MPI ran
#SBATCH --cpus-per-task=1
#SBATCH --mem=70GB
#SBATCH --time=12:00:00       #Time limit hrs:min:sec
#SBATCH --output=/home/b.weinstein/logs/tunnel.out   # Standard output and error log
#SBATCH --error=/home/b.weinstein/logs/tunnel.err
#SBATCH --partition=gpu
#SBATCH --gpus=1
(base) [b.weinstein@login12 tests]$ cat profile_predict_file.py
#Profile the dataset class on gpu
from deepforest import main
from deepforest import get_data
import os
import pandas as pd
import numpy as np
import cProfile, pstats
import tempfile
from PIL import Image
import cv2

def run(m, csv_file, root_dir):
    predictions = m.predict_file(csv_file=csv_file, root_dir=root_dir)

if __name__ == "__main__":
    profiler = cProfile.Profile()
    profiler.enable()
    m = main.deepforest()
    m.use_release()
    m.config["workers"] = 0
    m.config["batch_size"] = 24

    csv_file = get_data("OSBS_029.csv")
    image_path = get_data("OSBS_029.png")
    tmpdir = tempfile.gettempdir()
    df = pd.read_csv(csv_file)

    big_frame = []
    for x in range(100):
        img = Image.open("{}/{}".format(os.path.dirname(csv_file), df.image_path.unique()[0]))
        cv2.imwrite("{}/{}.png".format(tmpdir, x), np.array(img))
        new_df = df.copy()
        new_df.image_path = "{}.png".format(x)
        big_frame.append(new_df)

    big_frame = pd.concat(big_frame)
    big_frame.to_csv("{}/annotations.csv".format(tmpdir))


    run(m, csv_file = "{}/annotations.csv".format(tmpdir), root_dir = tmpdir)
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats('cumtime')
    stats.print_stats()
    stats.dump_stats('predict_file.prof')

@bw4sz bw4sz requested a review from henrykironde November 9, 2023 19:35
Copy link
Member

@ethanwhite ethanwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks good. My one question is whether or not we could just import predict_file() directly from predict.py instead of having what is basically just a pass through function, but I'm guessing there's something special about the lightening module class that means we need to put a copy there, so I'm going to go ahead and merge.

@ethanwhite ethanwhite merged commit a028d71 into main Nov 10, 2023
5 checks passed
@ethanwhite ethanwhite deleted the predict_file_dataloader branch November 10, 2023 15:09
janjatovic pushed a commit to Treeconomy/DeepForest_new that referenced this pull request Mar 26, 2024
Move main.predict_file to predict.predict_file and uses trainer.predict() for predict_file(). Speeds up main.evaluate!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

main.evaluate is too slow
2 participants