Skip to content

Commit 2331b9a

Browse files
committed
getting some kind of baseline result but they're still garbage
1 parent 3afb9cd commit 2331b9a

File tree

9 files changed

+151
-22
lines changed

9 files changed

+151
-22
lines changed

configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44
# '../_base_/schedules/schedule_160k.py'
55
# ]
66

7+
# _base_ = [
8+
# '../_base_/models/segmenter_vit-b16_mask.py',
9+
# '../_base_/datasets/aerial.py', '../_base_/default_runtime.py',
10+
# '../_base_/schedules/schedule_160k.py'
11+
# ]
12+
713
_base_ = [
814
'../_base_/models/segmenter_vit-b16_mask.py',
915
'../_base_/datasets/aerial.py', '../_base_/default_runtime.py',
1016
'../_base_/schedules/schedule_160k.py'
1117
]
1218

19+
1320
checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_small_p16_384_20220308-410f6037.pth' # noqa
1421

1522
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)

demo/inference_demo_2.ipynb

Lines changed: 29 additions & 6 deletions
Large diffs are not rendered by default.

mmseg/apis/test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def single_gpu_test(model,
7878
model.eval()
7979
results = []
8080
dataset = data_loader.dataset
81+
# print("DATASET LINE 81: ", dataset)
8182
prog_bar = mmcv.ProgressBar(len(dataset))
8283
# The pipeline about how the data_loader retrieval samples from dataset:
8384
# sampler -> batch_sampler -> indices
@@ -107,6 +108,7 @@ def single_gpu_test(model,
107108
out_file = osp.join(out_dir, img_meta['ori_filename'])
108109
else:
109110
out_file = None
111+
# print("dataset.PALETTE!!!! ", dataset.PALETTE)
110112

111113
model.module.show_result(
112114
img_show,

mmseg/core/evaluation/metrics.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import mmcv
55
import numpy as np
66
import torch
7+
import torchvision.transforms.functional as tf
78

89

910
def f_score(precision, recall, beta=1):
@@ -62,16 +63,24 @@ def intersect_and_union(pred_label,
6263
mmcv.imread(label, flag='unchanged', backend='pillow'))
6364
else:
6465
label = torch.from_numpy(label)
65-
66+
# print("label map ", label_map)
6667
if label_map is not None:
6768
for old_id, new_id in label_map.items():
6869
label[label == old_id] = new_id
6970
if reduce_zero_label:
7071
label[label == 0] = 255
7172
label = label - 1
7273
label[label == 254] = 255
73-
74+
# label = tf.rgb_to_grayscale(label, 1)
75+
# print("label shape ", label.shape)
76+
# print("pred label: ", torch.max(pred_label))
77+
label = label[:,:,0]
78+
# print("label shape after ", label.shape)
7479
mask = (label != ignore_index)
80+
# print("pred label: ", pred_label)
81+
# print("pred label shape: ", pred_label.shape)
82+
# print("mask: ", mask)
83+
# print("mask shape: ", mask.shape)
7584
pred_label = pred_label[mask]
7685
label = label[mask]
7786

mmseg/datasets/aerial.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,87 @@ class AerialDataset(CustomDataset):
1818
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
1919
'.png'.
2020
"""
21-
CLASSES = ('paved-area', 'dirt', 'grass', 'gravel', 'water', 'rocks', 'pool', 'vegetation',
22-
'roof', 'wall', 'window', 'door', 'fence', 'fence-pole', 'person', 'dog', 'car', 'bicycle', 'tree', 'bald-tree', 'ar-marker', 'obstacle', 'conflicting', 'unlabeled')
23-
24-
PALETTE = [[128, 64, 128], [130, 76, 0], [0, 102, 0], [112, 103, 87], [28, 42, 168], [48, 41, 30], [0, 50, 89], [107, 142, 35], [70, 70, 70], [102, 102, 156], [254, 228, 12], [254, 148, 12], [190, 153, 153], [153, 153, 153], [255, 22, 96], [102, 51, 0], [9, 143, 150], [119, 11, 32], [51, 51, 0], [190, 250, 190], [112, 150, 146], [2, 135, 115], [255, 0, 0], [0, 0, 0]]
21+
AERIAL_CLASSES = ('sidewalk', 'earth', 'grass', 'sand', 'water', 'rock', 'swimming pool', 'plant',
22+
'building', 'wall', 'windowpane', 'door', 'fence', 'pole', 'person', 'animal', 'car', 'bicycle', 'tree', 'television receiver', 'microwave', 'coffee table', 'trade name', 'sconce')
23+
CLASSES = (
24+
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
25+
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
26+
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
27+
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
28+
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
29+
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
30+
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
31+
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
32+
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
33+
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
34+
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
35+
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
36+
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
37+
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
38+
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
39+
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
40+
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
41+
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
42+
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
43+
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
44+
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
45+
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
46+
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
47+
'clock', 'flag')
48+
49+
AERIAL_PALETTE = [[128, 64, 128], [130, 76, 0], [0, 102, 0], [112, 103, 87], [28, 42, 168], [48, 41, 30], [0, 50, 89], [107, 142, 35], [70, 70, 70], [102, 102, 156], [254, 228, 12], [254, 148, 12], [190, 153, 153], [153, 153, 153], [255, 22, 96], [102, 51, 0], [9, 143, 150], [119, 11, 32], [51, 51, 0], [190, 250, 190], [112, 150, 146], [2, 135, 115], [255, 0, 0], [0, 0, 0]]
50+
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
51+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
52+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
53+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
54+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
55+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
56+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
57+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
58+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
59+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
60+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
61+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
62+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
63+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
64+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
65+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
66+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
67+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
68+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
69+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
70+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
71+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
72+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
73+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
74+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
75+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
76+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
77+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
78+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
79+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
80+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
81+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
82+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
83+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
84+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
85+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
86+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
87+
[102, 255, 0], [92, 0, 255]]
88+
# PALETTE = None
89+
# CLASSES = None
2590

2691
def __init__(self, **kwargs):
2792
super(AerialDataset, self).__init__(
2893
img_suffix='.jpg',
2994
seg_map_suffix='.png',
3095
reduce_zero_label=True,
31-
**kwargs)
96+
**kwargs
97+
)
98+
self.CLASSES, self.PALETTE = self.get_classes_and_palette(self.AERIAL_CLASSES, self.AERIAL_PALETTE)
99+
# print("self.classes len ", len(self.CLASSES))
100+
# print("self.palette len ", len(self.PALETTE))
101+
32102

33103
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
34104
"""Write the segmentation results to images.
@@ -49,6 +119,7 @@ def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
49119
list[str: str]: result txt files which contains corresponding
50120
semantic segmentation images.
51121
"""
122+
print("IN RESULTS 2 IMG")
52123
if indices is None:
53124
indices = list(range(len(self)))
54125

@@ -95,7 +166,7 @@ def format_results(self,
95166
the image paths, tmp_dir is the temporal directory created
96167
for saving json/png files when img_prefix is not specified.
97168
"""
98-
169+
print("IN FORMAT RESULTS")
99170
if indices is None:
100171
indices = list(range(len(self)))
101172

@@ -105,3 +176,5 @@ def format_results(self,
105176
result_files = self.results2img(results, imgfile_prefix, to_label_id,
106177
indices)
107178
return result_files
179+
180+

mmseg/datasets/custom.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,11 @@ def pre_eval(self, preds, indices):
282282
preds = [preds]
283283

284284
pre_eval_results = []
285-
285+
# print("preds shape: ", preds[0].shape)
286286
for pred, index in zip(preds, indices):
287287
seg_map = self.get_gt_seg_map_by_idx(index)
288+
# print("seg map shape: ", seg_map.shape)
289+
# print("pred shape: ", pred.shape)
288290
pre_eval_results.append(
289291
intersect_and_union(pred, seg_map, len(self.CLASSES),
290292
self.ignore_index, self.label_map,
@@ -305,6 +307,8 @@ def get_classes_and_palette(self, classes=None, palette=None):
305307
The palette of segmentation map. If None is given, random
306308
palette will be generated. Default: None
307309
"""
310+
# print("print classes!!: ", classes)
311+
# print("print palette!!: ", palette)
308312
if classes is None:
309313
self.custom_classes = False
310314
return self.CLASSES, self.PALETTE
@@ -326,14 +330,18 @@ def get_classes_and_palette(self, classes=None, palette=None):
326330
# are the new label ids.
327331
# used for changing pixel labels in load_annotations.
328332
self.label_map = {}
333+
# print("class_names ", class_names)
329334
for i, c in enumerate(self.CLASSES):
330335
if c not in class_names:
331336
self.label_map[i] = -1
332337
else:
333338
self.label_map[i] = class_names.index(c)
334-
339+
# print("self.label_map!!!!!: ", self.label_map)
340+
print("print classes!! before: ", class_names)
341+
print("print palette!! before: ", palette)
335342
palette = self.get_palette_for_custom_classes(class_names, palette)
336-
343+
print("print classes!! after: ", class_names)
344+
print("print palette!! after: ", palette)
337345
return class_names, palette
338346

339347
def get_palette_for_custom_classes(self, class_names, palette=None):

mmseg/datasets/pipelines/loading.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __call__(self, results):
120120
Returns:
121121
dict: The dict contains loaded semantic segmentation annotations.
122122
"""
123+
print("AYOOOOOOOOOOOO")
123124

124125
if self.file_client is None:
125126
self.file_client = mmcv.FileClient(**self.file_client_args)

mmseg/models/segmentors/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def show_result(self,
243243
img = mmcv.imread(img)
244244
img = img.copy()
245245
seg = result[0]
246+
# print("self classes ", self.CLASSES)
246247
if palette is None:
247248
if self.PALETTE is None:
248249
# Get random state before set seed,
@@ -259,6 +260,8 @@ def show_result(self,
259260
else:
260261
palette = self.PALETTE
261262
palette = np.array(palette)
263+
print("PALETTE SHAPE 0 ", palette.shape[0])
264+
print("LEN SELF CLASSES ", len(self.CLASSES))
262265
assert palette.shape[0] == len(self.CLASSES)
263266
assert palette.shape[1] == 3
264267
assert len(palette.shape) == 2

tools/test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def main():
191191
# build the dataloader
192192
# TODO: support multiple images per gpu (only minor changes are needed)
193193
dataset = build_dataset(cfg.data.test)
194+
print("Dataset!: ", dataset)
194195
data_loader = build_dataloader(
195196
dataset,
196197
samples_per_gpu=1,
@@ -205,11 +206,11 @@ def main():
205206
if fp16_cfg is not None:
206207
wrap_fp16_model(model)
207208
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
208-
if 'CLASSES' in checkpoint.get('meta', {}):
209-
model.CLASSES = checkpoint['meta']['CLASSES']
210-
else:
211-
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
212-
model.CLASSES = dataset.CLASSES
209+
# if 'CLASSES' in checkpoint.get('meta', {}):
210+
# model.CLASSES = checkpoint['meta']['CLASSES']
211+
# else:
212+
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
213+
model.CLASSES = dataset.CLASSES
213214
if 'PALETTE' in checkpoint.get('meta', {}):
214215
model.PALETTE = checkpoint['meta']['PALETTE']
215216
else:
@@ -254,6 +255,8 @@ def main():
254255
'Please use MMCV >= 1.4.4 for CPU training!'
255256
model = revert_sync_batchnorm(model)
256257
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
258+
# print("MODEL ", model)
259+
# print("DATA LOADER ", data_loader)
257260
results = single_gpu_test(
258261
model,
259262
data_loader,

0 commit comments

Comments
 (0)