Skip to content

Commit abf7b1e

Browse files
committed
New conditional encoder, and variational
1 parent e1fad33 commit abf7b1e

File tree

1 file changed

+60
-114
lines changed

1 file changed

+60
-114
lines changed

keras-autoencoder/conditional_vae.py renamed to keras-autoencoder/conditional_autoencoder.py

Lines changed: 60 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,34 @@
11
import matplotlib
22
matplotlib.use("Agg")
3-
from wandb.keras import WandbCallback
4-
import wandb
5-
from keras.utils.generic_utils import get_custom_objects
6-
from keras import backend as K
7-
from keras.callbacks import Callback
8-
from keras.models import Model, load_model
9-
from keras import metrics #mse, binary_crossentropy
10-
from keras import layers
11-
import keras
12-
import plotly.graph_objs as go
13-
import plotly.plotly as py
14-
from sklearn import manifold
15-
from sklearn.decomposition import PCA
16-
import matplotlib.pyplot as plt
17-
import pdb
18-
import os
19-
import subprocess
20-
import pandas as pd
21-
import numpy as np
22-
import cv2
233
import sys
4+
import cv2
5+
import numpy as np
6+
import pandas as pd
7+
import subprocess
8+
import os
9+
import pdb
10+
import matplotlib.pyplot as plt
11+
from sklearn.decomposition import PCA
12+
from sklearn import manifold
13+
import plotly.plotly as py
14+
import plotly.graph_objs as go
15+
import keras
16+
from keras import layers
17+
from keras.models import Model, load_model
18+
from keras.callbacks import Callback
19+
from keras import backend as K
20+
from keras.utils.generic_utils import get_custom_objects
21+
import wandb
22+
from wandb.keras import WandbCallback
2423

2524

26-
wandb.init(project="cvae")
25+
wandb.init()
2726
wandb.config.latent_dim = 2
28-
wandb.config.labels = [str(i) for i in range(10)]# ["Happy", "Sad"]
27+
wandb.config.labels = [str(i) for i in range(10)] #["Happy", "Sad"]
2928
wandb.config.batch_size = 128
3029
wandb.config.epochs = 25
31-
wandb.config.variational = False
32-
wandb.config.conditional = False
30+
wandb.config.conditional = True
31+
wandb.config.latent_vis = False
3332
wandb.config.dataset = "mnist"
3433

3534
EMOTIONS = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]
@@ -66,24 +65,6 @@ def load_fer2013(filter_emotions=[]):
6665
return train_faces, train_emotions, val_faces, val_emotions
6766

6867

