-
Notifications
You must be signed in to change notification settings - Fork 2
/
sample.py
73 lines (58 loc) · 2.7 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
from torchvision import transforms
from model.encoder_decoder import EncoderCNN, DecoderRNN
from PIL import Image
from utils import load_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args):
# Image preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build models
encoder = EncoderCNN(args.embed_size).eval() # eval mode (batch norm uses moving mean/variance)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
encoder = encoder.to(device)
decoder = decoder.to(device)
# Load the trained model parameters
encoder.load_state_dict(torch.load(args.encoder_path))
decoder.load_state_dict(torch.load(args.decoder_path))
# Prepare an image
image = load_image(args.image, transform)
image_tensor = image.to(device)
# Generate an caption from the image
feature = encoder(image_tensor)
sampled_ids = decoder.sample(feature)
sampled_ids = sampled_ids[0].cpu().numpy() # (1, max_seq_length) -> (max_seq_length)
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
word = vocab.idx2word[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)
# Print out the image and the generated caption
print(sentence)
image = Image.open(args.image)
plt.imshow(np.asarray(image))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
parser.add_argument('--encoder_path', type=str, default='models/encoder-2-1000.ckpt', help='path for trained encoder')
parser.add_argument('--decoder_path', type=str, default='models/decoder-2-1000.ckpt', help='path for trained decoder')
parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
main(args=parser.parse_args())