Skip to content

Commit 912fe22

Browse files
authored
Add files via upload
1 parent 9576714 commit 912fe22

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

LIVEFolder.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import torch.utils.data as data
2+
3+
from PIL import Image
4+
5+
import os
6+
import os.path
7+
#import math
8+
import scipy.io
9+
import numpy as np
10+
import random
11+
12+
13+
def getFileName(path, suffix):
14+
''' 获取指定目录下的所有指定后缀的文件名 '''
15+
filename = []
16+
f_list = os.listdir(path)
17+
# print f_list
18+
for i in f_list:
19+
# os.path.splitext():分离文件名与扩展名
20+
if os.path.splitext(i)[1] == suffix:
21+
filename.append(i)
22+
return filename
23+
24+
def getDistortionTypeFileName(path, num):
25+
filename = []
26+
index = 1
27+
for i in range(0,num):
28+
name = '%s%s%s' % ('img',str(index),'.bmp')
29+
filename.append(os.path.join(path,name))
30+
index = index + 1
31+
return filename
32+
33+
34+
35+
class LIVEFolder(data.Dataset):
36+
37+
def __init__(self, root, loader, index, transform=None, target_transform=None):
38+
39+
self.root = root
40+
self.loader = loader
41+
42+
self.refpath = os.path.join(self.root, 'refimgs')
43+
self.refname = getFileName( self.refpath,'.bmp')
44+
45+
self.jp2kroot = os.path.join(self.root, 'jp2k')
46+
self.jp2kname = getDistortionTypeFileName(self.jp2kroot,227)
47+
48+
self.jpegroot = os.path.join(self.root, 'jpeg')
49+
self.jpegname = getDistortionTypeFileName(self.jpegroot,233)
50+
51+
self.wnroot = os.path.join(self.root, 'wn')
52+
self.wnname = getDistortionTypeFileName(self.wnroot,174)
53+
54+
self.gblurroot = os.path.join(self.root, 'gblur')
55+
self.gblurname = getDistortionTypeFileName(self.gblurroot,174)
56+
57+
self.fastfadingroot = os.path.join(self.root, 'fastfading')
58+
self.fastfadingname = getDistortionTypeFileName(self.fastfadingroot,174)
59+
60+
self.imgpath = self.jp2kname + self.jpegname + self.wnname + self.gblurname + self.fastfadingname
61+
62+
self.dmos = scipy.io.loadmat(os.path.join(self.root, 'dmos_realigned.mat'))
63+
self.labels = self.dmos['dmos_new'].astype(np.float32)
64+
#self.labels = self.labels.tolist()[0]
65+
self.orgs = self.dmos['orgs']
66+
refnames_all = scipy.io.loadmat(os.path.join(self.root, 'refnames_all.mat'))
67+
self.refnames_all = refnames_all['refnames_all']
68+
69+
70+
sample = []
71+
72+
for i in range(0, len(index)):
73+
train_sel = (self.refname[index[i]] == self.refnames_all)
74+
train_sel = train_sel * ~self.orgs.astype(np.bool_)
75+
train_sel1 = np.where(train_sel == True)
76+
train_sel = train_sel1[1].tolist()
77+
for j, item in enumerate(train_sel):
78+
sample.append((self.imgpath[item],self.labels[0][item]))
79+
self.samples = sample
80+
self.transform = transform
81+
self.target_transform = target_transform
82+
83+
def __getitem__(self, index):
84+
"""
85+
Args:
86+
index (int): Index
87+
88+
Returns:
89+
tuple: (sample, target) where target is class_index of the target class.
90+
"""
91+
path, target = self.samples[index]
92+
sample = self.loader(path)
93+
if self.transform is not None:
94+
sample = self.transform(sample)
95+
if self.target_transform is not None:
96+
target = self.target_transform(target)
97+
98+
return sample, target
99+
100+
101+
def __len__(self):
102+
length = len(self.samples)
103+
return length
104+
105+
106+
107+
108+
def pil_loader(path):
109+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
110+
with open(path, 'rb') as f:
111+
img = Image.open(f)
112+
return img.convert('RGB')
113+
114+
115+
def accimage_loader(path):
116+
import accimage
117+
try:
118+
return accimage.Image(path)
119+
except IOError:
120+
# Potentially a decoding problem, fall back to PIL.Image
121+
return pil_loader(path)
122+
123+
124+
def default_loader(path):
125+
from torchvision import get_image_backend
126+
if get_image_backend() == 'accimage':
127+
return accimage_loader(path)
128+
else:
129+
return pil_loader(path)
130+
131+
if __name__ == '__main__':
132+
liveroot = 'D:\zwx_Project\zwx_IQA\dataset\databaserelease2'
133+
index = list(range(0,29))
134+
random.shuffle(index)
135+
train_index = index[0:round(0.8*29)]
136+
test_index = index[round(0.8*29):29]
137+
trainset = LIVEFolder(root = liveroot, loader = default_loader, index = train_index)
138+
testset = LIVEFolder(root = liveroot, loader = default_loader, index = test_index)

0 commit comments

Comments
 (0)