1+ #coding=utf-8
2+ #图片着色
3+ import keras
4+ # import tensorflow as tf
5+ from skimage .io import imread , imsave
6+ from skimage .color import rgb2gray , gray2rgb , rgb2lab , lab2rgb
7+ from keras .models import Sequential
8+ from keras .layers import Conv2D , UpSampling2D , InputLayer , Conv2DTranspose
9+ from keras .preprocessing .image import img_to_array , load_img
10+ import numpy as np
11+ from keras .preprocessing .image import ImageDataGenerator
12+ import os
13+ import cv2
14+
15+
16+ def get_train_data (img_file ):
17+ image = img_to_array (load_img (img_file ))
18+ image_shape = image .shape
19+ image = np .array (image , dtype = float )
20+ x = rgb2lab (1.0 / 255 * image )[:, :, 0 ]
21+ y = rgb2lab (1.0 / 255 * image )[:, :, 1 :]
22+ y /= 128
23+ x = x .reshape (1 , image_shape [0 ], image_shape [1 ], 1 )
24+ y = y .reshape (1 , image_shape [0 ], image_shape [1 ], 2 )
25+ return x , y , image_shape
26+
27+
28+ def build_model ():
29+ model = Sequential ()
30+ model .add (InputLayer (input_shape = (None , None , 1 )))
31+ model .add (Conv2D (8 , (3 , 3 ), activation = 'relu' , padding = 'same' , strides = 2 ))
32+ model .add (Conv2D (8 , (3 , 3 ), activation = 'relu' , padding = 'same' ))
33+ model .add (Conv2D (16 , (3 , 3 ), activation = 'relu' , padding = 'same' ))
34+ model .add (Conv2D (16 , (3 , 3 ), activation = 'relu' , padding = 'same' , strides = 2 ))
35+ model .add (Conv2D (32 , (3 , 3 ), activation = 'relu' , padding = 'same' ))
36+ model .add (Conv2D (32 , (3 , 3 ), activation = 'relu' , padding = 'same' , strides = 2 ))
37+ model .add (UpSampling2D ((2 , 2 )))
38+ model .add (Conv2D (32 , (3 , 3 ), activation = 'relu' , padding = 'same' ))
39+ model .add (UpSampling2D ((2 , 2 )))
40+ model .add (Conv2D (16 , (3 , 3 ), activation = 'relu' , padding = 'same' ))
41+ model .add (UpSampling2D ((2 , 2 )))
42+ model .add (Conv2D (2 , (3 , 3 ), activation = 'tanh' , padding = 'same' ))
43+ # model.compile(optimizer='rmsprop', loss='mse')
44+ model .compile (optimizer = 'adam' , loss = 'mse' )
45+ return model
46+
47+
48+ #训练数据
49+ def train ():
50+ x , y , img_shape = get_train_data ('./img/colorize/colorize-original.png' )
51+
52+ # x2, y2, img_shape2 = get_train_data(
53+ # './img/colorize/colorize2-original.png')
54+
55+ model = build_model ()
56+ num_epochs = 1000 #训练次数
57+ batch_size = 1
58+
59+ model .fit (x , y , batch_size = batch_size , epochs = num_epochs )
60+ # model.fit(x2, y2, batch_size=batch_size, epochs=num_epochs)
61+ model .save ('./data/simple_colorize.h5' )
62+
63+
64+ #着色
65+ def colorize ():
66+ path = './img/colorize/colorize2.png'
67+ # cv2.imwrite('./img/colorize3.png', cv2.imread(path, 0))
68+ x , y , image_shape = get_train_data (path )
69+ model = build_model ()
70+ model .load_weights ('./data/simple_colorize.h5' )
71+ output = model .predict (x )
72+ output *= 128
73+ tmp = np .zeros ((200 , 200 , 3 ))
74+ tmp [:, :, 0 ] = x [0 ][:, :, 0 ]
75+ tmp [:, :, 1 :] = output [0 ]
76+ colorizePath = path .replace (".png" , "-res.png" )
77+ imsave (colorizePath , lab2rgb (tmp ))
78+ cv2 .imshow ("I" , cv2 .imread (path ))
79+ cv2 .imshow ("II" , cv2 .imread (colorizePath ))
80+ cv2 .waitKey (0 )
81+ cv2 .destroyAllWindows ()
82+
83+ # imsave("test_image_gray.png", rgb2gray(lab2rgb(tmp)))
84+
85+
86+ if __name__ == '__main__' :
87+ # train()
88+ colorize ()
0 commit comments