-
Notifications
You must be signed in to change notification settings - Fork 0
/
Entangled.py
executable file
·123 lines (91 loc) · 3.98 KB
/
Entangled.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#%%
from detectron2.config import get_cfg
import os
from detectron2.engine import DefaultPredictor
from PIL import Image
import numpy as np
from detectron2 import model_zoo
import torch
import shutil
import torch
from torchvision import transforms
from progress.bar import Bar
import argparse
import torchvision.transforms as transforms
#add command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--images_dir', type=str, help='path to images')
parser.add_argument('--out_path', default=None, type=str, help='path to output')
parser.add_argument('--Binary_classifier_model_path', type=str, help='path to binary classifier model')
parser.add_argument('--FaterRCNN_model_path', type=str, help='path to FasterRCNN model')
args = parser.parse_args()
#if cuda is available, use gpu otherwise use cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.out_path is None:
os.makedirs(os.path.join('.', 'Entangle_Detect_Out'), exist_ok=True)
# %% Load FasterRCNN model and custom weights
cfg = get_cfg()
# Check for GPU availability and set the device accordingly
if torch.cuda.is_available():
cfg.MODEL.DEVICE = 'cuda'
else:
cfg.MODEL.DEVICE = 'cpu'
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.9
cfg.MODEL.WEIGHTS = args.FaterRCNN_model_path
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.TEST.DETECTIONS_PER_IMAGE = 1
FasterRCNN = DefaultPredictor(cfg)
#%%Define the plant not plant filter
def plant_notplant_filter(input_img, binary_classifier_model, img_dir, file_name):
#TODO: resize image to 224x224
# Define the transformation
# resize_transform = transforms.Resize((128, 128)) # Resize to 128x128 pixels
# Apply the transformation to the image
# input_img = resize_transform(input_img)
outputs = binary_classifier_model(input_img)
_, preds = torch.max(outputs, 1)
# [0][1] = Entangled
# [0][0] = NotEntangled
source = os.path.join(img_dir, file_name)
if preds.item() == 0:
dest = os.path.join(args.out_path, 'Entangled')
os.makedirs(dest, exist_ok=True)
#If model predicts "not plant"
elif preds.item() == 1:
dest = os.path.join(args.out_path, 'NotEntangled')
os.makedirs(dest, exist_ok=True)
shutil.copy(source, dest + '/' + file_name)
#%%Crop the bounding box and run the "plant not plant" model
def FasterRCNN_Predict(image_file_name, image_directory, binary_classifier_model):
image = Image.open(os.path.join(image_directory, image_file_name))
convert_tensor = transforms.ToTensor()
image_tensor = convert_tensor(image)
image_np = np.asarray(image)
outputs = FasterRCNN(image_np)
#If fasterRCNN fails to predicts bounding box, assume the rudder is submerged
pred_box = (outputs['instances'].pred_boxes).__dict__['tensor'].tolist()
if len(pred_box) != 0:
y, x, w, h = int(pred_box[0][0]), int(pred_box[0][1]), int(pred_box[0][2]), int(pred_box[0][3])
crop_img = image_tensor[:, x:h, y:w].unsqueeze(0).to(device)
plant_notplant_filter(input_img = crop_img, binary_classifier_model=binary_classifier_model, img_dir=image_directory, file_name= image_file_name)
else:
dest = os.path.join(args.out_path, 'Submerged')
source = os.path.join(image_directory, image_file_name)
os.makedirs(dest, exist_ok=True)
shutil.copy(source, dest + '/' + image_file_name)
#%%Load the "plant not plant" model
if torch.cuda.is_available():
Binary_classifier_model = torch.load(args.Binary_classifier_model_path)
else:
Binary_classifier_model = torch.load(args.Binary_classifier_model_path, map_location=torch.device('cpu'))
input_images = os.listdir(args.images_dir)
if len(input_images) == 0:
print('No images found in the input directory')
exit(0)
with Bar('Disentangling...', max=len(input_images)) as bar:
for image in input_images:
FasterRCNN_Predict(image_file_name=image, image_directory=args.images_dir, binary_classifier_model=Binary_classifier_model)
bar.next()
#%%