Skip to content

Commit 7927038

Browse files
committed
no backprop during testing
1 parent a99f8a9 commit 7927038

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

test.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
3434
dataloader_val = DataLoader(dataset_val, num_workers=1, collate_fn=collater, batch_sampler=sampler_val)
3535

36-
model = torch.load('csv_model_3.pt')
36+
model = torch.load('csv_model_1.pt')
3737

3838
use_gpu = True
3939

@@ -44,7 +44,6 @@
4444

4545
unnormalize = UnNormalizer()
4646

47-
4847
def draw_caption(image, box, caption):
4948

5049
b = np.array(box).astype(int)
@@ -53,29 +52,31 @@ def draw_caption(image, box, caption):
5352

5453
for idx, data in enumerate(dataloader_val):
5554

56-
scores, classification, transformed_anchors = model(data['img'].cuda().float())
57-
58-
idxs = np.where(scores>0.5)
59-
img = np.array(255 * unnormalize(data['img'][0, :, :, :])).copy()
55+
with torch.no_grad():
56+
st = time.time()
57+
scores, classification, transformed_anchors = model(data['img'].cuda().float())
58+
print('Elapsed time: {}'.format(time.time()-st))
59+
idxs = np.where(scores>0.5)
60+
img = np.array(255 * unnormalize(data['img'][0, :, :, :])).copy()
6061

61-
img[img<0] = 0
62-
img[img>255] = 255
62+
img[img<0] = 0
63+
img[img>255] = 255
6364

64-
img = np.transpose(img, (1, 2, 0))
65+
img = np.transpose(img, (1, 2, 0))
6566

66-
img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
67+
img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
6768

68-
for j in range(idxs[0].shape[0]):
69-
bbox = transformed_anchors[idxs[0][j], :]
70-
x1 = int(bbox[0])
71-
y1 = int(bbox[1])
72-
x2 = int(bbox[2])
73-
y2 = int(bbox[3])
74-
label_name = dataset_val.labels[int(classification[idxs[0][j]])]
75-
draw_caption(img, (x1, y1, x2, y2), label_name)
69+
for j in range(idxs[0].shape[0]):
70+
bbox = transformed_anchors[idxs[0][j], :]
71+
x1 = int(bbox[0])
72+
y1 = int(bbox[1])
73+
x2 = int(bbox[2])
74+
y2 = int(bbox[3])
75+
label_name = dataset_val.labels[int(classification[idxs[0][j]])]
76+
draw_caption(img, (x1, y1, x2, y2), label_name)
7677

77-
cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=2)
78-
print(label_name)
78+
cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=2)
79+
print(label_name)
7980

80-
cv2.imshow('img', img)
81-
cv2.waitKey(0)
81+
#cv2.imshow('img', img)
82+
#cv2.waitKey(0)

0 commit comments

Comments
 (0)