Skip to content

Commit 42390c8

Browse files
committed
update dataset
1 parent ccf5e59 commit 42390c8

File tree

5 files changed

+10367
-162
lines changed

5 files changed

+10367
-162
lines changed

dataset.py

Lines changed: 25 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1-
#encoding:utf-8
2-
#
3-
#created by xiongzihua
4-
#
5-
'''
6-
txt描述文件 image_name.jpg x y w h c x y w h c 这样就是说一张图片中有两个目标
7-
'''
8-
import os
9-
import sys
10-
import os.path
11-
import cv2
12-
13-
import random
14-
import numpy as np
1+
# encoding:utf-8
2+
# dataset loader for VOC dataset.
3+
"""
4+
how to use
5+
testdata = Yolodata(file_root = 'xxx/VOCdevkit/VOC2012/JPEGImages/', listano = 'xxx/voc2012.txt',batchsize=2)
6+
"""
157

168
import torch
179
import torch.utils.data as data
1810
import torchvision.transforms as transforms
1911

2012
from PIL import Image
21-
import matplotlib.pyplot as plt
13+
import cv2
14+
15+
import os
16+
import sys
17+
import os.path
18+
19+
import random
20+
import numpy as np
2221

2322
def pil_loader(path):
2423
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
@@ -80,17 +79,16 @@ def __getitem__(self,idx):
8079
img = self.loader(os.path.join(self.root+fname))
8180
boxes = self.boxes[idx].clone()
8281
labels = self.labels[idx].clone()
83-
img,boxes,labels = self.cvtransform(img,boxes,labels) # 各种变换用torch 自带的的transform做不到,所以借鉴了xiongzihua 的git hub上的代码写了一点cv变换
82+
if self.train:
83+
img,boxes,labels = self.cvtransform(img,boxes,labels) # 各种变换用torch 自带的的transform做不到,所以借鉴了xiongzihua 的git hub上的代码写了一点cv变换
8484
h,w,_= img.shape
8585
boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes)
86-
#print('bounding box is ')
87-
##print(labels)
88-
#target = self.encoder(boxes,labels)# 7x7x30
86+
8987
target = self.make_target(labels,boxes)
9088
target = torch.tensor(target).float()
9189

9290
img = self.BGR2RGB(img) #because pytorch pretrained model use RGB
93-
#img = self.subMean(img,self.mean) #减去均值
91+
9492
img = cv2.resize(img,(self.image_size,self.image_size))
9593

9694
if self.transform is not None:
@@ -114,31 +112,6 @@ def cvtransform(self, img, boxes, labels):
114112

115113
def __len__(self):
116114
return self.num_samples
117-
118-
def encoder(self,boxes,labels):
119-
'''
120-
boxes (tensor) [[x1,y1,x2,y2],[]]
121-
labels (tensor) [...]
122-
return [self.S, self.S, self.B*5+self.C]
123-
'''
124-
grid_num = 14
125-
target = torch.zeros((grid_num,grid_num,30))
126-
cell_size = 1./grid_num
127-
wh = boxes[:,2:]-boxes[:,:2] #这是x-x2和 y-y2
128-
cxcy = (boxes[:,2:]+boxes[:,:2])/2 #这时中心点坐标
129-
for i in range(cxcy.size()[0]):
130-
cxcy_sample = cxcy[i]
131-
ij = (cxcy_sample/cell_size).ceil()-1 #
132-
target[int(ij[1]),int(ij[0]),4] = 1
133-
target[int(ij[1]),int(ij[0]),9] = 1
134-
target[int(ij[1]),int(ij[0]),int(labels[i])+9] = 1
135-
xy = ij*cell_size #匹配到的网格的左上角相对坐标
136-
delta_xy = (cxcy_sample -xy)/cell_size
137-
target[int(ij[1]),int(ij[0]),2:4] = wh[i]
138-
target[int(ij[1]),int(ij[0]),:2] = delta_xy
139-
target[int(ij[1]),int(ij[0]),7:9] = wh[i]
140-
target[int(ij[1]),int(ij[0]),5:7] = delta_xy
141-
return target
142115

