Skip to content

Commit 4187039

Browse files
committed
new autoencoder
1 parent 90650f1 commit 4187039

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from keras.layers import Input, Dense, Flatten, Reshape, Conv2D, MaxPooling2D, UpSampling2D
2+
from keras.models import Model, Sequential
3+
from keras.callbacks import Callback
4+
from keras.datasets import mnist
5+
import numpy as np
6+
import wandb
7+
from wandb.wandb_keras import WandbKerasCallback
8+
9+
def add_noise(x_train, x_test):
10+
noise_factor = 0.5
11+
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
12+
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)
13+
14+
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
15+
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
16+
return x_train_noisy, x_test_noisy
17+
18+
run = wandb.init()
19+
config = run.config
20+
21+
config.encoding_dim = 32
22+
config.epochs = 10
23+
24+
(x_train, _), (x_test, _) = mnist.load_data()
25+
(x_train_noisy, x_test_noisy) = add_noise(x_train, x_test)
26+
27+
28+
x_train = x_train.astype('float32') / 255.
29+
x_test = x_test.astype('float32') / 255.
30+
31+
32+
model = Sequential()
33+
#model.add(Flatten(input_shape=(28,28)))
34+
#model.add(Dense(config.encoding_dim, activation='relu'))
35+
#model.add(Dense(784, activation='sigmoid'))
36+
#model.add(Reshape((28,28)))
37+
model.add(Reshape((28,28,1), input_shape=(28,28)))
38+
model.add(Conv2D(32, (3,3), padding='same', activation='relu'))
39+
model.add(MaxPooling2D((2,2)))
40+
model.add(Conv2D(32, (3,3), padding='same', activation='relu'))
41+
model.add(UpSampling2D())
42+
model.add(Conv2D(1, (3,3), padding='same', activation = 'sigmoid'))
43+
model.add(Reshape((28,28)))
44+
45+
46+
47+
model.compile(optimizer='adam', loss='mse')
48+
49+
class Images(Callback):
50+
def on_epoch_end(self, epoch, logs):
51+
indices = np.random.randint(self.validation_data[0].shape[0], size=8)
52+
test_data = self.validation_data[0][indices]
53+
pred_data = self.model.predict(test_data)
54+
run.history.row.update({
55+
"examples": [
56+
wandb.Image(np.hstack([data, pred_data[i]]), caption=str(i))
57+
for i, data in enumerate(test_data)]
58+
})
59+
60+
61+
model.fit(x_train_noisy, x_train,
62+
epochs=config.epochs,
63+
validation_data=(x_test_noisy, x_test), callbacks=[Images(), WandbKerasCallback()])
64+
65+
66+
model.save("auto-denoise.h5")
67+
68+
69+
70+

0 commit comments

Comments
 (0)