Skip to content

Commit e0072a2

Browse files
committed
Create generate_regularized_class_specific_samples.py
Create new file with regularized class specific image generator in it
1 parent f280751 commit e0072a2

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Created on Tues Mar 10 08:13:15 2020
3+
@author: Alex Stoken - https://github.com/alexstoken
4+
5+
Last tested with torchvision 0.5.0 with image and model on cpu
6+
"""
7+
import os
8+
import numpy as np
9+
10+
import torch
11+
from torch.optim import SGD
12+
from torchvision import models
13+
14+
from misc_functions import preprocess_image, recreate_image, save_image
15+
16+
17+
class RegularizedClassSpecificImageGeneration():
18+
"""
19+
Produces an image that maximizes a certain class with gradient ascent. Uses Gaussian blur, weight decay, and clipping.
20+
"""
21+
22+
def __init__(self, model, target_class):
23+
self.mean = [-0.485, -0.456, -0.406]
24+
self.std = [1/0.229, 1/0.224, 1/0.225]
25+
self.model = model
26+
self.model.eval()
27+
self.target_class = target_class
28+
# Generate a random image
29+
self.created_image = np.uint8(np.random.uniform(0, 255, (224, 224, 3)))
30+
# Create the folder to export images if not exists
31+
if not os.path.exists(f'../generated/class_{self.target_class}'):
32+
os.makedirs(f'../generated/class_{self.target_class}')
33+
34+
def generate(self, iterations=150, blur_freq=4, blur_rad=1, wd=0.0001, clipping_value=0.1):
35+
"""Generates class specific image with enhancements to improve image quality.
36+
See https://arxiv.org/abs/1506.06579 for details on each argument's effect on output quality.
37+
38+
39+
Play around with combinations of arguments. Besides the defaults, this combination has produced good images:
40+
blur_freq=6, blur_rad=0.8, wd = 0.05
41+
42+
Keyword Arguments:
43+
iterations {int} -- Total iterations for gradient ascent (default: {150})
44+
blur_freq {int} -- Frequency of Gaussian blur effect, in iterations (default: {6})
45+
blur_rad {float} -- Radius for gaussian blur, passed to PIL.ImageFilter.GaussianBlur() (default: {0.8})
46+
wd {float} -- Weight decay value for Stochastic Gradient Ascent (default: {0.05})
47+
clipping_value {None or float} -- Value for gradient clipping (default: {0.1})
48+
49+
Returns:
50+
np.ndarray -- Final maximally activated class image
51+
"""
52+
initial_learning_rate = 6
53+
for i in range(1, iterations):
54+
# Process image and return variable
55+
56+
#implement gaussian blurring every ith iteration
57+
#to improve output
58+
if i % blur_freq == 0:
59+
self.processed_image = preprocess_image(
60+
self.created_image, False, blur_rad)
61+
else:
62+
self.processed_image = preprocess_image(
63+
self.created_image, False)
64+
65+
# Define optimizer for the image - use weight decay to add regularization
66+
# in SGD, wd = 2 * L2 regularization (https://bbabenko.github.io/weight-decay/)
67+
optimizer = SGD([self.processed_image],
68+
lr=initial_learning_rate, weight_decay=wd)
69+
# Forward
70+
output = self.model(self.processed_image)
71+
# Target specific class
72+
class_loss = -output[0, self.target_class]
73+
74+
if i in np.linspace(0, iterations, 10, dtype=int):
75+
print('Iteration:', str(i), 'Loss',
76+
"{0:.2f}".format(class_loss.data.numpy()))
77+
# Zero grads
78+
self.model.zero_grad()
79+
# Backward
80+
class_loss.backward()
81+
82+
if clipping_value:
83+
torch.nn.utils.clip_grad_norm(
84+
self.model.parameters(), clipping_value)
85+
# Update image
86+
optimizer.step()
87+
# Recreate image
88+
self.created_image = recreate_image(self.processed_image)
89+
if i in np.linspace(0, iterations, 10, dtype=int):
90+
# Save image
91+
im_path = f'../generated/class_{self.target_class}/c_{self.target_class}_iter_{i}_loss_{class_loss.data.numpy()}.jpg'
92+
save_image(self.created_image, im_path)
93+
94+
#save final image
95+
im_path = f'../generated/class_{self.target_class}/c_{self.target_class}_iter_{i}_loss_{class_loss.data.numpy()}.jpg'
96+
save_image(self.created_image, im_path)
97+
98+
#write file with regularization details
99+
with open(f'../generated/class_{self.target_class}/run_details.txt', 'w') as f:
100+
f.write(f'Iterations: {iterations}\n')
101+
f.write(f'Blur freq: {blur_freq}\n')
102+
f.write(f'Blur radius: {blur_rad}\n')
103+
f.write(f'Weight decay: {wd}\n')
104+
f.write(f'Clip value: {clipping_value}\n')
105+
106+
#rename folder path with regularization details for easy access
107+
os.rename(f'../generated/class_{self.target_class}',
108+
f'../generated/class_{self.target_class}_blurfreq_{blur_freq}_blurrad_{blur_rad}_wd{wd}')
109+
return self.processed_image
110+
111+
112+
if __name__ == '__main__':
113+
target_class = 130 # Flamingo
114+
pretrained_model = models.alexnet(pretrained=True)
115+
csig = ClassSpecificImageGeneration(pretrained_model, target_class)
116+
csig.generate()

0 commit comments

Comments
 (0)