Skip to content

Commit a5ace6b

Browse files
committed
Move preprocess function from misc_functions to file
Preprocess function, which now contains blur and has been renamed preprocess_and_blur_image(), has been moved from misc_functios to this file
1 parent e0072a2 commit a5ace6b

File tree

1 file changed

+50
-4
lines changed

1 file changed

+50
-4
lines changed

src/generate_regularized_class_specific_samples.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.optim import SGD
1212
from torchvision import models
1313

14-
from misc_functions import preprocess_image, recreate_image, save_image
14+
from misc_functions import recreate_image, save_image
1515

1616

1717
class RegularizedClassSpecificImageGeneration():
@@ -56,10 +56,10 @@ def generate(self, iterations=150, blur_freq=4, blur_rad=1, wd=0.0001, clipping_
5656
#implement gaussian blurring every ith iteration
5757
#to improve output
5858
if i % blur_freq == 0:
59-
self.processed_image = preprocess_image(
59+
self.processed_image = preprocess_and_blur_image(
6060
self.created_image, False, blur_rad)
6161
else:
62-
self.processed_image = preprocess_image(
62+
self.processed_image = preprocess_and_blur_image(
6363
self.created_image, False)
6464

6565
# Define optimizer for the image - use weight decay to add regularization
@@ -109,8 +109,54 @@ def generate(self, iterations=150, blur_freq=4, blur_rad=1, wd=0.0001, clipping_
109109
return self.processed_image
110110

111111

112+
def preprocess_and_blur_image(pil_im, resize_im=True, blur_rad=None):
113+
"""
114+
Processes image with optional Gaussian blur for CNNs
115+
116+
Args:
117+
PIL_img (PIL_img): PIL Image or numpy array to process
118+
resize_im (bool): Resize to 224 or not
119+
blur_rad (int): Pixel radius for Gaussian blurring (default = None)
120+
returns:
121+
im_as_var (torch variable): Variable that contains processed float tensor
122+
"""
123+
# mean and std list for channels (Imagenet)
124+
mean = [0.485, 0.456, 0.406]
125+
std = [0.229, 0.224, 0.225]
126+
127+
#ensure or transform incoming image to PIL image
128+
if type(pil_im) != Image.Image:
129+
try:
130+
pil_im = Image.fromarray(pil_im)
131+
except Exception as e:
132+
print(
133+
"could not transform PIL_img to a PIL Image object. Please check input.")
134+
135+
# Resize image
136+
if resize_im:
137+
pil_im.thumbnail((224, 224))
138+
139+
#add gaussin blur to image
140+
if blur_rad:
141+
pil_im = pil_im.filter(ImageFilter.GaussianBlur(blur_rad))
142+
143+
im_as_arr = np.float32(pil_im)
144+
im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
145+
# Normalize the channels
146+
for channel, _ in enumerate(im_as_arr):
147+
im_as_arr[channel] /= 255
148+
im_as_arr[channel] -= mean[channel]
149+
im_as_arr[channel] /= std[channel]
150+
# Convert to float tensor
151+
im_as_ten = torch.from_numpy(im_as_arr).float()
152+
# Add one more channel to the beginning. Tensor shape = 1,3,224,224
153+
im_as_ten.unsqueeze_(0)
154+
# Convert to Pytorch variable
155+
im_as_var = Variable(im_as_ten, requires_grad=True)
156+
return im_as_var
157+
112158
if __name__ == '__main__':
113159
target_class = 130 # Flamingo
114160
pretrained_model = models.alexnet(pretrained=True)
115-
csig = ClassSpecificImageGeneration(pretrained_model, target_class)
161+
csig = RegularizedClassSpecificImageGeneration(pretrained_model, target_class)
116162
csig.generate()

0 commit comments

Comments
 (0)