|
6 | 6 | import numpy as np |
7 | 7 | import numbers |
8 | 8 | import types |
| 9 | +import collections |
9 | 10 |
|
10 | 11 |
|
11 | 12 | class Compose(object): |
@@ -115,29 +116,34 @@ def __call__(self, tensor): |
115 | 116 |
|
116 | 117 | class Scale(object): |
117 | 118 | """Rescales the input PIL.Image to the given 'size'. |
118 | | - 'size' will be the size of the smaller edge. |
| 119 | + If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale. |
| 120 | + If 'size' is a number, it will indicate the size of the smaller edge. |
119 | 121 | For example, if height > width, then image will be |
120 | 122 | rescaled to (size * height / width, size) |
121 | | - size: size of the smaller edge |
| 123 | + size: size of the exactly size or the smaller edge |
122 | 124 | interpolation: Default: PIL.Image.BILINEAR |
123 | 125 | """ |
124 | 126 |
|
125 | 127 | def __init__(self, size, interpolation=Image.BILINEAR): |
| 128 | + assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) |
126 | 129 | self.size = size |
127 | 130 | self.interpolation = interpolation |
128 | 131 |
|
129 | 132 | def __call__(self, img): |
130 | | - w, h = img.size |
131 | | - if (w <= h and w == self.size) or (h <= w and h == self.size): |
132 | | - return img |
133 | | - if w < h: |
134 | | - ow = self.size |
135 | | - oh = int(self.size * h / w) |
136 | | - return img.resize((ow, oh), self.interpolation) |
| 133 | + if isinstance(self.size, int): |
| 134 | + w, h = img.size |
| 135 | + if (w <= h and w == self.size) or (h <= w and h == self.size): |
| 136 | + return img |
| 137 | + if w < h: |
| 138 | + ow = self.size |
| 139 | + oh = int(self.size * h / w) |
| 140 | + return img.resize((ow, oh), self.interpolation) |
| 141 | + else: |
| 142 | + oh = self.size |
| 143 | + ow = int(self.size * w / h) |
| 144 | + return img.resize((ow, oh), self.interpolation) |
137 | 145 | else: |
138 | | - oh = self.size |
139 | | - ow = int(self.size * w / h) |
140 | | - return img.resize((ow, oh), self.interpolation) |
| 146 | + return img.resize(self.size, self.interpolation) |
141 | 147 |
|
142 | 148 |
|
143 | 149 | class CenterCrop(object): |
|
0 commit comments