-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test script #12
Comments
@Doch88 Pls feel free to adapt the below code snippet for generating scanpaths for your case. I also uploaded "plot_scanpath.py" file for visualizing scanpath, feel free to check it out.
|
@ouyangzhibo Thank you very much. I've changed a bit the code of your script to see the results on the train set: """Test script.
Usage:
test.py <hparams> <checkpoint_dir> <dataset_root> [--cuda=<id>]
test.py -h | --help
Options:
-h --help Show this screen.
--cuda=<id> id of the cuda device [default: 0].
"""
import os
import json
import torch
import numpy as np
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
from docopt import docopt
from os.path import join
from dataset import process_data
from irl_dcb.config import JsonConfig
import cv2 as cv
from irl_dcb.data import LHF_IRL
from irl_dcb.models import LHF_Policy_Cond_Small
from irl_dcb.environment import IRL_Env4LHF
from irl_dcb import utils
torch.manual_seed(42620)
np.random.seed(42620)
def gen_scanpaths(generator,
env_test,
test_img_loader,
patch_num,
max_traj_len,
im_w,
im_h,
num_sample=10):
all_actions = []
for i_sample in range(num_sample):
progress = tqdm(test_img_loader,
desc='trial ({}/{})'.format(i_sample + 1, num_sample))
for i_batch, batch in enumerate(progress):
env_test.set_data(batch)
img_names_batch = batch['img_name']
cat_names_batch = batch['cat_name']
with torch.no_grad():
env_test.reset()
trajs = utils.collect_trajs(env_test,
generator,
patch_num,
max_traj_len,
is_eval=True,
sample_action=True)
all_actions.extend([(cat_names_batch[i], img_names_batch[i],
'present', trajs['actions'][:, i])
for i in range(env_test.batch_size)])
scanpaths = utils.actions2scanpaths(all_actions, patch_num, im_w, im_h)
utils.cutFixOnTarget(scanpaths, bbox_annos)
return scanpaths
if __name__ == '__main__':
args = docopt(__doc__)
device = torch.device('cuda:{}'.format(args['--cuda']))
hparams = args["<hparams>"]
dataset_root = args["<dataset_root>"]
checkpoint = args["<checkpoint_dir>"]
hparams = JsonConfig(hparams)
# dir of pre-computed beliefs
DCB_dir_HR = join(dataset_root, 'DCBs/HR/')
DCB_dir_LR = join(dataset_root, 'DCBs/LR/')
data_name = '{}x{}'.format(hparams.Data.im_w, hparams.Data.im_h)
# bounding box of the target object (for search efficiency evaluation)
bbox_annos = np.load(join(dataset_root, 'bbox_annos.npy'),
allow_pickle=True).item()
with open(join(dataset_root,
'human_scanpaths_TP_trainval_train.json')) as json_file:
human_scanpaths_train = json.load(json_file)
with open(join(dataset_root,
'human_scanpaths_TP_trainval_valid.json')) as json_file:
human_scanpaths_valid = json.load(json_file)
target_init_fixs = {}
for traj in human_scanpaths_train + human_scanpaths_valid:
key = traj['task'] + '_' + traj['name']
target_init_fixs[key] = (traj['X'][0] / hparams.Data.im_w,
traj['Y'][0] / hparams.Data.im_h)
cat_names = list(np.unique([x['task'] for x in human_scanpaths_train]))
catIds = dict(zip(cat_names, list(range(len(cat_names)))))
dataset = process_data(human_scanpaths_train, human_scanpaths_valid,
DCB_dir_HR, DCB_dir_LR, bbox_annos, hparams)
train_task_img_pair = np.unique(
[traj['task'] + '_' + traj['name'] for traj in human_scanpaths_train])
test_dataset = LHF_IRL(DCB_dir_HR, DCB_dir_LR, target_init_fixs, train_task_img_pair, bbox_annos, hparams.Data, catIds)
dataloader = torch.utils.data.DataLoader(test_dataset,
batch_size=16,
shuffle=False,
num_workers=2)
# load trained model
input_size = 134 # number of belief maps
task_eye = torch.eye(len(dataset['catIds'])).to(device)
generator = LHF_Policy_Cond_Small(hparams.Data.patch_count,
len(dataset['catIds']), task_eye,
input_size).to(device)
generator.eval()
state = torch.load(join(checkpoint, 'trained_generator.pkg'), map_location=device)
generator.load_state_dict(state["model"])
# build environment
env_test = IRL_Env4LHF(hparams.Data,
max_step=hparams.Data.max_traj_length,
mask_size=hparams.Data.IOR_size,
status_update_mtd=hparams.Train.stop_criteria,
device=device,
inhibit_return=True)
# generate scanpaths
print('sample scanpaths (10 for each testing image)...')
predictions = gen_scanpaths(generator,
env_test,
dataloader,
hparams.Data.patch_num,
hparams.Data.max_traj_length,
hparams.Data.im_w,
hparams.Data.im_h,
num_sample=1)
for elem in predictions:
filename = elem['task']+"/" + elem['name']
image = cv.imread("../images/"+filename)
X = elem['X']
Y = elem['Y']
image = cv.resize(image, (hparams.Data.im_w, hparams.Data.im_h))
for i in range(len(X)):
x = int(X[i])
y = int(Y[i])
cv.putText(image, str(i), (x, y), cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
cv.circle(image, (x, y), 2, (255, 255, 255))
if i > 0:
xprec = int(X[i-1])
yprec = int(Y[i-1])
cv.line(image, (xprec, yprec), (x, y), (255, 255, 255))
os.makedirs("./results/" + elem['task'] + "/", exist_ok=True)
cv.imwrite("./results/" + elem['task'] + "/" + elem['name'], image) It works finely, although I had to change this line:
Because it gave me an error. After loading the pkg file you have a dict, and the real model is in the "model" key of that dictionary. |
Hi, Here would you say the steps to generate the prediction for a single image out of COCOSearch18 dataset? I've only computed DCBs, but I don't know the other steps to do the prediction. |
@ouyangzhibo Hi, is it possible to have a simple script that tests the pre-trained framework on a single image?
The text was updated successfully, but these errors were encountered: