|
78 | 78 | " for row in reader:\n", |
79 | 79 | " names[int(row[0])] = row[5].split(\";\")[0]\n", |
80 | 80 | "\n", |
81 | | - "def visualize_result(data, pred):\n", |
82 | | - " (img, info) = data\n", |
83 | | - "\n", |
| 81 | + "def visualize_result(img, pred, index=None):\n", |
| 82 | + " # filter prediction class if requested\n", |
| 83 | + " if index is not None:\n", |
| 84 | + " pred = pred.copy()\n", |
| 85 | + " pred[pred != index] = -1\n", |
| 86 | + " print(f'{names[index+1]}:')\n", |
| 87 | + " \n", |
84 | 88 | " # colorize prediction\n", |
85 | 89 | " pred_color = colorEncode(pred, colors).astype(numpy.uint8)\n", |
86 | 90 | "\n", |
|
146 | 150 | " mean=[0.485, 0.456, 0.406], # These are RGB mean+std values\n", |
147 | 151 | " std=[0.229, 0.224, 0.225]) # across a large photo dataset.\n", |
148 | 152 | "])\n", |
149 | | - "img_data = pil_to_tensor(\n", |
150 | | - " Image.open('ADE_val_00001519.jpg').convert('RGB'))\n", |
| 153 | + "pil_image = Image.open('ADE_val_00001519.jpg').convert('RGB')\n", |
| 154 | + "img_original = numpy.array(pil_image)\n", |
| 155 | + "img_data = pil_to_tensor(pil_image)\n", |
151 | 156 | "singleton_batch = {'img_data': img_data[None].cuda()}\n", |
152 | 157 | "output_size = img_data.shape[1:]" |
153 | 158 | ] |
|
168 | 173 | { |
169 | 174 | "cell_type": "code", |
170 | 175 | "execution_count": null, |
171 | | - "metadata": {}, |
| 176 | + "metadata": { |
| 177 | + "scrolled": false |
| 178 | + }, |
172 | 179 | "outputs": [], |
173 | 180 | "source": [ |
174 | 181 | "# Run the segmentation at the highest resolution.\n", |
|
177 | 184 | " \n", |
178 | 185 | "# Get the predicted scores for each pixel\n", |
179 | 186 | "_, pred = torch.max(scores, dim=1)\n", |
180 | | - "visualize_result(\n", |
181 | | - " (dataset_test[0]['img_ori'], dataset_test[0]['info']),\n", |
182 | | - " pred.cpu()[0].numpy())" |
| 187 | + "pred = pred.cpu()[0].numpy()\n", |
| 188 | + "visualize_result(img_original, pred)" |
183 | 189 | ] |
184 | 190 | }, |
185 | 191 | { |
186 | 192 | "cell_type": "markdown", |
187 | 193 | "metadata": {}, |
188 | 194 | "source": [ |
189 | | - "### Run the model at multiple sizes\n", |
| 195 | + "## Showing classes individually\n", |
190 | 196 | "\n", |
191 | | - "One way to get slightly cleaner predictions from a segmentation model is to run the model several times on the same image at different resolutions, and then take the average of the scores for prredictions.\n", |
192 | | - "\n", |
193 | | - "This code does that." |
| 197 | + "To see which colors are which, here we visualize individual classes, one at a time." |
194 | 198 | ] |
195 | 199 | }, |
196 | 200 | { |
|
199 | 203 | "metadata": {}, |
200 | 204 | "outputs": [], |
201 | 205 | "source": [ |
202 | | - "# The following code averages segmenter scores at multiple resolutions for better results\n", |
203 | | - "def test(segmentation_module, loader, gpu):\n", |
204 | | - " segmentation_module.eval()\n", |
205 | | - "\n", |
206 | | - " for batch_data in loader:\n", |
207 | | - " # process data\n", |
208 | | - " batch_data = batch_data[0]\n", |
209 | | - " segSize = (batch_data['img_ori'].shape[0],\n", |
210 | | - " batch_data['img_ori'].shape[1])\n", |
211 | | - " img_resized_list = batch_data['img_data']\n", |
212 | | - "\n", |
213 | | - " with torch.no_grad():\n", |
214 | | - " scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])\n", |
215 | | - " scores = async_copy_to(scores, gpu)\n", |
216 | | - "\n", |
217 | | - " for img in img_resized_list:\n", |
218 | | - " feed_dict = batch_data.copy()\n", |
219 | | - " feed_dict['img_data'] = img\n", |
220 | | - " del feed_dict['img_ori']\n", |
221 | | - " del feed_dict['info']\n", |
222 | | - " feed_dict = async_copy_to(feed_dict, gpu)\n", |
223 | | - "\n", |
224 | | - " # forward pass\n", |
225 | | - " pred_tmp = segmentation_module(feed_dict, segSize=segSize)\n", |
226 | | - " scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)\n", |
227 | | - "\n", |
228 | | - " _, pred = torch.max(scores, dim=1)\n", |
229 | | - " pred = as_numpy(pred.squeeze(0).cpu())\n", |
230 | | - "\n", |
231 | | - " # visualization\n", |
232 | | - " visualize_result(\n", |
233 | | - " (batch_data['img_ori'], batch_data['info']),\n", |
234 | | - " pred\n", |
235 | | - " )\n", |
236 | | - " \n", |
237 | | - "gpu = 0\n", |
238 | | - "torch.cuda.set_device(gpu)\n", |
239 | | - "\n", |
240 | | - "loader_test = torch.utils.data.DataLoader(\n", |
241 | | - " dataset_test,\n", |
242 | | - " batch_size=1,\n", |
243 | | - " shuffle=False,\n", |
244 | | - " collate_fn=user_scattered_collate,\n", |
245 | | - " num_workers=5,\n", |
246 | | - " drop_last=False)\n", |
247 | | - "\n", |
248 | | - "segmentation_module.cuda()\n", |
249 | | - "\n", |
250 | | - "test(segmentation_module, loader_test, gpu)\n" |
| 206 | + "# Top classes in answer\n", |
| 207 | + "predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]\n", |
| 208 | + "for c in predicted_classes[:15]:\n", |
| 209 | + " visualize_result(img_original, pred, c)" |
251 | 210 | ] |
252 | 211 | } |
253 | 212 | ], |
|
0 commit comments