143116
def change_box_to_center_axes(self, bboxes):
144117
rebboxes = []
@@ -151,7 +124,7 @@ def change_box_to_center_axes(self, bboxes):
151124
def make_target(self, labels, bboxes):
152125
"""make location np.ndarray from bboxes of an image
153126
154-
Parameters
127+
Input
155128
----------
156129
labels : list
157130
[0, 1, 4, 2, ...]
@@ -163,7 +136,6 @@ def make_target(self, labels, bboxes):
163136
-------
164137
np.ndarray
165138
[self.S, self.S, self.B*5+self.C]
166-
location array
167139
"""
168140

169141
bboxes = self.change_box_to_center_axes(bboxes)
@@ -188,10 +160,10 @@ def make_target(self, labels, bboxes):
188160
w = bboxes[:, 2].reshape(-1, 1)
189161
h = bboxes[:, 3].reshape(-1, 1)
190162

191-
x_idx = np.ceil(x_center * self.S) - 1
163+
x_idx = np.ceil(x_center * self.S) - 1 # 看这个bounding box 在哪个grid 里面
192164
y_idx = np.ceil(y_center * self.S) - 1
193165
# for exception 0, ceil(0)-1 = -1
194-
x_idx[x_idx<0] = 0
166+
x_idx[x_idx<0] = 0
195167
y_idx[y_idx<0] = 0
196168

197169
# calc offset of x_center, y_center
@@ -351,15 +323,15 @@ def random_bright(self, im, delta=16):
351323
return im
352324

353325
class Yolodata():
354-
def __init__(self, file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', listano = './voc2012.txt',batchsize=2):
326+
def __init__(self, train_file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', train_listano = './voc2012.txt', test_file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', test_listano = './voc2012.txt' ,batchsize=2):
355327
transform_train = transforms.Compose([
356328
#transforms.Resize(448),
357329
#transforms.RandomCrop(448),
358330
#transforms.RandomHorizontalFlip(),
359331
transforms.ToTensor(),
360332
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
361333

362-
img_data = VOCDataset(root = file_root,list_file=listano,train=True,transform=transform_train,loader = cv_loader)
334+
img_data = VOCDataset(root = train_file_root,list_file=train_listano,train=True,transform=transform_train,loader = cv_loader)
363335
train_loader = torch.utils.data.DataLoader(img_data, batch_size=batchsize,shuffle=True)
364336
self.train_loader = train_loader
365337
#self.img_data=img_data
@@ -370,7 +342,7 @@ def __init__(self, file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGIma
370342
transforms.ToTensor(),
371343
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
372344

373-
img_data_t = VOCDataset(root = file_root,list_file='voc2012.txt',train=False,transform=transform_test,loader = cv_loader)
345+
img_data_t = VOCDataset(root = test_file_root,list_file=test_listano,train=False,transform=transform_test,loader = cv_loader)
374346
test_loader = torch.utils.data.DataLoader(img_data_t, batch_size=int(0.5*batchsize),shuffle=False)
375347
self.test_loader = test_loader
376348

@@ -388,7 +360,7 @@ def getdata(self):
388360

389361
if __name__ == '__main__':
390362
#testdata = Yolodata(file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', listano = './voc2012.txt',batchsize=2)
391-
testdata = Yolodata(file_root = '/Users/duanyiqun/Downloads/VOCdevkit/VOC2012/JPEGImages/', listano = './voc2012.txt',batchsize=2)
363+
testdata = Yolodata(train_file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', train_listano = './voc2012.txt', test_file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', test_listano = './voc2012.txt' ,batchsize=2)
392364
testdata.test()
393365

394366

0 commit comments

Comments
 (0)