@@ -36,22 +36,25 @@ def add_noise(x_train, x_test):
36
36
model .add (Reshape ((28 ,28 )))
37
37
model .compile (optimizer = 'adam' , loss = 'mse' )
38
38
39
- #for visualization
39
+ # For visualization
40
40
class Images (Callback ):
41
+ def __init__ (self , validation_data ):
42
+ self .validation_data = validation_data
43
+
41
44
def on_epoch_end (self , epoch , logs ):
42
45
indices = np .random .randint (self .validation_data [0 ].shape [0 ], size = 8 )
43
46
test_data = self .validation_data [0 ][indices ]
44
47
pred_data = self .model .predict (test_data )
45
- run . history . row . update ({
48
+ wandb . log ({
46
49
"examples" : [
47
50
wandb .Image (np .hstack ([data , pred_data [i ]]), caption = str (i ))
48
- for i , data in enumerate (test_data )]
49
- })
50
-
51
+ for i , data in enumerate (test_data )]},
52
+ step = epoch )
51
53
52
54
model .fit (x_train_noisy , x_train ,
53
55
epochs = config .epochs ,
54
- validation_data = (x_test_noisy , x_test ), callbacks = [Images (), WandbCallback ()])
56
+ validation_data = (x_test_noisy , x_test ),
57
+ callbacks = [Images ((x_test_noisy , x_test )), WandbCallback ()])
55
58
56
59
57
60
model .save ("auto-denoise.h5" )
0 commit comments