|
| 1 | +#encoding:utf-8 |
| 2 | +# |
| 3 | +#created by xiongzihua |
| 4 | +# |
| 5 | +''' |
| 6 | +txt描述文件 image_name.jpg x y w h c x y w h c 这样就是说一张图片中有两个目标 |
| 7 | +''' |
| 8 | +import os |
| 9 | +import sys |
| 10 | +import os.path |
| 11 | + |
| 12 | +import random |
| 13 | +import numpy as np |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.utils.data as data |
| 17 | +import torchvision.transforms as transforms |
| 18 | + |
| 19 | +from PIL import Image |
| 20 | +import matplotlib.pyplot as plt |
| 21 | + |
| 22 | +def pil_loader(path): |
| 23 | + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) |
| 24 | + with open(path, 'rb') as f: |
| 25 | + img = Image.open(f) |
| 26 | + return img.convert('RGB') |
| 27 | + |
| 28 | +class VOCDataset(data.Dataset): |
| 29 | + image_size = 448 |
| 30 | + def __init__(self,root,list_file,train,transform,loader): |
| 31 | + print('loading annotations') |
| 32 | + self.loader = loader |
| 33 | + self.root=root |
| 34 | + self.train = train |
| 35 | + self.transform=transform |
| 36 | + self.fnames = [] |
| 37 | + self.boxes = [] |
| 38 | + self.labels = [] |
| 39 | + |
| 40 | + if isinstance(list_file, list): |
| 41 | + # Cat multiple list files together. |
| 42 | + # This is especially useful for voc07/voc12 combination. |
| 43 | + tmp_file = '/tmp/listfile.txt' |
| 44 | + os.system('cat %s > %s' % (' '.join(list_file), tmp_file)) |
| 45 | + list_file = tmp_file |
| 46 | + |
| 47 | + with open(list_file) as f: |
| 48 | + lines = f.readlines() |
| 49 | + |
| 50 | + for line in lines: |
| 51 | + splited = line.strip().split() |
| 52 | + self.fnames.append(splited[0]) |
| 53 | + num_boxes = (len(splited) - 1) // 5 |
| 54 | + box=[] |
| 55 | + label=[] |
| 56 | + for i in range(num_boxes): |
| 57 | + x = float(splited[1+5*i]) |
| 58 | + y = float(splited[2+5*i]) |
| 59 | + x2 = float(splited[3+5*i]) |
| 60 | + y2 = float(splited[4+5*i]) |
| 61 | + c = splited[5+5*i] |
| 62 | + box.append([x,y,x2,y2]) |
| 63 | + label.append(int(c)+1) |
| 64 | + self.boxes.append(torch.Tensor(box)) |
| 65 | + self.labels.append(torch.LongTensor(label)) |
| 66 | + self.num_samples = len(self.boxes) |
| 67 | + |
| 68 | + def __getitem__(self,idx): |
| 69 | + fname = self.fnames[idx] |
| 70 | + img = self.loader(os.path.join(self.root+fname)) |
| 71 | + boxes = self.boxes[idx].clone() |
| 72 | + labels = self.labels[idx].clone() |
| 73 | + |
| 74 | + if self.transform is not None: |
| 75 | + img = self.transform(img) |
| 76 | + |
| 77 | + h,w,_ = img.shape |
| 78 | + boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes) |
| 79 | + target = self.encoder(boxes,labels)# 7x7x30 |
| 80 | + |
| 81 | + return img,target |
| 82 | + |
| 83 | + def __len__(self): |
| 84 | + return self.num_samples |
| 85 | + |
| 86 | + def encoder(self,boxes,labels): |
| 87 | + ''' |
| 88 | + boxes (tensor) [[x1,y1,x2,y2],[]] |
| 89 | + labels (tensor) [...] |
| 90 | + return 7x7x30 |
| 91 | + ''' |
| 92 | + grid_num = 14 |
| 93 | + target = torch.zeros((grid_num,grid_num,30)) |
| 94 | + cell_size = 1./grid_num |
| 95 | + wh = boxes[:,2:]-boxes[:,:2] |
| 96 | + cxcy = (boxes[:,2:]+boxes[:,:2])/2 |
| 97 | + for i in range(cxcy.size()[0]): |
| 98 | + cxcy_sample = cxcy[i] |
| 99 | + ij = (cxcy_sample/cell_size).ceil()-1 # |
| 100 | + target[int(ij[1]),int(ij[0]),4] = 1 |
| 101 | + target[int(ij[1]),int(ij[0]),9] = 1 |
| 102 | + target[int(ij[1]),int(ij[0]),int(labels[i])+9] = 1 |
| 103 | + xy = ij*cell_size #匹配到的网格的左上角相对坐标 |
| 104 | + delta_xy = (cxcy_sample -xy)/cell_size |
| 105 | + target[int(ij[1]),int(ij[0]),2:4] = wh[i] |
| 106 | + target[int(ij[1]),int(ij[0]),:2] = delta_xy |
| 107 | + target[int(ij[1]),int(ij[0]),7:9] = wh[i] |
| 108 | + target[int(ij[1]),int(ij[0]),5:7] = delta_xy |
| 109 | + return target |
| 110 | + |
| 111 | +class Yolodata(): |
| 112 | + def __init__(self, file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', listano = './voc2012.txt',batchsize=2): |
| 113 | + transform_train = transforms.Compose([ |
| 114 | + transforms.Resize(224), |
| 115 | + transforms.RandomCrop(224), |
| 116 | + transforms.RandomHorizontalFlip(), |
| 117 | + transforms.ToTensor(), |
| 118 | + transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) |
| 119 | + |
| 120 | + img_data = VOCDataset(root = file_root,list_file=listano,train=True,transform=transform_train,loader = pil_loader) |
| 121 | + train_loader = torch.utils.data.DataLoader(img_data, batch_size=batchsize,shuffle=True) |
| 122 | + self.train_loader = train_loader |
| 123 | + #self.img_data=img_data |
| 124 | + |
| 125 | + transform_test = transforms.Compose([ |
| 126 | + transforms.Resize(224), |
| 127 | + transforms.CenterCrop(224), |
| 128 | + transforms.ToTensor(), |
| 129 | + transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) |
| 130 | + |
| 131 | + img_data_t = VOCDataset(root = file_root,list_file='voc2012.txt',train=False,transform=transform_test,loader = pil_loader) |
| 132 | + test_loader = torch.utils.data.DataLoader(img_data_t, batch_size=int(0.5*batchsize),shuffle=False) |
| 133 | + self.test_loader = test_loader |
| 134 | + |
| 135 | + def test(self): |
| 136 | + #print(len(self.img_data)) |
| 137 | + print('there are total %s batches in training and total %s batches for test' % (len(self.train_loader),len(self.test_loader))) |
| 138 | + for i, (batch_x, batch_y) in enumerate(self.train_loader): |
| 139 | + print( batch_x.size(), batch_y) |
| 140 | + for i, (batch_x, batch_y) in enumerate(self.test_loader): |
| 141 | + print( batch_x.size(), batch_y) |
| 142 | + |
| 143 | + def getdata(self): |
| 144 | + return self.train_loader, self.test_loader |
| 145 | + |
| 146 | + |
| 147 | +if __name__ == '__main__': |
| 148 | + testdata = Yolodata(file_root = '/home/claude.duan/data/VOCdevkit/VOC2012/JPEGImages/', listano = './voc2012.txt',batchsize=2) |
| 149 | + testdata.test() |
| 150 | + |
| 151 | + |
0 commit comments