-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_sesame.py
272 lines (247 loc) · 10.2 KB
/
eval_sesame.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import argparse
import os
import sys
import json
from functools import partial
import tqdm
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from dataloaders.test_dataset import TestReasoningDataset, TestReferDataset, collate_fn_test
from model.SESAME import load_pretrained_model_SESAME
from utils import (
prepare_input,
intersectionAndUnionGPU,
AverageMeter,
Summary,
)
def parse_args(args):
parser = argparse.ArgumentParser(description="SESAME's evaluation helper")
parser.add_argument("--cmd", default="inference", type=str, help="Value can be: 'inference' (default) or 'metrics' 'download'.")
parser.add_argument("--local_rank", default=0, type=int, help="node rank")
parser.add_argument("--process_num", default=0, type=int, help="For splitting inference across multiple processes/GPUs. The value should be between 0 and world_size-1.")
parser.add_argument("--world_size", default=2, type=int, help="total node count")
parser.add_argument("--pretrained_model_path", default="tsunghanwu/SESAME")
parser.add_argument("--vis_save_path", default="./save_directory", type=str)
parser.add_argument("--image_size", default=1024, type=int, help="image size")
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--val_dataset", default="fprefcoco", type=str)
parser.add_argument("--dataset_dir", default="./dataset", type=str)
parser.add_argument(
"--conv_type",
default="llava_v1",
type=str,
choices=["llava_v1", "llava_llama_2"],
)
return parser.parse_args(args)
def save_preds(val_dataset, preds, process_num, world_size, inference_dir):
filename = f"preds_{val_dataset}_{process_num}_of_{world_size}.json"
file_path = os.path.join(inference_dir, filename)
with open(file_path, "w") as file:
json.dump(preds, file)
def load_preds_file(val_dataset, process_num, world_size, inference_dir):
filename = f"preds_{val_dataset}_{process_num}_of_{world_size}.json"
file_path = os.path.join(inference_dir, filename)
with open(file_path, "rb") as file:
return json.load(file)
def inference(args):
# Initialization
os.makedirs(args.vis_save_path, exist_ok=True)
os.makedirs(os.path.join(args.vis_save_path, "segmentation_mask"), exist_ok=True)
(
tokenizer,
segmentation_lmm,
vision_tower,
context_len,
) = load_pretrained_model_SESAME(
model_path=args.pretrained_model_path
)
# Load bf16 datatype
vision_tower = vision_tower.to(torch.bfloat16)
segmentation_lmm = segmentation_lmm.to(torch.bfloat16)
segmentation_lmm = torch.compile(segmentation_lmm, mode="reduce-overhead")
# for eval only
tokenizer.padding_side = "left"
reason_seg_dataset = ["ReasonSeg"]
refer_seg_dataset = [
"fprefcoco", "fprefcoco+", "fprefcocog",
"refcoco", "refcoco+", "refcocog",
"R-refcoco", "R-refcoco+", "R-refcocog"
]
if args.val_dataset in reason_seg_dataset:
test_dataset = TestReasoningDataset(
args.dataset_dir,
vision_tower.image_processor,
args.image_size,
datasetname=args.val_dataset,
train_test_split="val"
)
elif args.val_dataset in refer_seg_dataset:
test_dataset = TestReferDataset(
args.dataset_dir,
vision_tower.image_processor,
args.image_size,
datasetname=args.val_dataset,
train_test_split="val"
)
test_dataset = get_dataset_slice(test_dataset, args.process_num, args.world_size, debug=False)
test_loader = DataLoader(
test_dataset, batch_size=1, num_workers=1,
shuffle=False, drop_last=False, pin_memory=False,
collate_fn=partial(
collate_fn_test,
tokenizer=tokenizer,
use_mm_start_end=getattr(segmentation_lmm.config, "mm_use_im_start_end", False),
padding="left",
),
)
idx = 0
output_json = {}
# Process each item in the test loader
for input_dict in tqdm.tqdm(test_loader):
idx += 1
input_dict = prepare_input(input_dict, "bf16", is_cuda=True)
N = len(input_dict["input_ids"])
# total_count += N
assert N == len(input_dict["exists"][0])
# Batch processing setup
batch_size = 5
num_batch = math.ceil(N / batch_size)
pred_masks = []
pred_exists = []
gt_masks = input_dict["masks_list"][0].int()
# Prepare output JSON structure
image_file_key = input_dict["image_paths"][0]
output_json[image_file_key] = {}
raw_questions = [
x.split("\n")[1].split("ASSISTANT:")[0].strip()
for x in input_dict["conversation_list"]
]
output_json[image_file_key] = {
"conversation_list": raw_questions,
"pred_sent": [],
"gt_exists": input_dict["exists"][0],
"ref_ids": input_dict["ref_ids"],
"sent_ids": input_dict["sent_ids"]
}
# Batch inference to prevent OOM
for n in range(num_batch):
start_idx = n * batch_size
end_idx = min((n + 1) * batch_size, N)
input_ids = input_dict["input_ids"][start_idx:end_idx]
real_batch_size = input_ids.shape[0]
# Image preparation
images_clip = input_dict["images_clip"].repeat(
real_batch_size, 1, 1, 1
) # N x 3 x 336 x 336
images = input_dict["images"].repeat(
real_batch_size, 1, 1, 1
) # N x 3 x 1024 x 1024
sam_mask_shape_list = input_dict["sam_mask_shape_list"] * real_batch_size
# Model evaluation
with torch.inference_mode():
output_ids, pred_masks_batch, object_presence = segmentation_lmm.evaluate(
images_clip,
images,
input_ids,
sam_mask_shape_list,
max_new_tokens=512,
)
pred_exists += object_presence
pred_masks += pred_masks_batch
real_output_ids = output_ids[:, input_ids.shape[1] :]
generated_outputs = tokenizer.batch_decode(
real_output_ids, skip_special_tokens=True
)
output_json[image_file_key]["pred_sent"] += generated_outputs
# Save results
pred_masks = torch.stack(pred_masks, dim=0)
output_json[image_file_key]["pred_exists"] = pred_exists
output_json[image_file_key]["segmentation_path"] = os.path.join(
"segmentation_mask", f"{args.process_num}_{idx:04d}.npz"
)
output_seg_fname = os.path.join(
args.vis_save_path,
output_json[image_file_key]["segmentation_path"],
)
np.savez_compressed(
output_seg_fname,
pred=pred_masks.cpu().numpy(),
gt=gt_masks.cpu().numpy(),
)
save_preds(args.val_dataset, output_json, args.process_num, args.world_size, args.vis_save_path)
def get_dataset_slice(
val_dataset, process_num: int, world_size: int, debug: bool = False
) -> Subset:
"""
Return a torch.utils.data.Subset object that is a subset of `val_dataset`.
`val_dataset` is broken into roughly equal chunks. Number of chunks = `world_size`. The Nth
chunk is returned (where N = `process_num`).
"""
all_indices = np.array(range(len(val_dataset)))
print(f"Splitting Total images: {len(all_indices)}, world_size: {world_size}")
splits = np.array_split(all_indices, world_size)
print("Total splits generated: ", len(splits))
print(
"Split sizes (each size is for one process/process_num): ",
[len(s) for s in splits],
)
subset_indices = splits[process_num]
if debug:
# Force the dataset size to be tiny so we can test the script
subset_indices = subset_indices[:10]
print("Indices for current process: ", len(subset_indices))
subset = Subset(val_dataset, indices=subset_indices)
print("Subset dataset size: ", len(subset))
return subset
def run_eval(preds, args):
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
total_count = 0
correct_query = 0
for idx, pred_key in enumerate(preds):
pred = preds[pred_key]
# Evaluate SEE (binary classification)
correct_this_round = np.sum(
np.array(pred["gt_exists"]) == np.array(pred["pred_exists"])
)
total_count += np.array(pred["gt_exists"]).shape[0]
correct_query += correct_this_round
# Evaluate Say (Pass)
# Evaluate Segment (IoU)
seg_fname = os.path.join(
args.vis_save_path,
pred["segmentation_path"],
)
segmentation = np.load(seg_fname)
pred_masks = torch.from_numpy(segmentation["pred"]).cuda()
gt_masks =torch.from_numpy(segmentation["gt"]).cuda()
intersection, union, _ = intersectionAndUnionGPU(
pred_masks.contiguous(), gt_masks.contiguous(), 2, ignore_index=255,
)
intersection_meter.update(intersection), union_meter.update(union)
if idx % 50 == 49:
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
detection_acc = correct_query / total_count
print(
f"[(batch {idx}) - {args.val_dataset}] cIoU: {iou_class[1]:.4f} | See Accuracy: {detection_acc:.4f}"
)
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
detection_acc = correct_query / total_count
print(
f"[Final Result - {args.val_dataset}] cIoU: {iou_class[1]:.4f} | See Accuracy: {detection_acc:.4f}"
)
def main(args):
args = parse_args(args)
print(args)
if args.cmd == "inference":
inference(args)
elif args.cmd == "metrics":
combined_preds = {}
for rank in range(args.world_size):
combined_preds |= load_preds_file(args.val_dataset, rank, args.world_size, args.vis_save_path)
run_eval(combined_preds, args)
if __name__ == "__main__":
main(sys.argv[1:])