1
+ import torch .utils .data as data
2
+
3
+ from PIL import Image
4
+
5
+ import os
6
+ import os .path
7
+ #import math
8
+ import scipy .io
9
+ import numpy as np
10
+ import random
11
+
12
+
13
+ def getFileName (path , suffix ):
14
+ ''' 获取指定目录下的所有指定后缀的文件名 '''
15
+ filename = []
16
+ f_list = os .listdir (path )
17
+ # print f_list
18
+ for i in f_list :
19
+ # os.path.splitext():分离文件名与扩展名
20
+ if os .path .splitext (i )[1 ] == suffix :
21
+ filename .append (i )
22
+ return filename
23
+
24
+ def getDistortionTypeFileName (path , num ):
25
+ filename = []
26
+ index = 1
27
+ for i in range (0 ,num ):
28
+ name = '%s%s%s' % ('img' ,str (index ),'.bmp' )
29
+ filename .append (os .path .join (path ,name ))
30
+ index = index + 1
31
+ return filename
32
+
33
+
34
+
35
+ class LIVEFolder (data .Dataset ):
36
+
37
+ def __init__ (self , root , loader , index , transform = None , target_transform = None ):
38
+
39
+ self .root = root
40
+ self .loader = loader
41
+
42
+ self .refpath = os .path .join (self .root , 'refimgs' )
43
+ self .refname = getFileName ( self .refpath ,'.bmp' )
44
+
45
+ self .jp2kroot = os .path .join (self .root , 'jp2k' )
46
+ self .jp2kname = getDistortionTypeFileName (self .jp2kroot ,227 )
47
+
48
+ self .jpegroot = os .path .join (self .root , 'jpeg' )
49
+ self .jpegname = getDistortionTypeFileName (self .jpegroot ,233 )
50
+
51
+ self .wnroot = os .path .join (self .root , 'wn' )
52
+ self .wnname = getDistortionTypeFileName (self .wnroot ,174 )
53
+
54
+ self .gblurroot = os .path .join (self .root , 'gblur' )
55
+ self .gblurname = getDistortionTypeFileName (self .gblurroot ,174 )
56
+
57
+ self .fastfadingroot = os .path .join (self .root , 'fastfading' )
58
+ self .fastfadingname = getDistortionTypeFileName (self .fastfadingroot ,174 )
59
+
60
+ self .imgpath = self .jp2kname + self .jpegname + self .wnname + self .gblurname + self .fastfadingname
61
+
62
+ self .dmos = scipy .io .loadmat (os .path .join (self .root , 'dmos_realigned.mat' ))
63
+ self .labels = self .dmos ['dmos_new' ].astype (np .float32 )
64
+ #self.labels = self.labels.tolist()[0]
65
+ self .orgs = self .dmos ['orgs' ]
66
+ refnames_all = scipy .io .loadmat (os .path .join (self .root , 'refnames_all.mat' ))
67
+ self .refnames_all = refnames_all ['refnames_all' ]
68
+
69
+
70
+ sample = []
71
+
72
+ for i in range (0 , len (index )):
73
+ train_sel = (self .refname [index [i ]] == self .refnames_all )
74
+ train_sel = train_sel * ~ self .orgs .astype (np .bool_ )
75
+ train_sel1 = np .where (train_sel == True )
76
+ train_sel = train_sel1 [1 ].tolist ()
77
+ for j , item in enumerate (train_sel ):
78
+ sample .append ((self .imgpath [item ],self .labels [0 ][item ]))
79
+ self .samples = sample
80
+ self .transform = transform
81
+ self .target_transform = target_transform
82
+
83
+ def __getitem__ (self , index ):
84
+ """
85
+ Args:
86
+ index (int): Index
87
+
88
+ Returns:
89
+ tuple: (sample, target) where target is class_index of the target class.
90
+ """
91
+ path , target = self .samples [index ]
92
+ sample = self .loader (path )
93
+ if self .transform is not None :
94
+ sample = self .transform (sample )
95
+ if self .target_transform is not None :
96
+ target = self .target_transform (target )
97
+
98
+ return sample , target
99
+
100
+
101
+ def __len__ (self ):
102
+ length = len (self .samples )
103
+ return length
104
+
105
+
106
+
107
+
108
+ def pil_loader (path ):
109
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
110
+ with open (path , 'rb' ) as f :
111
+ img = Image .open (f )
112
+ return img .convert ('RGB' )
113
+
114
+
115
+ def accimage_loader (path ):
116
+ import accimage
117
+ try :
118
+ return accimage .Image (path )
119
+ except IOError :
120
+ # Potentially a decoding problem, fall back to PIL.Image
121
+ return pil_loader (path )
122
+
123
+
124
+ def default_loader (path ):
125
+ from torchvision import get_image_backend
126
+ if get_image_backend () == 'accimage' :
127
+ return accimage_loader (path )
128
+ else :
129
+ return pil_loader (path )
130
+
131
+ if __name__ == '__main__' :
132
+ liveroot = 'D:\zwx_Project\zwx_IQA\dataset\databaserelease2'
133
+ index = list (range (0 ,29 ))
134
+ random .shuffle (index )
135
+ train_index = index [0 :round (0.8 * 29 )]
136
+ test_index = index [round (0.8 * 29 ):29 ]
137
+ trainset = LIVEFolder (root = liveroot , loader = default_loader , index = train_index )
138
+ testset = LIVEFolder (root = liveroot , loader = default_loader , index = test_index )
0 commit comments