|
11 | 11 | from torch.optim import SGD
|
12 | 12 | from torchvision import models
|
13 | 13 |
|
14 |
| -from misc_functions import preprocess_image, recreate_image, save_image |
| 14 | +from misc_functions import recreate_image, save_image |
15 | 15 |
|
16 | 16 |
|
17 | 17 | class RegularizedClassSpecificImageGeneration():
|
@@ -56,10 +56,10 @@ def generate(self, iterations=150, blur_freq=4, blur_rad=1, wd=0.0001, clipping_
|
56 | 56 | #implement gaussian blurring every ith iteration
|
57 | 57 | #to improve output
|
58 | 58 | if i % blur_freq == 0:
|
59 |
| - self.processed_image = preprocess_image( |
| 59 | + self.processed_image = preprocess_and_blur_image( |
60 | 60 | self.created_image, False, blur_rad)
|
61 | 61 | else:
|
62 |
| - self.processed_image = preprocess_image( |
| 62 | + self.processed_image = preprocess_and_blur_image( |
63 | 63 | self.created_image, False)
|
64 | 64 |
|
65 | 65 | # Define optimizer for the image - use weight decay to add regularization
|
@@ -109,8 +109,54 @@ def generate(self, iterations=150, blur_freq=4, blur_rad=1, wd=0.0001, clipping_
|
109 | 109 | return self.processed_image
|
110 | 110 |
|
111 | 111 |
|
| 112 | +def preprocess_and_blur_image(pil_im, resize_im=True, blur_rad=None): |
| 113 | + """ |
| 114 | + Processes image with optional Gaussian blur for CNNs |
| 115 | +
|
| 116 | + Args: |
| 117 | + PIL_img (PIL_img): PIL Image or numpy array to process |
| 118 | + resize_im (bool): Resize to 224 or not |
| 119 | + blur_rad (int): Pixel radius for Gaussian blurring (default = None) |
| 120 | + returns: |
| 121 | + im_as_var (torch variable): Variable that contains processed float tensor |
| 122 | + """ |
| 123 | + # mean and std list for channels (Imagenet) |
| 124 | + mean = [0.485, 0.456, 0.406] |
| 125 | + std = [0.229, 0.224, 0.225] |
| 126 | + |
| 127 | + #ensure or transform incoming image to PIL image |
| 128 | + if type(pil_im) != Image.Image: |
| 129 | + try: |
| 130 | + pil_im = Image.fromarray(pil_im) |
| 131 | + except Exception as e: |
| 132 | + print( |
| 133 | + "could not transform PIL_img to a PIL Image object. Please check input.") |
| 134 | + |
| 135 | + # Resize image |
| 136 | + if resize_im: |
| 137 | + pil_im.thumbnail((224, 224)) |
| 138 | + |
| 139 | + #add gaussin blur to image |
| 140 | + if blur_rad: |
| 141 | + pil_im = pil_im.filter(ImageFilter.GaussianBlur(blur_rad)) |
| 142 | + |
| 143 | + im_as_arr = np.float32(pil_im) |
| 144 | + im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H |
| 145 | + # Normalize the channels |
| 146 | + for channel, _ in enumerate(im_as_arr): |
| 147 | + im_as_arr[channel] /= 255 |
| 148 | + im_as_arr[channel] -= mean[channel] |
| 149 | + im_as_arr[channel] /= std[channel] |
| 150 | + # Convert to float tensor |
| 151 | + im_as_ten = torch.from_numpy(im_as_arr).float() |
| 152 | + # Add one more channel to the beginning. Tensor shape = 1,3,224,224 |
| 153 | + im_as_ten.unsqueeze_(0) |
| 154 | + # Convert to Pytorch variable |
| 155 | + im_as_var = Variable(im_as_ten, requires_grad=True) |
| 156 | + return im_as_var |
| 157 | + |
112 | 158 | if __name__ == '__main__':
|
113 | 159 | target_class = 130 # Flamingo
|
114 | 160 | pretrained_model = models.alexnet(pretrained=True)
|
115 |
| - csig = ClassSpecificImageGeneration(pretrained_model, target_class) |
| 161 | + csig = RegularizedClassSpecificImageGeneration(pretrained_model, target_class) |
116 | 162 | csig.generate()
|
0 commit comments