Skip to content

Commit d93b53c

Browse files
committed
remove unnecessary end-of-line space
1 parent 9cc3910 commit d93b53c

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

dataset/mnist.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,46 +29,46 @@
2929

3030
def _download(file_name):
3131
file_path = dataset_dir + "/" + file_name
32-
32+
3333
if os.path.exists(file_path):
3434
return
3535

3636
print("Downloading " + file_name + " ... ")
3737
urllib.request.urlretrieve(url_base + file_name, file_path)
3838
print("Done")
39-
39+
4040
def download_mnist():
4141
for v in key_file.values():
4242
_download(v)
43-
43+
4444
def _load_label(file_name):
4545
file_path = dataset_dir + "/" + file_name
46-
46+
4747
print("Converting " + file_name + " to NumPy Array ...")
4848
with gzip.open(file_path, 'rb') as f:
4949
labels = np.frombuffer(f.read(), np.uint8, offset=8)
5050
print("Done")
51-
51+
5252
return labels
5353

5454
def _load_img(file_name):
5555
file_path = dataset_dir + "/" + file_name
56-
57-
print("Converting " + file_name + " to NumPy Array ...")
56+
57+
print("Converting " + file_name + " to NumPy Array ...")
5858
with gzip.open(file_path, 'rb') as f:
5959
data = np.frombuffer(f.read(), np.uint8, offset=16)
6060
data = data.reshape(-1, img_size)
6161
print("Done")
62-
62+
6363
return data
64-
64+
6565
def _convert_numpy():
6666
dataset = {}
6767
dataset['train_img'] = _load_img(key_file['train_img'])
68-
dataset['train_label'] = _load_label(key_file['train_label'])
68+
dataset['train_label'] = _load_label(key_file['train_label'])
6969
dataset['test_img'] = _load_img(key_file['test_img'])
7070
dataset['test_label'] = _load_label(key_file['test_label'])
71-
71+
7272
return dataset
7373

7474
def init_mnist():
@@ -83,45 +83,45 @@ def _change_ont_hot_label(X):
8383
T = np.zeros((X.size, 10))
8484
for idx, row in enumerate(T):
8585
row[X[idx]] = 1
86-
86+
8787
return T
88-
88+
8989

9090
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
9191
"""MNISTデータセットの読み込み
92-
92+
9393
Parameters
9494
----------
9595
normalize : 画像のピクセル値を0.0~1.0に正規化する
96-
one_hot_label :
96+
one_hot_label :
9797
one_hot_labelがTrueの場合、ラベルはone-hot配列として返す
9898
one-hot配列とは、たとえば[0,0,1,0,0,0,0,0,0,0]のような配列
99-
flatten : 画像を一次元配列に平にするかどうか
100-
99+
flatten : 画像を一次元配列に平にするかどうか
100+
101101
Returns
102102
-------
103103
(訓練画像, 訓練ラベル), (テスト画像, テストラベル)
104104
"""
105105
if not os.path.exists(save_file):
106106
init_mnist()
107-
107+
108108
with open(save_file, 'rb') as f:
109109
dataset = pickle.load(f)
110-
110+
111111
if normalize:
112112
for key in ('train_img', 'test_img'):
113113
dataset[key] = dataset[key].astype(np.float32)
114114
dataset[key] /= 255.0
115-
115+
116116
if one_hot_label:
117117
dataset['train_label'] = _change_ont_hot_label(dataset['train_label'])
118-
dataset['test_label'] = _change_ont_hot_label(dataset['test_label'])
119-
118+
dataset['test_label'] = _change_ont_hot_label(dataset['test_label'])
119+
120120
if not flatten:
121121
for key in ('train_img', 'test_img'):
122122
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
123123

124-
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
124+
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
125125

126126

127127
if __name__ == '__main__':

0 commit comments

Comments
 (0)