Skip to content

Commit a70e5ef

Browse files
author
l2k2
committed
cleanup
1 parent f6c0188 commit a70e5ef

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from keras.layers import Lambda, Input, Dense
2+
from keras.models import Model
3+
from keras.datasets import mnist
4+
from keras.losses import mse, binary_crossentropy
5+
from keras.utils import plot_model
6+
from keras import backend as K
7+
from plotutil import PlotCallback
8+
import wandb
9+
from wandb.keras import WandbCallback
10+
11+
import numpy as np
12+
import os
13+
14+
wandb.init()
15+
config = wandb.config
16+
17+
# reparameterization trick
18+
# instead of sampling from Q(z|X), sample eps = N(0,I)
19+
# z = z_mean + sqrt(var)*eps
20+
def sampling(args):
21+
"""Reparameterization trick by sampling fr an isotropic unit Gaussian.
22+
# Arguments:
23+
args (tensor): mean and log of variance of Q(z|X)
24+
# Returns:
25+
z (tensor): sampled latent vector
26+
"""
27+
28+
z_mean, z_log_var = args
29+
batch = K.shape(z_mean)[0]
30+
dim = K.int_shape(z_mean)[1]
31+
# by default, random_normal has mean=0 and std=1.0
32+
epsilon = K.random_normal(shape=(batch, dim))
33+
return z_mean + K.exp(0.5 * z_log_var) * epsilon
34+
35+
# MNIST dataset
36+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
37+
38+
image_size = x_train.shape[1]
39+
original_dim = image_size * image_size
40+
x_train = np.reshape(x_train, [-1, original_dim])
41+
x_test = np.reshape(x_test, [-1, original_dim])
42+
x_train = x_train.astype('float32') / 255
43+
x_test = x_test.astype('float32') / 255
44+
45+
# network parameters
46+
input_shape = (original_dim, )
47+
intermediate_dim = 512
48+
batch_size = 128
49+
latent_dim = 2
50+
epochs = 50
51+
52+
# VAE model = encoder + decoder
53+
# build encoder model
54+
inputs = Input(shape=input_shape, name='encoder_input')
55+
x = Dense(intermediate_dim, activation='relu')(inputs)
56+
z_mean = Dense(latent_dim, name='z_mean')(x)
57+
z_log_var = Dense(latent_dim, name='z_log_var')(x)
58+
59+
# use reparameterization trick to push the sampling out as input
60+
z = Lambda(sampling, name='z')([z_mean, z_log_var])
61+
62+
# instantiate encoder model
63+
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
64+
65+
# build decoder model
66+
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
67+
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
68+
outputs = Dense(original_dim, activation='sigmoid')(x)
69+
70+
# instantiate decoder model
71+
decoder = Model(latent_inputs, outputs, name='decoder')
72+
73+
# instantiate VAE model
74+
outputs = decoder(encoder(inputs)[2])
75+
vae = Model(inputs, outputs, name='vae_mlp')
76+
77+
models = (encoder, decoder)
78+
data = (x_test, y_test)
79+
80+
reconstruction_loss = binary_crossentropy(inputs,
81+
outputs)
82+
83+
reconstruction_loss *= original_dim
84+
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
85+
kl_loss = K.sum(kl_loss, axis=-1)
86+
kl_loss *= -0.5
87+
vae_loss = K.mean(reconstruction_loss + kl_loss)
88+
vae.add_loss(vae_loss)
89+
vae.compile(optimizer='adam')
90+
91+
vae.fit(x_train,
92+
epochs=epochs,
93+
batch_size=batch_size,
94+
validation_data=(x_test, None),
95+
callbacks=[WandbCallback(), PlotCallback(encoder, decoder, (x_test, y_test))] )
96+
vae.save_weights('vae_mlp_mnist.h5')
97+

0 commit comments

Comments
 (0)