2
2
from keras .models import Model , Sequential
3
3
from keras .datasets import mnist
4
4
from keras .callbacks import Callback
5
- from autoencoder import Images
6
5
7
6
import numpy as np
7
+ from util import Images
8
8
import wandb
9
9
from wandb .keras import WandbCallback
10
+ class Images (Callback ):
11
+ def on_epoch_end (self , epoch , logs ):
12
+ indices = np .random .randint (self .validation_data [0 ].shape [0 ], size = 8 )
13
+ test_data = self .validation_data [0 ][indices ]
14
+ pred_data = self .model .predict (test_data )
15
+ wandb .log ({
16
+ "examples" : [
17
+ wandb .Image (np .hstack ([data , pred_data [i ]]), caption = str (i ))
18
+ for i , data in enumerate (test_data )]
19
+ }, commit = False )
10
20
11
21
run = wandb .init ()
12
22
config = run .config
13
23
14
24
config .epochs = 10
15
25
16
- (x_train , _ ), (x_test , _ ) = mnist .load_data ()
26
+ (X_train , _ ), (X_test , _ ) = mnist .load_data ()
17
27
18
- x_train = x_train .astype ('float32' ) / 255.
19
- x_test = x_test .astype ('float32' ) / 255.
28
+ X_train = X_train .astype ('float32' ) / 255.
29
+ X_test = X_test .astype ('float32' ) / 255.
20
30
21
31
model = Sequential ()
22
32
model .add (Reshape ((28 ,28 ,1 ), input_shape = (28 ,28 )))
29
39
30
40
model .compile (optimizer = 'adam' , loss = 'mse' )
31
41
32
- model .fit (x_train , x_train ,
33
- epochs = config .epochs ,
34
- validation_data = (x_test , x_test ),
35
- callbacks = [Images (), WandbCallback (save_model = False )])
36
-
37
-
38
- model .save ('auto-cnn.h5' )
42
+ model .fit (X_train , X_train ,
43
+ epochs = config .epochs ,
44
+ validation_data = (X_test , X_test ),
45
+ callbacks = [Images (), WandbCallback (save_model = False )])
39
46
40
47
48
+ model .save ('auto-cnn.h5' )
0 commit comments