69-
def sample(args):
70-
'''
71-
Draws samples from a standard normal and scales the samples with
72-
standard deviation of the variational distribution and shifts them
73-
by the mean.
74-
75-
Args:
76-
args: sufficient statistics of the variational distribution.
77-
78-
Returns:
79-
Samples from the variational distribution.
80-
'''
81-
t_mean, t_log_var = args
82-
t_sigma = K.sqrt(K.exp(t_log_var))
83-
epsilon = K.random_normal(shape=K.shape(t_mean), mean=0., stddev=1.)
84-
return t_mean + t_sigma * epsilon
85-
86-
8768
def concat_label(args):
8869
'''
8970
Converts 2d labels to 3d, i.e. [[0,1]] = [[[0,0],[0,0]],[[1,1],[1,1]]]
@@ -103,43 +84,46 @@ class ShowImages(Callback):
10384
'''
10485
Keras callback for logging predictions and a scatter plot of the latent dimension
10586
'''
87+
10688
def on_epoch_end(self, epoch, logs):
10789
indicies = np.random.randint(X_test.shape[0], size=36)
10890
latent_idx = np.random.randint(X_test.shape[0], size=500)
10991
inputs = X_test[indicies]
11092
t_inputs = X_train[indicies]
11193
r_labels = y_test[indicies]
11294
rand_labels = np.random.randint(len(wandb.config.labels), size=35)
113-
rand_labels = np.append(rand_labels, [len(wandb.config.labels) - 1]) # always add max label
95+
# always add max label
96+
rand_labels = np.append(rand_labels, [len(wandb.config.labels) - 1])
11497
labels = keras.utils.to_categorical(rand_labels)
11598
t_labels = y_train[indicies]
11699

117-
results = vae.predict([inputs, r_labels, labels])
118-
t_results = vae.predict([t_inputs, t_labels, t_labels])
119-
if wandb.config.variational:
120-
output = encoder.predict([X_test[latent_idx], y_test[latent_idx]])
121-
latent = output[0] #K.eval(sample(output))
122-
else:
123-
latent = encoder.predict([X_test[latent_idx], y_test[latent_idx]])
100+
results = cae.predict([inputs, r_labels, labels])
101+
t_results = cae.predict([t_inputs, t_labels, t_labels])
102+
print("Max pixel value", t_results.max())
103+
latent = encoder.predict([X_test[latent_idx], y_test[latent_idx]])
124104
# Plot latent space
125-
# tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
126-
latent_vis = PCA(n_components=2)
127-
X = latent_vis.fit_transform(latent)
128-
trace = go.Scatter(x=list(X[:, 0]), y=list(X[:, 1]),
129-
mode='markers', showlegend=False,
130-
marker=dict(color=list(np.argmax(y_test[latent_idx], axis=1)),
131-
colorscale='Viridis',
132-
size=8,
133-
showscale=True))
134-
fig = go.Figure(data=[trace])
135-
wandb.log({"latent_vis": fig}, commit=False)
105+
if wandb.config.latent_vis:
106+
if wandb.config.latent_dim > 2:
107+
# latent_vis = manifold.TSNE(n_components=2, init='pca', random_state=0)
108+
latent_vis = PCA(n_components=2)
109+
X = latent_vis.fit_transform(latent)
110+
else:
111+
X = latent
112+
trace = go.Scatter(x=list(X[:, 0]), y=list(X[:, 1]),
113+
mode='markers', showlegend=False,
114+
marker=dict(color=list(np.argmax(y_test[latent_idx], axis=1)),
115+
colorscale='Viridis',
116+
size=8,
117+
showscale=True))
118+
fig = go.Figure(data=[trace])
119+
wandb.log({"latent_vis": fig}, commit=False)
136120
# Always log training images
137121
wandb.log({
138122
"train_images": [wandb.Image(
139123
np.hstack([t_inputs[i], res])) for i, res in enumerate(t_results)
140124
]
141125
}, commit=False)
142-
126+
143127
# Log image conversion when conditional
144128
if wandb.config.conditional:
145129
wandb.log({
@@ -153,27 +137,16 @@ def on_epoch_end(self, epoch, logs):
153137
def create_encoder(input_shape):
154138
'''
155139
Create an encoder with an optional class append to the channel.
156-
Optionally outputting mean and stddev for a variational encoder
157140
'''
158141
encoder_input = layers.Input(shape=input_shape)
159142
label_input = layers.Input(shape=(len(wandb.config.labels),))
143+
x = layers.Flatten()(encoder_input)
160144
if wandb.config.conditional:
161-
x = layers.Lambda(concat_label, name="c")([encoder_input, label_input])
162-
else:
163-
x = encoder_input
164-
x = layers.Conv2D(32, 3, padding='same',
165-
activation='relu', strides=(2, 2))(x)
166-
x = layers.Conv2D(64, 3, padding='same',
167-
activation='relu', strides=(2, 2))(x)
168-
x = layers.Flatten()(x)
169-
x = layers.Dense(32, activation='relu')(x)
170-
if wandb.config.variational:
171-
t_mean = layers.Dense(wandb.config.latent_dim, name="latent_mean")(x)
172-
t_log_var = layers.Dense(
173-
wandb.config.latent_dim, name="latent_variance")(x)
174-
output = [t_mean, t_log_var]
175-
else:
176-
output = layers.Dense(wandb.config.latent_dim, activation="relu")(x)
145+
#x = layers.Lambda(concat_label, name="c")([encoder_input, label_input])
146+
x = layers.concatenate([x, label_input], axis=-1)
147+
148+
x = layers.Dense(512, activation="relu")(x)
149+
output = layers.Dense(wandb.config.latent_dim, activation="relu")(x)
177150

178151
return Model([encoder_input, label_input], output, name='encoder')
179152

@@ -185,15 +158,12 @@ def create_categorical_decoder():
185158
decoder_input = layers.Input(shape=(wandb.config.latent_dim,))
186159
label_input = layers.Input(shape=(len(wandb.config.labels),))
187160
if wandb.config.conditional:
188-
x = layers.concatenate([decoder_input, label_input], axis=1)
161+
x = layers.concatenate([decoder_input, label_input], axis=-1)
189162
else:
190163
x = decoder_input
191-
x = layers.Dense(img_size // 2 * img_size // 2 * 32, activation='relu')(x)
192-
x = layers.Reshape((img_size // 2, img_size // 2, 32))(x)
193-
x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
194-
x = layers.Conv2DTranspose(
195-
32, 3, padding='same', activation='relu', strides=(2, 2))(x)
196-
x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)
164+
x = layers.Dense(512, activation='relu')(x)
165+
x = layers.Dense(img_size * img_size, activation='sigmoid')(x)
166+
x = layers.Reshape((img_size, img_size, 1))(x)
197167

198168
return Model([decoder_input, label_input], x, name='decoder')
199169

@@ -232,45 +202,21 @@ def create_categorical_decoder():
232202

233203
encoder = create_encoder(input_shape=(img_size, img_size, 1))
234204
decoder = create_categorical_decoder()
235-
sampler = layers.Lambda(sample, name='sampler')
236205

237206
image = layers.Input(shape=(img_size, img_size, 1))
238207
true_label = layers.Input(shape=(len(wandb.config.labels),))
239208
dest_label = layers.Input(shape=(len(wandb.config.labels),))
240209
output = encoder([image, true_label])
241-
if wandb.config.variational:
242-
t_mean, t_log_var = output
243-
t = sampler(output)
244-
else:
245-
t = output
246-
t_decoded = decoder([t, dest_label])
247-
248-
def vae_loss(y_true, y_pred):
249-
''' Negative variational lower bound used as loss function for training the variational auto-encoder. '''
250-
# Reconstruction loss
251-
rc_loss = metrics.mse(K.flatten(image), K.flatten(t_decoded)) #binary_crossentropy
252-
#rc_loss *= img_size * img_size
253-
254-
if wandb.config.variational:
255-
# Regularization term (KL divergence)
256-
kl_loss = -0.5 * K.sum(1 + t_log_var
257-
- K.square(t_mean)
258-
- K.exp(t_log_var), axis=-1)
259-
# Average over mini-batch
260-
return K.mean(rc_loss + kl_loss)
261-
else:
262-
return K.mean(rc_loss)
263-
264-
get_custom_objects().update({"vae_loss": vae_loss})
210+
t_decoded = decoder([output, dest_label])
265211

266212

267-
vae = Model([image, true_label, dest_label], t_decoded, name='vae')
268-
vae.compile(optimizer='rmsprop', loss=vae_loss)
213+
cae = Model([image, true_label, dest_label], t_decoded, name='vae')
214+
cae.compile(optimizer='rmsprop', loss="mse")
269215

270216
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
271217

272218
if __name__ == '__main__':
273-
vae.fit([X_train, y_train, y_train], X_train, epochs=wandb.config.epochs,
219+
cae.fit([X_train, y_train, y_train], X_train, epochs=wandb.config.epochs,
274220
shuffle=True, batch_size=wandb.config.batch_size, callbacks=[ShowImages(), WandbCallback()],
275221
validation_data=([X_test, y_test, y_test], X_test))
276222
encoder.save("encoder.h5")

0 commit comments

Comments
 (0)