Skip to content

Commit 5022a04

Browse files
authored
Update unet.py
1 parent 2680297 commit 5022a04

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

unet.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import os
2+
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13
import numpy as np
24
from keras.models import *
35
from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D
46
from keras.optimizers import *
57
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
68
from keras import backend as keras
7-
from data import dataProcess
9+
from data import *
810

911
class myUnet(object):
1012

@@ -155,16 +157,28 @@ def train(self):
155157

156158
model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='loss',verbose=1, save_best_only=True)
157159
print('Fitting model...')
158-
model.fit(imgs_train, imgs_mask_train, batch_size=1, nb_epoch=10, verbose=1, shuffle=True, callbacks=[model_checkpoint])
160+
model.fit(imgs_train, imgs_mask_train, batch_size=4, nb_epoch=10, verbose=1,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])
159161

160162
print('predict test data')
161163
imgs_mask_test = model.predict(imgs_test, batch_size=1, verbose=1)
162-
np.save('imgs_mask_test.npy', imgs_mask_test)
164+
np.save('../results/imgs_mask_test.npy', imgs_mask_test)
165+
166+
def save_img(self):
167+
168+
print("array to image")
169+
imgs = np.load('imgs_mask_test.npy')
170+
for i in range(imgs.shape[0]):
171+
img = imgs[i]
172+
img = array_to_img(img)
173+
img.save("../results/%d.jpg"%(i))
174+
175+
163176

164177

165178
if __name__ == '__main__':
166179
myunet = myUnet()
167180
myunet.train()
181+
myunet.save_img()
168182

169183

170184

0 commit comments

Comments
 (0)