Skip to content

Commit 2a0d715

Browse files
author
vanpelt
committed
Fixed autoencoders
1 parent 4b50c2c commit 2a0d715

File tree

6 files changed

+61
-35
lines changed

6 files changed

+61
-35
lines changed

keras-autoencoder/autoencoder.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,11 @@
44
from keras.datasets import mnist
55
from keras.datasets import fashion_mnist
66

7-
from keras.callbacks import Callback
87
import numpy as np
8+
from util import Images
99
import wandb
1010
from wandb.keras import WandbCallback
1111

12-
class Images(Callback):
13-
def on_epoch_end(self, epoch, logs):
14-
indices = np.random.randint(self.validation_data[0].shape[0], size=8)
15-
test_data = self.validation_data[0][indices]
16-
pred_data = self.model.predict(test_data)
17-
wandb.log({
18-
"examples": [
19-
wandb.Image(np.hstack([data, pred_data[i]]), caption=str(i))
20-
for i, data in enumerate(test_data)]
21-
}, commit=False)
2212

2313
run = wandb.init()
2414
config = run.config
@@ -43,7 +33,9 @@ def on_epoch_end(self, epoch, logs):
4333
decoder.add(Dense(28*28, activation="sigmoid"))
4434
decoder.add(Reshape((28,28)))
4535

46-
model = Model(encoder, decoder)
36+
model = Sequential()
37+
model.add(encoder)
38+
model.add(decoder)
4739

4840
model.compile(optimizer='adam', loss='mse')
4941

keras-autoencoder/autoencoder_cnn.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,31 @@
22
from keras.models import Model, Sequential
33
from keras.datasets import mnist
44
from keras.callbacks import Callback
5-
from autoencoder import Images
65

76
import numpy as np
7+
from util import Images
88
import wandb
99
from wandb.keras import WandbCallback
10+
class Images(Callback):
11+
def on_epoch_end(self, epoch, logs):
12+
indices = np.random.randint(self.validation_data[0].shape[0], size=8)
13+
test_data = self.validation_data[0][indices]
14+
pred_data = self.model.predict(test_data)
15+
wandb.log({
16+
"examples": [
17+
wandb.Image(np.hstack([data, pred_data[i]]), caption=str(i))
18+
for i, data in enumerate(test_data)]
19+
}, commit=False)
1020

1121
run = wandb.init()
1222
config = run.config
1323

1424
config.epochs = 10
1525

16-
(x_train, _), (x_test, _) = mnist.load_data()
26+
(X_train, _), (X_test, _) = mnist.load_data()
1727

18-
x_train = x_train.astype('float32') / 255.
19-
x_test = x_test.astype('float32') / 255.
28+
X_train = X_train.astype('float32') / 255.
29+
X_test = X_test.astype('float32') / 255.
2030

2131
model = Sequential()
2232
model.add(Reshape((28,28,1), input_shape=(28,28)))
@@ -29,12 +39,10 @@
2939

3040
model.compile(optimizer='adam', loss='mse')
3141

32-
model.fit(x_train, x_train,
33-
epochs=config.epochs,
34-
validation_data=(x_test, x_test),
35-
callbacks=[Images(), WandbCallback(save_model=False)])
36-
37-
38-
model.save('auto-cnn.h5')
42+
model.fit(X_train, X_train,
43+
epochs=config.epochs,
44+
validation_data=(X_test, X_test),
45+
callbacks=[Images(), WandbCallback(save_model=False)])
3946

4047

48+
model.save('auto-cnn.h5')

0 commit comments

Comments
 (0)