Skip to content

Commit d6d91db

Browse files
committed
Simplify notebook further.
1 parent fdc0bb9 commit d6d91db

File tree

1 file changed

+21
-62
lines changed

1 file changed

+21
-62
lines changed

notebooks/DemoSegmenter.ipynb

Lines changed: 21 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@
7878
" for row in reader:\n",
7979
" names[int(row[0])] = row[5].split(\";\")[0]\n",
8080
"\n",
81-
"def visualize_result(data, pred):\n",
82-
" (img, info) = data\n",
83-
"\n",
81+
"def visualize_result(img, pred, index=None):\n",
82+
" # filter prediction class if requested\n",
83+
" if index is not None:\n",
84+
" pred = pred.copy()\n",
85+
" pred[pred != index] = -1\n",
86+
" print(f'{names[index+1]}:')\n",
87+
" \n",
8488
" # colorize prediction\n",
8589
" pred_color = colorEncode(pred, colors).astype(numpy.uint8)\n",
8690
"\n",
@@ -146,8 +150,9 @@
146150
" mean=[0.485, 0.456, 0.406], # These are RGB mean+std values\n",
147151
" std=[0.229, 0.224, 0.225]) # across a large photo dataset.\n",
148152
"])\n",
149-
"img_data = pil_to_tensor(\n",
150-
" Image.open('ADE_val_00001519.jpg').convert('RGB'))\n",
153+
"pil_image = Image.open('ADE_val_00001519.jpg').convert('RGB')\n",
154+
"img_original = numpy.array(pil_image)\n",
155+
"img_data = pil_to_tensor(pil_image)\n",
151156
"singleton_batch = {'img_data': img_data[None].cuda()}\n",
152157
"output_size = img_data.shape[1:]"
153158
]
@@ -168,7 +173,9 @@
168173
{
169174
"cell_type": "code",
170175
"execution_count": null,
171-
"metadata": {},
176+
"metadata": {
177+
"scrolled": false
178+
},
172179
"outputs": [],
173180
"source": [
174181
"# Run the segmentation at the highest resolution.\n",
@@ -177,20 +184,17 @@
177184
" \n",
178185
"# Get the predicted scores for each pixel\n",
179186
"_, pred = torch.max(scores, dim=1)\n",
180-
"visualize_result(\n",
181-
" (dataset_test[0]['img_ori'], dataset_test[0]['info']),\n",
182-
" pred.cpu()[0].numpy())"
187+
"pred = pred.cpu()[0].numpy()\n",
188+
"visualize_result(img_original, pred)"
183189
]
184190
},
185191
{
186192
"cell_type": "markdown",
187193
"metadata": {},
188194
"source": [
189-
"### Run the model at multiple sizes\n",
195+
"## Showing classes individually\n",
190196
"\n",
191-
"One way to get slightly cleaner predictions from a segmentation model is to run the model several times on the same image at different resolutions, and then take the average of the scores for prredictions.\n",
192-
"\n",
193-
"This code does that."
197+
"To see which colors are which, here we visualize individual classes, one at a time."
194198
]
195199
},
196200
{
@@ -199,55 +203,10 @@
199203
"metadata": {},
200204
"outputs": [],
201205
"source": [
202-
"# The following code averages segmenter scores at multiple resolutions for better results\n",
203-
"def test(segmentation_module, loader, gpu):\n",
204-
" segmentation_module.eval()\n",
205-
"\n",
206-
" for batch_data in loader:\n",
207-
" # process data\n",
208-
" batch_data = batch_data[0]\n",
209-
" segSize = (batch_data['img_ori'].shape[0],\n",
210-
" batch_data['img_ori'].shape[1])\n",
211-
" img_resized_list = batch_data['img_data']\n",
212-
"\n",
213-
" with torch.no_grad():\n",
214-
" scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])\n",
215-
" scores = async_copy_to(scores, gpu)\n",
216-
"\n",
217-
" for img in img_resized_list:\n",
218-
" feed_dict = batch_data.copy()\n",
219-
" feed_dict['img_data'] = img\n",
220-
" del feed_dict['img_ori']\n",
221-
" del feed_dict['info']\n",
222-
" feed_dict = async_copy_to(feed_dict, gpu)\n",
223-
"\n",
224-
" # forward pass\n",
225-
" pred_tmp = segmentation_module(feed_dict, segSize=segSize)\n",
226-
" scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)\n",
227-
"\n",
228-
" _, pred = torch.max(scores, dim=1)\n",
229-
" pred = as_numpy(pred.squeeze(0).cpu())\n",
230-
"\n",
231-
" # visualization\n",
232-
" visualize_result(\n",
233-
" (batch_data['img_ori'], batch_data['info']),\n",
234-
" pred\n",
235-
" )\n",
236-
" \n",
237-
"gpu = 0\n",
238-
"torch.cuda.set_device(gpu)\n",
239-
"\n",
240-
"loader_test = torch.utils.data.DataLoader(\n",
241-
" dataset_test,\n",
242-
" batch_size=1,\n",
243-
" shuffle=False,\n",
244-
" collate_fn=user_scattered_collate,\n",
245-
" num_workers=5,\n",
246-
" drop_last=False)\n",
247-
"\n",
248-
"segmentation_module.cuda()\n",
249-
"\n",
250-
"test(segmentation_module, loader_test, gpu)\n"
206+
"# Top classes in answer\n",
207+
"predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]\n",
208+
"for c in predicted_classes[:15]:\n",
209+
" visualize_result(img_original, pred, c)"
251210
]
252211
}
253212
],

0 commit comments

Comments
 (0)