Skip to content

Commit 0074eb2

Browse files
committed
initial
1 parent 8f58fa2 commit 0074eb2

File tree

5 files changed

+17529
-0
lines changed

5 files changed

+17529
-0
lines changed

dataset.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#encoding:utf-8
2+
#
3+
#created by xiongzihua
4+
#
5+
'''
6+
txt描述文件 image_name.jpg x y w h c x y w h c 这样就是说一张图片中有两个目标
7+
'''
8+
import os
9+
import sys
10+
import os.path
11+
12+
import random
13+
import numpy as np
14+
15+
import torch
16+
import torch.utils.data as data
17+
import torchvision.transforms as transforms
18+
19+
from PIL import Image
20+
import matplotlib.pyplot as plt
21+
22+
def pil_loader(path):
23+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
24+
with open(path, 'rb') as f:
25+
img = Image.open(f)
26+
return img.convert('RGB')
27+
28+
class VOCDataset(data.Dataset):
29+
image_size = 448
30+
def __init__(self,root,list_file,train,transform,loader):
31+
print('loading annotations')
32+
self.loader = loader
33+
self.root=root
34+
self.train = train
35+
self.transform=transform
36+
self.fnames = []
37+
self.boxes = []
38+
self.labels = []
39+
40+
if isinstance(list_file, list):
41+
# Cat multiple list files together.
42+
# This is especially useful for voc07/voc12 combination.
43+
tmp_file = '/tmp/listfile.txt'
44+
os.system('cat %s > %s' % (' '.join(list_file), tmp_file))
45+
list_file = tmp_file
46+
47+
with open(list_file) as f:
48+
lines = f.readlines()
49+
50+
for line in lines:
51+
splited = line.strip().split()
52+
self.fnames.append(splited[0])
53+
num_boxes = (len(splited) - 1) // 5
54+
box=[]
55+
label=[]
56+
for i in range(num_boxes):
57+
x = float(splited[1+5*i])
58+
y = float(splited[2+5*i])
59+
x2 = float(splited[3+5*i])
60+
y2 = float(splited[4+5*i])
61+
c = splited[5+5*i]
62+
box.append([x,y,x2,y2])
63+
label.append(int(c)+1)
64+
self.boxes.append(torch.Tensor(box))
65+
self.labels.append(torch.LongTensor(label))
66+
self.num_samples = len(self.boxes)
67+
68+
def __getitem__(self,idx):
69+
fname = self.fnames[idx]
70+
img = self.loader(os.path.join(self.root+fname))
71+
boxes = self.boxes[idx].clone()
72+
labels = self.labels[idx].clone()
73+
74+
if self.transform is not None:
75+
img = self.transform(img)
76+
77+
h,w,_ = img.shape
78+
boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes)
79+
target = self.encoder(boxes,labels)# 7x7x30
80+
81+
return img,target
82+
83+
def __len__(self):
84+
return self.num_samples
85+
86+
def encoder(self,boxes,labels):
87+
'''
88+
boxes (tensor) [[x1,y1,x2,y2],[]]
89+
labels (tensor) [...]
90+
return 7x7x30
91+
'''
92+
grid_num = 14
93+
target = torch.zeros((grid_num,grid_num,30))
94+
cell_size = 1./grid_num
95+
wh = boxes[:,2:]-boxes[:,:2]
96+
cxcy = (boxes[:,2:]+boxes[:,:2])/2
97+
for i in range(cxcy.size()[0]):
98+
cxcy_sample = cxcy[i]
99+
ij = (cxcy_sample/cell_size).ceil()-1 #
100+
target[int(ij[1]),int(ij[0]),4] = 1
101+
target[int(ij[1]),int(ij[0]),9] = 1
102+
target[int(ij[1]),int(ij[0]),int(labels[i])+9] = 1
103+
xy = ij*cell_size #匹配到的网格的左上角相对坐标
104+
delta_xy = (cxcy_sample -xy)/cell_size
105+
target[int(ij[1]),int(ij[0]),2:4] = wh[i]
106+
target[int(ij[1]),int(ij[0]),:2] = delta_xy
107+
target[int(ij[1]),int(ij[0]),7:9] = wh[i]
108+
target[int(ij[1]),int(ij[0]),5:7] = delta_xy
109+
return target
110+
111+
class Yolodata():
112+
def __init__(self, file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', listano = './voc2012.txt',batchsize=2):
113+
transform_train = transforms.Compose([
114+
transforms.Resize(224),
115+
transforms.RandomCrop(224),
116+
transforms.RandomHorizontalFlip(),
117+
transforms.ToTensor(),
118+
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
119+
120+
img_data = VOCDataset(root = file_root,list_file=listano,train=True,transform=transform_train,loader = pil_loader)
121+
train_loader = torch.utils.data.DataLoader(img_data, batch_size=batchsize,shuffle=True)
122+
self.train_loader = train_loader
123+
#self.img_data=img_data
124+
125+
transform_test = transforms.Compose([
126+
transforms.Resize(224),
127+
transforms.CenterCrop(224),
128+
transforms.ToTensor(),
129+
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
130+
131+
img_data_t = VOCDataset(root = file_root,list_file='voc2012.txt',train=False,transform=transform_test,loader = pil_loader)
132+
test_loader = torch.utils.data.DataLoader(img_data_t, batch_size=int(0.5*batchsize),shuffle=False)
133+
self.test_loader = test_loader
134+
135+
def test(self):
136+
#print(len(self.img_data))
137+
print('there are total %s batches in training and total %s batches for test' % (len(self.train_loader),len(self.test_loader)))
138+
for i, (batch_x, batch_y) in enumerate(self.train_loader):
139+
print( batch_x.size(), batch_y)
140+
for i, (batch_x, batch_y) in enumerate(self.test_loader):
141+
print( batch_x.size(), batch_y)
142+
143+
def getdata(self):
144+
return self.train_loader, self.test_loader
145+
146+
147+
if __name__ == '__main__':
148+
testdata = Yolodata(file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', listano = './voc2012.txt',batchsize=2)
149+
testdata.test()
150+
151+

visualize.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import visdom
2+
import numpy as np
3+
4+
class Visualizer():
5+
def __init__(self, env='main', **kwargs):
6+
'''
7+
**kwargs, dict option
8+
'''
9+
self.vis = visdom.Visdom(env=env)
10+
self.index = {} # x, dict
11+
self.log_text = ''
12+
self.env = env
13+
14+
def plot_train_val(self, loss_train=None, loss_val=None):
15+
'''
16+
plot val loss and train loss in one figure
17+
'''
18+
x = self.index.get('train_val', 0)
19+
20+
if x == 0:
21+
loss = loss_train if loss_train else loss_val
22+
win_y = np.column_stack((loss, loss))
23+
win_x = np.column_stack((x, x))
24+
self.win = self.vis.line(Y=win_y, X=win_x,
25+
env=self.env)
26+
# opts=dict(
27+
# title='train_test_loss',
28+
# ))
29+
self.index['train_val'] = x + 1
30+
return
31+
32+
if loss_train != None:
33+
self.vis.line(Y=np.array([loss_train]), X=np.array([x]),
34+
win=self.win,
35+
name='1',
36+
update='append',
37+
env=self.env)
38+
self.index['train_val'] = x + 5
39+
else:
40+
self.vis.line(Y=np.array([loss_val]), X=np.array([x]),
41+
win=self.win,
42+
name='2',
43+
update='append',
44+
env=self.env)
45+
46+
def plot_many(self, d):
47+
'''
48+
d: dict {name, value}
49+
'''
50+
for k, v in d.iteritems():
51+
self.plot(k, v)
52+
53+
def plot(self, name, y, **kwargs):
54+
'''
55+
plot('loss', 1.00)
56+
'''
57+
x = self.index.get(name, 0) # if none, return 0
58+
self.vis.line(Y=np.array([y]), X=np.array([x]),
59+
win=name,
60+
opts=dict(title=name),
61+
update=None if x== 0 else 'append',
62+
**kwargs)
63+
self.index[name] = x + 1
64+
65+
def log(self, info, win='log_text'):
66+
'''
67+
show text in box not write into txt?
68+
'''
69+
pass

0 commit comments

Comments
 (0)