Skip to content

Commit 1870960

Browse files
Create test_offline.py
1 parent 795cb4e commit 1870960

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

test_offline.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import random
2+
from data import ImageDetectionsField, TextField, RawField
3+
from data import COCO, DataLoader
4+
import evaluation
5+
from models.rstnet import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention
6+
7+
import torch
8+
from tqdm import tqdm
9+
import argparse
10+
import pickle
11+
import numpy as np
12+
import time
13+
14+
random.seed(1234)
15+
torch.manual_seed(1234)
16+
np.random.seed(1234)
17+
18+
19+
def predict_captions(model, dataloader, text_field):
20+
import itertools
21+
model.eval()
22+
gen = {}
23+
gts = {}
24+
with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar:
25+
for it, (images, caps_gt) in enumerate(iter(dataloader)):
26+
images = images.to(device)
27+
with torch.no_grad():
28+
out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
29+
caps_gen = text_field.decode(out, join_words=False)
30+
for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
31+
gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
32+
gen['%d_%d' % (it, i)] = [gen_i.strip(), ]
33+
gts['%d_%d' % (it, i)] = gts_i
34+
pbar.update()
35+
36+
gts = evaluation.PTBTokenizer.tokenize(gts)
37+
gen = evaluation.PTBTokenizer.tokenize(gen)
38+
scores, _ = evaluation.compute_scores(gts, gen)
39+
40+
return scores
41+
42+
43+
if __name__ == '__main__':
44+
start_time = time.time()
45+
device = torch.device('cuda')
46+
47+
parser = argparse.ArgumentParser(description='RSTNet')
48+
parser.add_argument('--batch_size', type=int, default=10)
49+
parser.add_argument('--workers', type=int, default=4)
50+
parser.add_argument('--m', type=int, default=40)
51+
52+
parser.add_argument('--features_path', type=str, default='./Datasets/X101-features/X101-grid-coco_trainval.hdf5')
53+
parser.add_argument('--annotation_folder', type=str, default='./Datasets/m2_annotations')
54+
55+
# the path of tested model and vocabulary
56+
parser.add_argument('--language_model_path', type=str, default='./saved_language_models/language_context.pth')
57+
parser.add_argument('--model_path', type=str, default='./saved_transformer_models/rstnet_best.pth')
58+
parser.add_argument('--vocab_path', type=str, default='./vocab.pkl')
59+
args = parser.parse_args()
60+
61+
print('The Evaluation of RSTNet')
62+
63+
# Pipeline for image regions
64+
image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False)
65+
66+
# Pipeline for text
67+
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
68+
remove_punctuation=True, nopoints=False)
69+
70+
# Create the dataset
71+
dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
72+
_, _, test_dataset = dataset.splits
73+
text_field.vocab = pickle.load(open(args.vocab_path, 'rb'))
74+
75+
# Model and dataloaders
76+
encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m})
77+
decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'], language_model_path=args.language_model_path)
78+
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)
79+
80+
data = torch.load(args.model_path)
81+
model.load_state_dict(data['state_dict'])
82+
83+
dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
84+
dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers)
85+
86+
scores = predict_captions(model, dict_dataloader_test, text_field)
87+
print(scores)
88+
print('it costs {} s to test.'.format(time.time() - start_time))
89+
90+

0 commit comments

Comments
 (0)