|
| 1 | +from keras.layers import Input, Dense, Flatten, Reshape, Conv2D, MaxPooling2D, UpSampling2D |
| 2 | +from keras.models import Model, Sequential |
| 3 | +from keras.callbacks import Callback |
| 4 | +from keras.datasets import mnist |
| 5 | +import numpy as np |
| 6 | +import wandb |
| 7 | +from wandb.wandb_keras import WandbKerasCallback |
| 8 | + |
| 9 | +def add_noise(x_train, x_test): |
| 10 | + noise_factor = 0.5 |
| 11 | + x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape) |
| 12 | + x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) |
| 13 | + |
| 14 | + x_train_noisy = np.clip(x_train_noisy, 0., 1.) |
| 15 | + x_test_noisy = np.clip(x_test_noisy, 0., 1.) |
| 16 | + return x_train_noisy, x_test_noisy |
| 17 | + |
| 18 | +run = wandb.init() |
| 19 | +config = run.config |
| 20 | + |
| 21 | +config.encoding_dim = 32 |
| 22 | +config.epochs = 10 |
| 23 | + |
| 24 | +(x_train, _), (x_test, _) = mnist.load_data() |
| 25 | +(x_train_noisy, x_test_noisy) = add_noise(x_train, x_test) |
| 26 | + |
| 27 | + |
| 28 | +x_train = x_train.astype('float32') / 255. |
| 29 | +x_test = x_test.astype('float32') / 255. |
| 30 | + |
| 31 | + |
| 32 | +model = Sequential() |
| 33 | +#model.add(Flatten(input_shape=(28,28))) |
| 34 | +#model.add(Dense(config.encoding_dim, activation='relu')) |
| 35 | +#model.add(Dense(784, activation='sigmoid')) |
| 36 | +#model.add(Reshape((28,28))) |
| 37 | +model.add(Reshape((28,28,1), input_shape=(28,28))) |
| 38 | +model.add(Conv2D(32, (3,3), padding='same', activation='relu')) |
| 39 | +model.add(MaxPooling2D((2,2))) |
| 40 | +model.add(Conv2D(32, (3,3), padding='same', activation='relu')) |
| 41 | +model.add(UpSampling2D()) |
| 42 | +model.add(Conv2D(1, (3,3), padding='same', activation = 'sigmoid')) |
| 43 | +model.add(Reshape((28,28))) |
| 44 | + |
| 45 | + |
| 46 | + |
| 47 | +model.compile(optimizer='adam', loss='mse') |
| 48 | + |
| 49 | +class Images(Callback): |
| 50 | + def on_epoch_end(self, epoch, logs): |
| 51 | + indices = np.random.randint(self.validation_data[0].shape[0], size=8) |
| 52 | + test_data = self.validation_data[0][indices] |
| 53 | + pred_data = self.model.predict(test_data) |
| 54 | + run.history.row.update({ |
| 55 | + "examples": [ |
| 56 | + wandb.Image(np.hstack([data, pred_data[i]]), caption=str(i)) |
| 57 | + for i, data in enumerate(test_data)] |
| 58 | + }) |
| 59 | + |
| 60 | + |
| 61 | +model.fit(x_train_noisy, x_train, |
| 62 | + epochs=config.epochs, |
| 63 | + validation_data=(x_test_noisy, x_test), callbacks=[Images(), WandbKerasCallback()]) |
| 64 | + |
| 65 | + |
| 66 | +model.save("auto-denoise.h5") |
| 67 | + |
| 68 | + |
| 69 | + |
| 70 | + |
0 commit comments