1
1
import matplotlib
2
2
matplotlib .use ("Agg" )
3
- from wandb .keras import WandbCallback
4
- import wandb
5
- from keras .utils .generic_utils import get_custom_objects
6
- from keras import backend as K
7
- from keras .callbacks import Callback
8
- from keras .models import Model , load_model
9
- from keras import metrics #mse, binary_crossentropy
10
- from keras import layers
11
- import keras
12
- import plotly .graph_objs as go
13
- import plotly .plotly as py
14
- from sklearn import manifold
15
- from sklearn .decomposition import PCA
16
- import matplotlib .pyplot as plt
17
- import pdb
18
- import os
19
- import subprocess
20
- import pandas as pd
21
- import numpy as np
22
- import cv2
23
3
import sys
4
+ import cv2
5
+ import numpy as np
6
+ import pandas as pd
7
+ import subprocess
8
+ import os
9
+ import pdb
10
+ import matplotlib .pyplot as plt
11
+ from sklearn .decomposition import PCA
12
+ from sklearn import manifold
13
+ import plotly .plotly as py
14
+ import plotly .graph_objs as go
15
+ import keras
16
+ from keras import layers
17
+ from keras .models import Model , load_model
18
+ from keras .callbacks import Callback
19
+ from keras import backend as K
20
+ from keras .utils .generic_utils import get_custom_objects
21
+ import wandb
22
+ from wandb .keras import WandbCallback
24
23
25
24
26
- wandb .init (project = "cvae" )
25
+ wandb .init ()
27
26
wandb .config .latent_dim = 2
28
- wandb .config .labels = [str (i ) for i in range (10 )]# ["Happy", "Sad"]
27
+ wandb .config .labels = [str (i ) for i in range (10 )] # ["Happy", "Sad"]
29
28
wandb .config .batch_size = 128
30
29
wandb .config .epochs = 25
31
- wandb .config .variational = False
32
- wandb .config .conditional = False
30
+ wandb .config .conditional = True
31
+ wandb .config .latent_vis = False
33
32
wandb .config .dataset = "mnist"
34
33
35
34
EMOTIONS = ["Angry" , "Disgust" , "Fear" , "Happy" , "Sad" , "Surprise" , "Neutral" ]
@@ -66,24 +65,6 @@ def load_fer2013(filter_emotions=[]):
66
65
return train_faces , train_emotions , val_faces , val_emotions
67
66
68
67
69
- def sample (args ):
70
- '''
71
- Draws samples from a standard normal and scales the samples with
72
- standard deviation of the variational distribution and shifts them
73
- by the mean.
74
-
75
- Args:
76
- args: sufficient statistics of the variational distribution.
77
-
78
- Returns:
79
- Samples from the variational distribution.
80
- '''
81
- t_mean , t_log_var = args
82
- t_sigma = K .sqrt (K .exp (t_log_var ))
83
- epsilon = K .random_normal (shape = K .shape (t_mean ), mean = 0. , stddev = 1. )
84
- return t_mean + t_sigma * epsilon
85
-
86
-
87
68
def concat_label (args ):
88
69
'''
89
70
Converts 2d labels to 3d, i.e. [[0,1]] = [[[0,0],[0,0]],[[1,1],[1,1]]]
@@ -103,43 +84,46 @@ class ShowImages(Callback):
103
84
'''
104
85
Keras callback for logging predictions and a scatter plot of the latent dimension
105
86
'''
87
+
106
88
def on_epoch_end (self , epoch , logs ):
107
89
indicies = np .random .randint (X_test .shape [0 ], size = 36 )
108
90
latent_idx = np .random .randint (X_test .shape [0 ], size = 500 )
109
91
inputs = X_test [indicies ]
110
92
t_inputs = X_train [indicies ]
111
93
r_labels = y_test [indicies ]
112
94
rand_labels = np .random .randint (len (wandb .config .labels ), size = 35 )
113
- rand_labels = np .append (rand_labels , [len (wandb .config .labels ) - 1 ]) # always add max label
95
+ # always add max label
96
+ rand_labels = np .append (rand_labels , [len (wandb .config .labels ) - 1 ])
114
97
labels = keras .utils .to_categorical (rand_labels )
115
98
t_labels = y_train [indicies ]
116
99
117
- results = vae .predict ([inputs , r_labels , labels ])
118
- t_results = vae .predict ([t_inputs , t_labels , t_labels ])
119
- if wandb .config .variational :
120
- output = encoder .predict ([X_test [latent_idx ], y_test [latent_idx ]])
121
- latent = output [0 ] #K.eval(sample(output))
122
- else :
123
- latent = encoder .predict ([X_test [latent_idx ], y_test [latent_idx ]])
100
+ results = cae .predict ([inputs , r_labels , labels ])
101
+ t_results = cae .predict ([t_inputs , t_labels , t_labels ])
102
+ print ("Max pixel value" , t_results .max ())
103
+ latent = encoder .predict ([X_test [latent_idx ], y_test [latent_idx ]])
124
104
# Plot latent space
125
- # tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
126
- latent_vis = PCA (n_components = 2 )
127
- X = latent_vis .fit_transform (latent )
128
- trace = go .Scatter (x = list (X [:, 0 ]), y = list (X [:, 1 ]),
129
- mode = 'markers' , showlegend = False ,
130
- marker = dict (color = list (np .argmax (y_test [latent_idx ], axis = 1 )),
131
- colorscale = 'Viridis' ,
132
- size = 8 ,
133
- showscale = True ))
134
- fig = go .Figure (data = [trace ])
135
- wandb .log ({"latent_vis" : fig }, commit = False )
105
+ if wandb .config .latent_vis :
106
+ if wandb .config .latent_dim > 2 :
107
+ # latent_vis = manifold.TSNE(n_components=2, init='pca', random_state=0)
108
+ latent_vis = PCA (n_components = 2 )
109
+ X = latent_vis .fit_transform (latent )
110
+ else :
111
+ X = latent
112
+ trace = go .Scatter (x = list (X [:, 0 ]), y = list (X [:, 1 ]),
113
+ mode = 'markers' , showlegend = False ,
114
+ marker = dict (color = list (np .argmax (y_test [latent_idx ], axis = 1 )),
115
+ colorscale = 'Viridis' ,
116
+ size = 8 ,
117
+ showscale = True ))
118
+ fig = go .Figure (data = [trace ])
119
+ wandb .log ({"latent_vis" : fig }, commit = False )
136
120
# Always log training images
137
121
wandb .log ({
138
122
"train_images" : [wandb .Image (
139
123
np .hstack ([t_inputs [i ], res ])) for i , res in enumerate (t_results )
140
124
]
141
125
}, commit = False )
142
-
126
+
143
127
# Log image conversion when conditional
144
128
if wandb .config .conditional :
145
129
wandb .log ({
@@ -153,27 +137,16 @@ def on_epoch_end(self, epoch, logs):
153
137
def create_encoder (input_shape ):
154
138
'''
155
139
Create an encoder with an optional class append to the channel.
156
- Optionally outputting mean and stddev for a variational encoder
157
140
'''
158
141
encoder_input = layers .Input (shape = input_shape )
159
142
label_input = layers .Input (shape = (len (wandb .config .labels ),))
143
+ x = layers .Flatten ()(encoder_input )
160
144
if wandb .config .conditional :
161
- x = layers .Lambda (concat_label , name = "c" )([encoder_input , label_input ])
162
- else :
163
- x = encoder_input
164
- x = layers .Conv2D (32 , 3 , padding = 'same' ,
165
- activation = 'relu' , strides = (2 , 2 ))(x )
166
- x = layers .Conv2D (64 , 3 , padding = 'same' ,
167
- activation = 'relu' , strides = (2 , 2 ))(x )
168
- x = layers .Flatten ()(x )
169
- x = layers .Dense (32 , activation = 'relu' )(x )
170
- if wandb .config .variational :
171
- t_mean = layers .Dense (wandb .config .latent_dim , name = "latent_mean" )(x )
172
- t_log_var = layers .Dense (
173
- wandb .config .latent_dim , name = "latent_variance" )(x )
174
- output = [t_mean , t_log_var ]
175
- else :
176
- output = layers .Dense (wandb .config .latent_dim , activation = "relu" )(x )
145
+ #x = layers.Lambda(concat_label, name="c")([encoder_input, label_input])
146
+ x = layers .concatenate ([x , label_input ], axis = - 1 )
147
+
148
+ x = layers .Dense (512 , activation = "relu" )(x )
149
+ output = layers .Dense (wandb .config .latent_dim , activation = "relu" )(x )
177
150
178
151
return Model ([encoder_input , label_input ], output , name = 'encoder' )
179
152
@@ -185,15 +158,12 @@ def create_categorical_decoder():
185
158
decoder_input = layers .Input (shape = (wandb .config .latent_dim ,))
186
159
label_input = layers .Input (shape = (len (wandb .config .labels ),))
187
160
if wandb .config .conditional :
188
- x = layers .concatenate ([decoder_input , label_input ], axis = 1 )
161
+ x = layers .concatenate ([decoder_input , label_input ], axis = - 1 )
189
162
else :
190
163
x = decoder_input
191
- x = layers .Dense (img_size // 2 * img_size // 2 * 32 , activation = 'relu' )(x )
192
- x = layers .Reshape ((img_size // 2 , img_size // 2 , 32 ))(x )
193
- x = layers .Conv2D (32 , 3 , padding = 'same' , activation = 'relu' )(x )
194
- x = layers .Conv2DTranspose (
195
- 32 , 3 , padding = 'same' , activation = 'relu' , strides = (2 , 2 ))(x )
196
- x = layers .Conv2D (1 , 3 , padding = 'same' , activation = 'sigmoid' )(x )
164
+ x = layers .Dense (512 , activation = 'relu' )(x )
165
+ x = layers .Dense (img_size * img_size , activation = 'sigmoid' )(x )
166
+ x = layers .Reshape ((img_size , img_size , 1 ))(x )
197
167
198
168
return Model ([decoder_input , label_input ], x , name = 'decoder' )
199
169
@@ -232,45 +202,21 @@ def create_categorical_decoder():
232
202
233
203
encoder = create_encoder (input_shape = (img_size , img_size , 1 ))
234
204
decoder = create_categorical_decoder ()
235
- sampler = layers .Lambda (sample , name = 'sampler' )
236
205
237
206
image = layers .Input (shape = (img_size , img_size , 1 ))
238
207
true_label = layers .Input (shape = (len (wandb .config .labels ),))
239
208
dest_label = layers .Input (shape = (len (wandb .config .labels ),))
240
209
output = encoder ([image , true_label ])
241
- if wandb .config .variational :
242
- t_mean , t_log_var = output
243
- t = sampler (output )
244
- else :
245
- t = output
246
- t_decoded = decoder ([t , dest_label ])
247
-
248
- def vae_loss (y_true , y_pred ):
249
- ''' Negative variational lower bound used as loss function for training the variational auto-encoder. '''
250
- # Reconstruction loss
251
- rc_loss = metrics .mse (K .flatten (image ), K .flatten (t_decoded )) #binary_crossentropy
252
- #rc_loss *= img_size * img_size
253
-
254
- if wandb .config .variational :
255
- # Regularization term (KL divergence)
256
- kl_loss = - 0.5 * K .sum (1 + t_log_var
257
- - K .square (t_mean )
258
- - K .exp (t_log_var ), axis = - 1 )
259
- # Average over mini-batch
260
- return K .mean (rc_loss + kl_loss )
261
- else :
262
- return K .mean (rc_loss )
263
-
264
- get_custom_objects ().update ({"vae_loss" : vae_loss })
210
+ t_decoded = decoder ([output , dest_label ])
265
211
266
212
267
- vae = Model ([image , true_label , dest_label ], t_decoded , name = 'vae' )
268
- vae .compile (optimizer = 'rmsprop' , loss = vae_loss )
213
+ cae = Model ([image , true_label , dest_label ], t_decoded , name = 'vae' )
214
+ cae .compile (optimizer = 'rmsprop' , loss = "mse" )
269
215
270
216
print (X_train .shape , y_train .shape , X_test .shape , y_test .shape )
271
217
272
218
if __name__ == '__main__' :
273
- vae .fit ([X_train , y_train , y_train ], X_train , epochs = wandb .config .epochs ,
219
+ cae .fit ([X_train , y_train , y_train ], X_train , epochs = wandb .config .epochs ,
274
220
shuffle = True , batch_size = wandb .config .batch_size , callbacks = [ShowImages (), WandbCallback ()],
275
221
validation_data = ([X_test , y_test , y_test ], X_test ))
276
222
encoder .save ("encoder.h5" )
0 commit comments