Skip to content
This repository was archived by the owner on Aug 19, 2023. It is now read-only.

Commit d520bb8

Browse files
committed
test script
1 parent 83dc25d commit d520bb8

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import numpy as np
3+
import torchvision
4+
from torchvision import datasets, models, transforms
5+
import time
6+
import os
7+
import copy
8+
import pdb
9+
import time
10+
from dataloader import CocoDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, UnNormalizer, Normalizer
11+
from torch.utils.data import Dataset, DataLoader
12+
13+
assert torch.__version__.split('.')[1] == '4'
14+
15+
import sys
16+
import cv2
17+
18+
print('CUDA available: {}'.format(torch.cuda.is_available()))
19+
20+
dataset_val = CocoDataset('../coco/', set_name='val2017', transform=transforms.Compose([Normalizer(), Resizer()]))
21+
22+
sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
23+
dataloader_val = DataLoader(dataset_val, num_workers=1, collate_fn=collater, batch_sampler=sampler_val)
24+
25+
model = torch.load('model.pt')
26+
27+
use_gpu = True
28+
29+
if use_gpu:
30+
model = model.cuda()
31+
32+
model.eval()
33+
34+
unnormalize = UnNormalizer()
35+
36+
37+
def draw_caption(image, box, caption):
38+
39+
b = np.array(box).astype(int)
40+
cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2)
41+
cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1)
42+
43+
for idx, data in enumerate(dataloader_val):
44+
45+
scores, classification, transformed_anchors = model(data['img'].cuda().float())
46+
47+
idxs = np.where(scores>0.5)
48+
img = np.array(255 * unnormalize(data['img'][0, :, :, :])).copy()
49+
50+
img[img<0] = 0
51+
img[img>255] = 255
52+
53+
img = np.transpose(img, (1, 2, 0))
54+
55+
img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
56+
57+
for j in range(idxs[0].shape[0]):
58+
bbox = transformed_anchors[idxs[0][j], :]
59+
x1 = int(bbox[0])
60+
y1 = int(bbox[1])
61+
x2 = int(bbox[2])
62+
y2 = int(bbox[3])
63+
label_name = dataset_val.labels[int(classification[idxs[0][j]])]
64+
draw_caption(img, (x1, y1, x2, y2), label_name)
65+
66+
cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=2)
67+
print(label_name)
68+
69+
cv2.imshow('img', img)
70+
cv2.waitKey(0)

0 commit comments

Comments
 (0)