|
| 1 | +import random |
| 2 | +from data import ImageDetectionsField, TextField, RawField |
| 3 | +from data import COCO, DataLoader |
| 4 | +import evaluation |
| 5 | +from models.rstnet import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention |
| 6 | + |
| 7 | +import torch |
| 8 | +from tqdm import tqdm |
| 9 | +import argparse |
| 10 | +import pickle |
| 11 | +import numpy as np |
| 12 | +import time |
| 13 | + |
| 14 | +random.seed(1234) |
| 15 | +torch.manual_seed(1234) |
| 16 | +np.random.seed(1234) |
| 17 | + |
| 18 | + |
| 19 | +def predict_captions(model, dataloader, text_field): |
| 20 | + import itertools |
| 21 | + model.eval() |
| 22 | + gen = {} |
| 23 | + gts = {} |
| 24 | + with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: |
| 25 | + for it, (images, caps_gt) in enumerate(iter(dataloader)): |
| 26 | + images = images.to(device) |
| 27 | + with torch.no_grad(): |
| 28 | + out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1) |
| 29 | + caps_gen = text_field.decode(out, join_words=False) |
| 30 | + for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): |
| 31 | + gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) |
| 32 | + gen['%d_%d' % (it, i)] = [gen_i.strip(), ] |
| 33 | + gts['%d_%d' % (it, i)] = gts_i |
| 34 | + pbar.update() |
| 35 | + |
| 36 | + gts = evaluation.PTBTokenizer.tokenize(gts) |
| 37 | + gen = evaluation.PTBTokenizer.tokenize(gen) |
| 38 | + scores, _ = evaluation.compute_scores(gts, gen) |
| 39 | + |
| 40 | + return scores |
| 41 | + |
| 42 | + |
| 43 | +if __name__ == '__main__': |
| 44 | + start_time = time.time() |
| 45 | + device = torch.device('cuda') |
| 46 | + |
| 47 | + parser = argparse.ArgumentParser(description='RSTNet') |
| 48 | + parser.add_argument('--batch_size', type=int, default=10) |
| 49 | + parser.add_argument('--workers', type=int, default=4) |
| 50 | + parser.add_argument('--m', type=int, default=40) |
| 51 | + |
| 52 | + parser.add_argument('--features_path', type=str, default='./Datasets/X101-features/X101-grid-coco_trainval.hdf5') |
| 53 | + parser.add_argument('--annotation_folder', type=str, default='./Datasets/m2_annotations') |
| 54 | + |
| 55 | + # the path of tested model and vocabulary |
| 56 | + parser.add_argument('--language_model_path', type=str, default='./saved_language_models/language_context.pth') |
| 57 | + parser.add_argument('--model_path', type=str, default='./saved_transformer_models/rstnet_best.pth') |
| 58 | + parser.add_argument('--vocab_path', type=str, default='./vocab.pkl') |
| 59 | + args = parser.parse_args() |
| 60 | + |
| 61 | + print('The Evaluation of RSTNet') |
| 62 | + |
| 63 | + # Pipeline for image regions |
| 64 | + image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False) |
| 65 | + |
| 66 | + # Pipeline for text |
| 67 | + text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy', |
| 68 | + remove_punctuation=True, nopoints=False) |
| 69 | + |
| 70 | + # Create the dataset |
| 71 | + dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) |
| 72 | + _, _, test_dataset = dataset.splits |
| 73 | + text_field.vocab = pickle.load(open(args.vocab_path, 'rb')) |
| 74 | + |
| 75 | + # Model and dataloaders |
| 76 | + encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m}) |
| 77 | + decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'], language_model_path=args.language_model_path) |
| 78 | + model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device) |
| 79 | + |
| 80 | + data = torch.load(args.model_path) |
| 81 | + model.load_state_dict(data['state_dict']) |
| 82 | + |
| 83 | + dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) |
| 84 | + dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers) |
| 85 | + |
| 86 | + scores = predict_captions(model, dict_dataloader_test, text_field) |
| 87 | + print(scores) |
| 88 | + print('it costs {} s to test.'.format(time.time() - start_time)) |
| 89 | + |
| 90 | + |
0 commit comments