29
29
30
30
def _download (file_name ):
31
31
file_path = dataset_dir + "/" + file_name
32
-
32
+
33
33
if os .path .exists (file_path ):
34
34
return
35
35
36
36
print ("Downloading " + file_name + " ... " )
37
37
urllib .request .urlretrieve (url_base + file_name , file_path )
38
38
print ("Done" )
39
-
39
+
40
40
def download_mnist ():
41
41
for v in key_file .values ():
42
42
_download (v )
43
-
43
+
44
44
def _load_label (file_name ):
45
45
file_path = dataset_dir + "/" + file_name
46
-
46
+
47
47
print ("Converting " + file_name + " to NumPy Array ..." )
48
48
with gzip .open (file_path , 'rb' ) as f :
49
49
labels = np .frombuffer (f .read (), np .uint8 , offset = 8 )
50
50
print ("Done" )
51
-
51
+
52
52
return labels
53
53
54
54
def _load_img (file_name ):
55
55
file_path = dataset_dir + "/" + file_name
56
-
57
- print ("Converting " + file_name + " to NumPy Array ..." )
56
+
57
+ print ("Converting " + file_name + " to NumPy Array ..." )
58
58
with gzip .open (file_path , 'rb' ) as f :
59
59
data = np .frombuffer (f .read (), np .uint8 , offset = 16 )
60
60
data = data .reshape (- 1 , img_size )
61
61
print ("Done" )
62
-
62
+
63
63
return data
64
-
64
+
65
65
def _convert_numpy ():
66
66
dataset = {}
67
67
dataset ['train_img' ] = _load_img (key_file ['train_img' ])
68
- dataset ['train_label' ] = _load_label (key_file ['train_label' ])
68
+ dataset ['train_label' ] = _load_label (key_file ['train_label' ])
69
69
dataset ['test_img' ] = _load_img (key_file ['test_img' ])
70
70
dataset ['test_label' ] = _load_label (key_file ['test_label' ])
71
-
71
+
72
72
return dataset
73
73
74
74
def init_mnist ():
@@ -83,45 +83,45 @@ def _change_ont_hot_label(X):
83
83
T = np .zeros ((X .size , 10 ))
84
84
for idx , row in enumerate (T ):
85
85
row [X [idx ]] = 1
86
-
86
+
87
87
return T
88
-
88
+
89
89
90
90
def load_mnist (normalize = True , flatten = True , one_hot_label = False ):
91
91
"""MNISTデータセットの読み込み
92
-
92
+
93
93
Parameters
94
94
----------
95
95
normalize : 画像のピクセル値を0.0~1.0に正規化する
96
- one_hot_label :
96
+ one_hot_label :
97
97
one_hot_labelがTrueの場合、ラベルはone-hot配列として返す
98
98
one-hot配列とは、たとえば[0,0,1,0,0,0,0,0,0,0]のような配列
99
- flatten : 画像を一次元配列に平にするかどうか
100
-
99
+ flatten : 画像を一次元配列に平にするかどうか
100
+
101
101
Returns
102
102
-------
103
103
(訓練画像, 訓練ラベル), (テスト画像, テストラベル)
104
104
"""
105
105
if not os .path .exists (save_file ):
106
106
init_mnist ()
107
-
107
+
108
108
with open (save_file , 'rb' ) as f :
109
109
dataset = pickle .load (f )
110
-
110
+
111
111
if normalize :
112
112
for key in ('train_img' , 'test_img' ):
113
113
dataset [key ] = dataset [key ].astype (np .float32 )
114
114
dataset [key ] /= 255.0
115
-
115
+
116
116
if one_hot_label :
117
117
dataset ['train_label' ] = _change_ont_hot_label (dataset ['train_label' ])
118
- dataset ['test_label' ] = _change_ont_hot_label (dataset ['test_label' ])
119
-
118
+ dataset ['test_label' ] = _change_ont_hot_label (dataset ['test_label' ])
119
+
120
120
if not flatten :
121
121
for key in ('train_img' , 'test_img' ):
122
122
dataset [key ] = dataset [key ].reshape (- 1 , 1 , 28 , 28 )
123
123
124
- return (dataset ['train_img' ], dataset ['train_label' ]), (dataset ['test_img' ], dataset ['test_label' ])
124
+ return (dataset ['train_img' ], dataset ['train_label' ]), (dataset ['test_img' ], dataset ['test_label' ])
125
125
126
126
127
127
if __name__ == '__main__' :
0 commit comments